xref: /aosp_15_r20/external/armnn/delegate/classic/src/test/ArmnnClassicDelegateTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
7*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <classic/include/armnn_delegate.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/builtin_op_kernels.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/interpreter.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ArmnnDelegate")
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArmnnDelegate Registered")
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker     using namespace tflite;
24*89c4ff92SAndroid Build Coastguard Worker     auto tfLiteInterpreter = std::make_unique<Interpreter>();
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->AddTensors(3);
27*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetInputs({0, 1});
28*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetOutputs({2});
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "input1", {1,2,2,1}, TfLiteQuantization());
31*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetTensorParametersReadWrite(1, kTfLiteFloat32, "input2", {1,2,2,1}, TfLiteQuantization());
32*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetTensorParametersReadWrite(2, kTfLiteFloat32, "output", {1,2,2,1}, TfLiteQuantization());
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     tflite::ops::builtin::BuiltinOpResolver opResolver;
35*89c4ff92SAndroid Build Coastguard Worker     const TfLiteRegistration* opRegister = opResolver.FindOp(BuiltinOperator_ADD, 1);
36*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->AddNodeWithParameters({0, 1}, {2}, "", 0, nullptr, opRegister);
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     // Create the Armnn Delegate
39*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
40*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendOptions> backendOptions;
41*89c4ff92SAndroid Build Coastguard Worker     backendOptions.emplace_back(
42*89c4ff92SAndroid Build Coastguard Worker         armnn::BackendOptions{ "BackendName",
43*89c4ff92SAndroid Build Coastguard Worker                                {
44*89c4ff92SAndroid Build Coastguard Worker                                   { "Option1", 42 },
45*89c4ff92SAndroid Build Coastguard Worker                                   { "Option2", true }
46*89c4ff92SAndroid Build Coastguard Worker                                }}
47*89c4ff92SAndroid Build Coastguard Worker     );
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::DelegateOptions delegateOptions(backends, backendOptions);
50*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
51*89c4ff92SAndroid Build Coastguard Worker                        theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
52*89c4ff92SAndroid Build Coastguard Worker                                         armnnDelegate::TfLiteArmnnDelegateDelete);
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     auto status = tfLiteInterpreter->ModifyGraphWithDelegate(std::move(theArmnnDelegate));
55*89c4ff92SAndroid Build Coastguard Worker     CHECK(status == kTfLiteOk);
56*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter != nullptr);
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArmnnDelegateOptimizerOptionsRegistered")
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker     using namespace tflite;
62*89c4ff92SAndroid Build Coastguard Worker     auto tfLiteInterpreter = std::make_unique<Interpreter>();
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->AddTensors(3);
65*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetInputs({0, 1});
66*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetOutputs({2});
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "input1", {1,2,2,1}, TfLiteQuantization());
69*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetTensorParametersReadWrite(1, kTfLiteFloat32, "input2", {1,2,2,1}, TfLiteQuantization());
70*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->SetTensorParametersReadWrite(2, kTfLiteFloat32, "output", {1,2,2,1}, TfLiteQuantization());
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     tflite::ops::builtin::BuiltinOpResolver opResolver;
73*89c4ff92SAndroid Build Coastguard Worker     const TfLiteRegistration* opRegister = opResolver.FindOp(BuiltinOperator_ADD, 1);
74*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter->AddNodeWithParameters({0, 1}, {2}, "", 0, nullptr, opRegister);
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     // Create the Armnn Delegate
77*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     armnn::OptimizerOptionsOpaque optimizerOptions(true, true, false, true);
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::DelegateOptions delegateOptions(backends, optimizerOptions);
82*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
83*89c4ff92SAndroid Build Coastguard Worker                        theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
84*89c4ff92SAndroid Build Coastguard Worker                                         armnnDelegate::TfLiteArmnnDelegateDelete);
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     auto status = tfLiteInterpreter->ModifyGraphWithDelegate(std::move(theArmnnDelegate));
87*89c4ff92SAndroid Build Coastguard Worker     CHECK(status == kTfLiteOk);
88*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter != nullptr);
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("DelegateOptions_ClassicDelegateDefault")
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker     // Check default options can be created
94*89c4ff92SAndroid Build Coastguard Worker     auto options = TfLiteArmnnDelegateOptionsDefault();
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     // Check Classic delegate created
97*89c4ff92SAndroid Build Coastguard Worker     auto classicDelegate = armnnDelegate::TfLiteArmnnDelegateCreate(options);
98*89c4ff92SAndroid Build Coastguard Worker     CHECK(classicDelegate);
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     // Check Classic Delegate can be deleted
101*89c4ff92SAndroid Build Coastguard Worker     CHECK(classicDelegate->data_);
102*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::TfLiteArmnnDelegateDelete(classicDelegate);
103*89c4ff92SAndroid Build Coastguard Worker }
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
108