# -*- coding: utf-8 -*-
import itertools
from functools import reduce
from typing import Iterable
import torch
from supar.structs.fn import sampled_logsumexp, sparsemax
from supar.utils.common import MIN
[docs]class Semiring(object):
r"""
Base semiring class :cite:`goodman-1999-semiring`.
A semiring is defined by a tuple :math:`<K, \oplus, \otimes, \mathbf{0}, \mathbf{1}>`.
:math:`K` is a set of values;
:math:`\oplus` is commutative, associative and has an identity element `0`;
:math:`\otimes` is associative, has an identity element `1` and distributes over `+`.
"""
zero = 0
one = 1
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
@classmethod
def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.sum(dim)
@classmethod
def prod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.prod(dim)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.cumsum(dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.cumprod(dim)
@classmethod
def dot(cls, x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor:
return cls.sum(cls.mul(x, y), dim)
@classmethod
def times(cls, *x: Iterable[torch.Tensor]) -> torch.Tensor:
return reduce(lambda i, j: cls.mul(i, j), x)
@classmethod
def zero_(cls, x: torch.Tensor) -> torch.Tensor:
return x.fill_(cls.zero)
@classmethod
def one_(cls, x: torch.Tensor) -> torch.Tensor:
return x.fill_(cls.one)
@classmethod
def zero_mask(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
return x.masked_fill(mask, cls.zero)
@classmethod
def zero_mask_(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
return x.masked_fill_(mask, cls.zero)
@classmethod
def one_mask(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
return x.masked_fill(mask, cls.one)
@classmethod
def one_mask_(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
return x.masked_fill_(mask, cls.one)
@classmethod
def zeros_like(cls, x: torch.Tensor) -> torch.Tensor:
return x.new_full(x.shape, cls.zero)
@classmethod
def ones_like(cls, x: torch.Tensor) -> torch.Tensor:
return x.new_full(x.shape, cls.one)
@classmethod
def convert(cls, x: torch.Tensor) -> torch.Tensor:
return x
@classmethod
def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
return x
[docs]class LogSemiring(Semiring):
r"""
Log-space semiring :math:`<\mathrm{logsumexp}, +, -\infty, 0>`.
"""
zero = MIN
one = 0
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x.logaddexp(y)
@classmethod
def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.logsumexp(dim)
@classmethod
def prod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.sum(dim)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.logcumsumexp(dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.cumsum(dim)
[docs]class MaxSemiring(LogSemiring):
r"""
Max semiring :math:`<\mathrm{max}, +, -\infty, 0>`.
"""
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x.max(y)
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.max(dim)[0]
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.cummax(dim)
[docs]def KMaxSemiring(k):
r"""
k-max semiring :math:`<\mathrm{kmax}, +, [-\infty, -\infty, \dots], [0, -\infty, \dots]>`.
"""
class KMaxSemiring(LogSemiring):
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x.unsqueeze(-1).max(y.unsqueeze(-2)).flatten(-2).topk(k, -1)[0]
@classmethod
def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return (x.unsqueeze(-1) + y.unsqueeze(-2)).flatten(-2).topk(k, -1)[0]
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.movedim(dim, -1).flatten(-2).topk(k, -1)[0]
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
@classmethod
def one_(cls, x: torch.Tensor) -> torch.Tensor:
x[..., :1].fill_(cls.one)
x[..., 1:].fill_(cls.zero)
return x
@classmethod
def convert(cls, x: torch.Tensor) -> torch.Tensor:
return torch.cat((x.unsqueeze(-1), cls.zero_(x.new_empty(*x.shape, k - 1))), -1)
return KMaxSemiring
[docs]class ExpectationSemiring(Semiring):
r"""
Expectation semiring :math:`<\oplus, +, [0, 0], [1, 0]>` :cite:`li-eisner-2009-first`.
Practical Applications: :math:`H(p) = \log Z - \frac{1}{Z}\sum_{d \in D} p(d) r(d)`.
"""
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
@classmethod
def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.stack((x[..., 0] * y[..., 0], x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]), -1)
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x.sum(dim)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
@classmethod
def zero_(cls, x: torch.Tensor) -> torch.Tensor:
return x.fill_(cls.zero)
@classmethod
def one_(cls, x: torch.Tensor) -> torch.Tensor:
x[..., 0].fill_(cls.one)
x[..., 1].fill_(cls.zero)
return x
[docs]class EntropySemiring(LogSemiring):
r"""
Entropy expectation semiring :math:`<\oplus, +, [-\infty, 0], [0, 0]>`,
where :math:`\oplus` computes the log-values and the running distributional entropy :math:`H[p]`
:cite:`li-eisner-2009-first,hwa-2000-sample,kim-etal-2019-unsupervised`.
"""
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return cls.sum(torch.stack((x, y)), 0)
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
p = x[..., 0].logsumexp(dim)
r = x[..., 0] - p.unsqueeze(dim)
r = r.exp().mul((x[..., 1] - r)).sum(dim)
return torch.stack((p, r), -1)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
@classmethod
def zero_(cls, x: torch.Tensor) -> torch.Tensor:
x[..., 0].fill_(cls.zero)
x[..., 1].fill_(cls.one)
return x
@classmethod
def one_(cls, x: torch.Tensor) -> torch.Tensor:
return x.fill_(cls.one)
@classmethod
def convert(cls, x: torch.Tensor) -> torch.Tensor:
return torch.stack((x, cls.ones_like(x)), -1)
@classmethod
def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
return x[..., 1]
[docs]class CrossEntropySemiring(LogSemiring):
r"""
Cross Entropy expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
where :math:`\oplus` computes the log-values and the running distributional cross entropy :math:`H[p,q]`
of the two distributions :cite:`li-eisner-2009-first`.
"""
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return cls.sum(torch.stack((x, y)), 0)
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
p = x[..., :-1].logsumexp(dim)
r = x[..., :-1] - p.unsqueeze(dim)
r = r[..., 0].exp().mul((x[..., -1] - r[..., 1])).sum(dim)
return torch.cat((p, r.unsqueeze(-1)), -1)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
@classmethod
def zero_(cls, x: torch.Tensor) -> torch.Tensor:
x[..., :-1].fill_(cls.zero)
x[..., -1].fill_(cls.one)
return x
@classmethod
def one_(cls, x: torch.Tensor) -> torch.Tensor:
return x.fill_(cls.one)
@classmethod
def convert(cls, x: torch.Tensor) -> torch.Tensor:
return torch.cat((x, cls.one_(torch.empty_like(x[..., :1]))), -1)
@classmethod
def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
return x[..., -1]
[docs]class KLDivergenceSemiring(LogSemiring):
r"""
KL divergence expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
where :math:`\oplus` computes the log-values and the running distributional KL divergence :math:`KL[p \parallel q]`
of the two distributions :cite:`li-eisner-2009-first`.
"""
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return cls.sum(torch.stack((x, y)), 0)
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
p = x[..., :-1].logsumexp(dim)
r = x[..., :-1] - p.unsqueeze(dim)
r = r[..., 0].exp().mul((x[..., -1] - r[..., 1] + r[..., 0])).sum(dim)
return torch.cat((p, r.unsqueeze(-1)), -1)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
@classmethod
def zero_(cls, x: torch.Tensor) -> torch.Tensor:
x[..., :-1].fill_(cls.zero)
x[..., -1].fill_(cls.one)
return x
@classmethod
def one_(cls, x: torch.Tensor) -> torch.Tensor:
return x.fill_(cls.one)
@classmethod
def convert(cls, x: torch.Tensor) -> torch.Tensor:
return torch.cat((x, cls.one_(torch.empty_like(x[..., :1]))), -1)
@classmethod
def unconvert(cls, x: torch.Tensor) -> torch.Tensor:
return x[..., -1]
[docs]class SampledSemiring(LogSemiring):
r"""
Sampling semiring :math:`<\mathrm{logsumexp}, +, -\infty, 0>`,
which is an exact forward-filtering, backward-sampling approach.
"""
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return cls.sum(torch.stack((x, y)), 0)
@classmethod
def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return sampled_logsumexp(x, dim)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)
[docs]class SparsemaxSemiring(LogSemiring):
r"""
Sparsemax semiring :math:`<\mathrm{sparsemax}, +, -\infty, 0>`
:cite:`martins-etal-2016-sparsemax,mensch-etal-2018-dp,correia-etal-2020-efficient`.
"""
@classmethod
def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return cls.sum(torch.stack((x, y)), 0)
@staticmethod
def sum(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
p = sparsemax(x, dim)
return x.mul(p).sum(dim) - p.norm(p=2, dim=dim)
@classmethod
def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim)
@classmethod
def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim)