xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/IndexingUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/ExpandUtils.h>
3 #include <ATen/native/CanUse32BitIndexMath.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/core/IListRef.h>
6 #include <c10/util/irange.h>
7 
8 namespace at::native {
9 
10 [[noreturn]]
invalid_mask(const Tensor & self,int64_t idx,const Tensor & mask,int64_t maskIdx)11 static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
12   TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
13   " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
14 }
15 
16 
expandTensors(const Tensor & self,IOptTensorListRef indices)17 static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
18   // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
19   std::vector<Tensor> result;
20   for (const auto& index_opt : indices) {
21     if (!index_opt.has_value()) {
22       result.emplace_back();
23     } else {
24       const auto& index = *index_opt;
25       if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
26         if (index.scalar_type() == kByte) {
27           TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
28           " please use a dtype torch.bool instead.");
29         }
30         // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
31         // corresponding dimensions in self
32         for (const auto j : c10::irange(index.dim())) {
33           int64_t srcIdx = static_cast<int64_t>(result.size() + j);
34           if (index.size(j) != self.size(srcIdx)) {
35             invalid_mask(self, srcIdx, index, j);
36           }
37         }
38         // Replace with nonzeros
39         auto nonzero = index.nonzero();
40         for (const auto j : c10::irange(index.dim())) {
41           result.emplace_back(nonzero.select(1, j));
42         }
43       } else {
44         result.emplace_back(index);
45       }
46     }
47   }
48   return result;
49 }
50 
51 static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
52   for (const auto& tensor : indices) {
53     if (tensor.has_value() && tensor->defined()) {
54       auto scalarType = tensor->scalar_type();
55       if (allow_int) {
56         if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
57             TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
58         }
59       } else {
60         if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
61             TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
62         }
63       }
64     }
65   }
66 }
67 
toListOfOptionalTensors(ArrayRef<Tensor> list)68 inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
69   torch::List<std::optional<Tensor>> result;
70   result.reserve(list.size());
71   for (const Tensor& a : list) {
72     result.push_back(a);
73   }
74   return result;
75 }
76 
toListOfOptionalTensors(ArrayRef<IValue> list)77 inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
78   torch::List<std::optional<Tensor>> result;
79   result.reserve(list.size());
80   for (const IValue& a : list) {
81     result.push_back(a.isTensor() ? std::optional<Tensor>(a.toTensor()) : std::optional<Tensor>());
82   }
83   return result;
84 }
85 
hasContiguousSubspace(TensorList tl)86 static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
87   // true if all the non-null tensors are adjacent
88   auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
89   auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
90   auto start = std::find_if(tl.begin(), tl.end(), isDefined);
91   auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
92   auto it = std::find_if(start, stop.base(), isNull);
93   return it == stop.base();
94 }
95 
96 
97 // Transposes the tensor and indices together so that all the non-null indices
98 // index the first k dimensions of the tensor. Returns the transposed tensor
99 // and the reordered indices. For example:
100 // transposeToFront(tensor, {nullptr, a, nullptr, b})
101 // returns
102 // tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
103 static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
transposeToFront(const Tensor & self,TensorList indices)104 transposeToFront(const Tensor& self, TensorList indices) {
105   std::vector<int64_t> dims;
106   std::vector<Tensor> transposedIndices;
107   dims.reserve(self.dim());
108   for (const auto i : c10::irange(self.dim())) {
109     if (indices[i].defined()) {
110       dims.push_back(i);
111       transposedIndices.emplace_back(indices[i]);
112     }
113   }
114   for (const auto i : c10::irange(self.dim())) {
115     if (!indices[i].defined()) {
116       dims.push_back(i);
117       transposedIndices.emplace_back();
118     }
119   }
120   return std::make_tuple(self.permute(dims), std::move(transposedIndices));
121 }
122 
123 inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
transposeToFrontAndInvPerm(const Tensor & self,TensorList indices)124 transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
125   std::vector<int64_t> dims;
126   std::vector<int64_t> invPerm;
127   std::vector<Tensor> transposedIndices;
128   dims.reserve(self.dim());
129   invPerm.resize(self.dim());
130   for (const auto i : c10::irange(self.dim())) {
131     if (indices[i].defined()) {
132       dims.push_back(i);
133       transposedIndices.emplace_back(indices[i]);
134     }
135   }
136   for (const auto i : c10::irange(self.dim())) {
137     if (!indices[i].defined()) {
138       dims.push_back(i);
139       transposedIndices.emplace_back();
140     }
141   }
142   for (const auto i : c10::irange(self.dim())) {
143     invPerm[dims[i]] = i;
144   }
145   return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
146 }
147 
148 struct AdvancedIndex {
149   AdvancedIndex(const Tensor& src, TensorList indices);
150 
151   Tensor src;
152   std::vector<Tensor> indices;
153   DimVector indexed_sizes;
154   DimVector indexed_strides;
155   int64_t dims_before;
156   int64_t dims_after;
157 };
158 
159 
160 } //namespace at::native
161