1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN 7*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <classic/include/armnn_delegate.hpp> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/builtin_op_kernels.h> 12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/interpreter.h> 13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate 16*89c4ff92SAndroid Build Coastguard Worker { 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ArmnnDelegate") 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArmnnDelegate Registered") 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker using namespace tflite; 24*89c4ff92SAndroid Build Coastguard Worker auto tfLiteInterpreter = std::make_unique<Interpreter>(); 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->AddTensors(3); 27*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetInputs({0, 1}); 28*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetOutputs({2}); 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "input1", {1,2,2,1}, TfLiteQuantization()); 31*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetTensorParametersReadWrite(1, kTfLiteFloat32, "input2", {1,2,2,1}, TfLiteQuantization()); 32*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetTensorParametersReadWrite(2, kTfLiteFloat32, "output", {1,2,2,1}, TfLiteQuantization()); 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker tflite::ops::builtin::BuiltinOpResolver opResolver; 35*89c4ff92SAndroid Build Coastguard Worker const TfLiteRegistration* opRegister = opResolver.FindOp(BuiltinOperator_ADD, 1); 36*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->AddNodeWithParameters({0, 1}, {2}, "", 0, nullptr, opRegister); 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker // Create the Armnn Delegate 39*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef }; 40*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendOptions> backendOptions; 41*89c4ff92SAndroid Build Coastguard Worker backendOptions.emplace_back( 42*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions{ "BackendName", 43*89c4ff92SAndroid Build Coastguard Worker { 44*89c4ff92SAndroid Build Coastguard Worker { "Option1", 42 }, 45*89c4ff92SAndroid Build Coastguard Worker { "Option2", true } 46*89c4ff92SAndroid Build Coastguard Worker }} 47*89c4ff92SAndroid Build Coastguard Worker ); 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::DelegateOptions delegateOptions(backends, backendOptions); 50*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)> 51*89c4ff92SAndroid Build Coastguard Worker theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions), 52*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::TfLiteArmnnDelegateDelete); 53*89c4ff92SAndroid Build Coastguard Worker 54*89c4ff92SAndroid Build Coastguard Worker auto status = tfLiteInterpreter->ModifyGraphWithDelegate(std::move(theArmnnDelegate)); 55*89c4ff92SAndroid Build Coastguard Worker CHECK(status == kTfLiteOk); 56*89c4ff92SAndroid Build Coastguard Worker CHECK(tfLiteInterpreter != nullptr); 57*89c4ff92SAndroid Build Coastguard Worker } 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArmnnDelegateOptimizerOptionsRegistered") 60*89c4ff92SAndroid Build Coastguard Worker { 61*89c4ff92SAndroid Build Coastguard Worker using namespace tflite; 62*89c4ff92SAndroid Build Coastguard Worker auto tfLiteInterpreter = std::make_unique<Interpreter>(); 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->AddTensors(3); 65*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetInputs({0, 1}); 66*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetOutputs({2}); 67*89c4ff92SAndroid Build Coastguard Worker 68*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "input1", {1,2,2,1}, TfLiteQuantization()); 69*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetTensorParametersReadWrite(1, kTfLiteFloat32, "input2", {1,2,2,1}, TfLiteQuantization()); 70*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->SetTensorParametersReadWrite(2, kTfLiteFloat32, "output", {1,2,2,1}, TfLiteQuantization()); 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker tflite::ops::builtin::BuiltinOpResolver opResolver; 73*89c4ff92SAndroid Build Coastguard Worker const TfLiteRegistration* opRegister = opResolver.FindOp(BuiltinOperator_ADD, 1); 74*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter->AddNodeWithParameters({0, 1}, {2}, "", 0, nullptr, opRegister); 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker // Create the Armnn Delegate 77*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef }; 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker armnn::OptimizerOptionsOpaque optimizerOptions(true, true, false, true); 80*89c4ff92SAndroid Build Coastguard Worker 81*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::DelegateOptions delegateOptions(backends, optimizerOptions); 82*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)> 83*89c4ff92SAndroid Build Coastguard Worker theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions), 84*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::TfLiteArmnnDelegateDelete); 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker auto status = tfLiteInterpreter->ModifyGraphWithDelegate(std::move(theArmnnDelegate)); 87*89c4ff92SAndroid Build Coastguard Worker CHECK(status == kTfLiteOk); 88*89c4ff92SAndroid Build Coastguard Worker CHECK(tfLiteInterpreter != nullptr); 89*89c4ff92SAndroid Build Coastguard Worker } 90*89c4ff92SAndroid Build Coastguard Worker 91*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("DelegateOptions_ClassicDelegateDefault") 92*89c4ff92SAndroid Build Coastguard Worker { 93*89c4ff92SAndroid Build Coastguard Worker // Check default options can be created 94*89c4ff92SAndroid Build Coastguard Worker auto options = TfLiteArmnnDelegateOptionsDefault(); 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Worker // Check Classic delegate created 97*89c4ff92SAndroid Build Coastguard Worker auto classicDelegate = armnnDelegate::TfLiteArmnnDelegateCreate(options); 98*89c4ff92SAndroid Build Coastguard Worker CHECK(classicDelegate); 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker // Check Classic Delegate can be deleted 101*89c4ff92SAndroid Build Coastguard Worker CHECK(classicDelegate->data_); 102*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::TfLiteArmnnDelegateDelete(classicDelegate); 103*89c4ff92SAndroid Build Coastguard Worker } 104*89c4ff92SAndroid Build Coastguard Worker 105*89c4ff92SAndroid Build Coastguard Worker } 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate 108