1 #pragma once 2 3 #ifdef USE_VULKAN_API 4 5 #include <ATen/native/vulkan/ops/Common.h> 6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h> 7 #include <torch/library.h> 8 9 namespace at { 10 namespace native { 11 namespace vulkan { 12 namespace ops { 13 14 class BatchNormPackedContext final : virtual public VulkanPackedContext, 15 public torch::jit::CustomClassHolder { 16 private: 17 c10::impl::GenericList unpacked_; 18 19 public: 20 BatchNormPackedContext( 21 const std::optional<Tensor>& weight_opt, 22 const std::optional<Tensor>& bias_opt, 23 const std::optional<Tensor>& running_mean_opt, 24 const std::optional<Tensor>& running_var_opt, 25 double eps); 26 27 /* 28 * Assigns a name to each index in the packed/unpacked list. 29 */ 30 struct ListArgs final { 31 static constexpr uint32_t kWeight = 0u; 32 static constexpr uint32_t kBias = 1u; 33 static constexpr uint32_t kRunningMean = 2u; 34 static constexpr uint32_t kRunningVar = 3u; 35 static constexpr uint32_t kEps = 4u; 36 37 static constexpr uint32_t kNumArgs = 5u; 38 }; 39 40 static BatchNormPackedContext pack(c10::impl::GenericList); 41 unpack()42 const c10::impl::GenericList unpack() const override { 43 TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!"); 44 45 return unpacked_; 46 } 47 }; 48 49 c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context( 50 std::optional<Tensor>&& weight_opt, 51 std::optional<Tensor>&& bias_opt, 52 std::optional<Tensor>&& running_mean_opt, 53 std::optional<Tensor>&& running_var_opt, 54 bool training, 55 double /* momentum */, 56 double eps, 57 bool /* cudnn_enable, deprecated */); 58 59 Tensor run_batchnorm_context( 60 const Tensor& input_arg, 61 const c10::intrusive_ptr<BatchNormPackedContext>& context); 62 63 } // namespace ops 64 } // namespace vulkan 65 } // namespace native 66 } // namespace at 67 68 #endif /* USE_VULKAN_API */ 69