# Differientiable Sampling and Argmax

## Introduction

**Softmax** is a commonly used function for turning an **unnormalized log probability** into a normalized probability (or **categorical distribution**).

After softmax, we usually **sample** from this categorical distribution, or taking an **argmax** function to select the index. However, one can notice that neither the **sampling** nor the **argmax** is **differientiable**.

Researchers have proposed several works to make this possible. I am going to discuss them here.

## Sampling

I will introduce Gumbel Softmax [1611.01144], which have made the **sampling** procedure differentiable.

### Gumbel Max

First, we need to introduce **Gumbel Max**. In short, Gumbel Max is a trick to use gumbel distribution to sample a categorical distribution.

The Gumbel Max trick provides an alternative way of doing this. It use **Reparameterization Trick** to avoid the stochastic node during backpropagation.

#### Proof

which is exactly a softmax probablity. QED.

**Reference:** **https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/******

### Gumbel Softmax

Notice that there is still an argmax in Gumbel Max, which still makes it indifferentiable. Therefore, we use a softmax function to approximate this argmax procedure.

We note that the output of Gumbel Softmax function here is a vector which sum to 1, which somewhat looks like a one-hot vector (but it's not). So by far, this does not actually replace the argmax function.

To actually get a pure one-hot vector, we need to use a **Straight-Through (ST) Gumbel Trick**.
Let's directly see an implementation of Gumbel Softmax in PyTorch
(We use the hard mode, soft mode does not get a pure one-hot vector).

When fowarding, the code use an argmax to get an actual one-hot vector.
And it uses `ret = y_hard - y_soft.detach() + y_soft`

, `y_hard`

has no grad, and by minusing `y_soft.detach()`

and adding `y_soft`

, it achieves a grad from `y_soft`

without modifying the forwarding value.

So eventually, we are able to get a pure one-hot vector in forward pass, and a grad when back propagating, which **makes the sampling procedure differientiable**.

from Eric Jang. https://blog.evjang.com/2016/11/tutorial-categorical-variational.html

## Argmax

How to make argmax differentiable?

Intuitively, the **Straight-Through Trick** is also applicable for softmax+argmax (or softargmax + argmax).
I am still not sure, needs more digging in the literature.

Some have introduced the soft-argmax function. It doesn't actually makes it differentiable, but use a continuous function to approximate the softmax+argmax procedure.

## Discussion

Goal

**softmax + argmax**is used for classification, we only want the index with the highest probability.**gumbel softmax + argmax**is used for sampling, we may want to sample an index not with the highest probability.

Deterministic

**softmax + argmax**is deterministic. Get the index with the highest probablity.**gumbel softmax + argmax**is stochastic. We need to sample from a gumbel distribution in the beginning.

Output vector

**softmax**and**gumbel softmax**aboth output a vector sum to 1.**softmax**outputs a*normalized probability distribution*.

**Straight-Through Trick**can actually be applied to both**softmax + argmax**and**gumbel softmax + argmax**, which can make both of them differentiable. (?)

## Reference

Gumbel Softmax [1611.01144]

Concrete Distribution (Gumbel Softmax Distribution) [1611.00712]

Eric Jang official blog: https://blog.evjang.com/2016/11/tutorial-categorical-variational.html

PyTorch Implementation of Gumbel Softmax: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.gumbel_softmax

