xref: /aosp_15_r20/external/pytorch/aten/src/ATen/WrapDimUtilsMulti.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/WrapDimUtils.h>
4 #include <c10/core/TensorImpl.h>
5 #include <c10/util/irange.h>
6 #include <bitset>
7 #include <sstream>
8 
9 namespace at {
10 
11 // This is in an extra file to work around strange interaction of
12 // bitset on Windows with operator overloading
13 
14 constexpr size_t dim_bitset_size = 64;
15 
dim_list_to_bitset(OptionalIntArrayRef opt_dims,size_t ndims)16 inline std::bitset<dim_bitset_size> dim_list_to_bitset(
17     OptionalIntArrayRef opt_dims,
18     size_t ndims) {
19   TORCH_CHECK(
20       ndims <= dim_bitset_size,
21       "only tensors with up to ",
22       dim_bitset_size,
23       " dims are supported");
24   std::bitset<dim_bitset_size> seen;
25   if (opt_dims.has_value()) {
26     auto dims = opt_dims.value();
27     for (const auto i : c10::irange(dims.size())) {
28       size_t dim = maybe_wrap_dim(dims[i], static_cast<int64_t>(ndims));
29       TORCH_CHECK(
30           !seen[dim],
31           "dim ",
32           dim,
33           " appears multiple times in the list of dims");
34       seen[dim] = true;
35     }
36   } else {
37     for (size_t dim = 0; dim < ndims; dim++) {
38       seen[dim] = true;
39     }
40   }
41   return seen;
42 }
43 
44 } // namespace at
45