xref: /aosp_15_r20/external/pytorch/torch/quasirandom.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerclass SobolEngine:
8*da0073e9SAndroid Build Coastguard Worker    r"""
9*da0073e9SAndroid Build Coastguard Worker    The :class:`torch.quasirandom.SobolEngine` is an engine for generating
10*da0073e9SAndroid Build Coastguard Worker    (scrambled) Sobol sequences. Sobol sequences are an example of low
11*da0073e9SAndroid Build Coastguard Worker    discrepancy quasi-random sequences.
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    This implementation of an engine for Sobol sequences is capable of
14*da0073e9SAndroid Build Coastguard Worker    sampling sequences up to a maximum dimension of 21201. It uses direction
15*da0073e9SAndroid Build Coastguard Worker    numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the
16*da0073e9SAndroid Build Coastguard Worker    search criterion D(6) up to the dimension 21201. This is the recommended
17*da0073e9SAndroid Build Coastguard Worker    choice by the authors.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    References:
20*da0073e9SAndroid Build Coastguard Worker      - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points.
21*da0073e9SAndroid Build Coastguard Worker        Journal of Complexity, 14(4):466-489, December 1998.
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker      - I. M. Sobol. The distribution of points in a cube and the accurate
24*da0073e9SAndroid Build Coastguard Worker        evaluation of integrals.
25*da0073e9SAndroid Build Coastguard Worker        Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967.
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    Args:
28*da0073e9SAndroid Build Coastguard Worker        dimension (Int): The dimensionality of the sequence to be drawn
29*da0073e9SAndroid Build Coastguard Worker        scramble (bool, optional): Setting this to ``True`` will produce
30*da0073e9SAndroid Build Coastguard Worker                                   scrambled Sobol sequences. Scrambling is
31*da0073e9SAndroid Build Coastguard Worker                                   capable of producing better Sobol
32*da0073e9SAndroid Build Coastguard Worker                                   sequences. Default: ``False``.
33*da0073e9SAndroid Build Coastguard Worker        seed (Int, optional): This is the seed for the scrambling. The seed
34*da0073e9SAndroid Build Coastguard Worker                              of the random number generator is set to this,
35*da0073e9SAndroid Build Coastguard Worker                              if specified. Otherwise, it uses a random seed.
36*da0073e9SAndroid Build Coastguard Worker                              Default: ``None``
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker    Examples::
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("unseeded random state")
41*da0073e9SAndroid Build Coastguard Worker        >>> soboleng = torch.quasirandom.SobolEngine(dimension=5)
42*da0073e9SAndroid Build Coastguard Worker        >>> soboleng.draw(3)
43*da0073e9SAndroid Build Coastguard Worker        tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
44*da0073e9SAndroid Build Coastguard Worker                [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
45*da0073e9SAndroid Build Coastguard Worker                [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]])
46*da0073e9SAndroid Build Coastguard Worker    """
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    MAXBIT = 30
49*da0073e9SAndroid Build Coastguard Worker    MAXDIM = 21201
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dimension, scramble=False, seed=None):
52*da0073e9SAndroid Build Coastguard Worker        if dimension > self.MAXDIM or dimension < 1:
53*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
54*da0073e9SAndroid Build Coastguard Worker                "Supported range of dimensionality "
55*da0073e9SAndroid Build Coastguard Worker                f"for SobolEngine is [1, {self.MAXDIM}]"
56*da0073e9SAndroid Build Coastguard Worker            )
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        self.seed = seed
59*da0073e9SAndroid Build Coastguard Worker        self.scramble = scramble
60*da0073e9SAndroid Build Coastguard Worker        self.dimension = dimension
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker        cpu = torch.device("cpu")
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        self.sobolstate = torch.zeros(
65*da0073e9SAndroid Build Coastguard Worker            dimension, self.MAXBIT, device=cpu, dtype=torch.long
66*da0073e9SAndroid Build Coastguard Worker        )
67*da0073e9SAndroid Build Coastguard Worker        torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        if not self.scramble:
70*da0073e9SAndroid Build Coastguard Worker            self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long)
71*da0073e9SAndroid Build Coastguard Worker        else:
72*da0073e9SAndroid Build Coastguard Worker            self._scramble()
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker        self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
75*da0073e9SAndroid Build Coastguard Worker        self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
76*da0073e9SAndroid Build Coastguard Worker        self.num_generated = 0
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def draw(
79*da0073e9SAndroid Build Coastguard Worker        self,
80*da0073e9SAndroid Build Coastguard Worker        n: int = 1,
81*da0073e9SAndroid Build Coastguard Worker        out: Optional[torch.Tensor] = None,
82*da0073e9SAndroid Build Coastguard Worker        dtype: Optional[torch.dtype] = None,
83*da0073e9SAndroid Build Coastguard Worker    ) -> torch.Tensor:
84*da0073e9SAndroid Build Coastguard Worker        r"""
85*da0073e9SAndroid Build Coastguard Worker        Function to draw a sequence of :attr:`n` points from a Sobol sequence.
86*da0073e9SAndroid Build Coastguard Worker        Note that the samples are dependent on the previous samples. The size
87*da0073e9SAndroid Build Coastguard Worker        of the result is :math:`(n, dimension)`.
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        Args:
90*da0073e9SAndroid Build Coastguard Worker            n (Int, optional): The length of sequence of points to draw.
91*da0073e9SAndroid Build Coastguard Worker                               Default: 1
92*da0073e9SAndroid Build Coastguard Worker            out (Tensor, optional): The output tensor
93*da0073e9SAndroid Build Coastguard Worker            dtype (:class:`torch.dtype`, optional): the desired data type of the
94*da0073e9SAndroid Build Coastguard Worker                                                    returned tensor.
95*da0073e9SAndroid Build Coastguard Worker                                                    Default: ``None``
96*da0073e9SAndroid Build Coastguard Worker        """
97*da0073e9SAndroid Build Coastguard Worker        if dtype is None:
98*da0073e9SAndroid Build Coastguard Worker            dtype = torch.get_default_dtype()
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        if self.num_generated == 0:
101*da0073e9SAndroid Build Coastguard Worker            if n == 1:
102*da0073e9SAndroid Build Coastguard Worker                result = self._first_point.to(dtype)
103*da0073e9SAndroid Build Coastguard Worker            else:
104*da0073e9SAndroid Build Coastguard Worker                result, self.quasi = torch._sobol_engine_draw(
105*da0073e9SAndroid Build Coastguard Worker                    self.quasi,
106*da0073e9SAndroid Build Coastguard Worker                    n - 1,
107*da0073e9SAndroid Build Coastguard Worker                    self.sobolstate,
108*da0073e9SAndroid Build Coastguard Worker                    self.dimension,
109*da0073e9SAndroid Build Coastguard Worker                    self.num_generated,
110*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
111*da0073e9SAndroid Build Coastguard Worker                )
112*da0073e9SAndroid Build Coastguard Worker                result = torch.cat((self._first_point.to(dtype), result), dim=-2)
113*da0073e9SAndroid Build Coastguard Worker        else:
114*da0073e9SAndroid Build Coastguard Worker            result, self.quasi = torch._sobol_engine_draw(
115*da0073e9SAndroid Build Coastguard Worker                self.quasi,
116*da0073e9SAndroid Build Coastguard Worker                n,
117*da0073e9SAndroid Build Coastguard Worker                self.sobolstate,
118*da0073e9SAndroid Build Coastguard Worker                self.dimension,
119*da0073e9SAndroid Build Coastguard Worker                self.num_generated - 1,
120*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
121*da0073e9SAndroid Build Coastguard Worker            )
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        self.num_generated += n
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        if out is not None:
126*da0073e9SAndroid Build Coastguard Worker            out.resize_as_(result).copy_(result)
127*da0073e9SAndroid Build Coastguard Worker            return out
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        return result
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def draw_base2(
132*da0073e9SAndroid Build Coastguard Worker        self,
133*da0073e9SAndroid Build Coastguard Worker        m: int,
134*da0073e9SAndroid Build Coastguard Worker        out: Optional[torch.Tensor] = None,
135*da0073e9SAndroid Build Coastguard Worker        dtype: Optional[torch.dtype] = None,
136*da0073e9SAndroid Build Coastguard Worker    ) -> torch.Tensor:
137*da0073e9SAndroid Build Coastguard Worker        r"""
138*da0073e9SAndroid Build Coastguard Worker        Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
139*da0073e9SAndroid Build Coastguard Worker        Note that the samples are dependent on the previous samples. The size
140*da0073e9SAndroid Build Coastguard Worker        of the result is :math:`(2**m, dimension)`.
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        Args:
143*da0073e9SAndroid Build Coastguard Worker            m (Int): The (base2) exponent of the number of points to draw.
144*da0073e9SAndroid Build Coastguard Worker            out (Tensor, optional): The output tensor
145*da0073e9SAndroid Build Coastguard Worker            dtype (:class:`torch.dtype`, optional): the desired data type of the
146*da0073e9SAndroid Build Coastguard Worker                                                    returned tensor.
147*da0073e9SAndroid Build Coastguard Worker                                                    Default: ``None``
148*da0073e9SAndroid Build Coastguard Worker        """
149*da0073e9SAndroid Build Coastguard Worker        n = 2**m
150*da0073e9SAndroid Build Coastguard Worker        total_n = self.num_generated + n
151*da0073e9SAndroid Build Coastguard Worker        if not (total_n & (total_n - 1) == 0):
152*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
153*da0073e9SAndroid Build Coastguard Worker                "The balance properties of Sobol' points require "
154*da0073e9SAndroid Build Coastguard Worker                f"n to be a power of 2. {self.num_generated} points have been "
155*da0073e9SAndroid Build Coastguard Worker                f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
156*da0073e9SAndroid Build Coastguard Worker                "If you still want to do this, please use "
157*da0073e9SAndroid Build Coastguard Worker                "'SobolEngine.draw()' instead."
158*da0073e9SAndroid Build Coastguard Worker            )
159*da0073e9SAndroid Build Coastguard Worker        return self.draw(n=n, out=out, dtype=dtype)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def reset(self):
162*da0073e9SAndroid Build Coastguard Worker        r"""
163*da0073e9SAndroid Build Coastguard Worker        Function to reset the ``SobolEngine`` to base state.
164*da0073e9SAndroid Build Coastguard Worker        """
165*da0073e9SAndroid Build Coastguard Worker        self.quasi.copy_(self.shift)
166*da0073e9SAndroid Build Coastguard Worker        self.num_generated = 0
167*da0073e9SAndroid Build Coastguard Worker        return self
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    def fast_forward(self, n):
170*da0073e9SAndroid Build Coastguard Worker        r"""
171*da0073e9SAndroid Build Coastguard Worker        Function to fast-forward the state of the ``SobolEngine`` by
172*da0073e9SAndroid Build Coastguard Worker        :attr:`n` steps. This is equivalent to drawing :attr:`n` samples
173*da0073e9SAndroid Build Coastguard Worker        without using the samples.
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        Args:
176*da0073e9SAndroid Build Coastguard Worker            n (Int): The number of steps to fast-forward by.
177*da0073e9SAndroid Build Coastguard Worker        """
178*da0073e9SAndroid Build Coastguard Worker        if self.num_generated == 0:
179*da0073e9SAndroid Build Coastguard Worker            torch._sobol_engine_ff_(
180*da0073e9SAndroid Build Coastguard Worker                self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
181*da0073e9SAndroid Build Coastguard Worker            )
182*da0073e9SAndroid Build Coastguard Worker        else:
183*da0073e9SAndroid Build Coastguard Worker            torch._sobol_engine_ff_(
184*da0073e9SAndroid Build Coastguard Worker                self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
185*da0073e9SAndroid Build Coastguard Worker            )
186*da0073e9SAndroid Build Coastguard Worker        self.num_generated += n
187*da0073e9SAndroid Build Coastguard Worker        return self
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    def _scramble(self):
190*da0073e9SAndroid Build Coastguard Worker        g: Optional[torch.Generator] = None
191*da0073e9SAndroid Build Coastguard Worker        if self.seed is not None:
192*da0073e9SAndroid Build Coastguard Worker            g = torch.Generator()
193*da0073e9SAndroid Build Coastguard Worker            g.manual_seed(self.seed)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        cpu = torch.device("cpu")
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        # Generate shift vector
198*da0073e9SAndroid Build Coastguard Worker        shift_ints = torch.randint(
199*da0073e9SAndroid Build Coastguard Worker            2, (self.dimension, self.MAXBIT), device=cpu, generator=g
200*da0073e9SAndroid Build Coastguard Worker        )
201*da0073e9SAndroid Build Coastguard Worker        self.shift = torch.mv(
202*da0073e9SAndroid Build Coastguard Worker            shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
203*da0073e9SAndroid Build Coastguard Worker        )
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        # Generate lower triangular matrices (stacked across dimensions)
206*da0073e9SAndroid Build Coastguard Worker        ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
207*da0073e9SAndroid Build Coastguard Worker        ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril()
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
212*da0073e9SAndroid Build Coastguard Worker        fmt_string = [f"dimension={self.dimension}"]
213*da0073e9SAndroid Build Coastguard Worker        if self.scramble:
214*da0073e9SAndroid Build Coastguard Worker            fmt_string += ["scramble=True"]
215*da0073e9SAndroid Build Coastguard Worker        if self.seed is not None:
216*da0073e9SAndroid Build Coastguard Worker            fmt_string += [f"seed={self.seed}"]
217*da0073e9SAndroid Build Coastguard Worker        return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
218