xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <ATen/native/vulkan/ops/Batchnorm.h>
3 #include <torch/library.h>
4 
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9 
10 namespace batchnorm {
11 
12 struct Params final {
13   api::utils::ivec3 out_extents;
14   int32_t c4;
15   float eps;
16 };
17 
record_op(api::Context * const context,vTensor & v_output,const vTensor & v_input,const vTensor & v_weight,const vTensor & v_bias,const vTensor & v_running_mean,const vTensor & v_running_var,const float eps)18 static void record_op(
19     api::Context* const context,
20     vTensor& v_output,
21     const vTensor& v_input,
22     const vTensor& v_weight,
23     const vTensor& v_bias,
24     const vTensor& v_running_mean,
25     const vTensor& v_running_var,
26     const float eps) {
27   api::PipelineBarrier pipeline_barrier{};
28 
29   api::utils::uvec3 global_size = v_output.extents();
30   api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
31 
32   uint32_t num_features = get_dim<Dim4D::Channel>(v_input.sizes());
33   uint32_t channels_ext = api::utils::div_up(num_features, 4u);
34 
35   Params block{
36       api::utils::make_ivec3(v_output.extents()),
37       api::utils::safe_downcast<int32_t>(channels_ext),
38       eps,
39   };
40 
41   api::UniformParamsBuffer params(context, block);
42 
43   context->submit_compute_job(
44       // shader descriptor
45       VK_KERNEL(batchnorm),
46       // pipeline barrier
47       pipeline_barrier,
48       // global work group size
49       global_size,
50       // local work group size
51       local_size,
52       // fence handle
53       VK_NULL_HANDLE,
54       // shader arguments
55       v_output.image(
56           pipeline_barrier,
57           api::PipelineStage::COMPUTE,
58           api::MemoryAccessType::WRITE),
59       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
60       v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
61       v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
62       v_running_mean.image(pipeline_barrier, api::PipelineStage::COMPUTE),
63       v_running_var.image(pipeline_barrier, api::PipelineStage::COMPUTE),
64       // params buffer
65       params.buffer());
66 }
67 
68 } // namespace batchnorm
69 
70 namespace {
71 
72 using namespace api::utils;
73 
batch_norm(const at::Tensor & input_arg,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double,double eps,bool)74 Tensor batch_norm(
75     const at::Tensor& input_arg,
76     const std::optional<Tensor>& weight_opt /* optional */,
77     const std::optional<Tensor>& bias_opt /* optional */,
78     const std::optional<Tensor>& running_mean_opt /* optional */,
79     const std::optional<Tensor>& running_var_opt /* optional */,
80     bool training,
81     double /* momentum, not used in eval mode */,
82     double eps,
83     bool /* cudnn_enable, deprecated */) {
84   TORCH_CHECK(!training, "Only evaluation mode is supported!");
85   TORCH_CHECK(input_arg.dim() == 4, "Input must have dim == 4!");
86   TORCH_CHECK(
87       get_dim<Dim4D::Channel>(input_arg) % 4 == 0,
88       "Input must have channels divisible by 4!");
89 
90   return run_batchnorm_context(
91       input_arg,
92       c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
93           weight_opt, bias_opt, running_mean_opt, running_var_opt, eps)));
94 }
95 
96 #ifdef USE_VULKAN_API
97 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)98 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
99   m.impl(TORCH_SELECTIVE_NAME("aten::batch_norm"), TORCH_FN(batch_norm));
100 }
101 
102 #endif /* USE_VULKAN_API */
103 
104 } // namespace
105 
BatchNormPackedContext(const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double eps)106 BatchNormPackedContext::BatchNormPackedContext(
107     const std::optional<Tensor>& weight_opt,
108     const std::optional<Tensor>& bias_opt,
109     const std::optional<Tensor>& running_mean_opt,
110     const std::optional<Tensor>& running_var_opt,
111     double eps)
112     : unpacked_{c10::AnyType::get()} {
113   packed_.reserve(ListArgs::kNumArgs);
114 
115   // Each optional tensor arg, if provided should be a 1 dimensional tensor. To
116   // achieve more efficient packing as a texture, they are first reshaped to {N,
117   // 1, 1}. Eventually this rearrangement should happen automatically in vTensor
118   // itself.
119 
120   // Weight
121   TORCH_CHECK(weight_opt, "Weight must be provided!");
122   TORCH_CHECK(weight_opt->dim() == 1, "Weight must have ndim == 1!");
123 
124   const int64_t num_features =
125       api::utils::safe_downcast<int64_t>(weight_opt->numel());
126   const Tensor weight_3d = weight_opt->reshape({num_features, 1, 1});
127   packed_.emplace_back(weight_3d.vulkan());
128 
129   // Bias
130   TORCH_CHECK(bias_opt, "Bias must be provided!");
131   TORCH_CHECK(bias_opt->dim() == 1, "Bias must have ndim == 1!");
132   TORCH_CHECK(
133       bias_opt->numel() == num_features,
134       "Bias must have the same numel as weight!");
135 
136   const Tensor bias_3d = bias_opt->reshape({num_features, 1, 1});
137   packed_.emplace_back(bias_3d.vulkan());
138 
139   // Running Mean
140   TORCH_CHECK(running_mean_opt, "Running mean must be provided!");
141   TORCH_CHECK(running_mean_opt->dim() == 1, "Running mean must have ndim == 1");
142   TORCH_CHECK(
143       running_mean_opt->numel() == num_features,
144       "Running mean must have the same numel as weight!");
145 
146   const Tensor running_mean_3d =
147       running_mean_opt->reshape({num_features, 1, 1});
148   packed_.emplace_back(running_mean_3d.vulkan());
149 
150   // Running var
151   TORCH_CHECK(running_var_opt, "Running var must be provided!");
152   TORCH_CHECK(running_var_opt->dim() == 1, "Running var must have ndim == 1");
153   TORCH_CHECK(
154       running_var_opt->numel() == num_features,
155       "Running var must have the same numel as weight!");
156 
157   const Tensor running_var_3d = running_var_opt->reshape({num_features, 1, 1});
158   packed_.emplace_back(running_var_3d.vulkan());
159 
160   // Epsilon
161   packed_.emplace_back(eps);
162 
163   if (!at::globalContext().releaseWeightsWhenPrepacking()) {
164     unpacked_.reserve(ListArgs::kNumArgs);
165     unpacked_.emplace_back(weight_opt);
166     unpacked_.emplace_back(bias_opt);
167     unpacked_.emplace_back(running_mean_opt);
168     unpacked_.emplace_back(running_var_opt);
169     unpacked_.emplace_back(eps);
170   }
171 }
172 
pack(c10::impl::GenericList unpacked)173 BatchNormPackedContext BatchNormPackedContext::pack(
174     c10::impl::GenericList unpacked) {
175   return BatchNormPackedContext(
176       get_optional_tensor(unpacked, ListArgs::kWeight),
177       get_optional_tensor(unpacked, ListArgs::kBias),
178       get_optional_tensor(unpacked, ListArgs::kRunningMean),
179       get_optional_tensor(unpacked, ListArgs::kRunningVar),
180       unpacked.get(ListArgs::kEps).toDouble());
181 }
182 
create_batchnorm_context(std::optional<Tensor> && weight_opt,std::optional<Tensor> && bias_opt,std::optional<Tensor> && running_mean_opt,std::optional<Tensor> && running_var_opt,bool training,double,double eps,bool)183 c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context(
184     std::optional<Tensor>&& weight_opt,
185     std::optional<Tensor>&& bias_opt,
186     std::optional<Tensor>&& running_mean_opt,
187     std::optional<Tensor>&& running_var_opt,
188     bool training,
189     double /* momentum */,
190     double eps,
191     bool /* cudnn_enable, deprecated */) {
192   return c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
193       weight_opt, bias_opt, running_mean_opt, running_var_opt, eps));
194 }
195 
run_batchnorm_context(const Tensor & input_arg,const c10::intrusive_ptr<BatchNormPackedContext> & batchnorm_context)196 Tensor run_batchnorm_context(
197     const Tensor& input_arg,
198     const c10::intrusive_ptr<BatchNormPackedContext>& batchnorm_context) {
199   api::Context* const context = api::context();
200 
201   const vTensor& v_input = convert(input_arg);
202 
203   const vTensor& v_weight = convert(
204       batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kWeight)
205           .toTensor());
206 
207   const vTensor& v_bias = convert(
208       batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kBias)
209           .toTensor());
210 
211   const vTensor& v_running_mean = convert(
212       batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningMean)
213           .toTensor());
214 
215   const vTensor& v_running_var = convert(
216       batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningVar)
217           .toTensor());
218 
219   const float eps = api::utils::safe_downcast<float>(
220       batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kEps)
221           .toDouble());
222 
223   vTensor v_output{
224       context,
225       v_input.sizes(),
226       v_input.dtype(),
227   };
228 
229   batchnorm::record_op(
230       context,
231       v_output,
232       v_input,
233       v_weight,
234       v_bias,
235       v_running_mean,
236       v_running_var,
237       eps);
238 
239   return convert(v_output);
240 }
241 
242 } // namespace ops
243 } // namespace vulkan
244 } // namespace native
245 } // namespace at
246