1 #pragma once
2 #include <ATen/ExpandUtils.h>
3 #include <ATen/native/CanUse32BitIndexMath.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/core/IListRef.h>
6 #include <c10/util/irange.h>
7
8 namespace at::native {
9
10 [[noreturn]]
invalid_mask(const Tensor & self,int64_t idx,const Tensor & mask,int64_t maskIdx)11 static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
12 TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
13 " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
14 }
15
16
expandTensors(const Tensor & self,IOptTensorListRef indices)17 static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
18 // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
19 std::vector<Tensor> result;
20 for (const auto& index_opt : indices) {
21 if (!index_opt.has_value()) {
22 result.emplace_back();
23 } else {
24 const auto& index = *index_opt;
25 if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
26 if (index.scalar_type() == kByte) {
27 TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
28 " please use a dtype torch.bool instead.");
29 }
30 // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
31 // corresponding dimensions in self
32 for (const auto j : c10::irange(index.dim())) {
33 int64_t srcIdx = static_cast<int64_t>(result.size() + j);
34 if (index.size(j) != self.size(srcIdx)) {
35 invalid_mask(self, srcIdx, index, j);
36 }
37 }
38 // Replace with nonzeros
39 auto nonzero = index.nonzero();
40 for (const auto j : c10::irange(index.dim())) {
41 result.emplace_back(nonzero.select(1, j));
42 }
43 } else {
44 result.emplace_back(index);
45 }
46 }
47 }
48 return result;
49 }
50
51 static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
52 for (const auto& tensor : indices) {
53 if (tensor.has_value() && tensor->defined()) {
54 auto scalarType = tensor->scalar_type();
55 if (allow_int) {
56 if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
57 TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
58 }
59 } else {
60 if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
61 TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
62 }
63 }
64 }
65 }
66 }
67
toListOfOptionalTensors(ArrayRef<Tensor> list)68 inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
69 torch::List<std::optional<Tensor>> result;
70 result.reserve(list.size());
71 for (const Tensor& a : list) {
72 result.push_back(a);
73 }
74 return result;
75 }
76
toListOfOptionalTensors(ArrayRef<IValue> list)77 inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
78 torch::List<std::optional<Tensor>> result;
79 result.reserve(list.size());
80 for (const IValue& a : list) {
81 result.push_back(a.isTensor() ? std::optional<Tensor>(a.toTensor()) : std::optional<Tensor>());
82 }
83 return result;
84 }
85
hasContiguousSubspace(TensorList tl)86 static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
87 // true if all the non-null tensors are adjacent
88 auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
89 auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
90 auto start = std::find_if(tl.begin(), tl.end(), isDefined);
91 auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
92 auto it = std::find_if(start, stop.base(), isNull);
93 return it == stop.base();
94 }
95
96
97 // Transposes the tensor and indices together so that all the non-null indices
98 // index the first k dimensions of the tensor. Returns the transposed tensor
99 // and the reordered indices. For example:
100 // transposeToFront(tensor, {nullptr, a, nullptr, b})
101 // returns
102 // tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
103 static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
transposeToFront(const Tensor & self,TensorList indices)104 transposeToFront(const Tensor& self, TensorList indices) {
105 std::vector<int64_t> dims;
106 std::vector<Tensor> transposedIndices;
107 dims.reserve(self.dim());
108 for (const auto i : c10::irange(self.dim())) {
109 if (indices[i].defined()) {
110 dims.push_back(i);
111 transposedIndices.emplace_back(indices[i]);
112 }
113 }
114 for (const auto i : c10::irange(self.dim())) {
115 if (!indices[i].defined()) {
116 dims.push_back(i);
117 transposedIndices.emplace_back();
118 }
119 }
120 return std::make_tuple(self.permute(dims), std::move(transposedIndices));
121 }
122
123 inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
transposeToFrontAndInvPerm(const Tensor & self,TensorList indices)124 transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
125 std::vector<int64_t> dims;
126 std::vector<int64_t> invPerm;
127 std::vector<Tensor> transposedIndices;
128 dims.reserve(self.dim());
129 invPerm.resize(self.dim());
130 for (const auto i : c10::irange(self.dim())) {
131 if (indices[i].defined()) {
132 dims.push_back(i);
133 transposedIndices.emplace_back(indices[i]);
134 }
135 }
136 for (const auto i : c10::irange(self.dim())) {
137 if (!indices[i].defined()) {
138 dims.push_back(i);
139 transposedIndices.emplace_back();
140 }
141 }
142 for (const auto i : c10::irange(self.dim())) {
143 invPerm[dims[i]] = i;
144 }
145 return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
146 }
147
148 struct AdvancedIndex {
149 AdvancedIndex(const Tensor& src, TensorList indices);
150
151 Tensor src;
152 std::vector<Tensor> indices;
153 DimVector indexed_sizes;
154 DimVector indexed_strides;
155 int64_t dims_before;
156 int64_t dims_after;
157 };
158
159
160 } //namespace at::native
161