Reparameterization trick

From Wikipedia, the free encyclopedia
Jump to navigation Jump to search

The reparameterization trick (aka "reparameterization gradient estimator") is a technique used in statistical machine learning, particularly in variational inference, variational autoencoders, and stochastic optimization. It allows for the efficient computation of gradients through random variables, enabling the optimization of parametric probability models using stochastic gradient descent, and the variance reduction of estimators.

It was developed in the 1980s in operations research, under the name of "pathwise gradients", or "stochastic gradients".[1][2] Its use in variational inference was proposed in 2013.[3]

Mathematics

[edit | edit source]

Let z be a random variable with distribution qϕ(z), where ϕ is a vector containing the parameters of the distribution.

REINFORCE estimator

[edit | edit source]

Consider an objective function of the form:L(ϕ)=𝔼zqϕ(z)[f(z)]Without the reparameterization trick, estimating the gradient ϕL(ϕ) can be challenging, because the parameter appears in the random variable itself. In more detail, we have to statistically estimate:ϕL(ϕ)=ϕdzqϕ(z)f(z)The REINFORCE estimator, widely used in reinforcement learning and especially policy gradient,[4] uses the following equality:ϕL(ϕ)=dzqϕ(z)ϕ(lnqϕ(z))f(z)=𝔼zqϕ(z)[ϕ(lnqϕ(z))f(z)]This allows the gradient to be estimated:ϕL(ϕ)1Ni=1Nϕ(lnqϕ(zi))f(zi)The REINFORCE estimator has high variance, and many methods were developed to reduce its variance.[5]

Reparameterization estimator

[edit | edit source]

The reparameterization trick expresses z as:z=gϕ(ϵ),ϵp(ϵ)Here, gϕ is a deterministic function parameterized by ϕ, and ϵ is a noise variable drawn from a fixed distribution p(ϵ). This gives:L(ϕ)=𝔼ϵp(ϵ)[f(gϕ(ϵ))]Now, the gradient can be estimated as:ϕL(ϕ)=𝔼ϵp(ϵ)[ϕf(gϕ(ϵ))]1Ni=1Nϕf(gϕ(ϵi))

Examples

[edit | edit source]

For some common distributions, the reparameterization trick takes specific forms:

Normal distribution: For z𝒩(μ,σ2), we can use:z=μ+σϵ,ϵ𝒩(0,1)

Exponential distribution: For zExp(λ), we can use:z=1λlog(ϵ),ϵUniform(0,1)Discrete distribution can be reparameterized by the Gumbel distribution (Gumbel-softmax trick or "concrete distribution").[6]

In general, any distribution that is differentiable with respect to its parameters can be reparameterized by inverting the multivariable CDF function, then apply the implicit method. See [1] for an exposition and application to the Gamma, Beta, Dirichlet, and von Mises distributions.

Applications

[edit | edit source]

Variational autoencoder

[edit | edit source]
File:Reparameterization Trick.png
The scheme of the reparameterization trick. The randomness variable ε is injected into the latent space z as external input. In this way, it is possible to backpropagate the gradient without involving stochastic variable during the update.
File:Reparameterized Variational Autoencoder.png
The scheme of a variational autoencoder after the reparameterization trick.

In Variational Autoencoders (VAEs), the VAE objective function, known as the Evidence Lower Bound (ELBO), is given by:

ELBO(ϕ,θ)=𝔼zqϕ(z|x)[logpθ(x|z)]DKL(qϕ(z|x)||p(z))

where qϕ(z|x) is the encoder (recognition model), pθ(x|z) is the decoder (generative model), and p(z) is the prior distribution over latent variables. The gradient of ELBO with respect to θ is simply𝔼zqϕ(z|x)[θlogpθ(x|z)]1Ll=1Lθlogpθ(x|zl)but the gradient with respect to ϕ requires the trick. Express the sampling operation zqϕ(z|x) as:z=μϕ(x)+σϕ(x)ϵ,ϵ𝒩(0,I)where μϕ(x) and σϕ(x) are the outputs of the encoder network, and denotes element-wise multiplication. Then we haveϕELBO(ϕ,θ)=𝔼ϵ𝒩(0,I)[ϕlogpθ(x|z)+ϕlogqϕ(z|x)ϕlogp(z)]where z=μϕ(x)+σϕ(x)ϵ. This allows us to estimate the gradient using Monte Carlo sampling:ϕELBO(ϕ,θ)1Ll=1L[ϕlogpθ(x|zl)+ϕlogqϕ(zl|x)ϕlogp(zl)]where zl=μϕ(x)+σϕ(x)ϵl and ϵl𝒩(0,I) for l=1,,L.

This formulation enables backpropagation through the sampling process, allowing for end-to-end training of the VAE model using stochastic gradient descent or its variants.

Variational inference

[edit | edit source]

More generally, the trick allows using stochastic gradient descent for variational inference. Let the variational objective (ELBO) be of the form:ELBO(ϕ)=𝔼zqϕ(z)[logp(x,z)logqϕ(z)]Using the reparameterization trick, we can estimate the gradient of this objective with respect to ϕ:ϕELBO(ϕ)1Ll=1Lϕ[logp(x,gϕ(ϵl))logqϕ(gϕ(ϵl))],ϵlp(ϵ)

Dropout

[edit | edit source]

The reparameterization trick has been applied to reduce the variance in dropout, a regularization technique in neural networks. The original dropout can be reparameterized with Bernoulli distributions:y=(Wϵ)x,ϵijBernoulli(αij)where W is the weight matrix, x is the input, and αij are the (fixed) dropout rates.

More generally, other distributions can be used than the Bernoulli distribution, such as the gaussian noise:yi=μi+σiϵi,ϵi𝒩(0,I)where μi=𝐦ix and σi2=𝐯ix2, with 𝐦i and 𝐯i being the mean and variance of the i-th output neuron. The reparameterization trick can be applied to all such cases, resulting in the variational dropout method.[7]

See also

[edit | edit source]

References

[edit | edit source]
  1. ^ a b Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).
  2. ^ Fu, Michael C. "Gradient estimation." Handbooks in operations research and management science 13 (2006): 575-616.
  3. ^ Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).
  4. ^ Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).
  5. ^ Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).
  6. ^ Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).
  7. ^ Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).

Further reading

[edit | edit source]
  • Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).
  • Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).
  • Lua error in Module:Citation/CS1/Configuration at line 2172: attempt to index field '?' (a nil value).