xref: /aosp_15_r20/external/pytorch/aten/src/ATen/WrapDimUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/IListRef.h>
4 #include <ATen/core/Tensor.h>
5 #include <c10/core/TensorImpl.h>
6 #include <c10/core/WrapDimMinimal.h>
7 #include <c10/util/irange.h>
8 
9 namespace at {
10 
11 // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
12 // range [-1, 0]. This is a special case for scalar tensors and manifests in
13 // e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
14 // [-dim_post_expr, dim_post_expr-1].
15 using c10::maybe_wrap_dim;
16 
maybe_wrap_dim(int64_t dim,TensorImpl * tensor)17 inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
18   return maybe_wrap_dim(dim, tensor->dim());
19 }
20 
maybe_wrap_dim(int64_t dim,TensorList tensors)21 inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
22   if (tensors.empty()) {
23     // can't wrap empty TensorList; rely on underlying implementation to throw
24     // error if necessary.
25     return dim;
26   }
27   return maybe_wrap_dim(dim, tensors[0].dim());
28 }
29 
maybe_wrap_dim(int64_t dim,const std::vector<std::vector<int64_t>> & tensor_sizes)30 inline int64_t maybe_wrap_dim(
31     int64_t dim,
32     const std::vector<std::vector<int64_t>>& tensor_sizes) {
33   if (tensor_sizes.empty()) {
34     // can't wrap empty list; rely on underlying implementation to throw error
35     // if necessary
36     return dim;
37   }
38   return maybe_wrap_dim(dim, tensor_sizes[0].size());
39 }
40 
41 // Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
42 // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
43 // specified using negative indices.
44 //
45 // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
46 // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
47 // dimensions not in the range [-dim_post_expr, dim_post_expr).
48 inline void maybe_wrap_dims_n(
49     int64_t* dims,
50     int64_t ndims,
51     int64_t dim_post_expr,
52     bool wrap_scalars = true) {
53   if (dim_post_expr <= 0) {
54     if (wrap_scalars) {
55       dim_post_expr = 1; // this will make range [-1, 0]
56     } else {
57       TORCH_CHECK_INDEX(
58           ndims == 0,
59           "Dimension specified as ",
60           dims[0],
61           " but tensor has no dimensions");
62       return;
63     }
64   }
65   int64_t min = -dim_post_expr;
66   int64_t max = dim_post_expr - 1;
67   for (const auto i : c10::irange(ndims)) {
68     auto& dim = dims[i];
69     if (dim < min || dim > max) {
70       TORCH_CHECK_INDEX(
71           false,
72           "Dimension out of range (expected to be in range of [",
73           min,
74           ", ",
75           max,
76           "], but got ",
77           dim,
78           ")");
79     }
80     if (dim < 0)
81       dim += dim_post_expr;
82   }
83 }
84 
85 // Given a contiguous container of dimensions `dims`, this function "Wraps"
86 // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
87 // specified using negative indices.
88 //
89 // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
90 // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
91 // dimensions not in the range [-dim_post_expr, dim_post_expr).
92 template <typename Container>
93 inline void maybe_wrap_dims(
94     Container& dims,
95     int64_t dim_post_expr,
96     bool wrap_scalars = true) {
97   return maybe_wrap_dims_n(
98       dims.data(), dims.size(), dim_post_expr, wrap_scalars);
99 }
100 
101 // previously, size [0] tensors were the only possible empty tensors; thus, it
102 // wasn't possible to cat empty tensors unless all the other tensors were
103 // 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
104 // dimension behavior and dimension size checking). We maintain this behavior
105 // for backwards compatibility, but only for this specific size (i.e. other
106 // empty sizes are not skipped).
legacy_cat_wrap_dim(int64_t dim,const std::vector<std::vector<int64_t>> & tensor_sizes)107 inline int64_t legacy_cat_wrap_dim(
108     int64_t dim,
109     const std::vector<std::vector<int64_t>>& tensor_sizes) {
110   for (auto& sizes : tensor_sizes) {
111     if (sizes.size() == 1 && sizes[0] == 0) {
112       continue;
113     }
114     return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
115   }
116   return dim;
117 }
118 
legacy_cat_wrap_dim_symint(int64_t dim,const std::vector<std::vector<c10::SymInt>> & tensor_sizes)119 inline int64_t legacy_cat_wrap_dim_symint(
120     int64_t dim,
121     const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
122   for (auto& sizes : tensor_sizes) {
123     if (sizes.size() == 1) {
124       if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) {
125         continue;
126       }
127     }
128     return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
129   }
130   return dim;
131 }
132 
legacy_cat_wrap_dim(int64_t dim,const MaterializedITensorListRef & tensors)133 inline int64_t legacy_cat_wrap_dim(
134     int64_t dim,
135     const MaterializedITensorListRef& tensors) {
136   for (const Tensor& tensor : tensors) {
137     if (tensor.dim() == 1) {
138       if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) {
139         continue;
140       }
141     }
142     return maybe_wrap_dim(dim, tensor.dim());
143   }
144   return dim;
145 }
146 
147 // wrap negative dims in a vector
wrap_all_dims(std::vector<int64_t> & dims_to_wrap,int64_t tensor_total_dims)148 inline void wrap_all_dims(
149     std::vector<int64_t>& dims_to_wrap,
150     int64_t tensor_total_dims) {
151   for (const auto i : c10::irange(dims_to_wrap.size())) {
152     dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
153   }
154 }
155 
156 } // namespace at
157