Source code for supar.utils.vocab

# -*- coding: utf-8 -*-

from __future__ import annotations

from collections import Counter, defaultdict
from typing import Iterable, Tuple, Union


[docs]class Vocab(object): r""" Defines a vocabulary object that will be used to numericalize a field. Args: counter (~collections.Counter): :class:`~collections.Counter` object holding the frequencies of each value found in the data. min_freq (int): The minimum frequency needed to include a token in the vocabulary. Default: 1. specials (Tuple[str]): The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: ``[]``. unk_index (int): The index of unk token. Default: 0. Attributes: itos: A list of token strings indexed by their numerical identifiers. stoi: A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers. """ def __init__(self, counter: Counter, min_freq: int = 1, specials: Tuple = tuple(), unk_index: int = 0) -> Vocab: self.itos = list(specials) self.stoi = defaultdict(lambda: unk_index) self.stoi.update({token: i for i, token in enumerate(self.itos)}) self.update([token for token, freq in counter.items() if freq >= min_freq]) self.unk_index = unk_index self.n_init = len(self) def __len__(self): return len(self.itos) def __getitem__(self, key: Union[int, str, Iterable]) -> Union[str, int, Iterable]: if isinstance(key, str): return self.stoi[key] elif not isinstance(key, Iterable): return self.itos[key] elif len(key) > 0 and isinstance(key[0], str): return [self.stoi[i] for i in key] else: return [self.itos[i] for i in key] def __contains__(self, token): return token in self.stoi def __getstate__(self): # avoid picking defaultdict attrs = dict(self.__dict__) # cast to regular dict attrs['stoi'] = dict(self.stoi) return attrs def __setstate__(self, state): stoi = defaultdict(lambda: self.unk_index) stoi.update(state['stoi']) state['stoi'] = stoi self.__dict__.update(state) def items(self): return self.stoi.items() def update(self, vocab: Union[Iterable[str], Vocab, Counter]) -> Vocab: if isinstance(vocab, Vocab): vocab = vocab.itos # NOTE: PAY CAREFUL ATTENTION TO DICT ORDER UNDER DISTRIBUTED TRAINING! vocab = sorted(set(vocab).difference(self.stoi)) self.itos.extend(vocab) self.stoi.update({token: i for i, token in enumerate(vocab, len(self.stoi))}) return self