1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 /** 10 * @file 11 * Kernel Test utilities. 12 */ 13 14 #pragma once 15 16 #include <executorch/runtime/core/error.h> 17 #include <executorch/runtime/kernel/kernel_includes.h> 18 #include <executorch/runtime/platform/runtime.h> 19 #include <executorch/test/utils/DeathTest.h> 20 #include <gtest/gtest.h> 21 22 #ifdef USE_ATEN_LIB 23 /** 24 * Ensure the kernel will fail when `_statement` is executed. 25 * @param _statement Statement to execute. 26 */ 27 #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ 28 EXPECT_ANY_THROW(_statement) 29 30 #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \ 31 EXPECT_ANY_THROW(_statement) 32 33 #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ 34 tf, op, input_contiguous, expected_contiguous, channels_last_support) \ 35 Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ 36 Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ 37 \ 38 Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ 39 Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ 40 \ 41 Tensor ret = op(input_channels_last, output_channels_last); \ 42 if (channels_last_support) { \ 43 EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ 44 } else { \ 45 EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ 46 } \ 47 EXPECT_TENSOR_EQ(output_channels_last, ret); 48 49 #else 50 51 #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ 52 do { \ 53 _statement; \ 54 expect_failure(); \ 55 if ((_context).failure_state() == torch::executor::Error::Ok) { \ 56 ET_LOG(Error, "Expected kernel failure but found success."); \ 57 ADD_FAILURE(); \ 58 } \ 59 } while (false) 60 61 #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _msg) \ 62 do { \ 63 _statement; \ 64 expect_failure(); \ 65 if ((_context).failure_state() == torch::executor::Error::Ok) { \ 66 ET_LOG(Error, "Expected kernel failure but found success."); \ 67 ADD_FAILURE(); \ 68 } \ 69 } while (false) 70 71 #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ 72 tf, op, input_contiguous, expected_contiguous, channels_last_support) \ 73 Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ 74 Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ 75 \ 76 Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ 77 Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ 78 \ 79 Tensor ret = op(input_channels_last, output_channels_last); \ 80 if (channels_last_support) { \ 81 EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ 82 } else { \ 83 EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ 84 } \ 85 EXPECT_TENSOR_EQ(output_channels_last, ret); \ 86 ET_EXPECT_KERNEL_FAILURE( \ 87 context_, op(input_channels_last, output_contiguous)); \ 88 ET_EXPECT_KERNEL_FAILURE( \ 89 context_, op(input_contiguous, output_channels_last)); 90 91 #endif // USE_ATEN_LIB 92 93 /* 94 * Common test fixture for kernel / operator-level tests. Provides 95 * a runtime context object and verifies failure state post-execution. 96 */ 97 class OperatorTest : public ::testing::Test { 98 public: OperatorTest()99 OperatorTest() : expect_failure_(false) {} 100 SetUp()101 void SetUp() override { 102 torch::executor::runtime_init(); 103 } 104 TearDown()105 void TearDown() override { 106 // Validate error state. 107 if (!expect_failure_) { 108 EXPECT_EQ(context_.failure_state(), torch::executor::Error::Ok); 109 } else { 110 EXPECT_NE(context_.failure_state(), torch::executor::Error::Ok); 111 } 112 } 113 expect_failure()114 void expect_failure() { 115 expect_failure_ = true; 116 } 117 118 protected: 119 executorch::runtime::KernelRuntimeContext context_; 120 bool expect_failure_; 121 }; 122