xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Batchnorm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/ops/Common.h>
6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h>
7 #include <torch/library.h>
8 
9 namespace at {
10 namespace native {
11 namespace vulkan {
12 namespace ops {
13 
14 class BatchNormPackedContext final : virtual public VulkanPackedContext,
15                                      public torch::jit::CustomClassHolder {
16  private:
17   c10::impl::GenericList unpacked_;
18 
19  public:
20   BatchNormPackedContext(
21       const std::optional<Tensor>& weight_opt,
22       const std::optional<Tensor>& bias_opt,
23       const std::optional<Tensor>& running_mean_opt,
24       const std::optional<Tensor>& running_var_opt,
25       double eps);
26 
27   /*
28    * Assigns a name to each index in the packed/unpacked list.
29    */
30   struct ListArgs final {
31     static constexpr uint32_t kWeight = 0u;
32     static constexpr uint32_t kBias = 1u;
33     static constexpr uint32_t kRunningMean = 2u;
34     static constexpr uint32_t kRunningVar = 3u;
35     static constexpr uint32_t kEps = 4u;
36 
37     static constexpr uint32_t kNumArgs = 5u;
38   };
39 
40   static BatchNormPackedContext pack(c10::impl::GenericList);
41 
unpack()42   const c10::impl::GenericList unpack() const override {
43     TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
44 
45     return unpacked_;
46   }
47 };
48 
49 c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context(
50     std::optional<Tensor>&& weight_opt,
51     std::optional<Tensor>&& bias_opt,
52     std::optional<Tensor>&& running_mean_opt,
53     std::optional<Tensor>&& running_var_opt,
54     bool training,
55     double /* momentum */,
56     double eps,
57     bool /* cudnn_enable, deprecated */);
58 
59 Tensor run_batchnorm_context(
60     const Tensor& input_arg,
61     const c10::intrusive_ptr<BatchNormPackedContext>& context);
62 
63 } // namespace ops
64 } // namespace vulkan
65 } // namespace native
66 } // namespace at
67 
68 #endif /* USE_VULKAN_API */
69