1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h>
10
11 namespace torch::executor::testing {
test_bool_input()12 void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_bool_input() {
13 TensorFactory<exec_aten::ScalarType::Bool> tf_bool;
14 TensorFactory<exec_aten::ScalarType::Float> tf_float;
15
16 const std::vector<int32_t> sizes = {1, 2};
17
18 exec_aten::Tensor a = tf_bool.make(sizes, /*data=*/{false, true});
19 exec_aten::Tensor out = tf_float.zeros(sizes);
20 exec_aten::Tensor res = tf_float.make(
21 sizes,
22 /*data=*/{(float)op_reference(false), (float)op_reference(true)});
23
24 EXPECT_TENSOR_CLOSE(op_out(a, out), res);
25 }
26
test_mismatched_input_shapes_dies()27 void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_mismatched_input_shapes_dies() {
28 if (get_supported_features()->is_aten) {
29 GTEST_SKIP() << "ATen kernel can handle mismatched input shapes";
30 }
31 TensorFactory<exec_aten::ScalarType::Float> tf;
32
33 exec_aten::Tensor a = tf.ones(/*sizes=*/{4});
34 exec_aten::Tensor out = tf.ones(/*sizes=*/{2, 2});
35
36 ET_EXPECT_KERNEL_FAILURE(context_, op_out(a, out));
37 }
38
39 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_static_dynamism_support()40 test_all_real_input_half_output_static_dynamism_support() {
41 #define TEST_ENTRY(ctype, dtype) \
42 test_floating_point_op_out< \
43 exec_aten::ScalarType::dtype, \
44 exec_aten::ScalarType::Half>();
45 ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
46 #undef TEST_ENTRY
47 }
48
49 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_static_dynamism_support()50 test_all_real_input_bfloat16_output_static_dynamism_support() {
51 #define TEST_ENTRY(ctype, dtype) \
52 test_floating_point_op_out< \
53 exec_aten::ScalarType::dtype, \
54 exec_aten::ScalarType::BFloat16>();
55 ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
56 #undef TEST_ENTRY
57 }
58
59 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_static_dynamism_support()60 test_all_real_input_float_output_static_dynamism_support() {
61 #define TEST_ENTRY(ctype, dtype) \
62 test_floating_point_op_out< \
63 exec_aten::ScalarType::dtype, \
64 exec_aten::ScalarType::Float>();
65 ET_FORALL_REALH_TYPES(TEST_ENTRY);
66 #undef TEST_ENTRY
67 }
68
69 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_static_dynamism_support()70 test_all_real_input_double_output_static_dynamism_support() {
71 #define TEST_ENTRY(ctype, dtype) \
72 test_floating_point_op_out< \
73 exec_aten::ScalarType::dtype, \
74 exec_aten::ScalarType::Double>();
75 ET_FORALL_REALH_TYPES(TEST_ENTRY);
76 #undef TEST_ENTRY
77 }
78
79 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_bound_dynamism_support()80 test_all_real_input_half_output_bound_dynamism_support() {
81 #define TEST_ENTRY(ctype, dtype) \
82 test_floating_point_op_out< \
83 exec_aten::ScalarType::dtype, \
84 exec_aten::ScalarType::Half>( \
85 {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
86 ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
87 #undef TEST_ENTRY
88 }
89
90 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_bound_dynamism_support()91 test_all_real_input_bfloat16_output_bound_dynamism_support() {
92 #define TEST_ENTRY(ctype, dtype) \
93 test_floating_point_op_out< \
94 exec_aten::ScalarType::dtype, \
95 exec_aten::ScalarType::BFloat16>( \
96 {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
97 ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
98 #undef TEST_ENTRY
99 }
100
101 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_bound_dynamism_support()102 test_all_real_input_float_output_bound_dynamism_support() {
103 #define TEST_ENTRY(ctype, dtype) \
104 test_floating_point_op_out< \
105 exec_aten::ScalarType::dtype, \
106 exec_aten::ScalarType::Float>( \
107 {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
108 ET_FORALL_REALH_TYPES(TEST_ENTRY);
109 #undef TEST_ENTRY
110 }
111
112 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_bound_dynamism_support()113 test_all_real_input_double_output_bound_dynamism_support() {
114 #define TEST_ENTRY(ctype, dtype) \
115 test_floating_point_op_out< \
116 exec_aten::ScalarType::dtype, \
117 exec_aten::ScalarType::Double>( \
118 {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
119 ET_FORALL_REALH_TYPES(TEST_ENTRY);
120 #undef TEST_ENTRY
121 }
122
123 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_unbound_dynamism_support()124 test_all_real_input_float_output_unbound_dynamism_support() {
125 if (!get_supported_features()->is_aten) {
126 GTEST_SKIP() << "Dynamic shape unbound not supported";
127 }
128 #define TEST_ENTRY(ctype, dtype) \
129 test_floating_point_op_out< \
130 exec_aten::ScalarType::dtype, \
131 exec_aten::ScalarType::Float>( \
132 {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
133 ET_FORALL_REALH_TYPES(TEST_ENTRY);
134 #undef TEST_ENTRY
135 }
136
137 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_unbound_dynamism_support()138 test_all_real_input_double_output_unbound_dynamism_support() {
139 if (!get_supported_features()->is_aten) {
140 GTEST_SKIP() << "Dynamic shape unbound not supported";
141 }
142 #define TEST_ENTRY(ctype, dtype) \
143 test_floating_point_op_out< \
144 exec_aten::ScalarType::dtype, \
145 exec_aten::ScalarType::Double>( \
146 {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
147 ET_FORALL_REALH_TYPES(TEST_ENTRY);
148 #undef TEST_ENTRY
149 }
150
test_non_float_output_dtype_dies()151 void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_non_float_output_dtype_dies() {
152 #define TEST_ENTRY(ctype, dtype) \
153 test_op_invalid_output_dtype_dies< \
154 exec_aten::ScalarType::Float, \
155 exec_aten::ScalarType::dtype>();
156 ET_FORALL_INT_TYPES(TEST_ENTRY);
157 #undef TEST_ENTRY
158 }
159
160 } // namespace torch::executor::testing
161