# -*- coding: utf-8 -*-
import torch
from supar.models.dep.biaffine.model import BiaffineDependencyModel
from supar.structs import DependencyCRF, MatrixTree
[docs]class CRFDependencyModel(BiaffineDependencyModel):
r"""
The implementation of first-order CRF Dependency Parser
:cite:`zhang-etal-2020-efficient,ma-hovy-2017-neural,koo-etal-2007-structured`).
Args:
n_words (int):
The size of the word vocabulary.
n_rels (int):
The number of labels in the treebank.
n_tags (int):
The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
n_chars (int):
The number of characters, required if character-level representations are used. Default: ``None``.
encoder (str):
Encoder to use.
``'lstm'``: BiLSTM encoder.
``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
Default: ``'lstm'``.
feat (List[str]):
Additional features to use, required if ``encoder='lstm'``.
``'tag'``: POS tag embeddings.
``'char'``: Character-level representations extracted by CharLSTM.
``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible.
Default: [``'char'``].
n_embed (int):
The size of word embeddings. Default: 100.
n_pretrained (int):
The size of pretrained word embeddings. Default: 100.
n_feat_embed (int):
The size of feature representations. Default: 100.
n_char_embed (int):
The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
n_char_hidden (int):
The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100.
char_pad_index (int):
The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
elmo (str):
Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
elmo_bos_eos (Tuple[bool]):
A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
Default: ``(True, False)``.
bert (str):
Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
Default: ``None``.
n_bert_layers (int):
Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
The final outputs would be weighted sum of the hidden states of these layers.
Default: 4.
mix_dropout (float):
The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
bert_pooling (str):
Pooling way to get token embeddings.
``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
Default: ``mean``.
bert_pad_index (int):
The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
Default: 0.
finetune (bool):
If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``.
n_plm_embed (int):
The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0.
embed_dropout (float):
The dropout ratio of input embeddings. Default: .33.
n_encoder_hidden (int):
The size of encoder hidden states. Default: 800.
n_encoder_layers (int):
The number of encoder layers. Default: 3.
encoder_dropout (float):
The dropout ratio of encoder layer. Default: .33.
n_arc_mlp (int):
Arc MLP size. Default: 500.
n_rel_mlp (int):
Label MLP size. Default: 100.
mlp_dropout (float):
The dropout ratio of MLP layers. Default: .33.
scale (float):
Scaling factor for affine scores. Default: 0.
pad_index (int):
The index of the padding token in the word vocabulary. Default: 0.
unk_index (int):
The index of the unknown token in the word vocabulary. Default: 1.
proj (bool):
If ``True``, takes :class:`DependencyCRF` as inference layer, :class:`MatrixTree` otherwise.
Default: ``True``.
.. _transformers:
https://github.com/huggingface/transformers
"""
[docs] def loss(self, s_arc, s_rel, arcs, rels, mask, mbr=True, partial=False):
r"""
Args:
s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
Scores of all possible arcs.
s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
Scores of all possible labels on each arc.
arcs (~torch.LongTensor): ``[batch_size, seq_len]``.
The tensor of gold-standard arcs.
rels (~torch.LongTensor): ``[batch_size, seq_len]``.
The tensor of gold-standard labels.
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
The mask for covering the unpadded tokens.
mbr (bool):
If ``True``, returns marginals for MBR decoding. Default: ``True``.
partial (bool):
``True`` denotes the trees are partially annotated. Default: ``False``.
Returns:
~torch.Tensor, ~torch.Tensor:
The training loss and
original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
"""
CRF = DependencyCRF if self.args.proj else MatrixTree
arc_dist = CRF(s_arc, mask.sum(-1))
arc_loss = -arc_dist.log_prob(arcs, partial=partial).sum() / mask.sum()
arc_probs = arc_dist.marginals if mbr else s_arc
# -1 denotes un-annotated arcs
if partial:
mask = mask & arcs.ge(0)
s_rel, rels = s_rel[mask], rels[mask]
s_rel = s_rel[torch.arange(len(rels)), arcs[mask]]
rel_loss = self.criterion(s_rel, rels)
loss = arc_loss + rel_loss
return loss, arc_probs