1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker /** 10*523fa7a6SAndroid Build Coastguard Worker * @file 11*523fa7a6SAndroid Build Coastguard Worker * Kernel Test utilities. 12*523fa7a6SAndroid Build Coastguard Worker */ 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker #pragma once 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/error.h> 17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/kernel_includes.h> 18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/runtime.h> 19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/test/utils/DeathTest.h> 20*523fa7a6SAndroid Build Coastguard Worker #include <gtest/gtest.h> 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB 23*523fa7a6SAndroid Build Coastguard Worker /** 24*523fa7a6SAndroid Build Coastguard Worker * Ensure the kernel will fail when `_statement` is executed. 25*523fa7a6SAndroid Build Coastguard Worker * @param _statement Statement to execute. 26*523fa7a6SAndroid Build Coastguard Worker */ 27*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ 28*523fa7a6SAndroid Build Coastguard Worker EXPECT_ANY_THROW(_statement) 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \ 31*523fa7a6SAndroid Build Coastguard Worker EXPECT_ANY_THROW(_statement) 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ 34*523fa7a6SAndroid Build Coastguard Worker tf, op, input_contiguous, expected_contiguous, channels_last_support) \ 35*523fa7a6SAndroid Build Coastguard Worker Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ 36*523fa7a6SAndroid Build Coastguard Worker Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ 37*523fa7a6SAndroid Build Coastguard Worker \ 38*523fa7a6SAndroid Build Coastguard Worker Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ 39*523fa7a6SAndroid Build Coastguard Worker Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ 40*523fa7a6SAndroid Build Coastguard Worker \ 41*523fa7a6SAndroid Build Coastguard Worker Tensor ret = op(input_channels_last, output_channels_last); \ 42*523fa7a6SAndroid Build Coastguard Worker if (channels_last_support) { \ 43*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ 44*523fa7a6SAndroid Build Coastguard Worker } else { \ 45*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ 46*523fa7a6SAndroid Build Coastguard Worker } \ 47*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_EQ(output_channels_last, ret); 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker #else 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ 52*523fa7a6SAndroid Build Coastguard Worker do { \ 53*523fa7a6SAndroid Build Coastguard Worker _statement; \ 54*523fa7a6SAndroid Build Coastguard Worker expect_failure(); \ 55*523fa7a6SAndroid Build Coastguard Worker if ((_context).failure_state() == torch::executor::Error::Ok) { \ 56*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Expected kernel failure but found success."); \ 57*523fa7a6SAndroid Build Coastguard Worker ADD_FAILURE(); \ 58*523fa7a6SAndroid Build Coastguard Worker } \ 59*523fa7a6SAndroid Build Coastguard Worker } while (false) 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _msg) \ 62*523fa7a6SAndroid Build Coastguard Worker do { \ 63*523fa7a6SAndroid Build Coastguard Worker _statement; \ 64*523fa7a6SAndroid Build Coastguard Worker expect_failure(); \ 65*523fa7a6SAndroid Build Coastguard Worker if ((_context).failure_state() == torch::executor::Error::Ok) { \ 66*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Expected kernel failure but found success."); \ 67*523fa7a6SAndroid Build Coastguard Worker ADD_FAILURE(); \ 68*523fa7a6SAndroid Build Coastguard Worker } \ 69*523fa7a6SAndroid Build Coastguard Worker } while (false) 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker #define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ 72*523fa7a6SAndroid Build Coastguard Worker tf, op, input_contiguous, expected_contiguous, channels_last_support) \ 73*523fa7a6SAndroid Build Coastguard Worker Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ 74*523fa7a6SAndroid Build Coastguard Worker Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ 75*523fa7a6SAndroid Build Coastguard Worker \ 76*523fa7a6SAndroid Build Coastguard Worker Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ 77*523fa7a6SAndroid Build Coastguard Worker Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ 78*523fa7a6SAndroid Build Coastguard Worker \ 79*523fa7a6SAndroid Build Coastguard Worker Tensor ret = op(input_channels_last, output_channels_last); \ 80*523fa7a6SAndroid Build Coastguard Worker if (channels_last_support) { \ 81*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ 82*523fa7a6SAndroid Build Coastguard Worker } else { \ 83*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ 84*523fa7a6SAndroid Build Coastguard Worker } \ 85*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_EQ(output_channels_last, ret); \ 86*523fa7a6SAndroid Build Coastguard Worker ET_EXPECT_KERNEL_FAILURE( \ 87*523fa7a6SAndroid Build Coastguard Worker context_, op(input_channels_last, output_contiguous)); \ 88*523fa7a6SAndroid Build Coastguard Worker ET_EXPECT_KERNEL_FAILURE( \ 89*523fa7a6SAndroid Build Coastguard Worker context_, op(input_contiguous, output_channels_last)); 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker #endif // USE_ATEN_LIB 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker /* 94*523fa7a6SAndroid Build Coastguard Worker * Common test fixture for kernel / operator-level tests. Provides 95*523fa7a6SAndroid Build Coastguard Worker * a runtime context object and verifies failure state post-execution. 96*523fa7a6SAndroid Build Coastguard Worker */ 97*523fa7a6SAndroid Build Coastguard Worker class OperatorTest : public ::testing::Test { 98*523fa7a6SAndroid Build Coastguard Worker public: OperatorTest()99*523fa7a6SAndroid Build Coastguard Worker OperatorTest() : expect_failure_(false) {} 100*523fa7a6SAndroid Build Coastguard Worker SetUp()101*523fa7a6SAndroid Build Coastguard Worker void SetUp() override { 102*523fa7a6SAndroid Build Coastguard Worker torch::executor::runtime_init(); 103*523fa7a6SAndroid Build Coastguard Worker } 104*523fa7a6SAndroid Build Coastguard Worker TearDown()105*523fa7a6SAndroid Build Coastguard Worker void TearDown() override { 106*523fa7a6SAndroid Build Coastguard Worker // Validate error state. 107*523fa7a6SAndroid Build Coastguard Worker if (!expect_failure_) { 108*523fa7a6SAndroid Build Coastguard Worker EXPECT_EQ(context_.failure_state(), torch::executor::Error::Ok); 109*523fa7a6SAndroid Build Coastguard Worker } else { 110*523fa7a6SAndroid Build Coastguard Worker EXPECT_NE(context_.failure_state(), torch::executor::Error::Ok); 111*523fa7a6SAndroid Build Coastguard Worker } 112*523fa7a6SAndroid Build Coastguard Worker } 113*523fa7a6SAndroid Build Coastguard Worker expect_failure()114*523fa7a6SAndroid Build Coastguard Worker void expect_failure() { 115*523fa7a6SAndroid Build Coastguard Worker expect_failure_ = true; 116*523fa7a6SAndroid Build Coastguard Worker } 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker protected: 119*523fa7a6SAndroid Build Coastguard Worker executorch::runtime::KernelRuntimeContext context_; 120*523fa7a6SAndroid Build Coastguard Worker bool expect_failure_; 121*523fa7a6SAndroid Build Coastguard Worker }; 122