Source code for supar.structs.dist
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Iterable, Union
import torch
import torch.autograd as autograd
from supar.structs.semiring import (CrossEntropySemiring, EntropySemiring,
KLDivergenceSemiring, KMaxSemiring,
LogSemiring, MaxSemiring, SampledSemiring,
Semiring)
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property
[docs]class StructuredDistribution(Distribution):
r"""
Base class for structured distribution :math:`p(y)` :cite:`eisner-2016-inside,goodman-1999-semiring,li-eisner-2009-first`.
Args:
scores (torch.Tensor):
Log potentials, also for high-order cases.
"""
def __init__(self, scores: torch.Tensor, **kwargs) -> StructuredDistribution:
self.scores = scores.requires_grad_() if isinstance(scores, torch.Tensor) else [s.requires_grad_() for s in scores]
self.kwargs = kwargs
def __repr__(self):
return f"{self.__class__.__name__}()"
def __add__(self, other: 'StructuredDistribution') -> StructuredDistribution:
return self.__class__(torch.stack((self.scores, other.scores), -1), lens=self.lens)
@lazy_property
def log_partition(self):
r"""
Computes the log partition function of the distribution :math:`p(y)`.
"""
return self.forward(LogSemiring)
@lazy_property
def marginals(self):
r"""
Computes marginal probabilities of the distribution :math:`p(y)`.
"""
return self.backward(self.log_partition.sum())
@lazy_property
def max(self):
r"""
Computes the max score of the distribution :math:`p(y)`.
"""
return self.forward(MaxSemiring)
@lazy_property
def argmax(self):
r"""
Computes :math:`\arg\max_y p(y)` of the distribution :math:`p(y)`.
"""
return self.backward(self.max.sum())
@lazy_property
def mode(self):
return self.argmax
[docs] def kmax(self, k: int) -> torch.Tensor:
r"""
Computes the k-max of the distribution :math:`p(y)`.
"""
return self.forward(KMaxSemiring(k))
[docs] def topk(self, k: int) -> Union[torch.Tensor, Iterable]:
r"""
Computes the k-argmax of the distribution :math:`p(y)`.
"""
raise NotImplementedError
[docs] def sample(self):
r"""
Obtains a structured sample from the distribution :math:`y \sim p(y)`.
TODO: multi-sampling.
"""
return self.backward(self.forward(SampledSemiring).sum()).detach()
@lazy_property
def entropy(self):
r"""
Computes entropy :math:`H[p]` of the distribution :math:`p(y)`.
"""
return self.forward(EntropySemiring)
[docs] def cross_entropy(self, other: 'StructuredDistribution') -> torch.Tensor:
r"""
Computes cross-entropy :math:`H[p,q]` of self and another distribution.
Args:
other (~supar.structs.dist.StructuredDistribution): Comparison distribution.
"""
return (self + other).forward(CrossEntropySemiring)
[docs] def kl(self, other: 'StructuredDistribution') -> torch.Tensor:
r"""
Computes KL-divergence :math:`KL[p \parallel q]=H[p,q]-H[p]` of self and another distribution.
Args:
other (~supar.structs.dist.StructuredDistribution): Comparison distribution.
"""
return (self + other).forward(KLDivergenceSemiring)
[docs] def log_prob(self, value: torch.LongTensor, *args, **kwargs) -> torch.Tensor:
"""
Computes log probability over values :math:`p(y)`.
"""
return self.score(value, *args, **kwargs) - self.log_partition
def score(self, value: torch.LongTensor, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
@torch.enable_grad()
def forward(self, semiring: Semiring) -> torch.Tensor:
raise NotImplementedError
def backward(self, log_partition: torch.Tensor) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
grads = autograd.grad(log_partition, self.scores, create_graph=True)
return grads[0] if isinstance(self.scores, torch.Tensor) else grads