# -*- coding: utf-8 -*-
from typing import List, Tuple
import torch
import torch.nn as nn
from supar.config import Config
from supar.model import Model
from supar.models.const.aj.transform import AttachJuxtaposeTree
from supar.modules import GraphConvolutionalNetwork
from supar.utils.common import INF
from supar.utils.fn import pad
[docs]class AttachJuxtaposeConstituencyModel(Model):
r"""
The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`.
Args:
n_words (int):
The size of the word vocabulary.
n_labels (int):
The number of labels in the treebank.
n_tags (int):
The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
n_chars (int):
The number of characters, required if character-level representations are used. Default: ``None``.
encoder (str):
Encoder to use.
``'lstm'``: BiLSTM encoder.
``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
Default: ``'lstm'``.
feat (List[str]):
Additional features to use, required if ``encoder='lstm'``.
``'tag'``: POS tag embeddings.
``'char'``: Character-level representations extracted by CharLSTM.
``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
Default: [``'char'``].
n_embed (int):
The size of word embeddings. Default: 100.
n_pretrained (int):
The size of pretrained word embeddings. Default: 100.
n_feat_embed (int):
The size of feature representations. Default: 100.
n_char_embed (int):
The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
n_char_hidden (int):
The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
char_pad_index (int):
The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
elmo (str):
Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
elmo_bos_eos (Tuple[bool]):
A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
Default: ``(True, False)``.
bert (str):
Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
Default: ``None``.
n_bert_layers (int):
Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
The final outputs would be weighted sum of the hidden states of these layers.
Default: 4.
mix_dropout (float):
The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
bert_pooling (str):
Pooling way to get token embeddings.
``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
Default: ``mean``.
bert_pad_index (int):
The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
Default: 0.
finetune (bool):
If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
n_plm_embed (int):
The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
embed_dropout (float):
The dropout ratio of input embeddings. Default: .33.
n_encoder_hidden (int):
The size of encoder hidden states. Default: 800.
n_encoder_layers (int):
The number of encoder layers. Default: 3.
encoder_dropout (float):
The dropout ratio of encoder layers. Default: .33.
n_gnn_layers (int):
The number of GNN layers. Default: 3.
gnn_dropout (float):
The dropout ratio of GNN layers. Default: .33.
pad_index (int):
The index of the padding token in the word vocabulary. Default: 0.
unk_index (int):
The index of the unknown token in the word vocabulary. Default: 1.
.. _transformers:
https://github.com/huggingface/transformers
"""
def __init__(self,
n_words,
n_labels,
n_tags=None,
n_chars=None,
encoder='lstm',
feat=['char'],
n_embed=100,
n_pretrained=100,
n_feat_embed=100,
n_char_embed=50,
n_char_hidden=100,
char_pad_index=0,
elmo='original_5b',
elmo_bos_eos=(True, True),
bert=None,
n_bert_layers=4,
mix_dropout=.0,
bert_pooling='mean',
bert_pad_index=0,
finetune=False,
n_plm_embed=0,
embed_dropout=.33,
n_encoder_hidden=800,
n_encoder_layers=3,
encoder_dropout=.33,
n_gnn_layers=3,
gnn_dropout=.33,
pad_index=0,
unk_index=1,
**kwargs):
super().__init__(**Config().update(locals()))
# the last one represents the dummy node in the initial states
self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden)
self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden,
n_layers=self.args.n_gnn_layers,
dropout=self.args.gnn_dropout)
self.node_classifier = nn.Sequential(
nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
nn.LayerNorm(self.args.n_encoder_hidden // 2),
nn.ReLU(),
nn.Linear(self.args.n_encoder_hidden // 2, 1),
)
self.label_classifier = nn.Sequential(
nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2),
nn.LayerNorm(self.args.n_encoder_hidden // 2),
nn.ReLU(),
nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels),
)
self.criterion = nn.CrossEntropyLoss()
[docs] def forward(
self,
words: torch.LongTensor,
feats: List[torch.LongTensor] = None
) -> torch.Tensor:
r"""
Args:
words (~torch.LongTensor): ``[batch_size, seq_len]``.
Word indices.
feats (List[~torch.LongTensor]):
A list of feat indices.
The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``,
or ``[batch_size, seq_len]`` otherwise.
Default: ``None``.
Returns:
~torch.Tensor:
Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input.
"""
return self.encode(words, feats)
[docs] def loss(
self,
x: torch.Tensor,
nodes: torch.LongTensor,
parents: torch.LongTensor,
news: torch.LongTensor,
mask: torch.BoolTensor
) -> torch.Tensor:
r"""
Args:
x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
Contextualized output hidden states.
nodes (~torch.LongTensor): ``[batch_size, seq_len]``.
The target node positions on rightmost chains.
parents (~torch.LongTensor): ``[batch_size, seq_len]``.
The parent node labels of terminals.
news (~torch.LongTensor): ``[batch_size, seq_len]``.
The parent node labels of juxtaposed targets and terminals.
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
The mask for covering the unpadded tokens in each chart.
Returns:
~torch.Tensor:
The training loss.
"""
spans, s_node, x_node = None, [], []
actions = torch.stack((nodes, parents, news))
for t, action in enumerate(actions.unbind(-1)):
if t == 0:
x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels))
span_mask = mask[:, :1]
else:
x_span = self.rightmost_chain(x, spans, mask, t)
span_lens = spans[:, :-1, -1].ge(0).sum(-1)
span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)
s_node.append(self.node_classifier(x_rightmost).squeeze(-1))
# we found softmax is slightly better than sigmoid in the original paper
s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0)
x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1))
spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index)
s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1)
s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1)
s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1)
s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1)
node_loss = self.criterion(s_node[mask], nodes[mask])
label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask])
return node_loss + label_loss
[docs] def decode(
self,
x: torch.Tensor,
mask: torch.BoolTensor,
beam_size: int = 1
) -> List[List[Tuple]]:
r"""
Args:
x (~torch.Tensor): ``[batch_size, seq_len, n_model]``.
Contextualized output hidden states.
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
The mask for covering the unpadded tokens in each chart.
beam_size (int):
Beam size for decoding. Default: 1.
Returns:
List[List[Tuple]]:
Sequences of factorized labeled trees.
"""
spans = None
batch_size, *_ = x.shape
n_labels = self.args.n_labels
# [batch_size * beam_size, ...]
x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:])
mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:])
# [batch_size]
batches = x.new_tensor(range(batch_size)).long() * beam_size
# accumulated scores
scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1)
for t in range(x.shape[1]):
if t == 0:
x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels))
span_mask = mask[:, :1]
else:
x_span = self.rightmost_chain(x, spans, mask, t)
span_lens = spans[:, :-1, -1].ge(0).sum(-1)
span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1)
s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1)
# we found softmax is slightly better than sigmoid in the original paper
x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1)
s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1)
s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1)
if t == 0:
s_parent[:, self.args.nul_index] = -INF
s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF
s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1)
s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1)
s_new, news = s_new.topk(min(n_labels, beam_size), -1)
s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1)
s_action = s_action.view(x.shape[0], -1)
k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1]
# [batch_size * beam_size, k_beam]
scores = scores.unsqueeze(-1) + s_action
# [batch_size, beam_size]
scores, cands = scores.view(batch_size, -1).topk(beam_size, -1)
# [batch_size * beam_size]
scores = scores.view(-1)
beams = cands.div(k_beam, rounding_mode='floor')
nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor'))
indices = (batches.unsqueeze(-1) + beams).view(-1)
parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent)
news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent)
action = torch.stack((nodes, parents, news)).view(3, -1)
spans = spans[indices] if spans is not None else None
spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t])
mask = mask.view(batch_size, beam_size, -1)[:, 0]
# select an 1-best tree for each sentence
spans = spans[batches + scores.view(batch_size, -1).argmax(-1)]
span_mask = spans.ge(0)
span_indices = torch.where(span_mask)
span_labels = spans[span_indices]
chart_preds = [[] for _ in range(x.shape[0])]
for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()):
chart_preds[i].append(span)
return chart_preds
def rightmost_chain(
self,
x: torch.Tensor,
spans: torch.LongTensor,
mask: torch.BoolTensor,
t: int
) -> torch.Tensor:
x_p, mask_p = x[:, :t], mask[:, :t]
lens = mask_p.sum(-1)
span_mask = spans[:, :-1, 1:].ge(0)
span_lens = span_mask.sum((-1, -2))
span_indices = torch.where(span_mask)
span_labels = spans[:, :-1, 1:][span_indices]
x_span = self.label_embed(span_labels)
x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]]
node_lens = lens + span_lens
adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max())))
x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1])))
span_mask = ~x_mask & adj_mask
# concatenate terminals and spans
x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p])
x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span)
adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1])
adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1)
adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:]))
adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0]
adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2))
# set the parent of root as itself
adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1))
adj_parent = adj_parent & span_mask.unsqueeze(1)
# closet ancestor spans as parents
adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1)
adj.scatter_(-1, adj_parent.unsqueeze(-1), 1)
adj = (adj | adj.transpose(-1, -2)).float()
x_tree = self.gnn_layers(x_tree, adj, adj_mask)
span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1))
span_lens = span_mask.sum(-1)
x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max())))
x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree)
return x_span