Source code for supar.structs.vi

# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from supar.structs import DependencyCRF
from supar.utils.common import MIN


[docs]class DependencyMFVI(nn.Module): r""" Mean Field Variational Inference for approximately calculating marginals of dependency trees :cite:`wang-tu-2020-second`. """ def __init__(self, max_iter: int = 3) -> DependencyMFVI: super().__init__() self.max_iter = max_iter def __repr__(self): return f"{self.__class__.__name__}(max_iter={self.max_iter})"
[docs] @torch.enable_grad() def forward( self, scores: List[torch.Tensor], mask: torch.BoolTensor, target: Optional[torch.LongTensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: scores (~torch.Tensor, ~torch.Tensor): Tuple of three tensors `s_arc` and `s_sib`. `s_arc` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask to avoid aggregation on padding tokens. target (~torch.LongTensor): ``[batch_size, seq_len]``. A Tensor of gold-standard dependent-head pairs. Default: ``None``. Returns: ~torch.Tensor, ~torch.Tensor: The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. """ logits = self.mfvi(*scores, mask) marginals = logits.softmax(-1) if target is None: return marginals loss = F.cross_entropy(logits[mask], target[mask]) return loss, marginals
def mfvi(self, s_arc, s_sib, mask): batch_size, seq_len = mask.shape ls, rs = torch.stack(torch.where(mask.new_ones(seq_len, seq_len))).view(-1, seq_len, seq_len).sort(0)[0] mask = mask.index_fill(1, ls.new_tensor(0), 1) # [seq_len, seq_len, batch_size], (h->m) mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) # [seq_len, seq_len, batch_size], (h->m) s_arc = s_arc.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) s_sib = s_sib.permute(2, 1, 3, 0) * mask2o # posterior distributions # [seq_len, seq_len, batch_size], (h->m) q = s_arc for _ in range(self.max_iter): q = q.softmax(0) # q(ij) = s(ij) + sum(q(ik)s^sib(ij,ik)), k != i,j q = s_arc + (q.unsqueeze(1) * s_sib).sum(2) return q.permute(2, 1, 0)
[docs]class DependencyLBP(nn.Module): r""" Loopy Belief Propagation for approximately calculating marginals of dependency trees :cite:`smith-eisner-2008-dependency`. """ def __init__(self, max_iter: int = 3) -> DependencyLBP: super().__init__() self.max_iter = max_iter def __repr__(self): return f"{self.__class__.__name__}(max_iter={self.max_iter})"
[docs] @torch.enable_grad() def forward( self, scores: List[torch.Tensor], mask: torch.BoolTensor, target: Optional[torch.LongTensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: scores (~torch.Tensor, ~torch.Tensor): Tuple of three tensors `s_arc` and `s_sib`. `s_arc` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask to avoid aggregation on padding tokens. target (~torch.LongTensor): ``[batch_size, seq_len]``. A Tensor of gold-standard dependent-head pairs. Default: ``None``. Returns: ~torch.Tensor, ~torch.Tensor: The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. """ logits = self.lbp(*scores, mask) marginals = logits.softmax(-1) if target is None: return marginals loss = F.cross_entropy(logits[mask], target[mask]) return loss, marginals
def lbp(self, s_arc, s_sib, mask): batch_size, seq_len = mask.shape ls, rs = torch.stack(torch.where(mask.new_ones(seq_len, seq_len))).view(-1, seq_len, seq_len).sort(0)[0] mask = mask.index_fill(1, ls.new_tensor(0), 1) # [seq_len, seq_len, batch_size], (h->m) mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) # [seq_len, seq_len, batch_size], (h->m) s_arc = s_arc.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) s_sib = s_sib.permute(2, 1, 3, 0).masked_fill_(~mask2o, MIN) # log beliefs # [seq_len, seq_len, batch_size], (h->m) q = s_arc # [seq_len, seq_len, seq_len, batch_size], (h->m->s) m_sib = s_sib.new_zeros(seq_len, seq_len, seq_len, batch_size) for _ in range(self.max_iter): q = q.log_softmax(0) # m(ik->ij) = logsumexp(q(ik) - m(ij->ik) + s(ij->ik)) m = q.unsqueeze(2) - m_sib # TODO: better solution for OOM m_sib = torch.logaddexp(m.logsumexp(0), m + s_sib).transpose(1, 2).log_softmax(0) # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j q = s_arc + (m_sib * mask2o).sum(2) return q.permute(2, 1, 0)
[docs]class ConstituencyMFVI(nn.Module): r""" Mean Field Variational Inference for approximately calculating marginals of constituent trees. """ def __init__(self, max_iter: int = 3) -> ConstituencyMFVI: super().__init__() self.max_iter = max_iter def __repr__(self): return f"{self.__class__.__name__}(max_iter={self.max_iter})"
[docs] @torch.enable_grad() def forward( self, scores: List[torch.Tensor], mask: torch.BoolTensor, target: Optional[torch.LongTensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: scores (~torch.Tensor, ~torch.Tensor): Tuple of two tensors `s_span` and `s_pair`. `s_span` (``[batch_size, seq_len, seq_len]``) holds scores of all possible spans. `s_pair` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of second-order triples. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask to avoid aggregation on padding tokens. target (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. A Tensor of gold-standard dependent-head pairs. Default: ``None``. Returns: ~torch.Tensor, ~torch.Tensor: The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. """ logits = self.mfvi(*scores, mask) marginals = logits.sigmoid() if target is None: return marginals loss = F.binary_cross_entropy_with_logits(logits[mask], target[mask].float()) return loss, marginals
def mfvi(self, s_span, s_pair, mask): batch_size, seq_len, _ = mask.shape ls, rs = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len).sort(0)[0] # [seq_len, seq_len, batch_size], (l->r) mask = mask.movedim(0, 2) # [seq_len, seq_len, seq_len, batch_size], (l->r->b) mask2o = mask.unsqueeze(2).repeat(1, 1, seq_len, 1) mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) # [seq_len, seq_len, batch_size], (l->r) s_span = s_span.movedim(0, 2) # [seq_len, seq_len, seq_len, batch_size], (l->r->b) s_pair = s_pair.permute(1, 2, 3, 0) * mask2o # posterior distributions # [seq_len, seq_len, batch_size], (l->r) q = s_span for _ in range(self.max_iter): q = q.sigmoid() # q(ij) = s(ij) + sum(q(jk)*s^pair(ij,jk), k != i,j q = s_span + (q.unsqueeze(1) * s_pair).sum(2) return q.permute(2, 0, 1)
[docs]class ConstituencyLBP(nn.Module): r""" Loopy Belief Propagation for approximately calculating marginals of constituent trees. """ def __init__(self, max_iter: int = 3) -> ConstituencyLBP: super().__init__() self.max_iter = max_iter def __repr__(self): return f"{self.__class__.__name__}(max_iter={self.max_iter})"
[docs] @torch.enable_grad() def forward( self, scores: List[torch.Tensor], mask: torch.BoolTensor, target: Optional[torch.LongTensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: scores (~torch.Tensor, ~torch.Tensor): Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`. `s_span` (``[batch_size, seq_len, seq_len]``) holds scores of all possible spans. `s_pair` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of second-order triples. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask to avoid aggregation on padding tokens. target (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. A Tensor of gold-standard dependent-head pairs. Default: ``None``. Returns: ~torch.Tensor, ~torch.Tensor: The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. """ logits = self.lbp(*scores, mask) marginals = logits.softmax(-1)[..., 1] if target is None: return marginals loss = F.cross_entropy(logits[mask], target[mask].long()) return loss, marginals
def lbp(self, s_span, s_pair, mask): batch_size, seq_len, _ = mask.shape ls, rs = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len).sort(0)[0] # [seq_len, seq_len, batch_size], (l->r) mask = mask.movedim(0, 2) # [seq_len, seq_len, seq_len, batch_size], (l->r->b) mask2o = mask.unsqueeze(2).repeat(1, 1, seq_len, 1) mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) # [2, seq_len, seq_len, batch_size], (l->r) s_span = torch.stack((torch.zeros_like(s_span), s_span)).permute(0, 3, 2, 1) # [seq_len, seq_len, seq_len, batch_size], (l->r->p) s_pair = s_pair.permute(2, 1, 3, 0) # log beliefs # [2, seq_len, seq_len, batch_size], (h->m) q = s_span # [2, seq_len, seq_len, seq_len, batch_size], (h->m->s) m_pair = s_pair.new_zeros(2, seq_len, seq_len, seq_len, batch_size) for _ in range(self.max_iter): q = q.log_softmax(0) # m(ik->ij) = logsumexp(q(ik) - m(ij->ik) + s(ij->ik)) m = q.unsqueeze(3) - m_pair m_pair = torch.stack((m.logsumexp(0), torch.stack((m[0], m[1] + s_pair)).logsumexp(0))).log_softmax(0) # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j q = s_span + (m_pair.transpose(2, 3) * mask2o).sum(3) return q.permute(3, 2, 1, 0)
[docs]class SemanticDependencyMFVI(nn.Module): r""" Mean Field Variational Inference for approximately calculating marginals of semantic dependency trees :cite:`wang-etal-2019-second`. """ def __init__(self, max_iter: int = 3) -> SemanticDependencyMFVI: super().__init__() self.max_iter = max_iter def __repr__(self): return f"{self.__class__.__name__}(max_iter={self.max_iter})"
[docs] @torch.enable_grad() def forward( self, scores: List[torch.Tensor], mask: torch.BoolTensor, target: Optional[torch.LongTensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: scores (~torch.Tensor, ~torch.Tensor): Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`. `s_edge` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. `s_cop` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-coparent triples. `s_grd` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-grandparent triples. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask to avoid aggregation on padding tokens. target (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. A Tensor of gold-standard dependent-head pairs. Default: ``None``. Returns: ~torch.Tensor, ~torch.Tensor: The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. """ logits = self.mfvi(*scores, mask) marginals = logits.sigmoid() if target is None: return marginals loss = F.binary_cross_entropy_with_logits(logits[mask], target[mask].float()) return loss, marginals
def mfvi(self, s_edge, s_sib, s_cop, s_grd, mask): _, seq_len, _ = mask.shape hs, ms = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len) # [seq_len, seq_len, batch_size], (h->m) mask = mask.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) mask2o = mask2o & hs.unsqueeze(-1).ne(hs.new_tensor(range(seq_len))).unsqueeze(-1) mask2o = mask2o & ms.unsqueeze(-1).ne(ms.new_tensor(range(seq_len))).unsqueeze(-1) mask2o.diagonal().fill_(0) # [seq_len, seq_len, batch_size], (h->m) s_edge = s_edge.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) s_sib = s_sib.permute(2, 1, 3, 0) * mask2o # [seq_len, seq_len, seq_len, batch_size], (h->m->c) s_cop = s_cop.permute(2, 1, 3, 0) * mask2o # [seq_len, seq_len, seq_len, batch_size], (h->m->g) s_grd = s_grd.permute(2, 1, 3, 0) * mask2o # posterior distributions # [seq_len, seq_len, batch_size], (h->m) q = s_edge for _ in range(self.max_iter): q = q.sigmoid() # q(ij) = s(ij) + sum(q(ik)s^sib(ij,ik) + q(kj)s^cop(ij,kj) + q(jk)s^grd(ij,jk)), k != i,j q = s_edge + (q.unsqueeze(1) * s_sib + q.transpose(0, 1).unsqueeze(0) * s_cop + q.unsqueeze(0) * s_grd).sum(2) return q.permute(2, 1, 0)
[docs]class SemanticDependencyLBP(nn.Module): r""" Loopy Belief Propagation for approximately calculating marginals of semantic dependency trees :cite:`wang-etal-2019-second`. """ def __init__(self, max_iter: int = 3) -> SemanticDependencyLBP: super().__init__() self.max_iter = max_iter def __repr__(self): return f"{self.__class__.__name__}(max_iter={self.max_iter})"
[docs] @torch.enable_grad() def forward( self, scores: List[torch.Tensor], mask: torch.BoolTensor, target: Optional[torch.LongTensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: scores (~torch.Tensor, ~torch.Tensor): Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`. `s_edge` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. `s_cop` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-coparent triples. `s_grd` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-grandparent triples. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask to avoid aggregation on padding tokens. target (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. A Tensor of gold-standard dependent-head pairs. Default: ``None``. Returns: ~torch.Tensor, ~torch.Tensor: The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. """ logits = self.lbp(*scores, mask) marginals = logits.softmax(-1)[..., 1] if target is None: return marginals loss = F.cross_entropy(logits[mask], target[mask]) return loss, marginals
def lbp(self, s_edge, s_sib, s_cop, s_grd, mask): lens = mask[..., 0].sum(1) _, seq_len, _ = mask.shape hs, ms = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len) # [seq_len, seq_len, batch_size], (h->m) mask = mask.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) mask2o = mask2o & hs.unsqueeze(-1).ne(hs.new_tensor(range(seq_len))).unsqueeze(-1) mask2o = mask2o & ms.unsqueeze(-1).ne(ms.new_tensor(range(seq_len))).unsqueeze(-1) mask2o.diagonal().fill_(0) # [2, seq_len, seq_len, batch_size], (h->m) s_edge = torch.stack((torch.zeros_like(s_edge), s_edge)).permute(0, 3, 2, 1) # [seq_len, seq_len, seq_len, batch_size], (h->m->s) s_sib = s_sib.permute(2, 1, 3, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->c) s_cop = s_cop.permute(2, 1, 3, 0) # [seq_len, seq_len, seq_len, batch_size], (h->m->g) s_grd = s_grd.permute(2, 1, 3, 0) # log beliefs # [2, seq_len, seq_len, batch_size], (h->m) q = s_edge # sibling factor # [2, seq_len, seq_len, seq_len, batch_size], (h->m->s) m_sib = s_sib.new_zeros(2, *mask2o.shape) # coparent factor # [2, seq_len, seq_len, seq_len, batch_size], (h->m->c) m_cop = s_cop.new_zeros(2, *mask2o.shape) # grandparent factor # [2, seq_len, seq_len, seq_len, batch_size], (h->m->g) m_grd = s_grd.new_zeros(2, *mask2o.shape) # tree factor # [2, seq_len, seq_len, batch_size], (h->m) m_tree = torch.zeros_like(s_edge) for _ in range(self.max_iter): # sibling factor v_sib = q.unsqueeze(2) - m_sib m_sib = torch.stack((v_sib.logsumexp(0), torch.stack((v_sib[0], v_sib[1] + s_sib)).logsumexp(0))).log_softmax(0) # coparent factor v_cop = q.transpose(1, 2).unsqueeze(1) - m_cop m_cop = torch.stack((v_cop.logsumexp(0), torch.stack((v_cop[0], v_cop[1] + s_cop)).logsumexp(0))).log_softmax(0) # grandparent factor v_grd = q.unsqueeze(1) - m_grd m_grd = torch.stack((v_grd.logsumexp(0), torch.stack((v_grd[0], v_grd[1] + s_grd)).logsumexp(0))).log_softmax(0) # tree factor v_tree = q - m_tree b_tree = DependencyCRF((v_tree[1] - v_tree[0]).permute(2, 1, 0), lens).marginals.permute(2, 1, 0) b_tree = torch.stack((1 - b_tree, b_tree)) m_tree = (b_tree.clamp(torch.finfo().eps).log() - v_tree).log_softmax(0) # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j q = s_edge + ((m_sib + m_cop + m_grd).transpose(2, 3) * mask2o).sum(3) + m_tree return q.permute(3, 2, 1, 0)