xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/IndexKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/IndexKernel.h>
3 #include <ATen/native/TensorAdvancedIndexing.h>  // For at::native::index_out
4 #include <ATen/core/Tensor.h>
5 #include <ATen/core/List.h>
6 #include <ATen/ExpandUtils.h>
7 #include <ATen/MemoryOverlap.h>
8 #include <ATen/NamedTensorUtils.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #include <ATen/CUDAFunctions.h>
14 #else
15 #include <ATen/ops/index_cuda_dispatch.h>
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/masked_scatter_native.h>
18 #include <ATen/ops/masked_select_native.h>
19 #endif
20 
21 
22 namespace at::native {
23 
masked_select_out_cuda_impl(Tensor & result,const Tensor & self,const Tensor & mask)24 static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
25   NoNamesGuard guard;
26 
27   TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
28               "masked_select: expected BoolTensor for mask");
29   TORCH_CHECK(self.scalar_type() == result.scalar_type(),
30               "masked_select(): self and result must have the same scalar type");
31 
32   auto mask_temp = (mask.dim() == 0)
33     ? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
34     : c10::MaybeOwned<Tensor>::borrowed(mask);
35   auto self_temp = (self.dim() == 0)
36     ? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
37     : c10::MaybeOwned<Tensor>::borrowed(self);
38 
39   // Cannot reassign to mask_temp and self_temp here! if they are
40   // owning and expand_outplace returns a borrow, the returned borrow
41   // would dangle.
42   auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
43   at::cuda::index_out(
44       result, *std::get<1>(mask_self_expanded),
45       c10::List<std::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
46 
47   return result;
48 }
49 
masked_select_cuda(const Tensor & self,const Tensor & mask)50 Tensor masked_select_cuda(const Tensor & self, const Tensor & mask) {
51   namedinference::compute_broadcast_outnames(self, mask);
52   Tensor result = at::empty({0}, self.options());
53   return masked_select_out_cuda_impl(result, self, mask);
54 }
55 
masked_select_out_cuda(const Tensor & self,const Tensor & mask,Tensor & result)56 Tensor & masked_select_out_cuda(const Tensor & self, const Tensor & mask, Tensor & result) {
57   namedinference::compute_broadcast_outnames(self, mask);
58   return masked_select_out_cuda_impl(result, self, mask);
59 }
60 
masked_scatter__cuda(Tensor & self,const Tensor & mask,const Tensor & source)61 Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& source) {
62   at::assert_no_internal_overlap(self);
63   TORCH_CHECK(
64       self.scalar_type() == source.scalar_type(),
65       "masked_scatter_: expected self and source to have same dtypes but got ",
66       self.scalar_type(),
67       " and ",
68       source.scalar_type());
69   TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
70      "but got mask with dtype ", mask.dtype());
71 
72   c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_scatter_");
73 
74   if (self.numel() == 0) {
75     return self;
76   }
77 
78   auto maskPrefixSum = at::empty(self.sizes(), mask.options().dtype(kLong));
79   launch_masked_scatter_kernel(self, *b_mask, maskPrefixSum, source);
80 
81   return self;
82 }
83 
84 }  // namespace at::native
85