1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerclass SobolEngine: 8*da0073e9SAndroid Build Coastguard Worker r""" 9*da0073e9SAndroid Build Coastguard Worker The :class:`torch.quasirandom.SobolEngine` is an engine for generating 10*da0073e9SAndroid Build Coastguard Worker (scrambled) Sobol sequences. Sobol sequences are an example of low 11*da0073e9SAndroid Build Coastguard Worker discrepancy quasi-random sequences. 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker This implementation of an engine for Sobol sequences is capable of 14*da0073e9SAndroid Build Coastguard Worker sampling sequences up to a maximum dimension of 21201. It uses direction 15*da0073e9SAndroid Build Coastguard Worker numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the 16*da0073e9SAndroid Build Coastguard Worker search criterion D(6) up to the dimension 21201. This is the recommended 17*da0073e9SAndroid Build Coastguard Worker choice by the authors. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker References: 20*da0073e9SAndroid Build Coastguard Worker - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points. 21*da0073e9SAndroid Build Coastguard Worker Journal of Complexity, 14(4):466-489, December 1998. 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker - I. M. Sobol. The distribution of points in a cube and the accurate 24*da0073e9SAndroid Build Coastguard Worker evaluation of integrals. 25*da0073e9SAndroid Build Coastguard Worker Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967. 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker Args: 28*da0073e9SAndroid Build Coastguard Worker dimension (Int): The dimensionality of the sequence to be drawn 29*da0073e9SAndroid Build Coastguard Worker scramble (bool, optional): Setting this to ``True`` will produce 30*da0073e9SAndroid Build Coastguard Worker scrambled Sobol sequences. Scrambling is 31*da0073e9SAndroid Build Coastguard Worker capable of producing better Sobol 32*da0073e9SAndroid Build Coastguard Worker sequences. Default: ``False``. 33*da0073e9SAndroid Build Coastguard Worker seed (Int, optional): This is the seed for the scrambling. The seed 34*da0073e9SAndroid Build Coastguard Worker of the random number generator is set to this, 35*da0073e9SAndroid Build Coastguard Worker if specified. Otherwise, it uses a random seed. 36*da0073e9SAndroid Build Coastguard Worker Default: ``None`` 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker Examples:: 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("unseeded random state") 41*da0073e9SAndroid Build Coastguard Worker >>> soboleng = torch.quasirandom.SobolEngine(dimension=5) 42*da0073e9SAndroid Build Coastguard Worker >>> soboleng.draw(3) 43*da0073e9SAndroid Build Coastguard Worker tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], 44*da0073e9SAndroid Build Coastguard Worker [0.5000, 0.5000, 0.5000, 0.5000, 0.5000], 45*da0073e9SAndroid Build Coastguard Worker [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]]) 46*da0073e9SAndroid Build Coastguard Worker """ 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker MAXBIT = 30 49*da0073e9SAndroid Build Coastguard Worker MAXDIM = 21201 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def __init__(self, dimension, scramble=False, seed=None): 52*da0073e9SAndroid Build Coastguard Worker if dimension > self.MAXDIM or dimension < 1: 53*da0073e9SAndroid Build Coastguard Worker raise ValueError( 54*da0073e9SAndroid Build Coastguard Worker "Supported range of dimensionality " 55*da0073e9SAndroid Build Coastguard Worker f"for SobolEngine is [1, {self.MAXDIM}]" 56*da0073e9SAndroid Build Coastguard Worker ) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker self.seed = seed 59*da0073e9SAndroid Build Coastguard Worker self.scramble = scramble 60*da0073e9SAndroid Build Coastguard Worker self.dimension = dimension 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker cpu = torch.device("cpu") 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker self.sobolstate = torch.zeros( 65*da0073e9SAndroid Build Coastguard Worker dimension, self.MAXBIT, device=cpu, dtype=torch.long 66*da0073e9SAndroid Build Coastguard Worker ) 67*da0073e9SAndroid Build Coastguard Worker torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker if not self.scramble: 70*da0073e9SAndroid Build Coastguard Worker self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long) 71*da0073e9SAndroid Build Coastguard Worker else: 72*da0073e9SAndroid Build Coastguard Worker self._scramble() 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker self.quasi = self.shift.clone(memory_format=torch.contiguous_format) 75*da0073e9SAndroid Build Coastguard Worker self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1) 76*da0073e9SAndroid Build Coastguard Worker self.num_generated = 0 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def draw( 79*da0073e9SAndroid Build Coastguard Worker self, 80*da0073e9SAndroid Build Coastguard Worker n: int = 1, 81*da0073e9SAndroid Build Coastguard Worker out: Optional[torch.Tensor] = None, 82*da0073e9SAndroid Build Coastguard Worker dtype: Optional[torch.dtype] = None, 83*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 84*da0073e9SAndroid Build Coastguard Worker r""" 85*da0073e9SAndroid Build Coastguard Worker Function to draw a sequence of :attr:`n` points from a Sobol sequence. 86*da0073e9SAndroid Build Coastguard Worker Note that the samples are dependent on the previous samples. The size 87*da0073e9SAndroid Build Coastguard Worker of the result is :math:`(n, dimension)`. 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker Args: 90*da0073e9SAndroid Build Coastguard Worker n (Int, optional): The length of sequence of points to draw. 91*da0073e9SAndroid Build Coastguard Worker Default: 1 92*da0073e9SAndroid Build Coastguard Worker out (Tensor, optional): The output tensor 93*da0073e9SAndroid Build Coastguard Worker dtype (:class:`torch.dtype`, optional): the desired data type of the 94*da0073e9SAndroid Build Coastguard Worker returned tensor. 95*da0073e9SAndroid Build Coastguard Worker Default: ``None`` 96*da0073e9SAndroid Build Coastguard Worker """ 97*da0073e9SAndroid Build Coastguard Worker if dtype is None: 98*da0073e9SAndroid Build Coastguard Worker dtype = torch.get_default_dtype() 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker if self.num_generated == 0: 101*da0073e9SAndroid Build Coastguard Worker if n == 1: 102*da0073e9SAndroid Build Coastguard Worker result = self._first_point.to(dtype) 103*da0073e9SAndroid Build Coastguard Worker else: 104*da0073e9SAndroid Build Coastguard Worker result, self.quasi = torch._sobol_engine_draw( 105*da0073e9SAndroid Build Coastguard Worker self.quasi, 106*da0073e9SAndroid Build Coastguard Worker n - 1, 107*da0073e9SAndroid Build Coastguard Worker self.sobolstate, 108*da0073e9SAndroid Build Coastguard Worker self.dimension, 109*da0073e9SAndroid Build Coastguard Worker self.num_generated, 110*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker result = torch.cat((self._first_point.to(dtype), result), dim=-2) 113*da0073e9SAndroid Build Coastguard Worker else: 114*da0073e9SAndroid Build Coastguard Worker result, self.quasi = torch._sobol_engine_draw( 115*da0073e9SAndroid Build Coastguard Worker self.quasi, 116*da0073e9SAndroid Build Coastguard Worker n, 117*da0073e9SAndroid Build Coastguard Worker self.sobolstate, 118*da0073e9SAndroid Build Coastguard Worker self.dimension, 119*da0073e9SAndroid Build Coastguard Worker self.num_generated - 1, 120*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 121*da0073e9SAndroid Build Coastguard Worker ) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker self.num_generated += n 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker if out is not None: 126*da0073e9SAndroid Build Coastguard Worker out.resize_as_(result).copy_(result) 127*da0073e9SAndroid Build Coastguard Worker return out 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker return result 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def draw_base2( 132*da0073e9SAndroid Build Coastguard Worker self, 133*da0073e9SAndroid Build Coastguard Worker m: int, 134*da0073e9SAndroid Build Coastguard Worker out: Optional[torch.Tensor] = None, 135*da0073e9SAndroid Build Coastguard Worker dtype: Optional[torch.dtype] = None, 136*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 137*da0073e9SAndroid Build Coastguard Worker r""" 138*da0073e9SAndroid Build Coastguard Worker Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. 139*da0073e9SAndroid Build Coastguard Worker Note that the samples are dependent on the previous samples. The size 140*da0073e9SAndroid Build Coastguard Worker of the result is :math:`(2**m, dimension)`. 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker Args: 143*da0073e9SAndroid Build Coastguard Worker m (Int): The (base2) exponent of the number of points to draw. 144*da0073e9SAndroid Build Coastguard Worker out (Tensor, optional): The output tensor 145*da0073e9SAndroid Build Coastguard Worker dtype (:class:`torch.dtype`, optional): the desired data type of the 146*da0073e9SAndroid Build Coastguard Worker returned tensor. 147*da0073e9SAndroid Build Coastguard Worker Default: ``None`` 148*da0073e9SAndroid Build Coastguard Worker """ 149*da0073e9SAndroid Build Coastguard Worker n = 2**m 150*da0073e9SAndroid Build Coastguard Worker total_n = self.num_generated + n 151*da0073e9SAndroid Build Coastguard Worker if not (total_n & (total_n - 1) == 0): 152*da0073e9SAndroid Build Coastguard Worker raise ValueError( 153*da0073e9SAndroid Build Coastguard Worker "The balance properties of Sobol' points require " 154*da0073e9SAndroid Build Coastguard Worker f"n to be a power of 2. {self.num_generated} points have been " 155*da0073e9SAndroid Build Coastguard Worker f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. " 156*da0073e9SAndroid Build Coastguard Worker "If you still want to do this, please use " 157*da0073e9SAndroid Build Coastguard Worker "'SobolEngine.draw()' instead." 158*da0073e9SAndroid Build Coastguard Worker ) 159*da0073e9SAndroid Build Coastguard Worker return self.draw(n=n, out=out, dtype=dtype) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def reset(self): 162*da0073e9SAndroid Build Coastguard Worker r""" 163*da0073e9SAndroid Build Coastguard Worker Function to reset the ``SobolEngine`` to base state. 164*da0073e9SAndroid Build Coastguard Worker """ 165*da0073e9SAndroid Build Coastguard Worker self.quasi.copy_(self.shift) 166*da0073e9SAndroid Build Coastguard Worker self.num_generated = 0 167*da0073e9SAndroid Build Coastguard Worker return self 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker def fast_forward(self, n): 170*da0073e9SAndroid Build Coastguard Worker r""" 171*da0073e9SAndroid Build Coastguard Worker Function to fast-forward the state of the ``SobolEngine`` by 172*da0073e9SAndroid Build Coastguard Worker :attr:`n` steps. This is equivalent to drawing :attr:`n` samples 173*da0073e9SAndroid Build Coastguard Worker without using the samples. 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker Args: 176*da0073e9SAndroid Build Coastguard Worker n (Int): The number of steps to fast-forward by. 177*da0073e9SAndroid Build Coastguard Worker """ 178*da0073e9SAndroid Build Coastguard Worker if self.num_generated == 0: 179*da0073e9SAndroid Build Coastguard Worker torch._sobol_engine_ff_( 180*da0073e9SAndroid Build Coastguard Worker self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated 181*da0073e9SAndroid Build Coastguard Worker ) 182*da0073e9SAndroid Build Coastguard Worker else: 183*da0073e9SAndroid Build Coastguard Worker torch._sobol_engine_ff_( 184*da0073e9SAndroid Build Coastguard Worker self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker self.num_generated += n 187*da0073e9SAndroid Build Coastguard Worker return self 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker def _scramble(self): 190*da0073e9SAndroid Build Coastguard Worker g: Optional[torch.Generator] = None 191*da0073e9SAndroid Build Coastguard Worker if self.seed is not None: 192*da0073e9SAndroid Build Coastguard Worker g = torch.Generator() 193*da0073e9SAndroid Build Coastguard Worker g.manual_seed(self.seed) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker cpu = torch.device("cpu") 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker # Generate shift vector 198*da0073e9SAndroid Build Coastguard Worker shift_ints = torch.randint( 199*da0073e9SAndroid Build Coastguard Worker 2, (self.dimension, self.MAXBIT), device=cpu, generator=g 200*da0073e9SAndroid Build Coastguard Worker ) 201*da0073e9SAndroid Build Coastguard Worker self.shift = torch.mv( 202*da0073e9SAndroid Build Coastguard Worker shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)) 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker # Generate lower triangular matrices (stacked across dimensions) 206*da0073e9SAndroid Build Coastguard Worker ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) 207*da0073e9SAndroid Build Coastguard Worker ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril() 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 212*da0073e9SAndroid Build Coastguard Worker fmt_string = [f"dimension={self.dimension}"] 213*da0073e9SAndroid Build Coastguard Worker if self.scramble: 214*da0073e9SAndroid Build Coastguard Worker fmt_string += ["scramble=True"] 215*da0073e9SAndroid Build Coastguard Worker if self.seed is not None: 216*da0073e9SAndroid Build Coastguard Worker fmt_string += [f"seed={self.seed}"] 217*da0073e9SAndroid Build Coastguard Worker return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")" 218