xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/MaskedFill.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 #include <torch/library.h>
4 #include <vector>
5 
6 namespace at {
7 namespace native {
8 namespace vulkan {
9 namespace ops {
10 namespace {
11 
12 using namespace api::utils;
13 
masked_fill_scalar(const Tensor & self_arg,const Tensor & mask_arg,const Scalar & value)14 Tensor masked_fill_scalar(
15     const Tensor& self_arg,
16     const Tensor& mask_arg,
17     const Scalar& value) {
18   utils::is_broadcastable(self_arg, mask_arg);
19 
20   api::Context* const context = api::context();
21 
22   const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
23 
24   const Tensor mask = mask_arg.is_vulkan() ? mask_arg : mask_arg.vulkan();
25   const vTensor& v_mask = convert(mask);
26 
27   // compute the output shape by broadcasting the shapes of self and mask
28   auto in_ndims = safe_downcast<uint32_t>(self_arg.dim());
29   auto in_sizes = self_arg.sizes();
30   auto mask_sizes = mask_arg.sizes();
31   std::vector<int64_t> out_sizes = utils::broadcast_size(self_arg, mask_arg);
32   TORCH_INTERNAL_ASSERT(!out_sizes.empty(), "output shape is empty!");
33 
34   // generalize the shape of output and mask to 4D
35   uvec4 generalized_out_sizes{1u, 1u, 1u, 1u},
36       generalized_mask_sizes{1u, 1u, 1u, 1u};
37   int add_out_ndims = static_cast<int>(4 - out_sizes.size());
38   for (int i = 0; (unsigned)i < out_sizes.size(); i++) {
39     generalized_out_sizes.data[i + add_out_ndims] = out_sizes[i];
40   }
41   int add_mask_ndims = static_cast<int>(4 - mask_sizes.size());
42   for (int i = 0; (unsigned)i < mask_sizes.size(); i++) {
43     generalized_mask_sizes.data[i + add_mask_ndims] = mask_sizes[i];
44   }
45 
46   auto out_ndims = safe_downcast<uint32_t>(out_sizes.size());
47 
48   // channels of mask and output after padding to nearest multiple of 4
49   uint32_t mask_c_aligned =
50       api::utils::align_up(generalized_mask_sizes.data[1u], 4u);
51   uint32_t out_c_aligned =
52       api::utils::align_up(generalized_out_sizes.data[1u], 4u);
53 
54   // compute the repeats needed to output a tensor of out_sizes by doing
55   // repeat operation on self
56   auto add_ndims = out_ndims - in_ndims;
57   std::vector<int64_t> repeats;
58   for (int i = 0; (unsigned)i < out_ndims; i++) {
59     if ((unsigned)i < add_ndims || in_sizes[i - add_ndims] == 1) {
60       repeats.push_back(out_sizes[i]);
61     } else {
62       repeats.push_back(1);
63     }
64   }
65 
66   // generate the output of out_sizes by doing repeat operation on self
67   at::Tensor out = self.repeat(repeats);
68   vTensor& v_out = convert(out);
69 
70   const struct Block final {
71     ivec3 outExtents;
72     int32_t fill0;
73     ivec3 maskExtents;
74     int32_t fill1;
75     uvec4 outTensorSize;
76     uvec4 maskTensorSize;
77     uvec2 alignedChannelInfo;
78     float value;
79   } block{
80       api::utils::make_ivec3(v_out.extents()),
81       0,
82       api::utils::make_ivec3(v_mask.extents()),
83       0,
84       generalized_out_sizes,
85       generalized_mask_sizes,
86       {out_c_aligned, mask_c_aligned},
87       value.to<float>(),
88   };
89 
90   api::UniformParamsBuffer params(context, block);
91   api::PipelineBarrier pipeline_barrier{};
92 
93   // One possible implementation of masked_fill is to do repeat operation on
94   // mask and generate a broadcasted mask of the same shape as the output, and
95   // then fill elements of the output with value where mask is True. However the
96   // repeat operation on mask would cause extra time and space overhead.
97   // Instead, in the shader file we traverse through the original mask and
98   // compute the corresponding broadcasted positions in the output tensor when a
99   // mask value is True.
100   context->submit_compute_job(
101       // shader descriptor
102       VK_KERNEL(masked_fill),
103       // pipeline barrier
104       pipeline_barrier,
105       // global work group size
106       v_mask.extents(),
107       // local work group size
108       adaptive_work_group_size(v_mask.extents()),
109       // fence handle
110       VK_NULL_HANDLE,
111       // shader arguments
112       v_out.image(
113           pipeline_barrier,
114           api::PipelineStage::COMPUTE,
115           api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
116       v_mask.image(pipeline_barrier, api::PipelineStage::COMPUTE),
117       // params buffer
118       params.buffer());
119 
120   return convert(v_out);
121 }
122 
masked_fill_tensor(const Tensor & self_arg,const Tensor & mask_arg,const Tensor & value)123 Tensor masked_fill_tensor(
124     const Tensor& self_arg,
125     const Tensor& mask_arg,
126     const Tensor& value) {
127   TORCH_CHECK(
128       value.dim() == 0,
129       "masked_fill only supports a 0-dimensional value tensor, but got tensor with ",
130       value.dim(),
131       " dimension(s).");
132   return masked_fill_scalar(self_arg, mask_arg, value.item<float>());
133 }
134 
135 #ifdef USE_VULKAN_API
136 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)137 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
138   m.impl(
139       TORCH_SELECTIVE_NAME("aten::masked_fill.Scalar"),
140       TORCH_FN(masked_fill_scalar));
141   m.impl(
142       TORCH_SELECTIVE_NAME("aten::masked_fill.Tensor"),
143       TORCH_FN(masked_fill_tensor));
144 }
145 
146 #endif /* USE_VULKAN_API */
147 
148 } // namespace
149 } // namespace ops
150 } // namespace vulkan
151 } // namespace native
152 } // namespace at
153