# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import List
import torch
import torch.nn as nn
class TokenDropout(nn.Module):
r"""
:class:`TokenDropout` seeks to randomly zero the vectors of some tokens with the probability of `p`.
Args:
p (float):
The probability of an element to be zeroed. Default: 0.5.
Examples:
>>> batch_size, seq_len, hidden_size = 1, 3, 5
>>> x = torch.ones(batch_size, seq_len, hidden_size)
>>> nn.Dropout()(x)
tensor([[[0., 2., 2., 0., 0.],
[2., 2., 0., 2., 2.],
[2., 2., 2., 2., 0.]]])
>>> TokenDropout()(x)
tensor([[[2., 2., 2., 2., 2.],
[0., 0., 0., 0., 0.],
[2., 2., 2., 2., 2.]]])
"""
def __init__(self, p: float = 0.5) -> TokenDropout:
super().__init__()
self.p = p
def __repr__(self):
return f"{self.__class__.__name__}(p={self.p})"
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Args:
x (~torch.Tensor):
A tensor of any shape.
Returns:
A tensor with the same shape as `x`.
"""
if not self.training:
return x
return x * (x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) / (1 - self.p)).unsqueeze(-1)
[docs]class SharedDropout(nn.Module):
r"""
:class:`SharedDropout` differs from the vanilla dropout strategy in that the dropout mask is shared across one dimension.
Args:
p (float):
The probability of an element to be zeroed. Default: 0.5.
batch_first (bool):
If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``.
Default: ``True``.
Examples:
>>> batch_size, seq_len, hidden_size = 1, 3, 5
>>> x = torch.ones(batch_size, seq_len, hidden_size)
>>> nn.Dropout()(x)
tensor([[[0., 2., 2., 0., 0.],
[2., 2., 0., 2., 2.],
[2., 2., 2., 2., 0.]]])
>>> SharedDropout()(x)
tensor([[[2., 0., 2., 0., 2.],
[2., 0., 2., 0., 2.],
[2., 0., 2., 0., 2.]]])
"""
def __init__(self, p: float = 0.5, batch_first: bool = True) -> SharedDropout:
super().__init__()
self.p = p
self.batch_first = batch_first
def __repr__(self):
s = f"p={self.p}"
if self.batch_first:
s += f", batch_first={self.batch_first}"
return f"{self.__class__.__name__}({s})"
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Args:
x (~torch.Tensor):
A tensor of any shape.
Returns:
A tensor with the same shape as `x`.
"""
if not self.training:
return x
return x * self.get_mask(x[:, 0], self.p).unsqueeze(1) if self.batch_first else self.get_mask(x[0], self.p)
@staticmethod
def get_mask(x: torch.Tensor, p: float) -> torch.FloatTensor:
return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p)
[docs]class IndependentDropout(nn.Module):
r"""
For :math:`N` tensors, they use different dropout masks respectively.
When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate,
and when all of them are dropped together, zeros are returned.
Args:
p (float):
The probability of an element to be zeroed. Default: 0.5.
Examples:
>>> batch_size, seq_len, hidden_size = 1, 3, 5
>>> x, y = torch.ones(batch_size, seq_len, hidden_size), torch.ones(batch_size, seq_len, hidden_size)
>>> x, y = IndependentDropout()(x, y)
>>> x
tensor([[[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.],
[2., 2., 2., 2., 2.]]])
>>> y
tensor([[[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[0., 0., 0., 0., 0.]]])
"""
def __init__(self, p: float = 0.5) -> IndependentDropout:
super().__init__()
self.p = p
def __repr__(self):
return f"{self.__class__.__name__}(p={self.p})"
[docs] def forward(self, *items: List[torch.Tensor]) -> List[torch.Tensor]:
r"""
Args:
items (List[~torch.Tensor]):
A list of tensors that have the same shape except the last dimension.
Returns:
A tensors are of the same shape as `items`.
"""
if not self.training:
return items
masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items]
total = sum(masks)
scale = len(items) / total.max(torch.ones_like(total))
masks = [mask * scale for mask in masks]
return [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)]