xref: /aosp_15_r20/external/executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h>
10 
11 namespace torch::executor::testing {
test_bool_input()12 void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_bool_input() {
13   TensorFactory<exec_aten::ScalarType::Bool> tf_bool;
14   TensorFactory<exec_aten::ScalarType::Float> tf_float;
15 
16   const std::vector<int32_t> sizes = {1, 2};
17 
18   exec_aten::Tensor a = tf_bool.make(sizes, /*data=*/{false, true});
19   exec_aten::Tensor out = tf_float.zeros(sizes);
20   exec_aten::Tensor res = tf_float.make(
21       sizes,
22       /*data=*/{(float)op_reference(false), (float)op_reference(true)});
23 
24   EXPECT_TENSOR_CLOSE(op_out(a, out), res);
25 }
26 
test_mismatched_input_shapes_dies()27 void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_mismatched_input_shapes_dies() {
28   if (get_supported_features()->is_aten) {
29     GTEST_SKIP() << "ATen kernel can handle mismatched input shapes";
30   }
31   TensorFactory<exec_aten::ScalarType::Float> tf;
32 
33   exec_aten::Tensor a = tf.ones(/*sizes=*/{4});
34   exec_aten::Tensor out = tf.ones(/*sizes=*/{2, 2});
35 
36   ET_EXPECT_KERNEL_FAILURE(context_, op_out(a, out));
37 }
38 
39 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_static_dynamism_support()40     test_all_real_input_half_output_static_dynamism_support() {
41 #define TEST_ENTRY(ctype, dtype)    \
42   test_floating_point_op_out<       \
43       exec_aten::ScalarType::dtype, \
44       exec_aten::ScalarType::Half>();
45   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
46 #undef TEST_ENTRY
47 }
48 
49 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_static_dynamism_support()50     test_all_real_input_bfloat16_output_static_dynamism_support() {
51 #define TEST_ENTRY(ctype, dtype)    \
52   test_floating_point_op_out<       \
53       exec_aten::ScalarType::dtype, \
54       exec_aten::ScalarType::BFloat16>();
55   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
56 #undef TEST_ENTRY
57 }
58 
59 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_static_dynamism_support()60     test_all_real_input_float_output_static_dynamism_support() {
61 #define TEST_ENTRY(ctype, dtype)    \
62   test_floating_point_op_out<       \
63       exec_aten::ScalarType::dtype, \
64       exec_aten::ScalarType::Float>();
65   ET_FORALL_REALH_TYPES(TEST_ENTRY);
66 #undef TEST_ENTRY
67 }
68 
69 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_static_dynamism_support()70     test_all_real_input_double_output_static_dynamism_support() {
71 #define TEST_ENTRY(ctype, dtype)    \
72   test_floating_point_op_out<       \
73       exec_aten::ScalarType::dtype, \
74       exec_aten::ScalarType::Double>();
75   ET_FORALL_REALH_TYPES(TEST_ENTRY);
76 #undef TEST_ENTRY
77 }
78 
79 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_bound_dynamism_support()80     test_all_real_input_half_output_bound_dynamism_support() {
81 #define TEST_ENTRY(ctype, dtype)    \
82   test_floating_point_op_out<       \
83       exec_aten::ScalarType::dtype, \
84       exec_aten::ScalarType::Half>( \
85       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
86   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
87 #undef TEST_ENTRY
88 }
89 
90 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_bound_dynamism_support()91     test_all_real_input_bfloat16_output_bound_dynamism_support() {
92 #define TEST_ENTRY(ctype, dtype)        \
93   test_floating_point_op_out<           \
94       exec_aten::ScalarType::dtype,     \
95       exec_aten::ScalarType::BFloat16>( \
96       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
97   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
98 #undef TEST_ENTRY
99 }
100 
101 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_bound_dynamism_support()102     test_all_real_input_float_output_bound_dynamism_support() {
103 #define TEST_ENTRY(ctype, dtype)     \
104   test_floating_point_op_out<        \
105       exec_aten::ScalarType::dtype,  \
106       exec_aten::ScalarType::Float>( \
107       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
108   ET_FORALL_REALH_TYPES(TEST_ENTRY);
109 #undef TEST_ENTRY
110 }
111 
112 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_bound_dynamism_support()113     test_all_real_input_double_output_bound_dynamism_support() {
114 #define TEST_ENTRY(ctype, dtype)      \
115   test_floating_point_op_out<         \
116       exec_aten::ScalarType::dtype,   \
117       exec_aten::ScalarType::Double>( \
118       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
119   ET_FORALL_REALH_TYPES(TEST_ENTRY);
120 #undef TEST_ENTRY
121 }
122 
123 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_unbound_dynamism_support()124     test_all_real_input_float_output_unbound_dynamism_support() {
125   if (!get_supported_features()->is_aten) {
126     GTEST_SKIP() << "Dynamic shape unbound not supported";
127   }
128 #define TEST_ENTRY(ctype, dtype)     \
129   test_floating_point_op_out<        \
130       exec_aten::ScalarType::dtype,  \
131       exec_aten::ScalarType::Float>( \
132       {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
133   ET_FORALL_REALH_TYPES(TEST_ENTRY);
134 #undef TEST_ENTRY
135 }
136 
137 void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_unbound_dynamism_support()138     test_all_real_input_double_output_unbound_dynamism_support() {
139   if (!get_supported_features()->is_aten) {
140     GTEST_SKIP() << "Dynamic shape unbound not supported";
141   }
142 #define TEST_ENTRY(ctype, dtype)      \
143   test_floating_point_op_out<         \
144       exec_aten::ScalarType::dtype,   \
145       exec_aten::ScalarType::Double>( \
146       {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
147   ET_FORALL_REALH_TYPES(TEST_ENTRY);
148 #undef TEST_ENTRY
149 }
150 
test_non_float_output_dtype_dies()151 void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_non_float_output_dtype_dies() {
152 #define TEST_ENTRY(ctype, dtype)     \
153   test_op_invalid_output_dtype_dies< \
154       exec_aten::ScalarType::Float,  \
155       exec_aten::ScalarType::dtype>();
156   ET_FORALL_INT_TYPES(TEST_ENTRY);
157 #undef TEST_ENTRY
158 }
159 
160 } // namespace torch::executor::testing
161