Source code for supar.utils.embed

# -*- coding: utf-8 -*-

from __future__ import annotations

import os
from typing import Iterable, Optional, Union

import torch
from supar.utils.common import CACHE
from supar.utils.fn import download
from supar.utils.logging import progress_bar
from torch.distributions.utils import lazy_property


[docs]class Embedding(object): r""" Defines a container object for holding pretrained embeddings. This object is callable and behaves like :class:`torch.nn.Embedding`. For huge files, this object supports lazy loading, seeking to retrieve vectors from the disk on the fly if necessary. Currently available embeddings: - `GloVe`_ - `Fasttext`_ - `Giga`_ - `Tencent`_ Args: path (str): Path to the embedding file or short name registered in ``supar.utils.embed.PRETRAINED``. unk (Optional[str]): The string token used to represent OOV tokens. Default: ``None``. skip_first (bool) If ``True``, skips the first line of the embedding file. Default: ``False``. cache (bool): If ``True``, instead of loading entire embeddings into memory, seeks to load vectors from the disk once called. Default: ``True``. sep (str): Separator used by embedding file. Default: ``' '``. Examples: >>> import torch.nn as nn >>> from supar.utils.embed import Embedding >>> glove = Embedding.load('glove-6b-100') >>> glove GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True) >>> fasttext = Embedding.load('fasttext-en') >>> fasttext FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True) >>> giga = Embedding.load('giga-100') >>> giga GigaEmbedding(n_tokens=372846, dim=100, cache=True) >>> indices = torch.tensor([glove.vocab[i.lower()] for i in ['She', 'enjoys', 'playing', 'tennis', '.']]) >>> indices tensor([ 67, 8371, 697, 2140, 2]) >>> glove(indices).shape torch.Size([5, 100]) >>> glove(indices).equal(nn.Embedding.from_pretrained(glove.vectors)(indices)) True .. _GloVe: https://nlp.stanford.edu/projects/glove/ .. _Fasttext: https://fasttext.cc/docs/en/crawl-vectors.html .. _Giga: https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip .. _Tencent: https://ai.tencent.com/ailab/nlp/zh/download.html """ CACHE = os.path.join(CACHE, 'data/embeds') def __init__( self, path: str, unk: Optional[str] = None, skip_first: bool = False, cache: bool = True, sep: str = ' ', **kwargs ) -> Embedding: super().__init__() self.path = path self.unk = unk self.skip_first = skip_first self.cache = cache self.sep = sep self.kwargs = kwargs self.vocab = {token: i for i, token in enumerate(self.tokens)} def __len__(self): return len(self.vocab) def __repr__(self): s = f"{self.__class__.__name__}(" s += f"n_tokens={len(self)}, dim={self.dim}" if self.unk is not None: s += f", unk={self.unk}" if self.skip_first: s += f", skip_first={self.skip_first}" if self.cache: s += f", cache={self.cache}" s += ")" return s def __contains__(self, token): return token in self.vocab def __getitem__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor: indices = key if not isinstance(indices, torch.Tensor): indices = torch.tensor(key) if self.cache: elems, indices = indices.unique(return_inverse=True) with open(self.path) as f: vectors = [] for index in elems.tolist(): f.seek(self.positions[index]) vectors.append(list(map(float, f.readline().strip().split(self.sep)[1:]))) vectors = torch.tensor(vectors) else: vectors = self.vectors return torch.embedding(vectors, indices) def __call__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor: return self[key] @lazy_property def dim(self): return len(self[0]) @lazy_property def unk_index(self): if self.unk is not None: return self.vocab[self.unk] raise AttributeError @lazy_property def tokens(self): with open(self.path) as f: if self.skip_first: f.readline() return [line.strip().split(self.sep)[0] for line in progress_bar(f)] @lazy_property def vectors(self): with open(self.path) as f: if self.skip_first: f.readline() return torch.tensor([list(map(float, line.strip().split(self.sep)[1:])) for line in progress_bar(f)]) @lazy_property def positions(self): with open(self.path) as f: if self.skip_first: f.readline() positions = [f.tell()] while True: line = f.readline() if line: positions.append(f.tell()) else: break return positions @classmethod def load(cls, path: str, unk: Optional[str] = None, **kwargs) -> Embedding: if path in PRETRAINED: cfg = dict(**PRETRAINED[path]) embed = cfg.pop('_target_') return embed(**cfg, **kwargs) return cls(path, unk, **kwargs)
[docs]class GloVeEmbedding(Embedding): r""" `GloVe`_: Global Vectors for Word Representation. Training is performed on aggregated global word-word co-occurrence statistics from a corpus, and the resulting representations showcase interesting linear substructures of the word vector space. Args: src (str): Size of the source data for training. Default: ``6B``. dim (int): Which dimension of the embeddings to use. Default: 100. reload (bool): If ``True``, forces a fresh download. Default: ``False``. Examples: >>> from supar.utils.embed import Embedding >>> Embedding.load('glove-6b-100') GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True) .. _GloVe: https://nlp.stanford.edu/projects/glove/ """ def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwargs) -> GloVeEmbedding: if src == '6B' or src == 'twitter.27B': url = f'https://nlp.stanford.edu/data/glove.{src}.zip' else: url = f'https://nlp.stanford.edu/data/glove.{src}.{dim}d.zip' path = os.path.join(os.path.join(self.CACHE, 'glove'), f'glove.{src}.{dim}d.txt') if not os.path.exists(path) or reload: download(url, os.path.join(self.CACHE, 'glove'), clean=True) super().__init__(path=path, unk='unk', *args, **kwargs, )
[docs]class FasttextEmbedding(Embedding): r""" `Fasttext`_ word embeddings for 157 languages, trained using CBOW, in dimension 300, with character n-grams of length 5, a window of size 5 and 10 negatives. Args: lang (str): Language code. Default: ``en``. reload (bool): If ``True``, forces a fresh download. Default: ``False``. Examples: >>> from supar.utils.embed import Embedding >>> Embedding.load('fasttext-en') FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True) .. _Fasttext: https://fasttext.cc/docs/en/crawl-vectors.html """ def __init__(self, lang: str = 'en', reload=False, *args, **kwargs) -> FasttextEmbedding: url = f'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{lang}.300.vec.gz' path = os.path.join(self.CACHE, 'fasttext', f'cc.{lang}.300.vec') if not os.path.exists(path) or reload: download(url, os.path.join(self.CACHE, 'fasttext'), clean=True) super().__init__(path=path, skip_first=True, *args, **kwargs)
[docs]class GigaEmbedding(Embedding): r""" `Giga`_ word embeddings, trained on Chinese Gigaword Third Edition for Chinese using word2vec, used by :cite:`zhang-etal-2020-efficient` and :cite:`zhang-etal-2020-fast`. Args: reload (bool): If ``True``, forces a fresh download. Default: ``False``. Examples: >>> from supar.utils.embed import Embedding >>> Embedding.load('giga-100') GigaEmbedding(n_tokens=372846, dim=100, cache=True) .. _Giga: https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip """ def __init__(self, reload=False, *args, **kwargs) -> GigaEmbedding: url = 'https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip' path = os.path.join(self.CACHE, 'giga', 'giga.100.txt') if not os.path.exists(path) or reload: download(url, os.path.join(self.CACHE, 'giga'), clean=True) super().__init__(path=path, *args, **kwargs)
[docs]class TencentEmbedding(Embedding): r""" `Tencent`_ word embeddings. The embeddings are trained on large-scale text collected from news, webpages, and novels with Directional Skip-Gram. 100-dimension and 200-dimension embeddings for over 12 million Chinese words are provided. Args: dim (int): Which dimension of the embeddings to use. Currently 100 and 200 are available. Default: 100. large (bool): If ``True``, uses large version with larger vocab size (12,287,933); 2,000,000 otherwise. Default: ``False``. reload (bool): If ``True``, forces a fresh download. Default: ``False``. Examples: >>> from supar.utils.embed import Embedding >>> Embedding.load('tencent-100') TencentEmbedding(n_tokens=2000000, dim=100, skip_first=True, cache=True) >>> Embedding.load('tencent-100-large') TencentEmbedding(n_tokens=12287933, dim=100, skip_first=True, cache=True) .. _Tencent: https://ai.tencent.com/ailab/nlp/zh/download.html """ def __init__(self, dim: int = 100, large: bool = False, reload=False, *args, **kwargs) -> TencentEmbedding: url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}.tar.gz' # noqa name = f'tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}' path = os.path.join(os.path.join(self.CACHE, 'tencent'), name, f'{name}.txt') if not os.path.exists(path) or reload: download(url, os.path.join(self.CACHE, 'tencent'), clean=True) super().__init__(path=path, skip_first=True, *args, **kwargs)
PRETRAINED = { 'glove-6b-50': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 50}, 'glove-6b-100': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 100}, 'glove-6b-200': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 200}, 'glove-6b-300': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 300}, 'glove-42b-300': {'_target_': GloVeEmbedding, 'src': '42B', 'dim': 300}, 'glove-840b-300': {'_target_': GloVeEmbedding, 'src': '84B', 'dim': 300}, 'glove-twitter-27b-25': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 25}, 'glove-twitter-27b-50': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 50}, 'glove-twitter-27b-100': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 100}, 'glove-twitter-27b-200': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 200}, 'fasttext-bg': {'_target_': FasttextEmbedding, 'lang': 'bg'}, 'fasttext-ca': {'_target_': FasttextEmbedding, 'lang': 'ca'}, 'fasttext-cs': {'_target_': FasttextEmbedding, 'lang': 'cs'}, 'fasttext-de': {'_target_': FasttextEmbedding, 'lang': 'de'}, 'fasttext-en': {'_target_': FasttextEmbedding, 'lang': 'en'}, 'fasttext-es': {'_target_': FasttextEmbedding, 'lang': 'es'}, 'fasttext-fr': {'_target_': FasttextEmbedding, 'lang': 'fr'}, 'fasttext-it': {'_target_': FasttextEmbedding, 'lang': 'it'}, 'fasttext-nl': {'_target_': FasttextEmbedding, 'lang': 'nl'}, 'fasttext-no': {'_target_': FasttextEmbedding, 'lang': 'no'}, 'fasttext-ro': {'_target_': FasttextEmbedding, 'lang': 'ro'}, 'fasttext-ru': {'_target_': FasttextEmbedding, 'lang': 'ru'}, 'giga-100': {'_target_': GigaEmbedding}, 'tencent-100': {'_target_': TencentEmbedding, 'dim': 100}, 'tencent-100-large': {'_target_': TencentEmbedding, 'dim': 100, 'large': True}, 'tencent-200': {'_target_': TencentEmbedding, 'dim': 200}, 'tencent-200-large': {'_target_': TencentEmbedding, 'dim': 200, 'large': True}, }