1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <DelegateTestInterpreter.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker namespace delegateTestInterpreter
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker
DelegateTestInterpreter(std::vector<char> & modelBuffer,const std::vector<armnn::BackendId> & backends,const std::string & customOp,bool disableFallback)13*89c4ff92SAndroid Build Coastguard Worker DelegateTestInterpreter::DelegateTestInterpreter(std::vector<char>& modelBuffer,
14*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BackendId>& backends,
15*89c4ff92SAndroid Build Coastguard Worker const std::string& customOp,
16*89c4ff92SAndroid Build Coastguard Worker bool disableFallback)
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker TfLiteModel* tfLiteModel = delegateTestInterpreter::CreateTfLiteModel(modelBuffer);
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptions* options = delegateTestInterpreter::CreateTfLiteInterpreterOptions();
21*89c4ff92SAndroid Build Coastguard Worker if (!customOp.empty())
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker options->mutable_op_resolver = delegateTestInterpreter::GenerateCustomOpResolver(customOp);
24*89c4ff92SAndroid Build Coastguard Worker }
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker // Disable fallback by default for unit tests unless specified.
27*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::DelegateOptions delegateOptions(backends);
28*89c4ff92SAndroid Build Coastguard Worker delegateOptions.DisableTfLiteRuntimeFallback(disableFallback);
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker auto armnnDelegate = armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions);
31*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptionsAddDelegate(options, armnnDelegate);
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker m_TfLiteDelegate = armnnDelegate;
34*89c4ff92SAndroid Build Coastguard Worker m_TfLiteInterpreter = TfLiteInterpreterCreate(tfLiteModel, options);
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker // The options and model can be deleted after the interpreter is created.
37*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptionsDelete(options);
38*89c4ff92SAndroid Build Coastguard Worker TfLiteModelDelete(tfLiteModel);
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
DelegateTestInterpreter(std::vector<char> & modelBuffer,const armnnDelegate::DelegateOptions & delegateOptions,const std::string & customOp)41*89c4ff92SAndroid Build Coastguard Worker DelegateTestInterpreter::DelegateTestInterpreter(std::vector<char>& modelBuffer,
42*89c4ff92SAndroid Build Coastguard Worker const armnnDelegate::DelegateOptions& delegateOptions,
43*89c4ff92SAndroid Build Coastguard Worker const std::string& customOp)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker TfLiteModel* tfLiteModel = delegateTestInterpreter::CreateTfLiteModel(modelBuffer);
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptions* options = delegateTestInterpreter::CreateTfLiteInterpreterOptions();
48*89c4ff92SAndroid Build Coastguard Worker if (!customOp.empty())
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker options->mutable_op_resolver = delegateTestInterpreter::GenerateCustomOpResolver(customOp);
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker auto armnnDelegate = armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions);
54*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptionsAddDelegate(options, armnnDelegate);
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker m_TfLiteDelegate = armnnDelegate;
57*89c4ff92SAndroid Build Coastguard Worker m_TfLiteInterpreter = TfLiteInterpreterCreate(tfLiteModel, options);
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker // The options and model can be deleted after the interpreter is created.
60*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptionsDelete(options);
61*89c4ff92SAndroid Build Coastguard Worker TfLiteModelDelete(tfLiteModel);
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
Cleanup()64*89c4ff92SAndroid Build Coastguard Worker void DelegateTestInterpreter::Cleanup()
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterDelete(m_TfLiteInterpreter);
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker if (m_TfLiteDelegate)
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::TfLiteArmnnDelegateDelete(static_cast<TfLiteDelegate*>(m_TfLiteDelegate));
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace