xref: /aosp_15_r20/external/pytorch/aten/src/ATen/WrapDimUtilsMulti.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/WrapDimUtils.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/TensorImpl.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
6*da0073e9SAndroid Build Coastguard Worker #include <bitset>
7*da0073e9SAndroid Build Coastguard Worker #include <sstream>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker namespace at {
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker // This is in an extra file to work around strange interaction of
12*da0073e9SAndroid Build Coastguard Worker // bitset on Windows with operator overloading
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker constexpr size_t dim_bitset_size = 64;
15*da0073e9SAndroid Build Coastguard Worker 
dim_list_to_bitset(OptionalIntArrayRef opt_dims,size_t ndims)16*da0073e9SAndroid Build Coastguard Worker inline std::bitset<dim_bitset_size> dim_list_to_bitset(
17*da0073e9SAndroid Build Coastguard Worker     OptionalIntArrayRef opt_dims,
18*da0073e9SAndroid Build Coastguard Worker     size_t ndims) {
19*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
20*da0073e9SAndroid Build Coastguard Worker       ndims <= dim_bitset_size,
21*da0073e9SAndroid Build Coastguard Worker       "only tensors with up to ",
22*da0073e9SAndroid Build Coastguard Worker       dim_bitset_size,
23*da0073e9SAndroid Build Coastguard Worker       " dims are supported");
24*da0073e9SAndroid Build Coastguard Worker   std::bitset<dim_bitset_size> seen;
25*da0073e9SAndroid Build Coastguard Worker   if (opt_dims.has_value()) {
26*da0073e9SAndroid Build Coastguard Worker     auto dims = opt_dims.value();
27*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(dims.size())) {
28*da0073e9SAndroid Build Coastguard Worker       size_t dim = maybe_wrap_dim(dims[i], static_cast<int64_t>(ndims));
29*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
30*da0073e9SAndroid Build Coastguard Worker           !seen[dim],
31*da0073e9SAndroid Build Coastguard Worker           "dim ",
32*da0073e9SAndroid Build Coastguard Worker           dim,
33*da0073e9SAndroid Build Coastguard Worker           " appears multiple times in the list of dims");
34*da0073e9SAndroid Build Coastguard Worker       seen[dim] = true;
35*da0073e9SAndroid Build Coastguard Worker     }
36*da0073e9SAndroid Build Coastguard Worker   } else {
37*da0073e9SAndroid Build Coastguard Worker     for (size_t dim = 0; dim < ndims; dim++) {
38*da0073e9SAndroid Build Coastguard Worker       seen[dim] = true;
39*da0073e9SAndroid Build Coastguard Worker     }
40*da0073e9SAndroid Build Coastguard Worker   }
41*da0073e9SAndroid Build Coastguard Worker   return seen;
42*da0073e9SAndroid Build Coastguard Worker }
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker } // namespace at
45