1 #include <ATen/Context.h>
2 #include <ATen/native/vulkan/ops/Batchnorm.h>
3 #include <torch/library.h>
4
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9
10 namespace batchnorm {
11
12 struct Params final {
13 api::utils::ivec3 out_extents;
14 int32_t c4;
15 float eps;
16 };
17
record_op(api::Context * const context,vTensor & v_output,const vTensor & v_input,const vTensor & v_weight,const vTensor & v_bias,const vTensor & v_running_mean,const vTensor & v_running_var,const float eps)18 static void record_op(
19 api::Context* const context,
20 vTensor& v_output,
21 const vTensor& v_input,
22 const vTensor& v_weight,
23 const vTensor& v_bias,
24 const vTensor& v_running_mean,
25 const vTensor& v_running_var,
26 const float eps) {
27 api::PipelineBarrier pipeline_barrier{};
28
29 api::utils::uvec3 global_size = v_output.extents();
30 api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
31
32 uint32_t num_features = get_dim<Dim4D::Channel>(v_input.sizes());
33 uint32_t channels_ext = api::utils::div_up(num_features, 4u);
34
35 Params block{
36 api::utils::make_ivec3(v_output.extents()),
37 api::utils::safe_downcast<int32_t>(channels_ext),
38 eps,
39 };
40
41 api::UniformParamsBuffer params(context, block);
42
43 context->submit_compute_job(
44 // shader descriptor
45 VK_KERNEL(batchnorm),
46 // pipeline barrier
47 pipeline_barrier,
48 // global work group size
49 global_size,
50 // local work group size
51 local_size,
52 // fence handle
53 VK_NULL_HANDLE,
54 // shader arguments
55 v_output.image(
56 pipeline_barrier,
57 api::PipelineStage::COMPUTE,
58 api::MemoryAccessType::WRITE),
59 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
60 v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
61 v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
62 v_running_mean.image(pipeline_barrier, api::PipelineStage::COMPUTE),
63 v_running_var.image(pipeline_barrier, api::PipelineStage::COMPUTE),
64 // params buffer
65 params.buffer());
66 }
67
68 } // namespace batchnorm
69
70 namespace {
71
72 using namespace api::utils;
73
batch_norm(const at::Tensor & input_arg,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double,double eps,bool)74 Tensor batch_norm(
75 const at::Tensor& input_arg,
76 const std::optional<Tensor>& weight_opt /* optional */,
77 const std::optional<Tensor>& bias_opt /* optional */,
78 const std::optional<Tensor>& running_mean_opt /* optional */,
79 const std::optional<Tensor>& running_var_opt /* optional */,
80 bool training,
81 double /* momentum, not used in eval mode */,
82 double eps,
83 bool /* cudnn_enable, deprecated */) {
84 TORCH_CHECK(!training, "Only evaluation mode is supported!");
85 TORCH_CHECK(input_arg.dim() == 4, "Input must have dim == 4!");
86 TORCH_CHECK(
87 get_dim<Dim4D::Channel>(input_arg) % 4 == 0,
88 "Input must have channels divisible by 4!");
89
90 return run_batchnorm_context(
91 input_arg,
92 c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
93 weight_opt, bias_opt, running_mean_opt, running_var_opt, eps)));
94 }
95
96 #ifdef USE_VULKAN_API
97
TORCH_LIBRARY_IMPL(aten,Vulkan,m)98 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
99 m.impl(TORCH_SELECTIVE_NAME("aten::batch_norm"), TORCH_FN(batch_norm));
100 }
101
102 #endif /* USE_VULKAN_API */
103
104 } // namespace
105
BatchNormPackedContext(const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,double eps)106 BatchNormPackedContext::BatchNormPackedContext(
107 const std::optional<Tensor>& weight_opt,
108 const std::optional<Tensor>& bias_opt,
109 const std::optional<Tensor>& running_mean_opt,
110 const std::optional<Tensor>& running_var_opt,
111 double eps)
112 : unpacked_{c10::AnyType::get()} {
113 packed_.reserve(ListArgs::kNumArgs);
114
115 // Each optional tensor arg, if provided should be a 1 dimensional tensor. To
116 // achieve more efficient packing as a texture, they are first reshaped to {N,
117 // 1, 1}. Eventually this rearrangement should happen automatically in vTensor
118 // itself.
119
120 // Weight
121 TORCH_CHECK(weight_opt, "Weight must be provided!");
122 TORCH_CHECK(weight_opt->dim() == 1, "Weight must have ndim == 1!");
123
124 const int64_t num_features =
125 api::utils::safe_downcast<int64_t>(weight_opt->numel());
126 const Tensor weight_3d = weight_opt->reshape({num_features, 1, 1});
127 packed_.emplace_back(weight_3d.vulkan());
128
129 // Bias
130 TORCH_CHECK(bias_opt, "Bias must be provided!");
131 TORCH_CHECK(bias_opt->dim() == 1, "Bias must have ndim == 1!");
132 TORCH_CHECK(
133 bias_opt->numel() == num_features,
134 "Bias must have the same numel as weight!");
135
136 const Tensor bias_3d = bias_opt->reshape({num_features, 1, 1});
137 packed_.emplace_back(bias_3d.vulkan());
138
139 // Running Mean
140 TORCH_CHECK(running_mean_opt, "Running mean must be provided!");
141 TORCH_CHECK(running_mean_opt->dim() == 1, "Running mean must have ndim == 1");
142 TORCH_CHECK(
143 running_mean_opt->numel() == num_features,
144 "Running mean must have the same numel as weight!");
145
146 const Tensor running_mean_3d =
147 running_mean_opt->reshape({num_features, 1, 1});
148 packed_.emplace_back(running_mean_3d.vulkan());
149
150 // Running var
151 TORCH_CHECK(running_var_opt, "Running var must be provided!");
152 TORCH_CHECK(running_var_opt->dim() == 1, "Running var must have ndim == 1");
153 TORCH_CHECK(
154 running_var_opt->numel() == num_features,
155 "Running var must have the same numel as weight!");
156
157 const Tensor running_var_3d = running_var_opt->reshape({num_features, 1, 1});
158 packed_.emplace_back(running_var_3d.vulkan());
159
160 // Epsilon
161 packed_.emplace_back(eps);
162
163 if (!at::globalContext().releaseWeightsWhenPrepacking()) {
164 unpacked_.reserve(ListArgs::kNumArgs);
165 unpacked_.emplace_back(weight_opt);
166 unpacked_.emplace_back(bias_opt);
167 unpacked_.emplace_back(running_mean_opt);
168 unpacked_.emplace_back(running_var_opt);
169 unpacked_.emplace_back(eps);
170 }
171 }
172
pack(c10::impl::GenericList unpacked)173 BatchNormPackedContext BatchNormPackedContext::pack(
174 c10::impl::GenericList unpacked) {
175 return BatchNormPackedContext(
176 get_optional_tensor(unpacked, ListArgs::kWeight),
177 get_optional_tensor(unpacked, ListArgs::kBias),
178 get_optional_tensor(unpacked, ListArgs::kRunningMean),
179 get_optional_tensor(unpacked, ListArgs::kRunningVar),
180 unpacked.get(ListArgs::kEps).toDouble());
181 }
182
create_batchnorm_context(std::optional<Tensor> && weight_opt,std::optional<Tensor> && bias_opt,std::optional<Tensor> && running_mean_opt,std::optional<Tensor> && running_var_opt,bool training,double,double eps,bool)183 c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context(
184 std::optional<Tensor>&& weight_opt,
185 std::optional<Tensor>&& bias_opt,
186 std::optional<Tensor>&& running_mean_opt,
187 std::optional<Tensor>&& running_var_opt,
188 bool training,
189 double /* momentum */,
190 double eps,
191 bool /* cudnn_enable, deprecated */) {
192 return c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
193 weight_opt, bias_opt, running_mean_opt, running_var_opt, eps));
194 }
195
run_batchnorm_context(const Tensor & input_arg,const c10::intrusive_ptr<BatchNormPackedContext> & batchnorm_context)196 Tensor run_batchnorm_context(
197 const Tensor& input_arg,
198 const c10::intrusive_ptr<BatchNormPackedContext>& batchnorm_context) {
199 api::Context* const context = api::context();
200
201 const vTensor& v_input = convert(input_arg);
202
203 const vTensor& v_weight = convert(
204 batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kWeight)
205 .toTensor());
206
207 const vTensor& v_bias = convert(
208 batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kBias)
209 .toTensor());
210
211 const vTensor& v_running_mean = convert(
212 batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningMean)
213 .toTensor());
214
215 const vTensor& v_running_var = convert(
216 batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningVar)
217 .toTensor());
218
219 const float eps = api::utils::safe_downcast<float>(
220 batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kEps)
221 .toDouble());
222
223 vTensor v_output{
224 context,
225 v_input.sizes(),
226 v_input.dtype(),
227 };
228
229 batchnorm::record_op(
230 context,
231 v_output,
232 v_input,
233 v_weight,
234 v_bias,
235 v_running_mean,
236 v_running_var,
237 eps);
238
239 return convert(v_output);
240 }
241
242 } // namespace ops
243 } // namespace vulkan
244 } // namespace native
245 } // namespace at
246