1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/mkldnn/Utils.h>
3 #include <ATen/native/Pool.h>
4 #include <c10/util/irange.h>
5
6 namespace at { namespace native {
7
pool_output_sizes(IntArrayRef input_size,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding_l,IntArrayRef padding_r,IntArrayRef dilation,bool ceil_mode)8 std::vector<int64_t> pool_output_sizes(
9 IntArrayRef input_size,
10 IntArrayRef kernel_size,
11 IntArrayRef stride,
12 IntArrayRef padding_l,
13 IntArrayRef padding_r,
14 IntArrayRef dilation,
15 bool ceil_mode) {
16 std::vector<int64_t> output_size(input_size.size());
17 // copy N and C
18 output_size[0] = input_size[0];
19 output_size[1] = input_size[1];
20
21 for (const auto i : c10::irange(2, input_size.size())) {
22 output_size[i] = pooling_output_shape_pad_lr<int64_t>(
23 input_size[i],
24 kernel_size[i - 2],
25 padding_l[i - 2],
26 padding_r[i - 2],
27 stride[i - 2],
28 dilation[i - 2],
29 ceil_mode
30 );
31 }
32
33 return output_size;
34 }
35
check_mkldnn_binary_fusion_inputs(const Tensor & input,const Tensor & other,const Tensor & weight,const Tensor & bias)36 void check_mkldnn_binary_fusion_inputs(
37 const Tensor& input,
38 const Tensor& other,
39 const Tensor& weight,
40 const Tensor& bias) {
41 if (!weight.is_mkldnn()) {
42 TORCH_CHECK(
43 input.options().type_equal(weight.options()),
44 "Input type (",
45 input.toString(),
46 ") and weight type (",
47 weight.toString(),
48 ") should be the same");
49 } else {
50 TORCH_CHECK(
51 input.scalar_type() == input.scalar_type(),
52 "mkldnn pointwise binary: input dtype and weight dtype should be the same");
53 }
54 TORCH_CHECK(
55 input.options().type_equal(other.options()),
56 "Input type (",
57 input.toString(),
58 ") and other type (",
59 other.toString(),
60 ") should be the same");
61 TORCH_CHECK(
62 !bias.defined() || (input.options().type_equal(bias.options())),
63 "Input type (",
64 input.toString(),
65 ") and bias type (",
66 bias.toString(),
67 ") should be the same");
68 TORCH_CHECK(
69 input.device().is_cpu(),
70 "mkldnn pointwise binary fusion: input's device should be CPU");
71 TORCH_CHECK(
72 input.scalar_type() == ScalarType::Float ||
73 input.scalar_type() == ScalarType::BFloat16 ||
74 input.scalar_type() == ScalarType::Half,
75 "mkldnn pointwise binary: input's dtype should be float, bfloat16 or half");
76 mkldnn_check_low_precision(input.scalar_type(), "mkldnn pointwise binary");
77 }
78
79 #if AT_MKLDNN_ENABLED()
80
81 #define ATTR_FUNC(NAME) \
82 [](torch::List<std::optional<at::Scalar>> scalars, \
83 std::optional<c10::string_view> algorithm) { \
84 return ideep::attr_t::fuse_##NAME(); \
85 }
86
87 AttrFunction attr_func_leaky_relu =
88 [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0102(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 89 std::optional<c10::string_view> algorithm) {
90 TORCH_CHECK(
91 scalars.size() == 1 &&
92 scalars[0].get().toOptional<at::Scalar>().has_value(),
93 "leaky_relu is expected to have one scalar input: negative_slope");
94 auto alpha_value =
95 scalars[0].get().toOptional<at::Scalar>().value().to<float>();
96 return ideep::attr_t::fuse_relu(1.0, alpha_value);
97 };
98
99 AttrFunction attr_func_hardtanh =
100 [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0202(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 101 std::optional<c10::string_view> algorithm) {
102 TORCH_CHECK(
103 scalars.size() == 2 &&
104 scalars[0].get().toOptional<at::Scalar>().has_value() &&
105 scalars[1].get().toOptional<at::Scalar>().has_value(),
106 "hardtanh is expected to have two scalar input: min_val and max_val");
107
108 auto lower_bound_value =
109 scalars[0].get().toOptional<at::Scalar>().value().to<float>();
110 auto upper_bound_value =
111 scalars[1].get().toOptional<at::Scalar>().value().to<float>();
112 return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
113 };
114
115 AttrFunction attr_func_gelu = [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0302(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 116 std::optional<c10::string_view> algorithm) {
117 TORCH_CHECK(
118 algorithm.has_value(),
119 "gelu is expected to have one str input: algorithm");
120 dnnl::algorithm gelu_type;
121 if (algorithm.value() == "none") {
122 gelu_type = dnnl::algorithm::eltwise_gelu_erf;
123 } else if (algorithm.value() == "tanh") {
124 gelu_type = dnnl::algorithm::eltwise_gelu_tanh;
125 } else {
126 TORCH_INTERNAL_ASSERT(
127 false, "Unsupported gelu algorithm: ", algorithm.value());
128 }
129
130 return ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type);
131 };
132
133 AttrFunction attr_func_hardsigmoid =
134 [](torch::List<std::optional<at::Scalar>> scalars,
__anon6c2fe9af0402(torch::List<std::optional<at::Scalar>> scalars, std::optional<c10::string_view> algorithm) 135 std::optional<c10::string_view> algorithm) {
136 ideep::attr_t attr;
137 ideep::post_ops po;
138 po.append_eltwise(
139 ideep::algorithm::eltwise_hardsigmoid, 1.0f / 6.0f, 0.5f);
140 attr.set_post_ops(po);
141 return attr;
142 };
143
fusion_unary_attr_map()144 const std::map<c10::string_view, AttrFunction>& fusion_unary_attr_map() {
145 static const std::map<c10::string_view, AttrFunction> fusion_attr_map{
146 {"relu", ATTR_FUNC(relu)},
147 {"sigmoid", ATTR_FUNC(sigmoid)},
148 {"tanh", ATTR_FUNC(tanh)},
149 {"swish", ATTR_FUNC(swish)},
150 {"hardswish", ATTR_FUNC(hardswish)},
151 {"hardsigmoid", attr_func_hardsigmoid},
152 {"leaky_relu", attr_func_leaky_relu},
153 {"hardtanh", attr_func_hardtanh},
154 {"gelu", attr_func_gelu},
155 };
156 return fusion_attr_map;
157 };
158
fusion_unary_alg_map()159 const std::map<c10::string_view, ideep::algorithm>& fusion_unary_alg_map() {
160 static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
161 {"relu", {ideep::algorithm::eltwise_relu}},
162 };
163 return fusion_attr_map;
164 };
165
fusion_binary_alg_map()166 const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map() {
167 static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
168 {"add", {ideep::algorithm::binary_add}},
169 {"sub", {ideep::algorithm::binary_sub}},
170 {"mul", {ideep::algorithm::binary_mul}},
171 {"div", {ideep::algorithm::binary_div}},
172 };
173 return fusion_attr_map;
174 };
175
176 #endif // AT_MKLDNN_ENABLED()
177 }}
178