xref: /aosp_15_r20/external/armnn/delegate/classic/src/test/DelegateTestInterpreter.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <DelegateTestInterpreter.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker namespace delegateTestInterpreter
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
DelegateTestInterpreter(std::vector<char> & modelBuffer,const std::vector<armnn::BackendId> & backends,const std::string & customOp,bool disableFallback)13*89c4ff92SAndroid Build Coastguard Worker DelegateTestInterpreter::DelegateTestInterpreter(std::vector<char>& modelBuffer,
14*89c4ff92SAndroid Build Coastguard Worker                                                  const std::vector<armnn::BackendId>& backends,
15*89c4ff92SAndroid Build Coastguard Worker                                                  const std::string& customOp,
16*89c4ff92SAndroid Build Coastguard Worker                                                  bool disableFallback)
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker     TfLiteModel* tfLiteModel = delegateTestInterpreter::CreateTfLiteModel(modelBuffer);
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterOptions* options = delegateTestInterpreter::CreateTfLiteInterpreterOptions();
21*89c4ff92SAndroid Build Coastguard Worker     if (!customOp.empty())
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         options->mutable_op_resolver = delegateTestInterpreter::GenerateCustomOpResolver(customOp);
24*89c4ff92SAndroid Build Coastguard Worker     }
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     // Disable fallback by default for unit tests unless specified.
27*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::DelegateOptions delegateOptions(backends);
28*89c4ff92SAndroid Build Coastguard Worker     delegateOptions.DisableTfLiteRuntimeFallback(disableFallback);
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     auto armnnDelegate = armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions);
31*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterOptionsAddDelegate(options, armnnDelegate);
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     m_TfLiteDelegate = armnnDelegate;
34*89c4ff92SAndroid Build Coastguard Worker     m_TfLiteInterpreter = TfLiteInterpreterCreate(tfLiteModel, options);
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     // The options and model can be deleted after the interpreter is created.
37*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterOptionsDelete(options);
38*89c4ff92SAndroid Build Coastguard Worker     TfLiteModelDelete(tfLiteModel);
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker 
DelegateTestInterpreter(std::vector<char> & modelBuffer,const armnnDelegate::DelegateOptions & delegateOptions,const std::string & customOp)41*89c4ff92SAndroid Build Coastguard Worker DelegateTestInterpreter::DelegateTestInterpreter(std::vector<char>& modelBuffer,
42*89c4ff92SAndroid Build Coastguard Worker                                                  const armnnDelegate::DelegateOptions& delegateOptions,
43*89c4ff92SAndroid Build Coastguard Worker                                                  const std::string& customOp)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker     TfLiteModel* tfLiteModel = delegateTestInterpreter::CreateTfLiteModel(modelBuffer);
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterOptions* options = delegateTestInterpreter::CreateTfLiteInterpreterOptions();
48*89c4ff92SAndroid Build Coastguard Worker     if (!customOp.empty())
49*89c4ff92SAndroid Build Coastguard Worker     {
50*89c4ff92SAndroid Build Coastguard Worker         options->mutable_op_resolver = delegateTestInterpreter::GenerateCustomOpResolver(customOp);
51*89c4ff92SAndroid Build Coastguard Worker     }
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     auto armnnDelegate = armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions);
54*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterOptionsAddDelegate(options, armnnDelegate);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     m_TfLiteDelegate = armnnDelegate;
57*89c4ff92SAndroid Build Coastguard Worker     m_TfLiteInterpreter = TfLiteInterpreterCreate(tfLiteModel, options);
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     // The options and model can be deleted after the interpreter is created.
60*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterOptionsDelete(options);
61*89c4ff92SAndroid Build Coastguard Worker     TfLiteModelDelete(tfLiteModel);
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker 
Cleanup()64*89c4ff92SAndroid Build Coastguard Worker void DelegateTestInterpreter::Cleanup()
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterDelete(m_TfLiteInterpreter);
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     if (m_TfLiteDelegate)
69*89c4ff92SAndroid Build Coastguard Worker     {
70*89c4ff92SAndroid Build Coastguard Worker         armnnDelegate::TfLiteArmnnDelegateDelete(static_cast<TfLiteDelegate*>(m_TfLiteDelegate));
71*89c4ff92SAndroid Build Coastguard Worker     }
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace