Differientiable Sampling and Argmax
WIP, last updated: 2019.12.6
Introduction
Softmax is a commonly used function for turning an unnormalized log probability into a normalized probability (or categorical distribution).
After softmax, we usually sample from this categorical distribution, or taking an argmax function to select the index. However, one can notice that neither the sampling nor the argmax is differientiable.
Researchers have proposed several works to make this possible. I am going to discuss them here.
Sampling
I will introduce Gumbel Softmax [1611.01144], which have made the sampling procedure differentiable.
Gumbel Max
First, we need to introduce Gumbel Max. In short, Gumbel Max is a trick to use gumbel distribution to sample a categorical distribution.
The Gumbel Max trick provides an alternative way of doing this. It use Reparameterization Trick to avoid the stochastic node during backpropagation.
Prove that
Prerequisites
Proof
which is exactly a softmax probablity. QED.
Reference: https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/****
Gumbel Softmax
Notice that there is still an argmax in Gumbel Max, which still makes it indifferentiable. Therefore, we use a softmax function to approximate this argmax procedure.
We note that the output of Gumbel Softmax function here is a vector which sum to 1, which somewhat looks like a one-hot vector (but it's not). So by far, this does not actually replace the argmax function.
To actually get a pure one-hot vector, we need to use a Straight-Through (ST) Gumbel Trick. Let's directly see an implementation of Gumbel Softmax in PyTorch (We use the hard mode, soft mode does not get a pure one-hot vector).
When fowarding, the code use an argmax to get an actual one-hot vector.
And it uses ret = y_hard - y_soft.detach() + y_soft
, y_hard
has no grad, and by minusing y_soft.detach()
and adding y_soft
, it achieves a grad from y_soft
without modifying the forwarding value.
So eventually, we are able to get a pure one-hot vector in forward pass, and a grad when back propagating, which makes the sampling procedure differientiable.
from Eric Jang. https://blog.evjang.com/2016/11/tutorial-categorical-variational.html
Argmax
How to make argmax differentiable?
Intuitively, the Straight-Through Trick is also applicable for softmax+argmax (or softargmax + argmax). I am still not sure, needs more digging in the literature.
Some have introduced the soft-argmax function. It doesn't actually makes it differentiable, but use a continuous function to approximate the softmax+argmax procedure.
Discussion
Goal
softmax + argmax is used for classification, we only want the index with the highest probability.
gumbel softmax + argmax is used for sampling, we may want to sample an index not with the highest probability.
Deterministic
softmax + argmax is deterministic. Get the index with the highest probablity.
gumbel softmax + argmax is stochastic. We need to sample from a gumbel distribution in the beginning.
Output vector
softmax and gumbel softmax aboth output a vector sum to 1.
softmax outputs a normalized probability distribution.
Straight-Through Trick can actually be applied to both softmax + argmax and gumbel softmax + argmax, which can make both of them differentiable. (?)
Reference
Gumbel Softmax [1611.01144]
Concrete Distribution (Gumbel Softmax Distribution) [1611.00712]
Eric Jang official blog: https://blog.evjang.com/2016/11/tutorial-categorical-variational.html
PyTorch Implementation of Gumbel Softmax: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.gumbel_softmax
Last updated