xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/dropout.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4
5__all__ = ["Dropout"]
6
7
8class Dropout(torch.nn.Dropout):
9    r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
10        And this is a placeholder to enable models where fp32 tensors
11        had dropout to work with quantized tensors in train and eval mode.
12
13    Args:
14        p: probability of an element to be zeroed
15        inplace: can optionally do the operation in-place. Default: ``False``
16    """
17
18    def forward(self, input):
19        return input
20
21    def _get_name(self):
22        return "QuantizedDropout"
23
24    @classmethod
25    def from_float(cls, mod, use_precomputed_fake_quant=False):
26        return cls(mod.p, mod.inplace)
27
28    @classmethod
29    def from_reference(cls, mod, scale, zero_point):
30        return cls(mod.p, mod.inplace)
31