Source code for supar.structs.chain

# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import List, Optional

import torch
from supar.structs.dist import StructuredDistribution
from supar.structs.semiring import LogSemiring, Semiring
from torch.distributions.utils import lazy_property


[docs]class LinearChainCRF(StructuredDistribution): r""" Linear-chain CRFs :cite:`lafferty-etal-2001-crf`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, n_tags]``. Log potentials. trans (~torch.Tensor): ``[n_tags+1, n_tags+1]``. Transition scores. ``trans[-1, :-1]``/``trans[:-1, -1]`` represent transitions for start/end positions respectively. lens (~torch.LongTensor): ``[batch_size]``. Sentence lengths for masking. Default: ``None``. Examples: >>> from supar import LinearChainCRF >>> batch_size, seq_len, n_tags = 2, 5, 4 >>> lens = torch.tensor([3, 4]) >>> value = torch.randint(n_tags, (batch_size, seq_len)) >>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens) >>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens) >>> s1.max tensor([4.4120, 8.9672], grad_fn=<MaxBackward0>) >>> s1.argmax tensor([[2, 0, 3, 0, 0], [3, 3, 3, 2, 0]]) >>> s1.log_partition tensor([ 6.3486, 10.9106], grad_fn=<LogsumexpBackward>) >>> s1.log_prob(value) tensor([ -8.1515, -10.5572], grad_fn=<SubBackward0>) >>> s1.entropy tensor([3.4150, 3.6549], grad_fn=<SelectBackward>) >>> s1.kl(s2) tensor([4.0333, 4.3807], grad_fn=<SelectBackward>) """ def __init__( self, scores: torch.Tensor, trans: Optional[torch.Tensor] = None, lens: Optional[torch.LongTensor] = None ) -> LinearChainCRF: super().__init__(scores, lens=lens) batch_size, seq_len, self.n_tags = scores.shape[:3] self.lens = scores.new_full((batch_size,), seq_len).long() if lens is None else lens self.mask = self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len))) self.trans = self.scores.new_full((self.n_tags+1, self.n_tags+1), LogSemiring.one) if trans is None else trans def __repr__(self): return f"{self.__class__.__name__}(n_tags={self.n_tags})" def __add__(self, other): return LinearChainCRF(torch.stack((self.scores, other.scores), -1), torch.stack((self.trans, other.trans), -1), self.lens) @lazy_property def argmax(self): return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2])
[docs] def topk(self, k: int) -> torch.LongTensor: preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1) return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
def score(self, value: torch.LongTensor) -> torch.Tensor: scores, mask, value = self.scores.transpose(0, 1), self.mask.t(), value.t() prev, succ = torch.cat((torch.full_like(value[:1], -1), value[:-1]), 0), value # [seq_len, batch_size] alpha = scores.gather(-1, value.unsqueeze(-1)).squeeze(-1) # [batch_size] alpha = LogSemiring.prod(LogSemiring.one_mask(LogSemiring.mul(alpha, self.trans[prev, succ]), ~mask), 0) alpha = alpha + self.trans[value.gather(0, self.lens.unsqueeze(0) - 1).squeeze(0), torch.full_like(value[0], -1)] return alpha def forward(self, semiring: Semiring) -> torch.Tensor: # [seq_len, batch_size, n_tags, ...] scores = semiring.convert(self.scores.transpose(0, 1)) trans = semiring.convert(self.trans) mask = self.mask.t() # [batch_size, n_tags] alpha = semiring.mul(trans[-1, :-1], scores[0]) for i in range(1, len(mask)): alpha[mask[i]] = semiring.mul(semiring.dot(alpha.unsqueeze(2), trans[:-1, :-1], 1), scores[i])[mask[i]] alpha = semiring.dot(alpha, trans[:-1, -1], 1) return semiring.unconvert(alpha)
[docs]class SemiMarkovCRF(StructuredDistribution): r""" Semi-markov CRFs :cite:`sarawagi-cohen-2004-semicrf`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_tags]``. Log potentials. trans (~torch.Tensor): ``[n_tags, n_tags]``. Transition scores. lens (~torch.LongTensor): ``[batch_size]``. Sentence lengths for masking. Default: ``None``. Examples: >>> from supar import SemiMarkovCRF >>> batch_size, seq_len, n_tags = 2, 5, 4 >>> lens = torch.tensor([3, 4]) >>> value = torch.tensor([[[ 0, -1, -1, -1, -1], [-1, -1, 2, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]], [[-1, 1, -1, -1, -1], [-1, -1, 3, -1, -1], [-1, -1, -1, 0, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]]]) >>> s1 = SemiMarkovCRF(torch.randn(batch_size, seq_len, seq_len, n_tags), torch.randn(n_tags, n_tags), lens) >>> s2 = SemiMarkovCRF(torch.randn(batch_size, seq_len, seq_len, n_tags), torch.randn(n_tags, n_tags), lens) >>> s1.max tensor([4.1971, 5.5746], grad_fn=<MaxBackward0>) >>> s1.argmax [[[0, 0, 1], [1, 1, 0], [2, 2, 1]], [[0, 0, 1], [1, 1, 3], [2, 2, 0], [3, 3, 1]]] >>> s1.log_partition tensor([6.3641, 8.4384], grad_fn=<LogsumexpBackward0>) >>> s1.log_prob(value) tensor([-5.7982, -7.4534], grad_fn=<SubBackward0>) >>> s1.entropy tensor([3.7520, 5.1609], grad_fn=<SelectBackward0>) >>> s1.kl(s2) tensor([3.5348, 2.2826], grad_fn=<SelectBackward0>) """ def __init__( self, scores: torch.Tensor, trans: Optional[torch.Tensor] = None, lens: Optional[torch.LongTensor] = None ) -> SemiMarkovCRF: super().__init__(scores, lens=lens) batch_size, seq_len, _, self.n_tags = scores.shape[:4] self.lens = scores.new_full((batch_size,), seq_len).long() if lens is None else lens self.mask = self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len))) self.mask = self.mask.unsqueeze(1) & self.mask.unsqueeze(2) self.trans = self.scores.new_full((self.n_tags, self.n_tags), LogSemiring.one) if trans is None else trans def __repr__(self): return f"{self.__class__.__name__}(n_tags={self.n_tags})" def __add__(self, other): return SemiMarkovCRF(torch.stack((self.scores, other.scores), -1), torch.stack((self.trans, other.trans), -1), self.lens) @lazy_property def argmax(self) -> List: return [torch.nonzero(i).tolist() for i in self.backward(self.max.sum())]
[docs] def topk(self, k: int) -> List: return list(zip(*[[torch.nonzero(j).tolist() for j in self.backward(i)] for i in self.kmax(k).sum(0)]))
def score(self, value: torch.LongTensor) -> torch.Tensor: mask = self.mask & value.ge(0) lens = mask.sum((1, 2)) indices = torch.where(mask) batch_size, seq_len = lens.shape[0], lens.max() span_mask = lens.unsqueeze(-1).gt(lens.new_tensor(range(seq_len))) scores = self.scores.new_full((batch_size, seq_len), LogSemiring.one) scores = scores.masked_scatter_(span_mask, self.scores[(*indices, value[indices])]) scores = LogSemiring.prod(LogSemiring.one_mask(scores, ~span_mask), -1) value = value.new_zeros(batch_size, seq_len).masked_scatter_(span_mask, value[indices]) trans = LogSemiring.prod(LogSemiring.one_mask(self.trans[value[:, :-1], value[:, 1:]], ~span_mask[:, 1:]), -1) return LogSemiring.mul(scores, trans) def forward(self, semiring: Semiring) -> torch.Tensor: # [seq_len, seq_len, batch_size, n_tags, ...] scores = semiring.convert(self.scores.movedim((1, 2), (0, 1))) trans = semiring.convert(self.trans) # [seq_len, batch_size, n_tags, ...] alpha = semiring.zeros_like(scores[0]) alpha[0] = scores[0, 0] # [batch_size, n_tags] for t in range(1, len(scores)): # [batch_size, n_tags, ...] s = semiring.dot(semiring.dot(alpha[:t].unsqueeze(3), trans, 2), scores[1:t+1, t], 0) alpha[t] = semiring.sum(torch.stack((s, scores[0, t])), 0) return semiring.unconvert(semiring.sum(alpha[self.lens - 1, range(len(self.lens))], 1))