xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/BinaryOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/ExpandUtils.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/add_native.h>
10 #include <ATen/ops/empty_native.h>
11 #include <ATen/ops/mul_native.h>
12 #endif
13 
14 #if !AT_MKLDNN_ENABLED()
15 
16 namespace at {
17 namespace native {
18 
mkldnn_add_out(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & result)19 Tensor& mkldnn_add_out(
20     const Tensor& self,
21     const Tensor& other,
22     const Scalar& alpha,
23     Tensor& result
24     ) {
25   TORCH_CHECK(false, "mkldnn_add_out: ATen not compiled with MKLDNN support");
26 }
27 
mkldnn_add(const Tensor & self,const Tensor & other,const Scalar & alpha)28 Tensor mkldnn_add(const Tensor& self, const Tensor& other, const Scalar& alpha) {
29   TORCH_CHECK(false, "mkldnn_add: ATen not compiled with MKLDNN support");
30 }
31 
mkldnn_add_(Tensor & self,const Tensor & other,const Scalar & alpha)32 Tensor& mkldnn_add_(Tensor& self, const Tensor& other, const Scalar& alpha) {
33   TORCH_CHECK(false, "mkldnn_add_: ATen not compiled with MKLDNN support");
34 }
35 
mkldnn_mul_out(const Tensor & self,const Tensor & other,Tensor & result)36 Tensor& mkldnn_mul_out(const Tensor& self, const Tensor& other, Tensor& result) {
37   TORCH_CHECK(false, "mkldnn_mul_out: ATen not compiled with MKLDNN support");
38 }
39 
mkldnn_mul(const Tensor & self,const Tensor & other)40 Tensor mkldnn_mul(const Tensor& self, const Tensor& other) {
41   TORCH_CHECK(false, "mkldnn_mul: ATen not compiled with MKLDNN support");
42 }
43 
mkldnn_mul_(Tensor & self,const Tensor & other)44 Tensor& mkldnn_mul_(Tensor& self, const Tensor& other) {
45   TORCH_CHECK(false, "mkldnn_mul_: ATen not compiled with MKLDNN support");
46 }
47 
48 } // namespace native
49 } // namespace at
50 
51 #else // AT_MKLDNN_ENABLED
52 
53 #include <ATen/native/mkldnn/MKLDNNCommon.h>
54 
55 namespace at {
56 namespace native {
57 
emptyBinaryOp(const Tensor & self,const Tensor & other)58 static Tensor emptyBinaryOp(const Tensor& self, const Tensor& other) {
59   if (!self.requires_grad() && !other.requires_grad()) {
60     auto out_size = infer_size(self.sizes(), other.sizes());
61     auto out_dtype = promoteTypes(
62         c10::typeMetaToScalarType(self.dtype()),
63         c10::typeMetaToScalarType(other.dtype()));
64     TORCH_CHECK(
65         self.device() == other.device(),
66         "Expected same device for binary mkldnn op");
67     return empty_mkldnn(
68         out_size,
69         out_dtype,
70         self.options().layout_opt(),
71         self.options().device_opt(),
72         self.options().pinned_memory_opt());
73   } else {
74     TORCH_CHECK(
75         false,
76         "MKLDNN does not support Binary Ops with a 0-dimension Tensor in training");
77   }
78 }
79 
mkldnn_add_out(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & result)80 Tensor& mkldnn_add_out(
81     const Tensor& self,
82     const Tensor& other,
83     const Scalar& alpha,
84     Tensor& result
85     ) {
86   ideep::tensor& x = itensor_from_mkldnn(self);
87   ideep::tensor& y = itensor_from_mkldnn(other);
88 
89   ideep::tensor& z = itensor_from_mkldnn(result);
90   if (result.is_same(other)) {
91     const std::vector<float> scales{alpha.to<float>(), 1.0};
92     ideep::sum::compute(scales, {y, x}, z);
93   } else {
94     const std::vector<float> scales{1.0, alpha.to<float>()};
95     ideep::sum::compute(scales, {x, y}, z);
96   }
97 
98   return result;
99 }
100 
mkldnn_add(const Tensor & self,const Tensor & other,const Scalar & alpha)101 Tensor mkldnn_add(const Tensor& self, const Tensor& other, const Scalar& alpha) {
102   if (self.numel() == 0 || other.numel() == 0) {
103     return emptyBinaryOp(self, other);
104   }
105 
106   ideep::tensor& x = itensor_from_mkldnn(self);
107   ideep::tensor& y = itensor_from_mkldnn(other);
108 
109   ideep::tensor z;
110   const std::vector<float> scales{1.0, alpha.to<float>()};
111   ideep::sum::compute(scales, {x, y}, z);
112 
113   return new_with_itensor_mkldnn(std::move(z), optTypeMetaToScalarType(self.options().dtype_opt()),
114                                  self.options().device_opt());
115 }
116 
mkldnn_add_(Tensor & self,const Tensor & other,const Scalar & alpha)117 Tensor& mkldnn_add_(Tensor& self, const Tensor& other, const Scalar& alpha) {
118   return native::mkldnn_add_out(self, other, alpha, self);
119 }
120 
mkldnn_mul_out(const Tensor & self,const Tensor & other,Tensor & result)121 Tensor& mkldnn_mul_out(const Tensor& self, const Tensor& other, Tensor& result) {
122   TORCH_CHECK(result.sizes() == self.sizes(),
123              "mkldnn_mul_out: the output size should be same as input size");
124   ideep::tensor& z = itensor_from_mkldnn(result);
125   ideep::tensor& x = itensor_from_mkldnn(self);
126 
127   // for zero_dim tensor
128   if (other.ndimension() == 0) {
129     ideep::eltwise_forward::compute(
130       x, z, ideep::algorithm::eltwise_linear,
131       ideep::prop_kind::forward_inference, /*alpha*/ other.item().to<float>());
132 
133     return result;
134   } else {
135     TORCH_CHECK(self.sizes() == other.sizes(),
136                "mkldnn_mul_out: currently mkldnn not support broadcasting");
137     ideep::tensor y = itensor_from_mkldnn(other);
138     ideep::binary::compute(x, y, z, dnnl::algorithm::binary_mul);
139 
140     return result;
141   }
142 }
143 
mkldnn_mul(const Tensor & self,const Tensor & other)144 Tensor mkldnn_mul(const Tensor& self, const Tensor& other) {
145   if (self.numel() == 0 || other.numel() == 0) {
146     return emptyBinaryOp(self, other);
147   }
148   Tensor result = empty_mkldnn(self.sizes(), optTypeMetaToScalarType(self.options().dtype_opt()),
149                                self.options().layout_opt(), self.options().device_opt(),
150                                self.options().pinned_memory_opt());
151   return native::mkldnn_mul_out(self, other, result);
152 }
153 
mkldnn_mul_(Tensor & self,const Tensor & other)154 Tensor& mkldnn_mul_(Tensor& self, const Tensor& other) {
155   return native::mkldnn_mul_out(self, other, self);
156 }
157 
158 } // namespace native
159 } // namespace at
160 
161 #endif // AT_MKLDNN_ENABLED
162