1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 #include <torch/library.h>
4 #include <vector>
5
6 namespace at {
7 namespace native {
8 namespace vulkan {
9 namespace ops {
10 namespace {
11
12 using namespace api::utils;
13
masked_fill_scalar(const Tensor & self_arg,const Tensor & mask_arg,const Scalar & value)14 Tensor masked_fill_scalar(
15 const Tensor& self_arg,
16 const Tensor& mask_arg,
17 const Scalar& value) {
18 utils::is_broadcastable(self_arg, mask_arg);
19
20 api::Context* const context = api::context();
21
22 const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
23
24 const Tensor mask = mask_arg.is_vulkan() ? mask_arg : mask_arg.vulkan();
25 const vTensor& v_mask = convert(mask);
26
27 // compute the output shape by broadcasting the shapes of self and mask
28 auto in_ndims = safe_downcast<uint32_t>(self_arg.dim());
29 auto in_sizes = self_arg.sizes();
30 auto mask_sizes = mask_arg.sizes();
31 std::vector<int64_t> out_sizes = utils::broadcast_size(self_arg, mask_arg);
32 TORCH_INTERNAL_ASSERT(!out_sizes.empty(), "output shape is empty!");
33
34 // generalize the shape of output and mask to 4D
35 uvec4 generalized_out_sizes{1u, 1u, 1u, 1u},
36 generalized_mask_sizes{1u, 1u, 1u, 1u};
37 int add_out_ndims = static_cast<int>(4 - out_sizes.size());
38 for (int i = 0; (unsigned)i < out_sizes.size(); i++) {
39 generalized_out_sizes.data[i + add_out_ndims] = out_sizes[i];
40 }
41 int add_mask_ndims = static_cast<int>(4 - mask_sizes.size());
42 for (int i = 0; (unsigned)i < mask_sizes.size(); i++) {
43 generalized_mask_sizes.data[i + add_mask_ndims] = mask_sizes[i];
44 }
45
46 auto out_ndims = safe_downcast<uint32_t>(out_sizes.size());
47
48 // channels of mask and output after padding to nearest multiple of 4
49 uint32_t mask_c_aligned =
50 api::utils::align_up(generalized_mask_sizes.data[1u], 4u);
51 uint32_t out_c_aligned =
52 api::utils::align_up(generalized_out_sizes.data[1u], 4u);
53
54 // compute the repeats needed to output a tensor of out_sizes by doing
55 // repeat operation on self
56 auto add_ndims = out_ndims - in_ndims;
57 std::vector<int64_t> repeats;
58 for (int i = 0; (unsigned)i < out_ndims; i++) {
59 if ((unsigned)i < add_ndims || in_sizes[i - add_ndims] == 1) {
60 repeats.push_back(out_sizes[i]);
61 } else {
62 repeats.push_back(1);
63 }
64 }
65
66 // generate the output of out_sizes by doing repeat operation on self
67 at::Tensor out = self.repeat(repeats);
68 vTensor& v_out = convert(out);
69
70 const struct Block final {
71 ivec3 outExtents;
72 int32_t fill0;
73 ivec3 maskExtents;
74 int32_t fill1;
75 uvec4 outTensorSize;
76 uvec4 maskTensorSize;
77 uvec2 alignedChannelInfo;
78 float value;
79 } block{
80 api::utils::make_ivec3(v_out.extents()),
81 0,
82 api::utils::make_ivec3(v_mask.extents()),
83 0,
84 generalized_out_sizes,
85 generalized_mask_sizes,
86 {out_c_aligned, mask_c_aligned},
87 value.to<float>(),
88 };
89
90 api::UniformParamsBuffer params(context, block);
91 api::PipelineBarrier pipeline_barrier{};
92
93 // One possible implementation of masked_fill is to do repeat operation on
94 // mask and generate a broadcasted mask of the same shape as the output, and
95 // then fill elements of the output with value where mask is True. However the
96 // repeat operation on mask would cause extra time and space overhead.
97 // Instead, in the shader file we traverse through the original mask and
98 // compute the corresponding broadcasted positions in the output tensor when a
99 // mask value is True.
100 context->submit_compute_job(
101 // shader descriptor
102 VK_KERNEL(masked_fill),
103 // pipeline barrier
104 pipeline_barrier,
105 // global work group size
106 v_mask.extents(),
107 // local work group size
108 adaptive_work_group_size(v_mask.extents()),
109 // fence handle
110 VK_NULL_HANDLE,
111 // shader arguments
112 v_out.image(
113 pipeline_barrier,
114 api::PipelineStage::COMPUTE,
115 api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
116 v_mask.image(pipeline_barrier, api::PipelineStage::COMPUTE),
117 // params buffer
118 params.buffer());
119
120 return convert(v_out);
121 }
122
masked_fill_tensor(const Tensor & self_arg,const Tensor & mask_arg,const Tensor & value)123 Tensor masked_fill_tensor(
124 const Tensor& self_arg,
125 const Tensor& mask_arg,
126 const Tensor& value) {
127 TORCH_CHECK(
128 value.dim() == 0,
129 "masked_fill only supports a 0-dimensional value tensor, but got tensor with ",
130 value.dim(),
131 " dimension(s).");
132 return masked_fill_scalar(self_arg, mask_arg, value.item<float>());
133 }
134
135 #ifdef USE_VULKAN_API
136
TORCH_LIBRARY_IMPL(aten,Vulkan,m)137 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
138 m.impl(
139 TORCH_SELECTIVE_NAME("aten::masked_fill.Scalar"),
140 TORCH_FN(masked_fill_scalar));
141 m.impl(
142 TORCH_SELECTIVE_NAME("aten::masked_fill.Tensor"),
143 TORCH_FN(masked_fill_tensor));
144 }
145
146 #endif /* USE_VULKAN_API */
147
148 } // namespace
149 } // namespace ops
150 } // namespace vulkan
151 } // namespace native
152 } // namespace at
153