πŸ“
Tzu-Heng's wiki
  • Tzu-Heng's wiki
  • Machine Learning
    • Traditionals
    • Deep Learning
    • Image Classification (CNN)
    • Detection
    • Semantic Segmentation
    • Generative Adversarial Networks
    • Style Transfer
    • Recommender Systems
    • Meta Learning
  • Notes
    • Differientiable Sampling and Argmax
    • GAN theory
    • Multi-task Learning (MTL)
    • Disentanglement in GANs
    • CNN practical notes
    • 3D Clothes
    • OpenGL
    • Generative Art
    • nginx usage
    • Deploy Deep Learning Models
    • Character Motion Synthesis
  • Data Structure & Algorithms
    • Sorting Algorithms
Powered by GitBook
On this page
  • Introduction
  • Sampling
  • Gumbel Max
  • Gumbel Softmax
  • Argmax
  • Discussion
  • Reference

Was this helpful?

  1. Notes

Differientiable Sampling and Argmax

WIP, last updated: 2019.12.6

PreviousMeta LearningNextGAN theory

Last updated 3 years ago

Was this helpful?

Introduction

Softmax is a commonly used function for turning an unnormalized log probability into a normalized probability (or categorical distribution).

Ο€=softmax(o)=eoβˆ‘jeoj,oj∈(βˆ’βˆž,+∞)\mathbf{\pi} = \text{softmax}(\mathbf{o}) = \frac{e^{\mathbf{o}}}{\sum_{j} e^{o_j}},\\o_j \in (-\infty, +\infty)Ο€=softmax(o)=βˆ‘j​eoj​eo​,ojβ€‹βˆˆ(βˆ’βˆž,+∞)

Say o\mathbf{o}o is the output of a neural network before softmax, we call o\mathbf{o}o the unnormalized log probability.

After softmax, we usually sample from this categorical distribution, or taking an argmax function to select the index. However, one can notice that neither the sampling nor the argmax is differientiable.

Researchers have proposed several works to make this possible. I am going to discuss them here.

Sampling

I will introduce Gumbel Softmax , which have made the sampling procedure differentiable.

Gumbel Max

First, we need to introduce Gumbel Max. In short, Gumbel Max is a trick to use gumbel distribution to sample a categorical distribution.

Say we want to sample from a categorical distribution Ο€\mathbf{\pi}Ο€. The usual way of doing this is using Ο€\piΟ€ to separate [0,1][0, 1][0,1] into intervals, sampling from a uniform distribution U∼[0,1]\text{U} \sim[0, 1]U∼[0,1], and see where it locates.

The Gumbel Max trick provides an alternative way of doing this. It use Reparameterization Trick to avoid the stochastic node during backpropagation.

Prove that

Prerequisites

Proof

which is exactly a softmax probablity. QED.

Gumbel Softmax

Notice that there is still an argmax in Gumbel Max, which still makes it indifferentiable. Therefore, we use a softmax function to approximate this argmax procedure.

We note that the output of Gumbel Softmax function here is a vector which sum to 1, which somewhat looks like a one-hot vector (but it's not). So by far, this does not actually replace the argmax function.

def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
    # type: (Tensor, float, bool, float, int) -> Tensor
    r"""
    Samples from the Gumbel-Softmax distribution (`Link 1`_  `Link 2`_) and optionally discretizes.

    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized as one-hot vectors,
            but will be differentiated as if it is the soft sample in autograd
      dim (int): A dimension along which softmax will be computed. Default: -1.

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
      If ``hard=True``, the returned samples will be one-hot, otherwise they will
      be probability distributions that sum to 1 across `dim`.

    .. note::
      This function is here for legacy reasons, may be removed from nn.Functional in the future.

    .. note::
      The main trick for `hard` is to do  `y_hard - y_soft.detach() + y_soft`

      It achieves two things:
      - makes the output value exactly one-hot
      (since we add then subtract y_soft value)
      - makes the gradient equal to y_soft gradient
      (since we strip all other gradients)

    Examples::
        >>> logits = torch.randn(20, 32)
        >>> # Sample soft categorical using reparametrization trick:
        >>> F.gumbel_softmax(logits, tau=1, hard=False)
        >>> # Sample hard categorical using "Straight-through" trick:
        >>> F.gumbel_softmax(logits, tau=1, hard=True)

    .. _Link 1:
        https://arxiv.org/abs/1611.00712
    .. _Link 2:
        https://arxiv.org/abs/1611.01144
    """
    if eps != 1e-10:
        warnings.warn("`eps` parameter is deprecated and has no effect.")

    gumbels = -torch.empty_like(logits).exponential_().log()  # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

When fowarding, the code use an argmax to get an actual one-hot vector. And it uses ret = y_hard - y_soft.detach() + y_soft, y_hard has no grad, and by minusing y_soft.detach() and adding y_soft, it achieves a grad from y_soft without modifying the forwarding value.

So eventually, we are able to get a pure one-hot vector in forward pass, and a grad when back propagating, which makes the sampling procedure differientiable.

Argmax

How to make argmax differentiable?

Intuitively, the Straight-Through Trick is also applicable for softmax+argmax (or softargmax + argmax). I am still not sure, needs more digging in the literature.

Some have introduced the soft-argmax function. It doesn't actually makes it differentiable, but use a continuous function to approximate the softmax+argmax procedure.

Discussion

  • Goal

    • softmax + argmax is used for classification, we only want the index with the highest probability.

    • gumbel softmax + argmax is used for sampling, we may want to sample an index not with the highest probability.

  • Deterministic

    • softmax + argmax is deterministic. Get the index with the highest probablity.

    • gumbel softmax + argmax is stochastic. We need to sample from a gumbel distribution in the beginning.

  • Output vector

    • softmax and gumbel softmax aboth output a vector sum to 1.

    • softmax outputs a normalized probability distribution.

  • Straight-Through Trick can actually be applied to both softmax + argmax and gumbel softmax + argmax, which can make both of them differentiable. (?)

Reference

y=arg⁑max⁑i(oi+gi)y = \arg \max_{i} (o_i +g_i)y=argimax​(oi​+gi​)

where gi∼Gumbel(0,1)g_i \sim \text{Gumbel}(0, 1)giβ€‹βˆΌGumbel(0,1), which can be sampled by βˆ’log⁑(βˆ’log⁑(Uniform[0,1]))-\log(-\log(\text{Uniform}[0, 1]))βˆ’log(βˆ’log(Uniform[0,1])). We can prove that yyy is distributed according to Ο€\mathbf{\pi}Ο€.

y=arg⁑max⁑i(oi+gi)y = \arg \max_{i} (o_i +g_i)y=argmaxi​(oi​+gi​), where gi∼Gumbel(0,1)g_i \sim \text{Gumbel}(0, 1)giβ€‹βˆΌGumbel(0,1) which can be sampled by βˆ’log⁑(βˆ’log⁑(Uniform[0,1]))-\log(-\log(\text{Uniform}[0, 1]))βˆ’log(βˆ’log(Uniform[0,1])) is distributed with Ο€=softmax(oi)=eoiβˆ‘jeoj\pi = \text{softmax}(o_i) = \frac{e^{o_i}}{\sum{j} e^{o_j}}Ο€=softmax(oi​)=βˆ‘jeoj​eoi​​

Gumbel Distribution (param by location ****ΞΌ\muΞΌ, and scale Ξ²>0\beta>0Ξ²>0) () CDF: F(x;ΞΌ,Ξ²)=eβˆ’e(xβˆ’ΞΌ)/Ξ²F(x; \mu, \beta) = e^{-e^{(x-\mu)/\beta}}F(x;ΞΌ,Ξ²)=eβˆ’e(xβˆ’ΞΌ)/Ξ² PDF: f(x;ΞΌ,Ξ²)=1Ξ²eβˆ’(z+eβˆ’z),z=xβˆ’ΞΌΞ²f(x; \mu, \beta) = \frac{1}{\beta} e^{-(z+e^{-z})}, z = \frac{x-\mu}{\beta}f(x;ΞΌ,Ξ²)=Ξ²1​eβˆ’(z+eβˆ’z),z=Ξ²xβˆ’ΞΌβ€‹ Mean: E(X)=ΞΌ+Ξ³Ξ²,Ξ³β‰ˆ0.5772\text{E}(X) = \mu+\gamma\beta, \gamma \approx 0.5772E(X)=ΞΌ+Ξ³Ξ²,Ξ³β‰ˆ0.5772is the . Quantile Function: Q(p)=ΞΌβˆ’Ξ²log⁑(βˆ’log⁑(p))Q(p) = \mu-\beta \log(-\log(p))Q(p)=ΞΌβˆ’Ξ²log(βˆ’log(p))( is used to sample random variables from a distribution given CDF, it is also called inverse CDF)

We actually want to prove that Gumbel(ΞΌ=oi,Ξ²=1)\text{Gumbel}(\mu=o_i, \beta=1)Gumbel(ΞΌ=oi​,Ξ²=1) is distributed with Ο€i=eoiβˆ‘jeoj\pi_i = \frac{e^{o_i}}{\sum_{j} e^{o_j}}Ο€i​=βˆ‘j​eoj​eoi​​.

We can find that Gumbel(ΞΌ=oi,Ξ²=1)\text{Gumbel}(\mu=o_i, \beta=1)Gumbel(ΞΌ=oi​,Ξ²=1) has the following PDF and CDF

f(x;ΞΌ,1)=eβˆ’(xβˆ’ΞΌ)–eβˆ’(xβˆ’ΞΌ)F(x;ΞΌ,1)=eβˆ’eβˆ’(xβˆ’ΞΌ)\begin{align} f(x; \mu, 1) &= e^{-(x-\mu) – e^{-(x-\mu)}}\\ F(x; \mu, 1) &= e^{-e^{-(x-\mu)}} \end{align}f(x;ΞΌ,1)F(x;ΞΌ,1)​=eβˆ’(xβˆ’ΞΌ)–eβˆ’(xβˆ’ΞΌ)=eβˆ’eβˆ’(xβˆ’ΞΌ)​​

.Then, the probability that all other Ο€jβ‰ i\pi_{j \neq i}Ο€jξ€ =i​ are less than Ο€i\pi_iΟ€i​ is:

Pr⁑(Ο€iΒ isΒ theΒ largestβˆ£Ο€i,{oj})=∏jβ‰ ieβˆ’eβˆ’(Ο€iβˆ’oj)\Pr(\pi_i ~\text{is the largest} | \pi_i, \{o_{j}\}) = \prod_{j \neq i} e^{-e^{-(\pi_i - o_j)}}Pr(Ο€i​ isΒ theΒ largestβˆ£Ο€i​,{oj​})=jξ€ =iβˆβ€‹eβˆ’eβˆ’(Ο€iβ€‹βˆ’oj​)

We know the marginal distribution over Ο€i\pi_iΟ€i​ and we are able to integrate it out to find the overall probability: (p(x)=∫yp(x,y)dy=∫yp(x∣y)p(y)dyp(x) = \int_y p(x,y) dy = \int_y p(x|y) p(y) dyp(x)=∫y​p(x,y)dy=∫y​p(x∣y)p(y)dy)

Pr⁑(iΒ isΒ largest∣{oj})=∫eβˆ’(Ο€iβˆ’oi)βˆ’eβˆ’(Ο€iβˆ’oi)Γ—βˆjβ‰ ieβˆ’eβˆ’(Ο€iβˆ’oj)dΟ€i=∫eβˆ’Ο€i+oiβˆ’eβˆ’Ο€iβˆ‘jeojdΟ€i=eoiβˆ‘jeoj\begin{align} \Pr(\text{$i$ is largest}|\{o_{j}\}) &= \int e^{-(\pi_i-o_i)-e^{-(\pi_i-o_i)}} \times \prod_{j\neq i}e^{-e^{-(\pi_i-o_j)}} \mathrm{d}\pi_i \\ &=\int e^{-\pi_i + o_i -e^{-\pi_i} \sum_{j} e^{o_j}}\mathrm{d}\pi_i \\ &=\frac{e^{o_i}}{\sum_{j}e^{o_j}} \end{align}Pr(iΒ isΒ largest∣{oj​})​=∫eβˆ’(Ο€iβ€‹βˆ’oi​)βˆ’eβˆ’(Ο€iβ€‹βˆ’oi​)Γ—jξ€ =iβˆβ€‹eβˆ’eβˆ’(Ο€iβ€‹βˆ’oj​)dΟ€i​=∫eβˆ’Ο€i​+oiβ€‹βˆ’eβˆ’Ο€iβ€‹βˆ‘j​eoj​dΟ€i​=βˆ‘j​eoj​eoi​​​​

Reference: ****

y=e(oi+gi)/Ο„βˆ‘je(oj+gj)/Ο„\mathbf{y} = \frac{e^{(o_i+g_i) / \tau}}{\sum_{j}e^{(o_j+g_j) / \tau}}y=βˆ‘j​e(oj​+gj​)/Ο„e(oi​+gi​)/τ​

where Ο„βˆˆ(0,∞)\tau \in (0, \infty)Ο„βˆˆ(0,∞) is a temparature hyperparameter.

To actually get a pure one-hot vector, we need to use a Straight-Through (ST) Gumbel Trick. Let's directly see an (We use the hard mode, soft mode does not get a pure one-hot vector).

Finally, let's look at how Ο„\tauΟ„affects the sampling procedure. The below image shows the sampling distribution (which is also called the Concrete Distribution ) and one random sample instance when using different hyperparameter Ο„\tauΟ„.

when Ο„β†’0\tau \rightarrow 0Ο„β†’0, the softmax becomes an argmax and the Gumbel-Softmax distribution becomes the categorical distribution. During training, we let Ο„>0\tau > 0Ο„>0 to allow gradients past the sample, then gradually anneal the temperature Ο„\tauΟ„ (but not completely to 0, as the gradients would blow up).

from Eric Jang.

Ο€=soft-argmax(o)=eΞ²oβˆ‘jeΞ²oj\mathbf{\pi} = \text{soft-argmax}(\mathbf{o}) = \frac{e^{\mathbf{\beta o}}}{\sum_{j} e^{\beta o_j}}Ο€=soft-argmax(o)=βˆ‘j​eΞ²oj​eΞ²o​

where Ξ²\betaΞ² can be a large value to make Ο€\mathbf{\pi}Ο€ very much "look like" a one-hot vector.

gumbel softmax outputs a sample somewhat more similar to a one-hot vector.(can be controlled by Ο„\tauΟ„)

Gumbel Softmax

Concrete Distribution (Gumbel Softmax Distribution)

Eric Jang official blog:

PyTorch Implementation of Gumbel Softmax:

https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/
implementation of Gumbel Softmax in PyTorch
https://blog.evjang.com/2016/11/tutorial-categorical-variational.html
[1611.01144]
[1611.00712]
https://blog.evjang.com/2016/11/tutorial-categorical-variational.html
https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.gumbel_softmax
https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/
[1611.01144]
wikipedia
Euler–Mascheroni constant
Quantile Function
[1611.00712]
image from https://arxiv.org/abs/1611.01144