xref: /aosp_15_r20/external/executorch/kernels/test/TestUtil.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker /**
10*523fa7a6SAndroid Build Coastguard Worker  * @file
11*523fa7a6SAndroid Build Coastguard Worker  * Kernel Test utilities.
12*523fa7a6SAndroid Build Coastguard Worker  */
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #pragma once
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/error.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/kernel_includes.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/runtime.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/test/utils/DeathTest.h>
20*523fa7a6SAndroid Build Coastguard Worker #include <gtest/gtest.h>
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
23*523fa7a6SAndroid Build Coastguard Worker /**
24*523fa7a6SAndroid Build Coastguard Worker  * Ensure the kernel will fail when `_statement` is executed.
25*523fa7a6SAndroid Build Coastguard Worker  * @param _statement Statement to execute.
26*523fa7a6SAndroid Build Coastguard Worker  */
27*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
28*523fa7a6SAndroid Build Coastguard Worker   EXPECT_ANY_THROW(_statement)
29*523fa7a6SAndroid Build Coastguard Worker 
30*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \
31*523fa7a6SAndroid Build Coastguard Worker   EXPECT_ANY_THROW(_statement)
32*523fa7a6SAndroid Build Coastguard Worker 
33*523fa7a6SAndroid Build Coastguard Worker #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS(                                  \
34*523fa7a6SAndroid Build Coastguard Worker     tf, op, input_contiguous, expected_contiguous, channels_last_support)    \
35*523fa7a6SAndroid Build Coastguard Worker   Tensor input_channels_last = tf.channels_last_like(input_contiguous);      \
36*523fa7a6SAndroid Build Coastguard Worker   Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
37*523fa7a6SAndroid Build Coastguard Worker                                                                              \
38*523fa7a6SAndroid Build Coastguard Worker   Tensor output_contiguous = tf.zeros_like(expected_contiguous);             \
39*523fa7a6SAndroid Build Coastguard Worker   Tensor output_channels_last = tf.channels_last_like(output_contiguous);    \
40*523fa7a6SAndroid Build Coastguard Worker                                                                              \
41*523fa7a6SAndroid Build Coastguard Worker   Tensor ret = op(input_channels_last, output_channels_last);                \
42*523fa7a6SAndroid Build Coastguard Worker   if (channels_last_support) {                                               \
43*523fa7a6SAndroid Build Coastguard Worker     EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last);           \
44*523fa7a6SAndroid Build Coastguard Worker   } else {                                                                   \
45*523fa7a6SAndroid Build Coastguard Worker     EXPECT_TENSOR_NE(output_channels_last, expected_channel_last);           \
46*523fa7a6SAndroid Build Coastguard Worker   }                                                                          \
47*523fa7a6SAndroid Build Coastguard Worker   EXPECT_TENSOR_EQ(output_channels_last, ret);
48*523fa7a6SAndroid Build Coastguard Worker 
49*523fa7a6SAndroid Build Coastguard Worker #else
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE(_context, _statement)              \
52*523fa7a6SAndroid Build Coastguard Worker   do {                                                              \
53*523fa7a6SAndroid Build Coastguard Worker     _statement;                                                     \
54*523fa7a6SAndroid Build Coastguard Worker     expect_failure();                                               \
55*523fa7a6SAndroid Build Coastguard Worker     if ((_context).failure_state() == torch::executor::Error::Ok) { \
56*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(Error, "Expected kernel failure but found success.");  \
57*523fa7a6SAndroid Build Coastguard Worker       ADD_FAILURE();                                                \
58*523fa7a6SAndroid Build Coastguard Worker     }                                                               \
59*523fa7a6SAndroid Build Coastguard Worker   } while (false)
60*523fa7a6SAndroid Build Coastguard Worker 
61*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _msg) \
62*523fa7a6SAndroid Build Coastguard Worker   do {                                                                \
63*523fa7a6SAndroid Build Coastguard Worker     _statement;                                                       \
64*523fa7a6SAndroid Build Coastguard Worker     expect_failure();                                                 \
65*523fa7a6SAndroid Build Coastguard Worker     if ((_context).failure_state() == torch::executor::Error::Ok) {   \
66*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(Error, "Expected kernel failure but found success.");    \
67*523fa7a6SAndroid Build Coastguard Worker       ADD_FAILURE();                                                  \
68*523fa7a6SAndroid Build Coastguard Worker     }                                                                 \
69*523fa7a6SAndroid Build Coastguard Worker   } while (false)
70*523fa7a6SAndroid Build Coastguard Worker 
71*523fa7a6SAndroid Build Coastguard Worker #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS(                                  \
72*523fa7a6SAndroid Build Coastguard Worker     tf, op, input_contiguous, expected_contiguous, channels_last_support)    \
73*523fa7a6SAndroid Build Coastguard Worker   Tensor input_channels_last = tf.channels_last_like(input_contiguous);      \
74*523fa7a6SAndroid Build Coastguard Worker   Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
75*523fa7a6SAndroid Build Coastguard Worker                                                                              \
76*523fa7a6SAndroid Build Coastguard Worker   Tensor output_contiguous = tf.zeros_like(expected_contiguous);             \
77*523fa7a6SAndroid Build Coastguard Worker   Tensor output_channels_last = tf.channels_last_like(output_contiguous);    \
78*523fa7a6SAndroid Build Coastguard Worker                                                                              \
79*523fa7a6SAndroid Build Coastguard Worker   Tensor ret = op(input_channels_last, output_channels_last);                \
80*523fa7a6SAndroid Build Coastguard Worker   if (channels_last_support) {                                               \
81*523fa7a6SAndroid Build Coastguard Worker     EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last);           \
82*523fa7a6SAndroid Build Coastguard Worker   } else {                                                                   \
83*523fa7a6SAndroid Build Coastguard Worker     EXPECT_TENSOR_NE(output_channels_last, expected_channel_last);           \
84*523fa7a6SAndroid Build Coastguard Worker   }                                                                          \
85*523fa7a6SAndroid Build Coastguard Worker   EXPECT_TENSOR_EQ(output_channels_last, ret);                               \
86*523fa7a6SAndroid Build Coastguard Worker   ET_EXPECT_KERNEL_FAILURE(                                                  \
87*523fa7a6SAndroid Build Coastguard Worker       context_, op(input_channels_last, output_contiguous));                 \
88*523fa7a6SAndroid Build Coastguard Worker   ET_EXPECT_KERNEL_FAILURE(                                                  \
89*523fa7a6SAndroid Build Coastguard Worker       context_, op(input_contiguous, output_channels_last));
90*523fa7a6SAndroid Build Coastguard Worker 
91*523fa7a6SAndroid Build Coastguard Worker #endif // USE_ATEN_LIB
92*523fa7a6SAndroid Build Coastguard Worker 
93*523fa7a6SAndroid Build Coastguard Worker /*
94*523fa7a6SAndroid Build Coastguard Worker  * Common test fixture for kernel / operator-level tests. Provides
95*523fa7a6SAndroid Build Coastguard Worker  * a runtime context object and verifies failure state post-execution.
96*523fa7a6SAndroid Build Coastguard Worker  */
97*523fa7a6SAndroid Build Coastguard Worker class OperatorTest : public ::testing::Test {
98*523fa7a6SAndroid Build Coastguard Worker  public:
OperatorTest()99*523fa7a6SAndroid Build Coastguard Worker   OperatorTest() : expect_failure_(false) {}
100*523fa7a6SAndroid Build Coastguard Worker 
SetUp()101*523fa7a6SAndroid Build Coastguard Worker   void SetUp() override {
102*523fa7a6SAndroid Build Coastguard Worker     torch::executor::runtime_init();
103*523fa7a6SAndroid Build Coastguard Worker   }
104*523fa7a6SAndroid Build Coastguard Worker 
TearDown()105*523fa7a6SAndroid Build Coastguard Worker   void TearDown() override {
106*523fa7a6SAndroid Build Coastguard Worker     // Validate error state.
107*523fa7a6SAndroid Build Coastguard Worker     if (!expect_failure_) {
108*523fa7a6SAndroid Build Coastguard Worker       EXPECT_EQ(context_.failure_state(), torch::executor::Error::Ok);
109*523fa7a6SAndroid Build Coastguard Worker     } else {
110*523fa7a6SAndroid Build Coastguard Worker       EXPECT_NE(context_.failure_state(), torch::executor::Error::Ok);
111*523fa7a6SAndroid Build Coastguard Worker     }
112*523fa7a6SAndroid Build Coastguard Worker   }
113*523fa7a6SAndroid Build Coastguard Worker 
expect_failure()114*523fa7a6SAndroid Build Coastguard Worker   void expect_failure() {
115*523fa7a6SAndroid Build Coastguard Worker     expect_failure_ = true;
116*523fa7a6SAndroid Build Coastguard Worker   }
117*523fa7a6SAndroid Build Coastguard Worker 
118*523fa7a6SAndroid Build Coastguard Worker  protected:
119*523fa7a6SAndroid Build Coastguard Worker   executorch::runtime::KernelRuntimeContext context_;
120*523fa7a6SAndroid Build Coastguard Worker   bool expect_failure_;
121*523fa7a6SAndroid Build Coastguard Worker };
122