xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SobolEngineOpsUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /// This file contains some tensor-agnostic operations to be used in the
2 /// core functions of the `SobolEngine`
3 #include <ATen/core/Tensor.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #else
8 #include <ATen/ops/arange.h>
9 #include <ATen/ops/mul.h>
10 #include <ATen/ops/pow.h>
11 #endif
12 
13 namespace at::native::sobol_utils {
14 
15 /// Function to return the minimum of number of bits to represent the integer `n`
bit_length(const int64_t n)16 inline int64_t bit_length(const int64_t n) {
17   int64_t nbits, nloc;
18   for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++);
19   return nbits;
20 }
21 
22 /// Function to get the position of the rightmost zero in the bit representation of an integer
23 /// This value is the zero-indexed position
rightmost_zero(const int64_t n)24 inline int64_t rightmost_zero(const int64_t n) {
25   int64_t z, i;
26   for (z = n, i = 0; z % 2 == 1; z /= 2, i++);
27   return i;
28 }
29 
30 /// Function to get a subsequence of bits in the representation of an integer starting from
31 /// `pos` and of length `length`
bitsubseq(const int64_t n,const int64_t pos,const int64_t length)32 inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) {
33   return (n >> pos) & ((1 << length) - 1);
34 }
35 
36 /// Function to perform the inner product between a batched square matrix and a power of 2 vector
cdot_pow2(const at::Tensor & bmat)37 inline at::Tensor cdot_pow2(const at::Tensor& bmat) {
38   at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options());
39   inter = at::pow(2, inter).expand_as(bmat);
40   return at::mul(inter, bmat).sum(-1);
41 }
42 
43 /// All definitions below this point are data. These are constant, and should not be modified
44 /// without notice
45 
46 constexpr int64_t MAXDIM = 21201;
47 constexpr int64_t MAXDEG = 18;
48 constexpr int64_t MAXBIT = 30;
49 constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT;
50 constexpr float RECIPD = 1.0 / LARGEST_NUMBER;
51 
52 extern const int64_t poly[MAXDIM];
53 extern const int64_t initsobolstate[MAXDIM][MAXDEG];
54 
55 } // namespace at::native::sobol_utils
56