xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/sampling_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4# https://pytorch.org/docs/stable/torch.html#random-sampling
5
6
7class SamplingOpsModule(torch.nn.Module):
8    def forward(self):
9        a = torch.empty(3, 3).uniform_(0.0, 1.0)
10        size = (1, 4)
11        weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
12        return len(
13            # torch.seed(),
14            # torch.manual_seed(0),
15            torch.bernoulli(a),
16            # torch.initial_seed(),
17            torch.multinomial(weights, 2),
18            torch.normal(2.0, 3.0, size),
19            torch.poisson(a),
20            torch.rand(2, 3),
21            torch.rand_like(a),
22            torch.randint(10, size),
23            torch.randint_like(a, 4),
24            torch.rand(4),
25            torch.randn_like(a),
26            torch.randperm(4),
27            a.bernoulli_(),
28            a.cauchy_(),
29            a.exponential_(),
30            a.geometric_(0.5),
31            a.log_normal_(),
32            a.normal_(),
33            a.random_(),
34            a.uniform_(),
35        )
36