xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorAdvancedIndexingUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/IndexingUtils.h>
4 #include <ATen/native/TensorIterator.h>
5 
6 namespace at::native {
7 namespace {
8 #ifndef STRIP_ERROR_MESSAGES
shapes_as_str(TensorList tensors)9 inline std::string shapes_as_str(TensorList tensors) {
10   std::ostringstream os;
11   bool first = true;
12   for (auto& tensor : tensors) {
13     if (tensor.defined()) {
14       if (!first) {
15         os << ", ";
16       }
17       os << tensor.sizes();
18       first = false;
19     }
20   }
21   return os.str();
22 }
23 #endif
24 } // anonymous namespace
25 
canDispatchToMaskedFill(const Tensor & self,const torch::List<std::optional<at::Tensor>> & indices,const Tensor & value)26 inline std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<std::optional<at::Tensor>>& indices,
27 const Tensor& value){
28   if (!(value.numel() ==1 && value.device().is_cpu())){
29     return std::make_tuple(false,Tensor());
30   }
31   int64_t num_ind = 0;
32   Tensor mask;
33   auto self_device = self.device();
34   for (const std::optional<Tensor>& i: indices) {
35     if (!i.has_value() || !(*i).defined()){
36       num_ind++;
37     } else {
38       const Tensor &index = *i;
39       if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
40           index.device() != self_device || mask.defined()){
41         return std::make_tuple(false, Tensor());
42       } else {
43         mask = index;
44         for (const auto j : c10::irange(index.dim())) {
45           int64_t srcIdx = num_ind + j;
46           TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
47   " does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
48         }
49         num_ind += mask.ndimension();
50       }
51     }
52   }
53   for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
54     mask = mask.unsqueeze(-1);
55   }
56   return std::make_tuple(true, mask);
57 }
58 
make_info(Tensor self,IOptTensorListRef orig)59 inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
60   checkIndexTensorTypes(orig, /*allow_int*/ true);
61   // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
62   auto indices = expandTensors(self, orig);
63   // next broadcast all index tensors together
64   try {
65     indices = expand_outplace(indices);
66   } catch (std::exception& e) {
67     TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
68                    " with shapes ", shapes_as_str(indices));
69   }
70   // add missing null Tensors so that it matches self.dim()
71   while (indices.size() < (size_t)self.dim()) {
72     indices.emplace_back();
73   }
74   // if the non-null indices are not all adjacent, transpose self and indices
75   // together so that they're adjacent at the front
76   if (!hasContiguousSubspace(indices)) {
77     std::tie(self, indices) = transposeToFront(self, indices);
78   }
79   // Ensure indices are on the same device as self
80   for (auto & indice : indices) {
81     if (indice.defined() && indice.device() != self.device()) {
82       indice = indice.to(self.device());
83     }
84   }
85   for (auto & indice : indices) {
86     if (indice.defined() && indice.dtype() == at::kInt) {
87       indice = indice.to(at::kLong);
88     }
89   }
90 
91   return AdvancedIndex(self, indices);
92 }
93 
94 } // namespace at::native
95