1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker namespace torch::executor::testing {
test_bool_input()12*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_bool_input() {
13*523fa7a6SAndroid Build Coastguard Worker TensorFactory<exec_aten::ScalarType::Bool> tf_bool;
14*523fa7a6SAndroid Build Coastguard Worker TensorFactory<exec_aten::ScalarType::Float> tf_float;
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker const std::vector<int32_t> sizes = {1, 2};
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor a = tf_bool.make(sizes, /*data=*/{false, true});
19*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor out = tf_float.zeros(sizes);
20*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor res = tf_float.make(
21*523fa7a6SAndroid Build Coastguard Worker sizes,
22*523fa7a6SAndroid Build Coastguard Worker /*data=*/{(float)op_reference(false), (float)op_reference(true)});
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_CLOSE(op_out(a, out), res);
25*523fa7a6SAndroid Build Coastguard Worker }
26*523fa7a6SAndroid Build Coastguard Worker
test_mismatched_input_shapes_dies()27*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_mismatched_input_shapes_dies() {
28*523fa7a6SAndroid Build Coastguard Worker if (get_supported_features()->is_aten) {
29*523fa7a6SAndroid Build Coastguard Worker GTEST_SKIP() << "ATen kernel can handle mismatched input shapes";
30*523fa7a6SAndroid Build Coastguard Worker }
31*523fa7a6SAndroid Build Coastguard Worker TensorFactory<exec_aten::ScalarType::Float> tf;
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor a = tf.ones(/*sizes=*/{4});
34*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor out = tf.ones(/*sizes=*/{2, 2});
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker ET_EXPECT_KERNEL_FAILURE(context_, op_out(a, out));
37*523fa7a6SAndroid Build Coastguard Worker }
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_static_dynamism_support()40*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_half_output_static_dynamism_support() {
41*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
42*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
43*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
44*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Half>();
45*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
46*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
47*523fa7a6SAndroid Build Coastguard Worker }
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_static_dynamism_support()50*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_bfloat16_output_static_dynamism_support() {
51*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
52*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
53*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
54*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::BFloat16>();
55*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
56*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
57*523fa7a6SAndroid Build Coastguard Worker }
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_static_dynamism_support()60*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_float_output_static_dynamism_support() {
61*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
62*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
63*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
64*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Float>();
65*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALH_TYPES(TEST_ENTRY);
66*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_static_dynamism_support()70*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_double_output_static_dynamism_support() {
71*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
72*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
73*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
74*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Double>();
75*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALH_TYPES(TEST_ENTRY);
76*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
77*523fa7a6SAndroid Build Coastguard Worker }
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_half_output_bound_dynamism_support()80*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_half_output_bound_dynamism_support() {
81*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
82*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
83*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
84*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Half>( \
85*523fa7a6SAndroid Build Coastguard Worker {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
86*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
87*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
88*523fa7a6SAndroid Build Coastguard Worker }
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_bfloat16_output_bound_dynamism_support()91*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_bfloat16_output_bound_dynamism_support() {
92*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
93*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
94*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
95*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::BFloat16>( \
96*523fa7a6SAndroid Build Coastguard Worker {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
97*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
98*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
99*523fa7a6SAndroid Build Coastguard Worker }
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_bound_dynamism_support()102*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_float_output_bound_dynamism_support() {
103*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
104*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
105*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
106*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Float>( \
107*523fa7a6SAndroid Build Coastguard Worker {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
108*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALH_TYPES(TEST_ENTRY);
109*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
110*523fa7a6SAndroid Build Coastguard Worker }
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_bound_dynamism_support()113*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_double_output_bound_dynamism_support() {
114*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
115*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
116*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
117*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Double>( \
118*523fa7a6SAndroid Build Coastguard Worker {10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
119*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALH_TYPES(TEST_ENTRY);
120*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
121*523fa7a6SAndroid Build Coastguard Worker }
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_float_output_unbound_dynamism_support()124*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_float_output_unbound_dynamism_support() {
125*523fa7a6SAndroid Build Coastguard Worker if (!get_supported_features()->is_aten) {
126*523fa7a6SAndroid Build Coastguard Worker GTEST_SKIP() << "Dynamic shape unbound not supported";
127*523fa7a6SAndroid Build Coastguard Worker }
128*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
129*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
130*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
131*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Float>( \
132*523fa7a6SAndroid Build Coastguard Worker {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
133*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALH_TYPES(TEST_ENTRY);
134*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
135*523fa7a6SAndroid Build Coastguard Worker }
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::
test_all_real_input_double_output_unbound_dynamism_support()138*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_double_output_unbound_dynamism_support() {
139*523fa7a6SAndroid Build Coastguard Worker if (!get_supported_features()->is_aten) {
140*523fa7a6SAndroid Build Coastguard Worker GTEST_SKIP() << "Dynamic shape unbound not supported";
141*523fa7a6SAndroid Build Coastguard Worker }
142*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
143*523fa7a6SAndroid Build Coastguard Worker test_floating_point_op_out< \
144*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype, \
145*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Double>( \
146*523fa7a6SAndroid Build Coastguard Worker {1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
147*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_REALH_TYPES(TEST_ENTRY);
148*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
149*523fa7a6SAndroid Build Coastguard Worker }
150*523fa7a6SAndroid Build Coastguard Worker
test_non_float_output_dtype_dies()151*523fa7a6SAndroid Build Coastguard Worker void UnaryUfuncRealHBBF16ToFloatHBF16Test::test_non_float_output_dtype_dies() {
152*523fa7a6SAndroid Build Coastguard Worker #define TEST_ENTRY(ctype, dtype) \
153*523fa7a6SAndroid Build Coastguard Worker test_op_invalid_output_dtype_dies< \
154*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Float, \
155*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::dtype>();
156*523fa7a6SAndroid Build Coastguard Worker ET_FORALL_INT_TYPES(TEST_ENTRY);
157*523fa7a6SAndroid Build Coastguard Worker #undef TEST_ENTRY
158*523fa7a6SAndroid Build Coastguard Worker }
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker } // namespace torch::executor::testing
161