1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/IndexingUtils.h>
4 #include <ATen/native/TensorIterator.h>
5
6 namespace at::native {
7 namespace {
8 #ifndef STRIP_ERROR_MESSAGES
shapes_as_str(TensorList tensors)9 inline std::string shapes_as_str(TensorList tensors) {
10 std::ostringstream os;
11 bool first = true;
12 for (auto& tensor : tensors) {
13 if (tensor.defined()) {
14 if (!first) {
15 os << ", ";
16 }
17 os << tensor.sizes();
18 first = false;
19 }
20 }
21 return os.str();
22 }
23 #endif
24 } // anonymous namespace
25
canDispatchToMaskedFill(const Tensor & self,const torch::List<std::optional<at::Tensor>> & indices,const Tensor & value)26 inline std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<std::optional<at::Tensor>>& indices,
27 const Tensor& value){
28 if (!(value.numel() ==1 && value.device().is_cpu())){
29 return std::make_tuple(false,Tensor());
30 }
31 int64_t num_ind = 0;
32 Tensor mask;
33 auto self_device = self.device();
34 for (const std::optional<Tensor>& i: indices) {
35 if (!i.has_value() || !(*i).defined()){
36 num_ind++;
37 } else {
38 const Tensor &index = *i;
39 if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
40 index.device() != self_device || mask.defined()){
41 return std::make_tuple(false, Tensor());
42 } else {
43 mask = index;
44 for (const auto j : c10::irange(index.dim())) {
45 int64_t srcIdx = num_ind + j;
46 TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
47 " does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
48 }
49 num_ind += mask.ndimension();
50 }
51 }
52 }
53 for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
54 mask = mask.unsqueeze(-1);
55 }
56 return std::make_tuple(true, mask);
57 }
58
make_info(Tensor self,IOptTensorListRef orig)59 inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
60 checkIndexTensorTypes(orig, /*allow_int*/ true);
61 // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
62 auto indices = expandTensors(self, orig);
63 // next broadcast all index tensors together
64 try {
65 indices = expand_outplace(indices);
66 } catch (std::exception& e) {
67 TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
68 " with shapes ", shapes_as_str(indices));
69 }
70 // add missing null Tensors so that it matches self.dim()
71 while (indices.size() < (size_t)self.dim()) {
72 indices.emplace_back();
73 }
74 // if the non-null indices are not all adjacent, transpose self and indices
75 // together so that they're adjacent at the front
76 if (!hasContiguousSubspace(indices)) {
77 std::tie(self, indices) = transposeToFront(self, indices);
78 }
79 // Ensure indices are on the same device as self
80 for (auto & indice : indices) {
81 if (indice.defined() && indice.device() != self.device()) {
82 indice = indice.to(self.device());
83 }
84 }
85 for (auto & indice : indices) {
86 if (indice.defined() && indice.dtype() == at::kInt) {
87 indice = indice.to(at::kLong);
88 }
89 }
90
91 return AdvancedIndex(self, indices);
92 }
93
94 } // namespace at::native
95