1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/IndexKernel.h>
3 #include <ATen/native/TensorAdvancedIndexing.h> // For at::native::index_out
4 #include <ATen/core/Tensor.h>
5 #include <ATen/core/List.h>
6 #include <ATen/ExpandUtils.h>
7 #include <ATen/MemoryOverlap.h>
8 #include <ATen/NamedTensorUtils.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #include <ATen/CUDAFunctions.h>
14 #else
15 #include <ATen/ops/index_cuda_dispatch.h>
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/masked_scatter_native.h>
18 #include <ATen/ops/masked_select_native.h>
19 #endif
20
21
22 namespace at::native {
23
masked_select_out_cuda_impl(Tensor & result,const Tensor & self,const Tensor & mask)24 static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
25 NoNamesGuard guard;
26
27 TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
28 "masked_select: expected BoolTensor for mask");
29 TORCH_CHECK(self.scalar_type() == result.scalar_type(),
30 "masked_select(): self and result must have the same scalar type");
31
32 auto mask_temp = (mask.dim() == 0)
33 ? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
34 : c10::MaybeOwned<Tensor>::borrowed(mask);
35 auto self_temp = (self.dim() == 0)
36 ? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
37 : c10::MaybeOwned<Tensor>::borrowed(self);
38
39 // Cannot reassign to mask_temp and self_temp here! if they are
40 // owning and expand_outplace returns a borrow, the returned borrow
41 // would dangle.
42 auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
43 at::cuda::index_out(
44 result, *std::get<1>(mask_self_expanded),
45 c10::List<std::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
46
47 return result;
48 }
49
masked_select_cuda(const Tensor & self,const Tensor & mask)50 Tensor masked_select_cuda(const Tensor & self, const Tensor & mask) {
51 namedinference::compute_broadcast_outnames(self, mask);
52 Tensor result = at::empty({0}, self.options());
53 return masked_select_out_cuda_impl(result, self, mask);
54 }
55
masked_select_out_cuda(const Tensor & self,const Tensor & mask,Tensor & result)56 Tensor & masked_select_out_cuda(const Tensor & self, const Tensor & mask, Tensor & result) {
57 namedinference::compute_broadcast_outnames(self, mask);
58 return masked_select_out_cuda_impl(result, self, mask);
59 }
60
masked_scatter__cuda(Tensor & self,const Tensor & mask,const Tensor & source)61 Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& source) {
62 at::assert_no_internal_overlap(self);
63 TORCH_CHECK(
64 self.scalar_type() == source.scalar_type(),
65 "masked_scatter_: expected self and source to have same dtypes but got ",
66 self.scalar_type(),
67 " and ",
68 source.scalar_type());
69 TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
70 "but got mask with dtype ", mask.dtype());
71
72 c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_scatter_");
73
74 if (self.numel() == 0) {
75 return self;
76 }
77
78 auto maskPrefixSum = at::empty(self.sizes(), mask.options().dtype(kLong));
79 launch_masked_scatter_kernel(self, *b_mask, maskPrefixSum, source);
80
81 return self;
82 }
83
84 } // namespace at::native
85