xref: /aosp_15_r20/external/executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker namespace torch::executor::testing {
test_bool_input()12*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_bool_input() {
13*523fa7a6SAndroid Build Coastguard Worker   TensorFactory<exec_aten::ScalarType::Bool> tf_bool;
14*523fa7a6SAndroid Build Coastguard Worker   TensorFactory<exec_aten::ScalarType::Float> tf_float;
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker   const std::vector<int32_t> sizes = {1, 2};
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker   exec_aten::Tensor a = tf_bool.make(sizes, /*data=*/{false, true});
19*523fa7a6SAndroid Build Coastguard Worker   exec_aten::Tensor out = tf_float.zeros(sizes);
20*523fa7a6SAndroid Build Coastguard Worker   exec_aten::Tensor res = tf_float.make(
21*523fa7a6SAndroid Build Coastguard Worker       sizes,
22*523fa7a6SAndroid Build Coastguard Worker       /*data=*/{(float)op_reference(false), (float)op_reference(true)});
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker   EXPECT_TENSOR_CLOSE(op_out(a, out), res);
25*523fa7a6SAndroid Build Coastguard Worker }
26*523fa7a6SAndroid Build Coastguard Worker 
test_mismatched_input_shapes_dies()27*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_mismatched_input_shapes_dies() {
28*523fa7a6SAndroid Build Coastguard Worker   if (get_supported_features()->is_aten) {
29*523fa7a6SAndroid Build Coastguard Worker     GTEST_SKIP() << "ATen kernel can handle mismatched input shapes";
30*523fa7a6SAndroid Build Coastguard Worker   }
31*523fa7a6SAndroid Build Coastguard Worker   TensorFactory<exec_aten::ScalarType::Float> tf;
32*523fa7a6SAndroid Build Coastguard Worker 
33*523fa7a6SAndroid Build Coastguard Worker   exec_aten::Tensor a = tf.ones(/*sizes=*/{4});
34*523fa7a6SAndroid Build Coastguard Worker   exec_aten::Tensor out = tf.ones(/*sizes=*/{2, 2});
35*523fa7a6SAndroid Build Coastguard Worker 
36*523fa7a6SAndroid Build Coastguard Worker   ET_EXPECT_KERNEL_FAILURE(context_, op_out(a, out));
37*523fa7a6SAndroid Build Coastguard Worker }
38*523fa7a6SAndroid Build Coastguard Worker 
39*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_static_dynamism_support()40*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_half_output_static_dynamism_support() {
41*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)    \
42*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<       \
43*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype, \
44*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Half>();
45*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
46*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
47*523fa7a6SAndroid Build Coastguard Worker }
48*523fa7a6SAndroid Build Coastguard Worker 
49*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_static_dynamism_support()50*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_bfloat16_output_static_dynamism_support() {
51*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)    \
52*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<       \
53*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype, \
54*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::BFloat16>();
55*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
56*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
57*523fa7a6SAndroid Build Coastguard Worker }
58*523fa7a6SAndroid Build Coastguard Worker 
59*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_static_dynamism_support()60*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_float_output_static_dynamism_support() {
61*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)    \
62*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<       \
63*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype, \
64*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Float>();
65*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALH_TYPES(TEST_ENTRY);
66*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker 
69*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_static_dynamism_support()70*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_double_output_static_dynamism_support() {
71*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)    \
72*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<       \
73*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype, \
74*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Double>();
75*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALH_TYPES(TEST_ENTRY);
76*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
77*523fa7a6SAndroid Build Coastguard Worker }
78*523fa7a6SAndroid Build Coastguard Worker 
79*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_bound_dynamism_support()80*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_half_output_bound_dynamism_support() {
81*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)    \
82*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<       \
83*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype, \
84*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Half>( \
85*523fa7a6SAndroid Build Coastguard Worker       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
86*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
87*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
88*523fa7a6SAndroid Build Coastguard Worker }
89*523fa7a6SAndroid Build Coastguard Worker 
90*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_bound_dynamism_support()91*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_bfloat16_output_bound_dynamism_support() {
92*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)        \
93*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<           \
94*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype,     \
95*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::BFloat16>( \
96*523fa7a6SAndroid Build Coastguard Worker       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
97*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
98*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
99*523fa7a6SAndroid Build Coastguard Worker }
100*523fa7a6SAndroid Build Coastguard Worker 
101*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_bound_dynamism_support()102*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_float_output_bound_dynamism_support() {
103*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)     \
104*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<        \
105*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype,  \
106*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Float>( \
107*523fa7a6SAndroid Build Coastguard Worker       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
108*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALH_TYPES(TEST_ENTRY);
109*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
110*523fa7a6SAndroid Build Coastguard Worker }
111*523fa7a6SAndroid Build Coastguard Worker 
112*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_bound_dynamism_support()113*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_double_output_bound_dynamism_support() {
114*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)      \
115*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<         \
116*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype,   \
117*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Double>( \
118*523fa7a6SAndroid Build Coastguard Worker       {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
119*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALH_TYPES(TEST_ENTRY);
120*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
121*523fa7a6SAndroid Build Coastguard Worker }
122*523fa7a6SAndroid Build Coastguard Worker 
123*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_unbound_dynamism_support()124*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_float_output_unbound_dynamism_support() {
125*523fa7a6SAndroid Build Coastguard Worker   if (!get_supported_features()->is_aten) {
126*523fa7a6SAndroid Build Coastguard Worker     GTEST_SKIP() << "Dynamic shape unbound not supported";
127*523fa7a6SAndroid Build Coastguard Worker   }
128*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)     \
129*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<        \
130*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype,  \
131*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Float>( \
132*523fa7a6SAndroid Build Coastguard Worker       {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
133*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALH_TYPES(TEST_ENTRY);
134*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
135*523fa7a6SAndroid Build Coastguard Worker }
136*523fa7a6SAndroid Build Coastguard Worker 
137*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_unbound_dynamism_support()138*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_double_output_unbound_dynamism_support() {
139*523fa7a6SAndroid Build Coastguard Worker   if (!get_supported_features()->is_aten) {
140*523fa7a6SAndroid Build Coastguard Worker     GTEST_SKIP() << "Dynamic shape unbound not supported";
141*523fa7a6SAndroid Build Coastguard Worker   }
142*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)      \
143*523fa7a6SAndroid Build Coastguard Worker   test_floating_point_op_out<         \
144*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype,   \
145*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Double>( \
146*523fa7a6SAndroid Build Coastguard Worker       {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
147*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_REALH_TYPES(TEST_ENTRY);
148*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
149*523fa7a6SAndroid Build Coastguard Worker }
150*523fa7a6SAndroid Build Coastguard Worker 
test_non_float_output_dtype_dies()151*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_non_float_output_dtype_dies() {
152*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype)     \
153*523fa7a6SAndroid Build Coastguard Worker   test_op_invalid_output_dtype_dies< \
154*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::Float,  \
155*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType::dtype>();
156*523fa7a6SAndroid Build Coastguard Worker   ET_FORALL_INT_TYPES(TEST_ENTRY);
157*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
158*523fa7a6SAndroid Build Coastguard Worker }
159*523fa7a6SAndroid Build Coastguard Worker 
160*523fa7a6SAndroid Build Coastguard Worker } // namespace torch::executor::testing
161