xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/Utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/mkldnn/Utils.h>
3 #include <ATen/native/Pool.h>
4 #include <c10/util/irange.h>
5 
6 namespace at { namespace native {
7 
pool_output_sizes(IntArrayRef input_size,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding_l,IntArrayRef padding_r,IntArrayRef dilation,bool ceil_mode)8 std::vector<int64_t> pool_output_sizes(
9     IntArrayRef input_size,
10     IntArrayRef kernel_size,
11     IntArrayRef stride,
12     IntArrayRef padding_l,
13     IntArrayRef padding_r,
14     IntArrayRef dilation,
15     bool ceil_mode) {
16   std::vector<int64_t> output_size(input_size.size());
17   // copy N and C
18   output_size[0] = input_size[0];
19   output_size[1] = input_size[1];
20 
21   for (const auto i : c10::irange(2, input_size.size())) {
22     output_size[i] = pooling_output_shape_pad_lr<int64_t>(
23       input_size[i],
24       kernel_size[i - 2],
25       padding_l[i - 2],
26       padding_r[i - 2],
27       stride[i - 2],
28       dilation[i - 2],
29       ceil_mode
30     );
31   }
32 
33    return output_size;
34 }
35 
check_mkldnn_binary_fusion_inputs(const Tensor & input,const Tensor & other,const Tensor & weight,const Tensor & bias)36 void check_mkldnn_binary_fusion_inputs(
37     const Tensor& input,
38     const Tensor& other,
39     const Tensor& weight,
40     const Tensor& bias) {
41   if (!weight.is_mkldnn()) {
42     TORCH_CHECK(
43         input.options().type_equal(weight.options()),
44         "Input type (",
45         input.toString(),
46         ") and weight type (",
47         weight.toString(),
48         ") should be the same");
49   } else {
50     TORCH_CHECK(
51         input.scalar_type() == input.scalar_type(),
52         "mkldnn pointwise binary: input dtype and weight dtype should be the same");
53   }
54   TORCH_CHECK(
55       input.options().type_equal(other.options()),
56       "Input type (",
57       input.toString(),
58       ") and other type (",
59       other.toString(),
60       ") should be the same");
61   TORCH_CHECK(
62       !bias.defined() || (input.options().type_equal(bias.options())),
63       "Input type (",
64       input.toString(),
65       ") and bias type (",
66       bias.toString(),
67       ") should be the same");
68   TORCH_CHECK(
69       input.device().is_cpu(),
70       "mkldnn pointwise binary fusion: input's device should be CPU");
71   TORCH_CHECK(
72       input.scalar_type() == ScalarType::Float ||
73           input.scalar_type() == ScalarType::BFloat16 ||
74           input.scalar_type() == ScalarType::Half,
75       "mkldnn pointwise binary: input's dtype should be float, bfloat16 or half");
76   mkldnn_check_low_precision(input.scalar_type(), "mkldnn pointwise binary");
77 }
78 
79 #if AT_MKLDNN_ENABLED()
80 
81 #define ATTR_FUNC(NAME)                              \
82   [](torch::List<std::optional<at::Scalar>> scalars, \
83      std::optional<c10::string_view> algorithm) {    \
84     return ideep::attr_t::fuse_##NAME();             \
85   }
86 
87 AttrFunction attr_func_leaky_relu =
88     [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0102(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 89        std::optional<c10::string_view> algorithm) {
90       TORCH_CHECK(
91           scalars.size() == 1 &&
92               scalars[0].get().toOptional<at::Scalar>().has_value(),
93           "leaky_relu is expected to have one scalar input: negative_slope");
94       auto alpha_value =
95           scalars[0].get().toOptional<at::Scalar>().value().to<float>();
96       return ideep::attr_t::fuse_relu(1.0, alpha_value);
97     };
98 
99 AttrFunction attr_func_hardtanh =
100     [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0202(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 101        std::optional<c10::string_view> algorithm) {
102       TORCH_CHECK(
103           scalars.size() == 2 &&
104               scalars[0].get().toOptional<at::Scalar>().has_value() &&
105               scalars[1].get().toOptional<at::Scalar>().has_value(),
106           "hardtanh is expected to have two scalar input: min_val and max_val");
107 
108       auto lower_bound_value =
109           scalars[0].get().toOptional<at::Scalar>().value().to<float>();
110       auto upper_bound_value =
111           scalars[1].get().toOptional<at::Scalar>().value().to<float>();
112       return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
113     };
114 
115 AttrFunction attr_func_gelu = [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0302(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 116                                  std::optional<c10::string_view> algorithm) {
117   TORCH_CHECK(
118       algorithm.has_value(),
119       "gelu is expected to have one str input: algorithm");
120   dnnl::algorithm gelu_type;
121   if (algorithm.value() == "none") {
122     gelu_type = dnnl::algorithm::eltwise_gelu_erf;
123   } else if (algorithm.value() == "tanh") {
124     gelu_type = dnnl::algorithm::eltwise_gelu_tanh;
125   } else {
126     TORCH_INTERNAL_ASSERT(
127         false, "Unsupported gelu algorithm: ", algorithm.value());
128   }
129 
130   return ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type);
131 };
132 
133 AttrFunction attr_func_hardsigmoid =
134     [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0402(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 135        std::optional<c10::string_view> algorithm) {
136       ideep::attr_t attr;
137       ideep::post_ops po;
138       po.append_eltwise(
139           ideep::algorithm::eltwise_hardsigmoid, 1.0f / 6.0f, 0.5f);
140       attr.set_post_ops(po);
141       return attr;
142     };
143 
fusion_unary_attr_map()144 const std::map<c10::string_view, AttrFunction>& fusion_unary_attr_map() {
145   static const std::map<c10::string_view, AttrFunction> fusion_attr_map{
146       {"relu", ATTR_FUNC(relu)},
147       {"sigmoid", ATTR_FUNC(sigmoid)},
148       {"tanh", ATTR_FUNC(tanh)},
149       {"swish", ATTR_FUNC(swish)},
150       {"hardswish", ATTR_FUNC(hardswish)},
151       {"hardsigmoid", attr_func_hardsigmoid},
152       {"leaky_relu", attr_func_leaky_relu},
153       {"hardtanh", attr_func_hardtanh},
154       {"gelu", attr_func_gelu},
155   };
156   return fusion_attr_map;
157 };
158 
fusion_unary_alg_map()159 const std::map<c10::string_view, ideep::algorithm>& fusion_unary_alg_map() {
160   static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
161       {"relu", {ideep::algorithm::eltwise_relu}},
162   };
163   return fusion_attr_map;
164 };
165 
fusion_binary_alg_map()166 const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map() {
167   static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
168       {"add", {ideep::algorithm::binary_add}},
169       {"sub", {ideep::algorithm::binary_sub}},
170       {"mul", {ideep::algorithm::binary_mul}},
171       {"div", {ideep::algorithm::binary_div}},
172   };
173   return fusion_attr_map;
174 };
175 
176 #endif // AT_MKLDNN_ENABLED()
177 }}
178