Understand Gumbel-Softmax Distribution with Examples

By | December 17, 2024

In this tutorial, we will use some examples to show what is the gumbel-softmax distribution and how to use it.

Gumbel-Softmax Distribution

It is defined as:

Gumbel-Softmax Distribution

Here:

\(\pi_i\) : a categorical distribution with class \(i\) probability. It must be larger than 0.

\(g_i\): a sample drawn from Gumbel(0, 1)

\(\tau\): a temperatur parameter

To understand the effect of \(\tau\), you can read:

An Explanation of Softmax Function with Hyperparameter

LLM Temperature Explained

How to Create \(g_i\)?

\(g_i\) is a sample drawn from Gumbel(0, 1). The Gumbel(0, 1) distribution can be sampled using inverse transform sampling by drawing u ∼ Uniform(0, 1) and computing g = − log(− log(u)

Here is an example:

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    if is_cuda:
        U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)

Why Use Gumbel-Softmax

The most impportant reason is argmax() function does not support backprop and gradient operation.

TensorFlow tf.argmax() does not Support Backprop and Gradient Operation – TensorFlow Tutorial

By gumbel softmax, we can get the embedding with the maximum probability.

For example:

from torch import nn
import torch.nn.functional as F

# model:
class MyModel():
    def __init__():
        self.net_gate = nn.ModuleList(
            [nn.Sequential(
                nn.Linear(self.dim_m, self.dim_gate),
                nn.LeakyReLU(),
                nn.Linear(self.dim_gate, 1),
                nn.Sigmoid()
            ) for _ in range(self.n_mods)]
        )

    def forward():
        ...  # some operations...
        # hard=True, return one-hot
        # gates.shape = [bsz, n_mods, 1]
        # e.g., gates[i] = [0.4, 0.7, 0.2]
        gates_onehot = F.gumbel_softmax(gates, tau=0.1, hard=True, dim=-1) 
        # mods_emb: shape=[mod_dim, bsz, 3]
        # gates_onehot[i] = [0, 1, 0]
        select_emb = torch.mul(gates_onehot, uni_emb)
        select_emb = select_emb.sum(dim=-1)

How to use Gumbel-Softmax?

In pytorch, we can use torch.nn.functional.gumbel_softmax(). It is defined as:

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)

Parameters:

  • logits (Tensor) – […, num_features] unnormalized log probabilities.
  • tau (float) – non-negative scalar temperature
  • hard (bool) – if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
  • dim (int) – A dimension along which softmax will be computed. Default: -1.

pytorch gumbel softmax

From the source code, we can find:

\(logits = log(\pi_i)\)

Here is an example:

import torch.nn.functional as F
import torch
logits = torch.randn(5, 10)

# Sample soft categorical using reparametrization trick:
x = F.gumbel_softmax(logits, tau=0.1, hard=False, dim=-1)
print(x.shape)
print(x)

Output:

torch.Size([5, 10])
tensor([[5.2353e-11, 1.6716e-13, 6.1010e-13, 6.5021e-02, 4.5803e-24, 1.6481e-13,
         1.1888e-19, 9.3498e-01, 3.5390e-20, 5.6035e-22],
        [5.2005e-05, 2.2679e-07, 9.9995e-01, 8.2990e-09, 1.3104e-15, 2.1313e-06,
         6.7306e-11, 1.5186e-18, 5.4508e-16, 2.3432e-15],
        [8.5472e-10, 5.6161e-05, 9.1573e-01, 2.4018e-15, 8.4214e-02, 2.5155e-08,
         6.9891e-12, 8.9868e-16, 2.6720e-11, 1.0618e-09],
        [9.0234e-13, 3.9960e-14, 2.0421e-13, 1.0000e+00, 3.0450e-24, 7.6137e-07,
         1.9128e-15, 6.1501e-14, 8.3176e-10, 1.6146e-13],
        [7.0892e-19, 9.7890e-11, 9.7465e-08, 5.2453e-18, 3.5562e-19, 1.0000e+00,
         7.0251e-25, 9.6118e-11, 1.9668e-12, 9.9246e-13]])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
torch.Size([5, 10])
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])

Leave a Reply