# Differientiable Sampling and Argmax

WIP, last updated: 2019.12.6

## Introduction

**Softmax** is a commonly used function for turning an **unnormalized log probability** into a normalized probability (or **categorical distribution**).

Say $\mathbf{o}$ is the output of a neural network before softmax, we call $\mathbf{o}$ the **unnormalized log probability.**

After softmax, we usually **sample** from this categorical distribution, or taking an **argmax** function to select the index. However, one can notice that neither the **sampling** nor the **argmax** is **differientiable**.

Researchers have proposed several works to make this possible. I am going to discuss them here.

## Sampling

I will introduce Gumbel Softmax [1611.01144], which have made the **sampling** procedure differentiable.

### Gumbel Max

First, we need to introduce **Gumbel Max**. In short, Gumbel Max is a trick to use gumbel distribution to sample a categorical distribution.

Say we want to sample from a categorical distribution $\mathbf{\pi}$. The usual way of doing this is using $\pi$ to separate $[0, 1]$ into intervals, sampling from a uniform distribution $\text{U} \sim[0, 1]$, and see where it locates.

The Gumbel Max trick provides an alternative way of doing this. It use **Reparameterization Trick** to avoid the stochastic node during backpropagation.

where $g_i \sim \text{Gumbel}(0, 1)$, which can be sampled by $-\log(-\log(\text{Uniform}[0, 1]))$. We can prove that $y$ is distributed according to $\mathbf{\pi}$.

**Prove that**

**Prove that**

$y = \arg \max_{i} (o_i +g_i)$, where $g_i \sim \text{Gumbel}(0, 1)$ which can be sampled by $-\log(-\log(\text{Uniform}[0, 1]))$ is distributed with $\pi = \text{softmax}(o_i) = \frac{e^{o_i}}{\sum{j} e^{o_j}}$

**Prerequisites**

**Prerequisites**

**Gumbel Distribution** (param by location ****$\mu$, and scale $\beta>0$) (wikipedia)
**CDF:** $F(x; \mu, \beta) = e^{-e^{(x-\mu)/\beta}}$
**PDF:** $f(x; \mu, \beta) = \frac{1}{\beta} e^{-(z+e^{-z})}, z = \frac{x-\mu}{\beta}$
**Mean:** $\text{E}(X) = \mu+\gamma\beta, \gamma \approx 0.5772$is the Euler–Mascheroni constant.
**Quantile Function:** $Q(p) = \mu-\beta \log(-\log(p))$(Quantile Function is used to sample random variables from a distribution given CDF, it is also called inverse CDF)

#### Proof

We actually want to prove that $\text{Gumbel}(\mu=o_i, \beta=1)$ is distributed with $\pi_i = \frac{e^{o_i}}{\sum_{j} e^{o_j}}$.

We can find that $\text{Gumbel}(\mu=o_i, \beta=1)$ has the following PDF and CDF

.Then, the probability that all other $\pi_{j \neq i}$ are less than $\pi_i$ is:

We know the marginal distribution over $\pi_i$ and we are able to integrate it out to find the overall probability: ($p(x) = \int_y p(x,y) dy = \int_y p(x|y) p(y) dy$)

which is exactly a softmax probablity. QED.

**Reference:** **https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/******

### Gumbel Softmax

Notice that there is still an argmax in Gumbel Max, which still makes it indifferentiable. Therefore, we use a softmax function to approximate this argmax procedure.

where $\tau \in (0, \infty)$ is a temparature hyperparameter.

We note that the output of Gumbel Softmax function here is a vector which sum to 1, which somewhat looks like a one-hot vector (but it's not). So by far, this does not actually replace the argmax function.

To actually get a pure one-hot vector, we need to use a **Straight-Through (ST) Gumbel Trick**.
Let's directly see an implementation of Gumbel Softmax in PyTorch
(We use the hard mode, soft mode does not get a pure one-hot vector).

When fowarding, the code use an argmax to get an actual one-hot vector.
And it uses `ret = y_hard - y_soft.detach() + y_soft`

, `y_hard`

has no grad, and by minusing `y_soft.detach()`

and adding `y_soft`

, it achieves a grad from `y_soft`

without modifying the forwarding value.

So eventually, we are able to get a pure one-hot vector in forward pass, and a grad when back propagating, which **makes the sampling procedure differientiable**.

Finally, let's look at how $\tau$affects the sampling procedure. The below image shows the sampling distribution (which is also called the Concrete Distribution [1611.00712]) and one random sample instance when using different hyperparameter $\tau$.

when $\tau \rightarrow 0$, the softmax becomes an argmax and the Gumbel-Softmax distribution becomes the categorical distribution. During training, we let $\tau > 0$ to allow gradients past the sample, then gradually anneal the temperature $\tau$ (but not completely to 0, as the gradients would blow up).

from Eric Jang. https://blog.evjang.com/2016/11/tutorial-categorical-variational.html

## Argmax

How to make argmax differentiable?

Intuitively, the **Straight-Through Trick** is also applicable for softmax+argmax (or softargmax + argmax).
I am still not sure, needs more digging in the literature.

Some have introduced the soft-argmax function. It doesn't actually makes it differentiable, but use a continuous function to approximate the softmax+argmax procedure.

where $\beta$ can be a large value to make $\mathbf{\pi}$ very much "look like" a one-hot vector.

## Discussion

Goal

**softmax + argmax**is used for classification, we only want the index with the highest probability.**gumbel softmax + argmax**is used for sampling, we may want to sample an index not with the highest probability.

Deterministic

**softmax + argmax**is deterministic. Get the index with the highest probablity.**gumbel softmax + argmax**is stochastic. We need to sample from a gumbel distribution in the beginning.

Output vector

**softmax**and**gumbel softmax**aboth output a vector sum to 1.**softmax**outputs a*normalized probability distribution*.**gumbel softmax**outputs a*sample*somewhat more similar to a one-hot vector.(can be controlled by $\tau$)

**Straight-Through Trick**can actually be applied to both**softmax + argmax**and**gumbel softmax + argmax**, which can make both of them differentiable. (?)

## Reference

Gumbel Softmax [1611.01144]

Concrete Distribution (Gumbel Softmax Distribution) [1611.00712]

Eric Jang official blog: https://blog.evjang.com/2016/11/tutorial-categorical-variational.html

PyTorch Implementation of Gumbel Softmax: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.gumbel_softmax

Last updated