1r""" 2The ``distributions`` package contains parameterizable probability distributions 3and sampling functions. This allows the construction of stochastic computation 4graphs and stochastic gradient estimators for optimization. This package 5generally follows the design of the `TensorFlow Distributions`_ package. 6 7.. _`TensorFlow Distributions`: 8 https://arxiv.org/abs/1711.10604 9 10It is not possible to directly backpropagate through random samples. However, 11there are two main methods for creating surrogate functions that can be 12backpropagated through. These are the score function estimator/likelihood ratio 13estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly 14seen as the basis for policy gradient methods in reinforcement learning, and the 15pathwise derivative estimator is commonly seen in the reparameterization trick 16in variational autoencoders. Whilst the score function only requires the value 17of samples :math:`f(x)`, the pathwise derivative requires the derivative 18:math:`f'(x)`. The next sections discuss these two in a reinforcement learning 19example. For more details see 20`Gradient Estimation Using Stochastic Computation Graphs`_ . 21 22.. _`Gradient Estimation Using Stochastic Computation Graphs`: 23 https://arxiv.org/abs/1506.05254 24 25Score function 26^^^^^^^^^^^^^^ 27 28When the probability density function is differentiable with respect to its 29parameters, we only need :meth:`~torch.distributions.Distribution.sample` and 30:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE: 31 32.. math:: 33 34 \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} 35 36where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate, 37:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of 38taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`. 39 40In practice we would sample an action from the output of a network, apply this 41action in an environment, and then use ``log_prob`` to construct an equivalent 42loss function. Note that we use a negative because optimizers use gradient 43descent, whilst the rule above assumes gradient ascent. With a categorical 44policy, the code for implementing REINFORCE would be as follows:: 45 46 probs = policy_network(state) 47 # Note that this is equivalent to what used to be called multinomial 48 m = Categorical(probs) 49 action = m.sample() 50 next_state, reward = env.step(action) 51 loss = -m.log_prob(action) * reward 52 loss.backward() 53 54Pathwise derivative 55^^^^^^^^^^^^^^^^^^^ 56 57The other way to implement these stochastic/policy gradients would be to use the 58reparameterization trick from the 59:meth:`~torch.distributions.Distribution.rsample` method, where the 60parameterized random variable can be constructed via a parameterized 61deterministic function of a parameter-free random variable. The reparameterized 62sample therefore becomes differentiable. The code for implementing the pathwise 63derivative would be as follows:: 64 65 params = policy_network(state) 66 m = Normal(*params) 67 # Any distribution with .has_rsample == True could work based on the application 68 action = m.rsample() 69 next_state, reward = env.step(action) # Assuming that reward is differentiable 70 loss = -reward 71 loss.backward() 72""" 73 74from . import transforms 75from .bernoulli import Bernoulli 76from .beta import Beta 77from .binomial import Binomial 78from .categorical import Categorical 79from .cauchy import Cauchy 80from .chi2 import Chi2 81from .constraint_registry import biject_to, transform_to 82from .continuous_bernoulli import ContinuousBernoulli 83from .dirichlet import Dirichlet 84from .distribution import Distribution 85from .exp_family import ExponentialFamily 86from .exponential import Exponential 87from .fishersnedecor import FisherSnedecor 88from .gamma import Gamma 89from .geometric import Geometric 90from .gumbel import Gumbel 91from .half_cauchy import HalfCauchy 92from .half_normal import HalfNormal 93from .independent import Independent 94from .inverse_gamma import InverseGamma 95from .kl import _add_kl_info, kl_divergence, register_kl 96from .kumaraswamy import Kumaraswamy 97from .laplace import Laplace 98from .lkj_cholesky import LKJCholesky 99from .log_normal import LogNormal 100from .logistic_normal import LogisticNormal 101from .lowrank_multivariate_normal import LowRankMultivariateNormal 102from .mixture_same_family import MixtureSameFamily 103from .multinomial import Multinomial 104from .multivariate_normal import MultivariateNormal 105from .negative_binomial import NegativeBinomial 106from .normal import Normal 107from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough 108from .pareto import Pareto 109from .poisson import Poisson 110from .relaxed_bernoulli import RelaxedBernoulli 111from .relaxed_categorical import RelaxedOneHotCategorical 112from .studentT import StudentT 113from .transformed_distribution import TransformedDistribution 114from .transforms import * # noqa: F403 115from .uniform import Uniform 116from .von_mises import VonMises 117from .weibull import Weibull 118from .wishart import Wishart 119 120 121_add_kl_info() 122del _add_kl_info 123 124__all__ = [ 125 "Bernoulli", 126 "Beta", 127 "Binomial", 128 "Categorical", 129 "Cauchy", 130 "Chi2", 131 "ContinuousBernoulli", 132 "Dirichlet", 133 "Distribution", 134 "Exponential", 135 "ExponentialFamily", 136 "FisherSnedecor", 137 "Gamma", 138 "Geometric", 139 "Gumbel", 140 "HalfCauchy", 141 "HalfNormal", 142 "Independent", 143 "InverseGamma", 144 "Kumaraswamy", 145 "LKJCholesky", 146 "Laplace", 147 "LogNormal", 148 "LogisticNormal", 149 "LowRankMultivariateNormal", 150 "MixtureSameFamily", 151 "Multinomial", 152 "MultivariateNormal", 153 "NegativeBinomial", 154 "Normal", 155 "OneHotCategorical", 156 "OneHotCategoricalStraightThrough", 157 "Pareto", 158 "RelaxedBernoulli", 159 "RelaxedOneHotCategorical", 160 "StudentT", 161 "Poisson", 162 "Uniform", 163 "VonMises", 164 "Weibull", 165 "Wishart", 166 "TransformedDistribution", 167 "biject_to", 168 "kl_divergence", 169 "register_kl", 170 "transform_to", 171] 172__all__.extend(transforms.__all__) 173