# Differientiable Sampling and Argmax

WIP, last updated: 2019.12.6

Last updated

WIP, last updated: 2019.12.6

Last updated

Introduction

**Softmax** is a commonly used function for turning an **unnormalized log probability** into a normalized probability (or **categorical distribution**).

$\mathbf{\pi} = \text{softmax}(\mathbf{o}) = \frac{e^{\mathbf{o}}}{\sum_{j} e^{o_j}},\\o_j \in (-\infty, +\infty)$

Say $\mathbf{o}$ is the output of a neural network before softmax, we call $\mathbf{o}$ the **unnormalized log probability.**

After softmax, we usually **sample** from this categorical distribution, or taking an **argmax** function to select the index. However, one can notice that neither the **sampling** nor the **argmax** is **differientiable**.

Researchers have proposed several works to make this possible. I am going to discuss them here.

Sampling

I will introduce Gumbel Softmax [1611.01144], which have made the **sampling** procedure differentiable.

Gumbel Max

First, we need to introduce **Gumbel Max**. In short, Gumbel Max is a trick to use gumbel distribution to sample a categorical distribution.

Say we want to sample from a categorical distribution $\mathbf{\pi}$. The usual way of doing this is using $\pi$ to separate $[0, 1]$ into intervals, sampling from a uniform distribution $\text{U} \sim[0, 1]$, and see where it locates.

The Gumbel Max trick provides an alternative way of doing this. It use **Reparameterization Trick** to avoid the stochastic node during backpropagation.

Proof

which is exactly a softmax probablity. QED.

**Reference:** **https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/******

Gumbel Softmax

Notice that there is still an argmax in Gumbel Max, which still makes it indifferentiable. Therefore, we use a softmax function to approximate this argmax procedure.

We note that the output of Gumbel Softmax function here is a vector which sum to 1, which somewhat looks like a one-hot vector (but it's not). So by far, this does not actually replace the argmax function.

To actually get a pure one-hot vector, we need to use a **Straight-Through (ST) Gumbel Trick**.
Let's directly see an implementation of Gumbel Softmax in PyTorch
(We use the hard mode, soft mode does not get a pure one-hot vector).

When fowarding, the code use an argmax to get an actual one-hot vector.
And it uses `ret = y_hard - y_soft.detach() + y_soft`

, `y_hard`

has no grad, and by minusing `y_soft.detach()`

and adding `y_soft`

, it achieves a grad from `y_soft`

without modifying the forwarding value.

So eventually, we are able to get a pure one-hot vector in forward pass, and a grad when back propagating, which **makes the sampling procedure differientiable**.

from Eric Jang. https://blog.evjang.com/2016/11/tutorial-categorical-variational.html

Argmax

How to make argmax differentiable?

Intuitively, the **Straight-Through Trick** is also applicable for softmax+argmax (or softargmax + argmax).
I am still not sure, needs more digging in the literature.

Some have introduced the soft-argmax function. It doesn't actually makes it differentiable, but use a continuous function to approximate the softmax+argmax procedure.

Discussion

Goal

**softmax + argmax**is used for classification, we only want the index with the highest probability.**gumbel softmax + argmax**is used for sampling, we may want to sample an index not with the highest probability.

Deterministic

**softmax + argmax**is deterministic. Get the index with the highest probablity.**gumbel softmax + argmax**is stochastic. We need to sample from a gumbel distribution in the beginning.

Output vector

**softmax**and**gumbel softmax**aboth output a vector sum to 1.**softmax**outputs a*normalized probability distribution*.

**Straight-Through Trick**can actually be applied to both**softmax + argmax**and**gumbel softmax + argmax**, which can make both of them differentiable. (?)

Reference

Gumbel Softmax [1611.01144]

Concrete Distribution (Gumbel Softmax Distribution) [1611.00712]

Eric Jang official blog: https://blog.evjang.com/2016/11/tutorial-categorical-variational.html

PyTorch Implementation of Gumbel Softmax: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.gumbel_softmax

$y = \arg \max_{i} (o_i +g_i)$

where $g_i \sim \text{Gumbel}(0, 1)$, which can be sampled by $-\log(-\log(\text{Uniform}[0, 1]))$. We can prove that $y$ is distributed according to $\mathbf{\pi}$.

$y = \arg \max_{i} (o_i +g_i)$, where $g_i \sim \text{Gumbel}(0, 1)$ which can be sampled by $-\log(-\log(\text{Uniform}[0, 1]))$ is distributed with $\pi = \text{softmax}(o_i) = \frac{e^{o_i}}{\sum{j} e^{o_j}}$

**Gumbel Distribution** (param by location ****$\mu$, and scale $\beta>0$) (wikipedia)
**CDF:** $F(x; \mu, \beta) = e^{-e^{(x-\mu)/\beta}}$
**PDF:** $f(x; \mu, \beta) = \frac{1}{\beta} e^{-(z+e^{-z})}, z = \frac{x-\mu}{\beta}$
**Mean:** $\text{E}(X) = \mu+\gamma\beta, \gamma \approx 0.5772$is the Euler–Mascheroni constant.
**Quantile Function:** $Q(p) = \mu-\beta \log(-\log(p))$(Quantile Function is used to sample random variables from a distribution given CDF, it is also called inverse CDF)

We actually want to prove that $\text{Gumbel}(\mu=o_i, \beta=1)$ is distributed with $\pi_i = \frac{e^{o_i}}{\sum_{j} e^{o_j}}$.

We can find that $\text{Gumbel}(\mu=o_i, \beta=1)$ has the following PDF and CDF

$\begin{align} f(x; \mu, 1) &= e^{-(x-\mu) – e^{-(x-\mu)}}\\ F(x; \mu, 1) &= e^{-e^{-(x-\mu)}} \end{align}$

.Then, the probability that all other $\pi_{j \neq i}$ are less than $\pi_i$ is:

$\Pr(\pi_i ~\text{is the largest} | \pi_i, \{o_{j}\}) = \prod_{j \neq i} e^{-e^{-(\pi_i - o_j)}}$

We know the marginal distribution over $\pi_i$ and we are able to integrate it out to find the overall probability: ($p(x) = \int_y p(x,y) dy = \int_y p(x|y) p(y) dy$)

$\begin{align} \Pr(\text{$i$ is largest}|\{o_{j}\}) &= \int e^{-(\pi_i-o_i)-e^{-(\pi_i-o_i)}} \times \prod_{j\neq i}e^{-e^{-(\pi_i-o_j)}} \mathrm{d}\pi_i \\ &=\int e^{-\pi_i + o_i -e^{-\pi_i} \sum_{j} e^{o_j}}\mathrm{d}\pi_i \\ &=\frac{e^{o_i}}{\sum_{j}e^{o_j}} \end{align}$

$\mathbf{y} = \frac{e^{(o_i+g_i) / \tau}}{\sum_{j}e^{(o_j+g_j) / \tau}}$

where $\tau \in (0, \infty)$ is a temparature hyperparameter.

Finally, let's look at how $\tau$affects the sampling procedure. The below image shows the sampling distribution (which is also called the Concrete Distribution [1611.00712]) and one random sample instance when using different hyperparameter $\tau$.

when $\tau \rightarrow 0$, the softmax becomes an argmax and the Gumbel-Softmax distribution becomes the categorical distribution. During training, we let $\tau > 0$ to allow gradients past the sample, then gradually anneal the temperature $\tau$ (but not completely to 0, as the gradients would blow up).

$\mathbf{\pi} = \text{soft-argmax}(\mathbf{o}) = \frac{e^{\mathbf{\beta o}}}{\sum_{j} e^{\beta o_j}}$

where $\beta$ can be a large value to make $\mathbf{\pi}$ very much "look like" a one-hot vector.

**gumbel softmax** outputs a *sample* somewhat more similar to a one-hot vector.(can be controlled by $\tau$)