Source code for supar.models.const.vi.model

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from supar.config import Config
from supar.models.const.crf.model import CRFConstituencyModel
from supar.modules import MLP, Biaffine, Triaffine
from supar.structs import ConstituencyCRF, ConstituencyLBP, ConstituencyMFVI


[docs]class VIConstituencyModel(CRFConstituencyModel): r""" The implementation of Constituency Parser using variational inference. Args: n_words (int): The size of the word vocabulary. n_labels (int): The number of labels in the treebank. n_tags (int): The number of POS tags, required if POS tag embeddings are used. Default: ``None``. n_chars (int): The number of characters, required if character-level representations are used. Default: ``None``. encoder (str): Encoder to use. ``'lstm'``: BiLSTM encoder. ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. Default: ``'lstm'``. feat (List[str]): Additional features to use, required if ``encoder='lstm'``. ``'tag'``: POS tag embeddings. ``'char'``: Character-level representations extracted by CharLSTM. ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. Default: [``'char'``]. n_embed (int): The size of word embeddings. Default: 100. n_pretrained (int): The size of pretrained word embeddings. Default: 100. n_feat_embed (int): The size of feature representations. Default: 100. n_char_embed (int): The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. n_char_hidden (int): The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. char_pad_index (int): The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. elmo (str): Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. elmo_bos_eos (Tuple[bool]): A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. Default: ``(True, False)``. bert (str): Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. Default: ``None``. n_bert_layers (int): Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. The final outputs would be weighted sum of the hidden states of these layers. Default: 4. mix_dropout (float): The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. bert_pooling (str): Pooling way to get token embeddings. ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. Default: ``mean``. bert_pad_index (int): The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. Default: 0. finetune (bool): If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. n_plm_embed (int): The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. embed_dropout (float): The dropout ratio of input embeddings. Default: .33. n_encoder_hidden (int): The size of encoder hidden states. Default: 800. n_encoder_layers (int): The number of encoder layers. Default: 3. encoder_dropout (float): The dropout ratio of encoder layer. Default: .33. n_span_mlp (int): Span MLP size. Default: 500. n_pair_mlp (int): Binary factor MLP size. Default: 100. n_label_mlp (int): Label MLP size. Default: 100. mlp_dropout (float): The dropout ratio of MLP layers. Default: .33. inference (str): Approximate inference methods. Default: ``mfvi``. max_iter (int): Max iteration times for inference. Default: 3. interpolation (int): Constant to even out the label/edge loss. Default: .1. pad_index (int): The index of the padding token in the word vocabulary. Default: 0. unk_index (int): The index of the unknown token in the word vocabulary. Default: 1. .. _transformers: https://github.com/huggingface/transformers """ def __init__(self, n_words, n_labels, n_tags=None, n_chars=None, encoder='lstm', feat=['char'], n_embed=100, n_pretrained=100, n_feat_embed=100, n_char_embed=50, n_char_hidden=100, char_pad_index=0, elmo='original_5b', elmo_bos_eos=(True, True), bert=None, n_bert_layers=4, mix_dropout=.0, bert_pooling='mean', bert_pad_index=0, finetune=False, n_plm_embed=0, embed_dropout=.33, n_encoder_hidden=800, n_encoder_layers=3, encoder_dropout=.33, n_span_mlp=500, n_pair_mlp=100, n_label_mlp=100, mlp_dropout=.33, inference='mfvi', max_iter=3, interpolation=0.1, pad_index=0, unk_index=1, **kwargs): super().__init__(**Config().update(locals())) self.span_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout) self.span_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout) self.pair_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout) self.pair_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout) self.pair_mlp_b = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout) self.label_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout) self.label_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout) self.span_attn = Biaffine(n_in=n_span_mlp, bias_x=True, bias_y=False) self.pair_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=False) self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True) self.inference = (ConstituencyMFVI if inference == 'mfvi' else ConstituencyLBP)(max_iter) self.criterion = nn.CrossEntropyLoss()
[docs] def forward(self, words, feats): r""" Args: words (~torch.LongTensor): ``[batch_size, seq_len]``. Word indices. feats (List[~torch.LongTensor]): A list of feat indices. The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, or ``[batch_size, seq_len]`` otherwise. Returns: ~torch.Tensor, ~torch.Tensor, ~torch.Tensor: Scores of all possible constituents (``[batch_size, seq_len, seq_len]``), second-order triples (``[batch_size, seq_len, seq_len, n_labels]``) and all possible labels on each constituent (``[batch_size, seq_len, seq_len, n_labels]``). """ x = self.encode(words, feats) x_f, x_b = x.chunk(2, -1) x = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1) span_l = self.span_mlp_l(x) span_r = self.span_mlp_r(x) pair_l = self.pair_mlp_l(x) pair_r = self.pair_mlp_r(x) pair_b = self.pair_mlp_b(x) label_l = self.label_mlp_l(x) label_r = self.label_mlp_r(x) # [batch_size, seq_len, seq_len] s_span = self.span_attn(span_l, span_r) s_pair = self.pair_attn(pair_l, pair_r, pair_b).permute(0, 3, 1, 2) # [batch_size, seq_len, seq_len, n_labels] s_label = self.label_attn(label_l, label_r).permute(0, 2, 3, 1) return s_span, s_pair, s_label
[docs] def loss(self, s_span, s_pair, s_label, charts, mask): r""" Args: s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all constituents. s_pair (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. Scores of second-order triples. s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. Scores of all constituent labels. charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. The tensor of gold-standard labels. Positions without labels are filled with -1. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask for covering the unpadded tokens in each chart. Returns: ~torch.Tensor, ~torch.Tensor: The training loss and marginals of shape ``[batch_size, seq_len, seq_len]``. """ span_mask = charts.ge(0) & mask span_loss, span_probs = self.inference((s_span, s_pair), mask, span_mask) label_loss = self.criterion(s_label[span_mask], charts[span_mask]) loss = self.args.interpolation * label_loss + (1 - self.args.interpolation) * span_loss return loss, span_probs
[docs] def decode(self, s_span, s_label, mask): r""" Args: s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all constituents. s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. Scores of all constituent labels. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask for covering the unpadded tokens in each chart. Returns: List[List[Tuple]]: Sequences of factorized labeled trees. """ span_preds = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).argmax label_preds = s_label.argmax(-1).tolist() return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)]