Source code for supar.models.const.tt.model

# -*- coding: utf-8 -*-

from typing import List, Tuple

import torch
import torch.nn as nn
from supar.config import Config
from supar.model import Model
from supar.utils.common import INF


[docs]class TetraTaggingConstituencyModel(Model): r""" The implementation of TetraTagging Constituency Parser :cite:`kitaev-klein-2020-tetra`. Args: n_words (int): The size of the word vocabulary. n_tags (int): The number of POS tags, required if POS tag embeddings are used. Default: ``None``. n_chars (int): The number of characters, required if character-level representations are used. Default: ``None``. encoder (str): Encoder to use. ``'lstm'``: BiLSTM encoder. ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. Default: ``'lstm'``. feat (List[str]): Additional features to use, required if ``encoder='lstm'``. ``'tag'``: POS tag embeddings. ``'char'``: Character-level representations extracted by CharLSTM. ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. Default: [``'char'``]. n_embed (int): The size of word embeddings. Default: 100. n_pretrained (int): The size of pretrained word embeddings. Default: 100. n_feat_embed (int): The size of feature representations. Default: 100. n_char_embed (int): The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. n_char_hidden (int): The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. char_pad_index (int): The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. elmo (str): Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. elmo_bos_eos (Tuple[bool]): A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. Default: ``(True, False)``. bert (str): Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. Default: ``None``. n_bert_layers (int): Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. The final outputs would be weighted sum of the hidden states of these layers. Default: 4. mix_dropout (float): The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. bert_pooling (str): Pooling way to get token embeddings. ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. Default: ``mean``. bert_pad_index (int): The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. Default: 0. finetune (bool): If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. n_plm_embed (int): The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. embed_dropout (float): The dropout ratio of input embeddings. Default: .33. n_encoder_hidden (int): The size of encoder hidden states. Default: 800. n_encoder_layers (int): The number of encoder layers. Default: 3. encoder_dropout (float): The dropout ratio of encoder layers. Default: .33. n_gnn_layers (int): The number of GNN layers. Default: 3. gnn_dropout (float): The dropout ratio of GNN layers. Default: .33. pad_index (int): The index of the padding token in the word vocabulary. Default: 0. unk_index (int): The index of the unknown token in the word vocabulary. Default: 1. .. _transformers: https://github.com/huggingface/transformers """ def __init__(self, n_words, n_tags=None, n_chars=None, encoder='lstm', feat=['char'], n_embed=100, n_pretrained=100, n_feat_embed=100, n_char_embed=50, n_char_hidden=100, char_pad_index=0, elmo='original_5b', elmo_bos_eos=(True, True), bert=None, n_bert_layers=4, mix_dropout=.0, bert_pooling='mean', bert_pad_index=0, finetune=False, n_plm_embed=0, embed_dropout=.33, n_encoder_hidden=800, n_encoder_layers=3, encoder_dropout=.33, n_gnn_layers=3, gnn_dropout=.33, pad_index=0, unk_index=1, **kwargs): super().__init__(**Config().update(locals())) self.proj = nn.Linear(self.args.n_encoder_hidden, self.args.n_leaves + self.args.n_nodes) self.criterion = nn.CrossEntropyLoss()
[docs] def forward( self, words: torch.LongTensor, feats: List[torch.LongTensor] = None ) -> torch.Tensor: r""" Args: words (~torch.LongTensor): ``[batch_size, seq_len]``. Word indices. feats (List[~torch.LongTensor]): A list of feat indices. The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, or ``[batch_size, seq_len]`` otherwise. Default: ``None``. Returns: ~torch.Tensor, ~torch.Tensor: Scores for all leaves (``[batch_size, seq_len, n_leaves]``) and nodes (``[batch_size, seq_len, n_nodes]``). """ s = self.proj(self.encode(words, feats)[:, 1:-1]) s_leaf, s_node = s[..., :self.args.n_leaves], s[..., self.args.n_leaves:] return s_leaf, s_node
[docs] def loss( self, s_leaf: torch.Tensor, s_node: torch.Tensor, leaves: torch.LongTensor, nodes: torch.LongTensor, mask: torch.BoolTensor ) -> torch.Tensor: r""" Args: s_leaf (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``. Leaf scores. s_node (~torch.Tensor): ``[batch_size, seq_len, n_nodes]``. Non-terminal scores. leaves (~torch.LongTensor): ``[batch_size, seq_len]``. Actions for leaves. nodes (~torch.LongTensor): ``[batch_size, seq_len]``. Actions for non-terminals. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask for covering the unpadded tokens in each chart. Returns: ~torch.Tensor: The training loss. """ leaf_mask, node_mask = mask, mask[:, 1:] leaf_loss = self.criterion(s_leaf[leaf_mask], leaves[leaf_mask]) node_loss = self.criterion(s_node[:, :-1][node_mask], nodes[node_mask]) if nodes.shape[1] > 0 else 0 return leaf_loss + node_loss
[docs] def decode( self, s_leaf: torch.Tensor, s_node: torch.Tensor, mask: torch.BoolTensor, left_mask: torch.BoolTensor, depth: int = 8 ) -> List[List[Tuple]]: r""" Args: s_leaf (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``. Leaf scores. s_node (~torch.Tensor): ``[batch_size, seq_len, n_nodes]``. Non-terminal scores. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask for covering the unpadded tokens in each chart. left_mask (~torch.BoolTensor): ``[n_leaves + n_nodes]``. The mask for distingushing left/rightward actions. depth (int): Stack depth. Default: 8. Returns: List[List[Tuple]]: Sequences of factorized labeled trees. """ from torch_scatter import scatter_max lens = mask.sum(-1) batch_size, seq_len, n_leaves = s_leaf.shape leaf_left_mask, node_left_mask = left_mask[:n_leaves], left_mask[n_leaves:] # [n_leaves], [n_nodes] changes = (torch.where(leaf_left_mask, 1, 0), torch.where(node_left_mask, 0, -1)) # [batch_size, depth] depths = lens.new_full((depth,), -2).index_fill_(-1, lens.new_tensor(0), -1).repeat(batch_size, 1) # [2, batch_size, depth, seq_len] labels, paths = lens.new_zeros(2, batch_size, depth, seq_len), lens.new_zeros(2, batch_size, depth, seq_len) # [batch_size, depth] s = s_leaf.new_zeros(batch_size, depth) def advance(s, s_t, depths, changes): batch_size, n_labels = s_t.shape # [batch_size, depth * n_labels] depths = (depths.unsqueeze(-1) + changes).view(batch_size, -1) # [batch_size, depth, n_labels] s_t = s.unsqueeze(-1) + s_t.unsqueeze(1) # [batch_size, depth * n_labels] # fill scores of invalid depths with -INF s_t = s_t.view(batch_size, -1).masked_fill_((depths < 0).logical_or_(depths >= depth), -INF) # [batch_size, depth] # for each depth, we use the `scatter_max` trick to obtain the 1-best label s, ls = scatter_max(s_t, depths.clamp(0, depth - 1), -1, s_t.new_full((batch_size, depth), -INF)) # [batch_size, depth] depths = depths.gather(-1, ls.clamp(0, depths.shape[1] - 1)).masked_fill_(s.eq(-INF), -1) ll = ls % n_labels lp = depths - changes[ll] return s, ll, lp, depths for t in range(seq_len): m = lens.gt(t) s[m], labels[0, m, :, t], paths[0, m, :, t], depths[m] = advance(s[m], s_leaf[m, t], depths[m], changes[0]) if t == seq_len - 1: break m = lens.gt(t + 1) s[m], labels[1, m, :, t], paths[1, m, :, t], depths[m] = advance(s[m], s_node[m, t], depths[m], changes[1]) lens = lens.tolist() labels, paths = labels.movedim((0, 2), (2, 3))[mask].split(lens), paths.movedim((0, 2), (2, 3))[mask].split(lens) leaves, nodes = [], [] for i, length in enumerate(lens): leaf_labels, node_labels = labels[i].transpose(0, 1).tolist() leaf_paths, node_paths = paths[i].transpose(0, 1).tolist() leaf_pred, node_pred, prev = [leaf_labels[-1][0]], [], leaf_paths[-1][0] for j in reversed(range(length - 1)): node_pred.append(node_labels[j][prev]) prev = node_paths[j][prev] leaf_pred.append(leaf_labels[j][prev]) prev = leaf_paths[j][prev] leaves.append(list(reversed(leaf_pred))) nodes.append(list(reversed(node_pred))) return leaves, nodes