xref: /aosp_15_r20/external/pytorch/torch/distributions/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1r"""
2The ``distributions`` package contains parameterizable probability distributions
3and sampling functions. This allows the construction of stochastic computation
4graphs and stochastic gradient estimators for optimization. This package
5generally follows the design of the `TensorFlow Distributions`_ package.
6
7.. _`TensorFlow Distributions`:
8    https://arxiv.org/abs/1711.10604
9
10It is not possible to directly backpropagate through random samples. However,
11there are two main methods for creating surrogate functions that can be
12backpropagated through. These are the score function estimator/likelihood ratio
13estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
14seen as the basis for policy gradient methods in reinforcement learning, and the
15pathwise derivative estimator is commonly seen in the reparameterization trick
16in variational autoencoders. Whilst the score function only requires the value
17of samples :math:`f(x)`, the pathwise derivative requires the derivative
18:math:`f'(x)`. The next sections discuss these two in a reinforcement learning
19example. For more details see
20`Gradient Estimation Using Stochastic Computation Graphs`_ .
21
22.. _`Gradient Estimation Using Stochastic Computation Graphs`:
23     https://arxiv.org/abs/1506.05254
24
25Score function
26^^^^^^^^^^^^^^
27
28When the probability density function is differentiable with respect to its
29parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
30:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
31
32.. math::
33
34    \Delta\theta  = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
35
36where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
37:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
38taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
39
40In practice we would sample an action from the output of a network, apply this
41action in an environment, and then use ``log_prob`` to construct an equivalent
42loss function. Note that we use a negative because optimizers use gradient
43descent, whilst the rule above assumes gradient ascent. With a categorical
44policy, the code for implementing REINFORCE would be as follows::
45
46    probs = policy_network(state)
47    # Note that this is equivalent to what used to be called multinomial
48    m = Categorical(probs)
49    action = m.sample()
50    next_state, reward = env.step(action)
51    loss = -m.log_prob(action) * reward
52    loss.backward()
53
54Pathwise derivative
55^^^^^^^^^^^^^^^^^^^
56
57The other way to implement these stochastic/policy gradients would be to use the
58reparameterization trick from the
59:meth:`~torch.distributions.Distribution.rsample` method, where the
60parameterized random variable can be constructed via a parameterized
61deterministic function of a parameter-free random variable. The reparameterized
62sample therefore becomes differentiable. The code for implementing the pathwise
63derivative would be as follows::
64
65    params = policy_network(state)
66    m = Normal(*params)
67    # Any distribution with .has_rsample == True could work based on the application
68    action = m.rsample()
69    next_state, reward = env.step(action)  # Assuming that reward is differentiable
70    loss = -reward
71    loss.backward()
72"""
73
74from . import transforms
75from .bernoulli import Bernoulli
76from .beta import Beta
77from .binomial import Binomial
78from .categorical import Categorical
79from .cauchy import Cauchy
80from .chi2 import Chi2
81from .constraint_registry import biject_to, transform_to
82from .continuous_bernoulli import ContinuousBernoulli
83from .dirichlet import Dirichlet
84from .distribution import Distribution
85from .exp_family import ExponentialFamily
86from .exponential import Exponential
87from .fishersnedecor import FisherSnedecor
88from .gamma import Gamma
89from .geometric import Geometric
90from .gumbel import Gumbel
91from .half_cauchy import HalfCauchy
92from .half_normal import HalfNormal
93from .independent import Independent
94from .inverse_gamma import InverseGamma
95from .kl import _add_kl_info, kl_divergence, register_kl
96from .kumaraswamy import Kumaraswamy
97from .laplace import Laplace
98from .lkj_cholesky import LKJCholesky
99from .log_normal import LogNormal
100from .logistic_normal import LogisticNormal
101from .lowrank_multivariate_normal import LowRankMultivariateNormal
102from .mixture_same_family import MixtureSameFamily
103from .multinomial import Multinomial
104from .multivariate_normal import MultivariateNormal
105from .negative_binomial import NegativeBinomial
106from .normal import Normal
107from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
108from .pareto import Pareto
109from .poisson import Poisson
110from .relaxed_bernoulli import RelaxedBernoulli
111from .relaxed_categorical import RelaxedOneHotCategorical
112from .studentT import StudentT
113from .transformed_distribution import TransformedDistribution
114from .transforms import *  # noqa: F403
115from .uniform import Uniform
116from .von_mises import VonMises
117from .weibull import Weibull
118from .wishart import Wishart
119
120
121_add_kl_info()
122del _add_kl_info
123
124__all__ = [
125    "Bernoulli",
126    "Beta",
127    "Binomial",
128    "Categorical",
129    "Cauchy",
130    "Chi2",
131    "ContinuousBernoulli",
132    "Dirichlet",
133    "Distribution",
134    "Exponential",
135    "ExponentialFamily",
136    "FisherSnedecor",
137    "Gamma",
138    "Geometric",
139    "Gumbel",
140    "HalfCauchy",
141    "HalfNormal",
142    "Independent",
143    "InverseGamma",
144    "Kumaraswamy",
145    "LKJCholesky",
146    "Laplace",
147    "LogNormal",
148    "LogisticNormal",
149    "LowRankMultivariateNormal",
150    "MixtureSameFamily",
151    "Multinomial",
152    "MultivariateNormal",
153    "NegativeBinomial",
154    "Normal",
155    "OneHotCategorical",
156    "OneHotCategoricalStraightThrough",
157    "Pareto",
158    "RelaxedBernoulli",
159    "RelaxedOneHotCategorical",
160    "StudentT",
161    "Poisson",
162    "Uniform",
163    "VonMises",
164    "Weibull",
165    "Wishart",
166    "TransformedDistribution",
167    "biject_to",
168    "kl_divergence",
169    "register_kl",
170    "transform_to",
171]
172__all__.extend(transforms.__all__)
173