#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #endif namespace at::native { using namespace sobol_utils; /// This is the core function to draw samples from a `SobolEngine` given /// its state variables (`sobolstate` and `quasi`). `dimension` can be /// inferred from `sobolstate`, but choosing to pass it explicitly to avoid /// an extra operation to obtain the size of the first dimension of /// `sobolstate`. std::tuple _sobol_engine_draw(const Tensor& quasi, int64_t n, const Tensor& sobolstate, int64_t dimension, int64_t num_generated, std::optional dtype) { TORCH_CHECK(sobolstate.dtype() == at::kLong, "sobolstate needs to be of type ", at::kLong); TORCH_CHECK(quasi.dtype() == at::kLong, "quasi needs to be of type ", at::kLong); Tensor wquasi = quasi.clone(at::MemoryFormat::Contiguous); auto result_dtype = dtype.has_value() ? dtype.value() : at::kFloat; Tensor result = at::empty({n, dimension}, sobolstate.options().dtype(result_dtype)); AT_DISPATCH_FLOATING_TYPES(result_dtype, "_sobol_engine_draw", [&]() -> void { // We deal with `data` and `strides` due to performance issues. int64_t l; int64_t* wquasi_data = wquasi.data_ptr(); int64_t* sobolstate_data = sobolstate.data_ptr(); scalar_t* result_data = result.data_ptr(); int64_t wquasi_stride = wquasi.stride(0); int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1); int64_t result_row_stride = result.stride(0), result_col_stride = result.stride(1); for (int64_t i = 0; i < n; i++, num_generated++) { l = rightmost_zero(num_generated); for (const auto j : c10::irange(dimension)) { wquasi_data[j * wquasi_stride] ^= sobolstate_data[j * sobolstate_row_stride + l * sobolstate_col_stride]; result_data[i * result_row_stride + j * result_col_stride] = wquasi_data[j * wquasi_stride]; } } }); result.mul_(RECIPD); return std::tuple(result, wquasi); } /// This is the core function to fast-forward a `SobolEngine` given /// its state variables (`sobolstate` and `quasi`). `dimension` can be /// inferred from `sobolstate`, but is passed as an argument for the same reasons /// specified above. Tensor& _sobol_engine_ff_(Tensor& quasi, int64_t n, const Tensor& sobolstate, int64_t dimension, int64_t num_generated) { TORCH_CHECK(sobolstate.dtype() == at::kLong, "sobolstate needs to be of type ", at::kLong); TORCH_CHECK(quasi.dtype() == at::kLong, "quasi needs to be of type ", at::kLong); // We deal with `data` and `strides` due to performance issues. // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t l; int64_t* quasi_data = quasi.data_ptr(); int64_t* sobolstate_data = sobolstate.data_ptr(); int64_t quasi_stride = quasi.stride(0); int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1); for (int64_t i = 0; i < n; i++, num_generated++) { l = rightmost_zero(num_generated); for (const auto j : c10::irange(dimension)) { quasi_data[j * quasi_stride] ^= sobolstate_data[j * sobolstate_row_stride + l * sobolstate_col_stride]; } } return quasi; } /// This is an implicit function used for randomizing the state variables of the. /// `SobolEngine`. Arguments are a randomized `sobolstate` state variables /// and a list of random lower triangular matrices consisting of 0s and 1s. `dimension` is /// passed explicitly again. Tensor& _sobol_engine_scramble_(Tensor& sobolstate, const Tensor& ltm, int64_t dimension) { TORCH_CHECK(sobolstate.dtype() == at::kLong, "sobolstate needs to be of type ", at::kLong); /// Require a tensor accessor for `sobolstate` auto ss_a = sobolstate.accessor(); /// For every tensor in the list of tensors, the diagonals are made 1 /// Require a dot product of every row with a specific vector of each of the matrices in `ltm`. /// Instead, we perform an element-wise product of all the matrices and sum over the last dimension. /// The required product of the m^{th} row in the d^{th} square matrix in `ltm` can be accessed /// using ltm_d_a[d][m] m and d are zero-indexed Tensor diag_true = ltm.clone(at::MemoryFormat::Contiguous); diag_true.diagonal(0, -2, -1).fill_(1); Tensor ltm_dots = cdot_pow2(diag_true); auto ltm_d_a = ltm_dots.accessor(); /// Main scrambling loop for (const auto d : c10::irange(dimension)) { for (const auto j : c10::irange(MAXBIT)) { int64_t vdj = ss_a[d][j], l = 1, t2 = 0; for (int64_t p = MAXBIT - 1; p >= 0; --p) { int64_t lsmdp = ltm_d_a[d][p]; int64_t t1 = 0; for (const auto k : c10::irange(MAXBIT)) { t1 += (bitsubseq(lsmdp, k, 1) * bitsubseq(vdj, k, 1)); } t1 = t1 % 2; t2 = t2 + t1 * l; l = l << 1; } ss_a[d][j] = t2; } } return sobolstate; } /// This is a core function to initialize the main state variable of a `SobolEngine`. /// `dimension` is passed explicitly as well (see why above) Tensor& _sobol_engine_initialize_state_(Tensor& sobolstate, int64_t dimension) { TORCH_CHECK(sobolstate.dtype() == at::kLong, "sobolstate needs to be of type ", at::kLong); /// Use a tensor accessor for `sobolstate` auto ss_a = sobolstate.accessor(); /// First row of `sobolstate` is all 1s for (const auto m : c10::irange(MAXBIT)) { ss_a[0][m] = 1; } /// Remaining rows of sobolstate (row 2 through dim, indexed by [1:dim]) for (const auto d : c10::irange(1, dimension)) { int64_t p = poly[d]; int64_t m = bit_length(p) - 1; // First m elements of row d comes from initsobolstate for (const auto i : c10::irange(m)) { ss_a[d][i] = initsobolstate[d][i]; } // Fill in remaining elements of v as in Section 2 (top of pg. 90) of: // P. Bratley and B. L. Fox. Algorithm 659: Implementing sobol's // quasirandom sequence generator. ACM Trans. // Math. Softw., 14(1):88-100, Mar. 1988. for (const auto j : c10::irange(m, MAXBIT)) { int64_t newv = ss_a[d][j - m]; int64_t pow2 = 1; for (const auto k : c10::irange(m)) { pow2 <<= 1; if ((p >> (m - 1 - k)) & 1) { newv = newv ^ (pow2 * ss_a[d][j - k - 1]); } } ss_a[d][j] = newv; } } /// Multiply each column of sobolstate by power of 2: /// sobolstate * [2^(maxbit-1), 2^(maxbit-2),..., 2, 1] Tensor pow2s = at::pow( 2, at::native::arange( (MAXBIT - 1), -1, -1, optTypeMetaToScalarType(sobolstate.options().dtype_opt()), sobolstate.options().layout_opt(), sobolstate.options().device_opt(), sobolstate.options().pinned_memory_opt())); sobolstate.mul_(pow2s); return sobolstate; } } // namespace at::native