xref: /aosp_15_r20/external/executorch/kernels/test/TestUtil.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /**
10  * @file
11  * Kernel Test utilities.
12  */
13 
14 #pragma once
15 
16 #include <executorch/runtime/core/error.h>
17 #include <executorch/runtime/kernel/kernel_includes.h>
18 #include <executorch/runtime/platform/runtime.h>
19 #include <executorch/test/utils/DeathTest.h>
20 #include <gtest/gtest.h>
21 
22 #ifdef USE_ATEN_LIB
23 /**
24  * Ensure the kernel will fail when `_statement` is executed.
25  * @param _statement Statement to execute.
26  */
27 #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
28   EXPECT_ANY_THROW(_statement)
29 
30 #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \
31   EXPECT_ANY_THROW(_statement)
32 
33 #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS(                                  \
34     tf, op, input_contiguous, expected_contiguous, channels_last_support)    \
35   Tensor input_channels_last = tf.channels_last_like(input_contiguous);      \
36   Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
37                                                                              \
38   Tensor output_contiguous = tf.zeros_like(expected_contiguous);             \
39   Tensor output_channels_last = tf.channels_last_like(output_contiguous);    \
40                                                                              \
41   Tensor ret = op(input_channels_last, output_channels_last);                \
42   if (channels_last_support) {                                               \
43     EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last);           \
44   } else {                                                                   \
45     EXPECT_TENSOR_NE(output_channels_last, expected_channel_last);           \
46   }                                                                          \
47   EXPECT_TENSOR_EQ(output_channels_last, ret);
48 
49 #else
50 
51 #define ET_EXPECT_KERNEL_FAILURE(_context, _statement)              \
52   do {                                                              \
53     _statement;                                                     \
54     expect_failure();                                               \
55     if ((_context).failure_state() == torch::executor::Error::Ok) { \
56       ET_LOG(Error, "Expected kernel failure but found success.");  \
57       ADD_FAILURE();                                                \
58     }                                                               \
59   } while (false)
60 
61 #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _msg) \
62   do {                                                                \
63     _statement;                                                       \
64     expect_failure();                                                 \
65     if ((_context).failure_state() == torch::executor::Error::Ok) {   \
66       ET_LOG(Error, "Expected kernel failure but found success.");    \
67       ADD_FAILURE();                                                  \
68     }                                                                 \
69   } while (false)
70 
71 #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS(                                  \
72     tf, op, input_contiguous, expected_contiguous, channels_last_support)    \
73   Tensor input_channels_last = tf.channels_last_like(input_contiguous);      \
74   Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
75                                                                              \
76   Tensor output_contiguous = tf.zeros_like(expected_contiguous);             \
77   Tensor output_channels_last = tf.channels_last_like(output_contiguous);    \
78                                                                              \
79   Tensor ret = op(input_channels_last, output_channels_last);                \
80   if (channels_last_support) {                                               \
81     EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last);           \
82   } else {                                                                   \
83     EXPECT_TENSOR_NE(output_channels_last, expected_channel_last);           \
84   }                                                                          \
85   EXPECT_TENSOR_EQ(output_channels_last, ret);                               \
86   ET_EXPECT_KERNEL_FAILURE(                                                  \
87       context_, op(input_channels_last, output_contiguous));                 \
88   ET_EXPECT_KERNEL_FAILURE(                                                  \
89       context_, op(input_contiguous, output_channels_last));
90 
91 #endif // USE_ATEN_LIB
92 
93 /*
94  * Common test fixture for kernel / operator-level tests. Provides
95  * a runtime context object and verifies failure state post-execution.
96  */
97 class OperatorTest : public ::testing::Test {
98  public:
OperatorTest()99   OperatorTest() : expect_failure_(false) {}
100 
SetUp()101   void SetUp() override {
102     torch::executor::runtime_init();
103   }
104 
TearDown()105   void TearDown() override {
106     // Validate error state.
107     if (!expect_failure_) {
108       EXPECT_EQ(context_.failure_state(), torch::executor::Error::Ok);
109     } else {
110       EXPECT_NE(context_.failure_state(), torch::executor::Error::Ok);
111     }
112   }
113 
expect_failure()114   void expect_failure() {
115     expect_failure_ = true;
116   }
117 
118  protected:
119   executorch::runtime::KernelRuntimeContext context_;
120   bool expect_failure_;
121 };
122