In this tutorial, we will use some examples to show what is the gumbel-softmax distribution and how to use it.
Gumbel-Softmax Distribution
It is defined as:
Here:
\(\pi_i\) : a categorical distribution with class \(i\) probability. It must be larger than 0.
\(g_i\): a sample drawn from Gumbel(0, 1)
\(\tau\): a temperatur parameter
To understand the effect of \(\tau\), you can read:
An Explanation of Softmax Function with Hyperparameter
How to Create \(g_i\)?
\(g_i\) is a sample drawn from Gumbel(0, 1). The Gumbel(0, 1) distribution can be sampled using inverse transform sampling by drawing u ∼ Uniform(0, 1) and computing g = − log(− log(u)
Here is an example:
def sample_gumbel(shape, eps=1e-20): U = torch.rand(shape) if is_cuda: U = U.cuda() return -torch.log(-torch.log(U + eps) + eps)
Why Use Gumbel-Softmax
The most impportant reason is argmax() function does not support backprop and gradient operation.
TensorFlow tf.argmax() does not Support Backprop and Gradient Operation – TensorFlow Tutorial
By gumbel softmax, we can get the embedding with the maximum probability.
For example:
from torch import nn import torch.nn.functional as F # model: class MyModel(): def __init__(): self.net_gate = nn.ModuleList( [nn.Sequential( nn.Linear(self.dim_m, self.dim_gate), nn.LeakyReLU(), nn.Linear(self.dim_gate, 1), nn.Sigmoid() ) for _ in range(self.n_mods)] ) def forward(): ... # some operations... # hard=True, return one-hot # gates.shape = [bsz, n_mods, 1] # e.g., gates[i] = [0.4, 0.7, 0.2] gates_onehot = F.gumbel_softmax(gates, tau=0.1, hard=True, dim=-1) # mods_emb: shape=[mod_dim, bsz, 3] # gates_onehot[i] = [0, 1, 0] select_emb = torch.mul(gates_onehot, uni_emb) select_emb = select_emb.sum(dim=-1)
How to use Gumbel-Softmax?
In pytorch, we can use torch.nn.functional.gumbel_softmax(). It is defined as:
torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)
Parameters:
- logits (Tensor) – […, num_features] unnormalized log probabilities.
- tau (float) – non-negative scalar temperature
- hard (bool) – if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
- dim (int) – A dimension along which softmax will be computed. Default: -1.
From the source code, we can find:
\(logits = log(\pi_i)\)
Here is an example:
import torch.nn.functional as F import torch logits = torch.randn(5, 10) # Sample soft categorical using reparametrization trick: x = F.gumbel_softmax(logits, tau=0.1, hard=False, dim=-1) print(x.shape) print(x)
Output:
torch.Size([5, 10]) tensor([[5.2353e-11, 1.6716e-13, 6.1010e-13, 6.5021e-02, 4.5803e-24, 1.6481e-13, 1.1888e-19, 9.3498e-01, 3.5390e-20, 5.6035e-22], [5.2005e-05, 2.2679e-07, 9.9995e-01, 8.2990e-09, 1.3104e-15, 2.1313e-06, 6.7306e-11, 1.5186e-18, 5.4508e-16, 2.3432e-15], [8.5472e-10, 5.6161e-05, 9.1573e-01, 2.4018e-15, 8.4214e-02, 2.5155e-08, 6.9891e-12, 8.9868e-16, 2.6720e-11, 1.0618e-09], [9.0234e-13, 3.9960e-14, 2.0421e-13, 1.0000e+00, 3.0450e-24, 7.6137e-07, 1.9128e-15, 6.1501e-14, 8.3176e-10, 1.6146e-13], [7.0892e-19, 9.7890e-11, 9.7465e-08, 5.2453e-18, 3.5562e-19, 1.0000e+00, 7.0251e-25, 9.6118e-11, 1.9668e-12, 9.9246e-13]]) tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000]) torch.Size([5, 10]) tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])