Source code for supar.utils.fn

# -*- coding: utf-8 -*-

import gzip
import mmap
import os
import pickle
import shutil
import struct
import sys
import tarfile
import unicodedata
import urllib
import zipfile
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

import torch
from omegaconf import DictConfig, OmegaConf

from supar.utils.common import CACHE
from supar.utils.parallel import wait


def ispunct(token: str, pos: str = None, puncts: Set = {'``', "''", ':', ',', '.', 'PU', 'PUNCT'}) -> bool:
    return all(unicodedata.category(char).startswith('P') for char in token) if pos is None else pos in puncts


def isfullwidth(token: str) -> bool:
    return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] for char in token)


def islatin(token: str) -> bool:
    return all('LATIN' in unicodedata.name(char) for char in token)


def isdigit(token: str) -> bool:
    return all('DIGIT' in unicodedata.name(char) for char in token)


def tohalfwidth(token: str) -> str:
    return unicodedata.normalize('NFKC', token)


[docs]def kmeans(x: List[int], k: int, max_it: int = 32) -> Tuple[List[float], List[List[int]]]: r""" KMeans algorithm for clustering the sentences by length. Args: x (List[int]): The list of sentence lengths. k (int): The number of clusters, which is an approximate value. The final number of clusters can be less or equal to `k`. max_it (int): Maximum number of iterations. If centroids does not converge after several iterations, the algorithm will be early stopped. Returns: List[float], List[List[int]]: The first list contains average lengths of sentences in each cluster. The second is the list of clusters holding indices of data points. Examples: >>> x = torch.randint(10, 20, (10,)).tolist() >>> x [15, 10, 17, 11, 18, 13, 17, 19, 18, 14] >>> centroids, clusters = kmeans(x, 3) >>> centroids [10.5, 14.0, 17.799999237060547] >>> clusters [[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]] """ x = torch.tensor(x, dtype=torch.float) # collect unique datapoints datapoints, indices, freqs = x.unique(return_inverse=True, return_counts=True) # the number of clusters must not be greater than the number of datapoints k = min(len(datapoints), k) # initialize k centroids randomly centroids = datapoints[torch.randperm(len(datapoints))[:k]] # assign each datapoint to the cluster with the closest centroid dists, y = torch.abs_(datapoints.unsqueeze(-1) - centroids).min(-1) for _ in range(max_it): # if an empty cluster is encountered, # choose the farthest datapoint from the biggest cluster and move that the empty one mask = torch.arange(k).unsqueeze(-1).eq(y) none = torch.where(~mask.any(-1))[0].tolist() for i in none: # the biggest cluster biggest = torch.where(mask[mask.sum(-1).argmax()])[0] # the datapoint farthest from the centroid of the biggest cluster farthest = dists[biggest].argmax() # update the assigned cluster of the farthest datapoint y[biggest[farthest]] = i # re-calculate the mask mask = torch.arange(k).unsqueeze(-1).eq(y) # update the centroids centroids, old = (datapoints * freqs * mask).sum(-1) / (freqs * mask).sum(-1), centroids # re-assign all datapoints to clusters dists, y = torch.abs_(datapoints.unsqueeze(-1) - centroids).min(-1) # stop iteration early if the centroids converge if centroids.equal(old): break # assign all datapoints to the new-generated clusters # the empty ones are discarded assigned = y.unique().tolist() # get the centroids of the assigned clusters centroids = centroids[assigned].tolist() # map all values of datapoints to buckets clusters = [torch.where(indices.unsqueeze(-1).eq(torch.where(y.eq(i))[0]).any(-1))[0].tolist() for i in assigned] return centroids, clusters
[docs]def stripe(x: torch.Tensor, n: int, w: int, offset: Tuple = (0, 0), horizontal: bool = True) -> torch.Tensor: r""" Returns a parallelogram stripe of the tensor. Args: x (~torch.Tensor): the input tensor with 2 or more dims. n (int): the length of the stripe. w (int): the width of the stripe. offset (tuple): the offset of the first two dims. horizontal (bool): `True` if returns a horizontal stripe; `False` otherwise. Returns: A parallelogram stripe of the tensor. Examples: >>> x = torch.arange(25).view(5, 5) >>> x tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) >>> stripe(x, 2, 3) tensor([[0, 1, 2], [6, 7, 8]]) >>> stripe(x, 2, 3, (1, 1)) tensor([[ 6, 7, 8], [12, 13, 14]]) >>> stripe(x, 2, 3, (1, 1), 0) tensor([[ 6, 11, 16], [12, 17, 22]]) """ x = x.contiguous() seq_len, stride = x.size(1), list(x.stride()) numel = stride[1] return x.as_strided(size=(n, w, *x.shape[2:]), stride=[(seq_len + 1) * numel, (1 if horizontal else seq_len) * numel] + stride[2:], storage_offset=(offset[0]*seq_len+offset[1])*numel)
def diagonal_stripe(x: torch.Tensor, offset: int = 1) -> torch.Tensor: r""" Returns a diagonal parallelogram stripe of the tensor. Args: x (~torch.Tensor): the input tensor with 3 or more dims. offset (int): which diagonal to consider. Default: 1. Returns: A diagonal parallelogram stripe of the tensor. Examples: >>> x = torch.arange(125).view(5, 5, 5) >>> diagonal_stripe(x) tensor([[ 5], [36], [67], [98]]) >>> diagonal_stripe(x, 2) tensor([[10, 11], [41, 42], [72, 73]]) >>> diagonal_stripe(x, -2) tensor([[ 50, 51], [ 81, 82], [112, 113]]) """ x = x.contiguous() seq_len, stride = x.size(1), list(x.stride()) n, w, numel = seq_len - abs(offset), abs(offset), stride[2] return x.as_strided(size=(n, w, *x.shape[3:]), stride=[((seq_len + 1) * x.size(2) + 1) * numel] + stride[2:], storage_offset=offset*stride[1] if offset > 0 else abs(offset)*stride[0]) def expanded_stripe(x: torch.Tensor, n: int, w: int, offset: Tuple = (0, 0)) -> torch.Tensor: r""" Returns an expanded parallelogram stripe of the tensor. Args: x (~torch.Tensor): the input tensor with 2 or more dims. n (int): the length of the stripe. w (int): the width of the stripe. offset (tuple): the offset of the first two dims. Returns: An expanded parallelogram stripe of the tensor. Examples: >>> x = torch.arange(25).view(5, 5) >>> x tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) >>> expanded_stripe(x, 2, 3) tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]], [[ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]]) >>> expanded_stripe(x, 2, 3, (1, 1)) tensor([[[ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]], [[10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]]) """ x = x.contiguous() stride = list(x.stride()) return x.as_strided(size=(n, w, *list(x.shape[1:])), stride=stride[:1] + [stride[0]] + stride[1:], storage_offset=(offset[1])*stride[0]) def binarize( data: Union[List[str], Dict[str, Iterable]], fbin: str = None, merge: bool = False ) -> Tuple[str, torch.Tensor]: start, meta = 0, defaultdict(list) # the binarized file is organized as: # `data`: pickled objects # `meta`: a dict containing the pointers of each kind of data # `index`: fixed size integers representing the storage positions of the meta data with open(fbin, 'wb') as f: # in this case, data should be a list of binarized files if merge: for file in data: if not os.path.exists(file): raise RuntimeError("Some files are missing. Please check the paths") mi = debinarize(file, meta=True) for key, val in mi.items(): val[:, 0] += start meta[key].append(val) with open(file, 'rb') as fi: length = int(sum(val[:, 1].sum() for val in mi.values())) f.write(fi.read(length)) start = start + length meta = {key: torch.cat(val) for key, val in meta.items()} else: for key, val in data.items(): for i in val: buf = i if isinstance(i, (bytes, bytearray)) else pickle.dumps(i) f.write(buf) meta[key].append((start, len(buf))) start = start + len(buf) meta = {key: torch.tensor(val) for key, val in meta.items()} pickled = pickle.dumps(meta) # append the meta data to the end of the bin file f.write(pickled) # record the positions of the meta data f.write(struct.pack('LL', start, len(pickled))) return fbin, meta def debinarize( fbin: str, pos_or_key: Optional[Union[Tuple[int, int], str]] = (0, 0), meta: bool = False, unpickle: bool = False ) -> Union[Any, Iterable[Any]]: with open(fbin, 'rb') as f, mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: if meta or isinstance(pos_or_key, str): length = len(struct.pack('LL', 0, 0)) mm.seek(-length, os.SEEK_END) offset, length = struct.unpack('LL', mm.read(length)) mm.seek(offset) if meta: return pickle.loads(mm.read(length)) # fetch by key objs, meta = [], pickle.loads(mm.read(length))[pos_or_key] for offset, length in meta.tolist(): mm.seek(offset) objs.append(mm.read(length) if unpickle else pickle.loads(mm.read(length))) return objs # fetch by positions offset, length = pos_or_key mm.seek(offset) return mm.read(length) if unpickle else pickle.loads(mm.read(length)) def pad( tensors: List[torch.Tensor], padding_value: int = 0, total_length: int = None, padding_side: str = 'right' ) -> torch.Tensor: size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) for i in range(len(tensors[0].size()))] if total_length is not None: assert total_length >= size[1] size[1] = total_length out_tensor = tensors[0].data.new(*size).fill_(padding_value) for i, tensor in enumerate(tensors): out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor return out_tensor @wait def download(url: str, path: Optional[str] = None, reload: bool = False, clean: bool = False) -> str: filename = os.path.basename(urllib.parse.urlparse(url).path) if path is None: path = CACHE os.makedirs(path, exist_ok=True) path = os.path.join(path, filename) if reload and os.path.exists(path): os.remove(path) if not os.path.exists(path): sys.stderr.write(f"Downloading {url} to {path}\n") try: torch.hub.download_url_to_file(url, path, progress=True) except (ValueError, urllib.error.URLError): raise RuntimeError(f"File {url} unavailable. Please try other sources.") return extract(path, reload, clean) def extract(path: str, reload: bool = False, clean: bool = False) -> str: extracted = path if zipfile.is_zipfile(path): with zipfile.ZipFile(path) as f: extracted = os.path.join(os.path.dirname(path), f.infolist()[0].filename) if reload or not os.path.exists(extracted): f.extractall(os.path.dirname(path)) elif tarfile.is_tarfile(path): with tarfile.open(path) as f: extracted = os.path.join(os.path.dirname(path), f.getnames()[0]) if reload or not os.path.exists(extracted): f.extractall(os.path.dirname(path)) elif path.endswith('.gz'): extracted = path[:-3] with gzip.open(path) as fgz: with open(extracted, 'wb') as f: shutil.copyfileobj(fgz, f) if clean: os.remove(path) return extracted def resolve_config(args: Union[Dict, DictConfig]) -> DictConfig: OmegaConf.register_new_resolver("eval", eval) return DictConfig(OmegaConf.to_container(args, resolve=True)) def collect_args(args: Union[Dict, DictConfig]) -> DictConfig: for key in ('self', 'cls', '__class__'): args.pop(key, None) args.update(args.pop('kwargs', dict())) return DictConfig(args) def get_rng_state() -> Dict[str, torch.Tensor]: state = {'rng_state': torch.get_rng_state()} if torch.cuda.is_available(): state['cuda_rng_state'] = torch.cuda.get_rng_state() return state def set_rng_state(state: Dict) -> None: torch.set_rng_state(state['rng_state']) if torch.cuda.is_available(): torch.cuda.set_rng_state(state['cuda_rng_state'])