# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import pickle
import struct
from io import BytesIO
from typing import Any, Iterable, Optional
import torch
from torch.distributions.utils import lazy_property
from supar.utils.logging import get_logger, progress_bar
logger = get_logger(__name__)
class Batch(object):
def __init__(self, sentences: Iterable[Sentence]) -> Batch:
self.sentences = sentences
self.names, self.fields = [], {}
def __repr__(self):
return f'{self.__class__.__name__}({", ".join([f"{name}" for name in self.names])})'
def __len__(self):
return len(self.sentences)
def __getitem__(self, index):
return self.fields[self.names[index]]
def __getattr__(self, name):
return [s.fields[name] for s in self.sentences]
def __setattr__(self, name: str, value: Iterable[Any]):
if name not in ('sentences', 'fields', 'names'):
for s, v in zip(self.sentences, value):
setattr(s, name, v)
else:
self.__dict__[name] = value
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
@property
def device(self):
return 'cuda' if torch.cuda.is_available() else 'cpu'
@lazy_property
def lens(self):
return torch.tensor([len(i) for i in self.sentences]).to(self.device, non_blocking=True)
@lazy_property
def mask(self):
return self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(self.lens.max())))
def compose(self, transform: Transform) -> Batch:
for f in transform.flattened_fields:
self.names.append(f.name)
self.fields[f.name] = f.compose([s.fields[f.name] for s in self.sentences])
return self
def shrink(self, batch_size: Optional[int] = None) -> Batch:
if batch_size is None:
batch_size = len(self) // 2
if batch_size <= 0:
raise RuntimeError(f"The batch has only {len(self)} sentences and can't be shrinked!")
return Batch([self.sentences[i] for i in torch.randperm(len(self))[:batch_size].tolist()])
def pin_memory(self):
for s in self.sentences:
for i in s.fields.values():
if isinstance(i, torch.Tensor):
i.pin_memory()
return self
class Sentence(object):
def __init__(self, transform, index: Optional[int] = None) -> Sentence:
self.index = index
# mapping from each nested field to their proper position
self.maps = dict()
# original values and numericalized values of each position
self.values, self.fields = [], {}
for i, field in enumerate(transform):
if not isinstance(field, Iterable):
field = [field]
for f in field:
if f is not None:
self.maps[f.name] = i
self.fields[f.name] = None
def __contains__(self, name):
return name in self.fields
def __getattr__(self, name):
if name in self.fields:
return self.values[self.maps[name]]
raise AttributeError(f"`{name}` not found")
def __setattr__(self, name, value):
if 'fields' in self.__dict__ and name in self:
index = self.maps[name]
if index >= len(self.values):
self.__dict__[name] = value
else:
self.values[index] = value
else:
self.__dict__[name] = value
def __getstate__(self):
state = vars(self)
if 'fields' in state:
state['fields'] = {
name: ((value.dtype, value.tolist())
if isinstance(value, torch.Tensor)
else value)
for name, value in state['fields'].items()
}
return state
def __setstate__(self, state):
if 'fields' in state:
state['fields'] = {
name: (torch.tensor(value[1], dtype=value[0])
if isinstance(value, tuple) and isinstance(value[0], torch.dtype)
else value)
for name, value in state['fields'].items()
}
self.__dict__.update(state)
def __len__(self):
try:
return len(next(iter(self.fields.values())))
except Exception:
raise AttributeError("Cannot get size of a sentence with no fields")
@lazy_property
def size(self):
return len(self)
def numericalize(self, fields):
for f in fields:
self.fields[f.name] = next(f.transform([getattr(self, f.name)]))
self.pad_index = fields[0].pad_index
return self
def tobytes(self) -> bytes:
bufs, fields = [], {}
for name, value in self.fields.items():
if isinstance(value, torch.Tensor):
fields[name] = value
buf, dtype = value.numpy().tobytes(), value.dtype
self.fields[name] = (len(buf), dtype)
bufs.append(buf)
buf, sentence = b''.join(bufs), pickle.dumps(self)
for name, value in fields.items():
self.fields[name] = value
return buf + sentence + struct.pack('LL', len(buf), len(sentence))
@classmethod
def frombuffer(cls, buf: bytes) -> Sentence:
mm = BytesIO(buf)
mm.seek(-len(struct.pack('LL', 0, 0)), os.SEEK_END)
offset, length = struct.unpack('LL', mm.read())
mm.seek(offset)
sentence = pickle.loads(mm.read(length))
mm.seek(0)
for name, value in sentence.fields.items():
if isinstance(value, tuple) and isinstance(value[1], torch.dtype):
length, dtype = value
sentence.fields[name] = torch.frombuffer(bytearray(mm.read(length)), dtype=dtype)
return sentence