Source code for supar.models.dep.crf.model

# -*- coding: utf-8 -*-

import torch
from supar.models.dep.biaffine.model import BiaffineDependencyModel
from supar.structs import DependencyCRF, MatrixTree


[docs]class CRFDependencyModel(BiaffineDependencyModel): r""" The implementation of first-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient,ma-hovy-2017-neural,koo-etal-2007-structured`). Args: n_words (int): The size of the word vocabulary. n_rels (int): The number of labels in the treebank. n_tags (int): The number of POS tags, required if POS tag embeddings are used. Default: ``None``. n_chars (int): The number of characters, required if character-level representations are used. Default: ``None``. encoder (str): Encoder to use. ``'lstm'``: BiLSTM encoder. ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. Default: ``'lstm'``. feat (List[str]): Additional features to use, required if ``encoder='lstm'``. ``'tag'``: POS tag embeddings. ``'char'``: Character-level representations extracted by CharLSTM. ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. Default: [``'char'``]. n_embed (int): The size of word embeddings. Default: 100. n_pretrained (int): The size of pretrained word embeddings. Default: 100. n_feat_embed (int): The size of feature representations. Default: 100. n_char_embed (int): The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. n_char_hidden (int): The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. char_pad_index (int): The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. elmo (str): Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. elmo_bos_eos (Tuple[bool]): A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. Default: ``(True, False)``. bert (str): Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. Default: ``None``. n_bert_layers (int): Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. The final outputs would be weighted sum of the hidden states of these layers. Default: 4. mix_dropout (float): The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. bert_pooling (str): Pooling way to get token embeddings. ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. Default: ``mean``. bert_pad_index (int): The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. Default: 0. finetune (bool): If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. n_plm_embed (int): The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. embed_dropout (float): The dropout ratio of input embeddings. Default: .33. n_encoder_hidden (int): The size of encoder hidden states. Default: 800. n_encoder_layers (int): The number of encoder layers. Default: 3. encoder_dropout (float): The dropout ratio of encoder layer. Default: .33. n_arc_mlp (int): Arc MLP size. Default: 500. n_rel_mlp (int): Label MLP size. Default: 100. mlp_dropout (float): The dropout ratio of MLP layers. Default: .33. scale (float): Scaling factor for affine scores. Default: 0. pad_index (int): The index of the padding token in the word vocabulary. Default: 0. unk_index (int): The index of the unknown token in the word vocabulary. Default: 1. proj (bool): If ``True``, takes :class:`DependencyCRF` as inference layer, :class:`MatrixTree` otherwise. Default: ``True``. .. _transformers: https://github.com/huggingface/transformers """
[docs] def loss(self, s_arc, s_rel, arcs, rels, mask, mbr=True, partial=False): r""" Args: s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible arcs. s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. Scores of all possible labels on each arc. arcs (~torch.LongTensor): ``[batch_size, seq_len]``. The tensor of gold-standard arcs. rels (~torch.LongTensor): ``[batch_size, seq_len]``. The tensor of gold-standard labels. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask for covering the unpadded tokens. mbr (bool): If ``True``, returns marginals for MBR decoding. Default: ``True``. partial (bool): ``True`` denotes the trees are partially annotated. Default: ``False``. Returns: ~torch.Tensor, ~torch.Tensor: The training loss and original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise. """ CRF = DependencyCRF if self.args.proj else MatrixTree arc_dist = CRF(s_arc, mask.sum(-1)) arc_loss = -arc_dist.log_prob(arcs, partial=partial).sum() / mask.sum() arc_probs = arc_dist.marginals if mbr else s_arc # -1 denotes un-annotated arcs if partial: mask = mask & arcs.ge(0) s_rel, rels = s_rel[mask], rels[mask] s_rel = s_rel[torch.arange(len(rels)), arcs[mask]] rel_loss = self.criterion(s_rel, rels) loss = arc_loss + rel_loss return loss, arc_probs