Source code for supar.models.dep.vi.parser

# -*- coding: utf-8 -*-

from typing import Iterable, Union

import torch
from supar.config import Config
from supar.models.dep.biaffine.parser import BiaffineDependencyParser
from supar.models.dep.vi.model import VIDependencyModel
from supar.utils.fn import ispunct
from supar.utils.logging import get_logger
from supar.utils.metric import AttachmentMetric
from supar.utils.transform import Batch

logger = get_logger(__name__)


[docs]class VIDependencyParser(BiaffineDependencyParser): r""" The implementation of Dependency Parser using Variational Inference :cite:`wang-tu-2020-second`. """ NAME = 'vi-dependency' MODEL = VIDependencyModel def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def train( self, train: Union[str, Iterable], dev: Union[str, Iterable], test: Union[str, Iterable], epochs: int = 1000, patience: int = 100, batch_size: int = 5000, update_steps: int = 1, buckets: int = 32, workers: int = 0, amp: bool = False, cache: bool = False, punct: bool = False, tree: bool = False, proj: bool = False, partial: bool = False, verbose: bool = True, **kwargs ): return super().train(**Config().update(locals()))
[docs] def evaluate( self, data: Union[str, Iterable], batch_size: int = 5000, buckets: int = 8, workers: int = 0, amp: bool = False, cache: bool = False, punct: bool = False, tree: bool = True, proj: bool = True, partial: bool = False, verbose: bool = True, **kwargs ): return super().evaluate(**Config().update(locals()))
[docs] def predict( self, data: Union[str, Iterable], pred: str = None, lang: str = None, prob: bool = False, batch_size: int = 5000, buckets: int = 8, workers: int = 0, amp: bool = False, cache: bool = False, tree: bool = True, proj: bool = True, verbose: bool = True, **kwargs ): return super().predict(**Config().update(locals()))
def train_step(self, batch: Batch) -> torch.Tensor: words, _, *feats, arcs, rels = batch mask = batch.mask # ignore the first token of each sentence mask[:, 0] = 0 s_arc, s_sib, s_rel = self.model(words, feats) loss, *_ = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask) return loss @torch.no_grad() def eval_step(self, batch: Batch) -> AttachmentMetric: words, _, *feats, arcs, rels = batch mask = batch.mask # ignore the first token of each sentence mask[:, 0] = 0 s_arc, s_sib, s_rel = self.model(words, feats) loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask) arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) if self.args.partial: mask &= arcs.ge(0) # ignore all punctuation if not specified if not self.args.punct: mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words])) return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask) @torch.no_grad() def pred_step(self, batch: Batch) -> Batch: words, _, *feats = batch mask, lens = batch.mask, (batch.lens - 1).tolist() # ignore the first token of each sentence mask[:, 0] = 0 s_arc, s_sib, s_rel = self.model(words, feats) s_arc = self.model.inference((s_arc, s_sib), mask) arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)] batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)] if self.args.prob: batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.unbind())] return batch