1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <DelegateTestInterpreterUtils.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendId.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/core/c/c_api.h> 16*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/kernel_util.h> 17*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/custom_ops_register.h> 18*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h> 19*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/c_api_internal.h> 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker namespace delegateTestInterpreter 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker class DelegateTestInterpreter 25*89c4ff92SAndroid Build Coastguard Worker { 26*89c4ff92SAndroid Build Coastguard Worker public: 27*89c4ff92SAndroid Build Coastguard Worker /// Create TfLite Interpreter only DelegateTestInterpreter(std::vector<char> & modelBuffer,const std::string & customOp="")28*89c4ff92SAndroid Build Coastguard Worker DelegateTestInterpreter(std::vector<char>& modelBuffer, const std::string& customOp = "") 29*89c4ff92SAndroid Build Coastguard Worker { 30*89c4ff92SAndroid Build Coastguard Worker TfLiteModel* model = delegateTestInterpreter::CreateTfLiteModel(modelBuffer); 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptions* options = delegateTestInterpreter::CreateTfLiteInterpreterOptions(); 33*89c4ff92SAndroid Build Coastguard Worker if (!customOp.empty()) 34*89c4ff92SAndroid Build Coastguard Worker { 35*89c4ff92SAndroid Build Coastguard Worker options->mutable_op_resolver = delegateTestInterpreter::GenerateCustomOpResolver(customOp); 36*89c4ff92SAndroid Build Coastguard Worker } 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker m_TfLiteInterpreter = TfLiteInterpreterCreate(model, options); 39*89c4ff92SAndroid Build Coastguard Worker m_TfLiteDelegate = nullptr; 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker // The options and model can be deleted after the interpreter is created. 42*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreterOptionsDelete(options); 43*89c4ff92SAndroid Build Coastguard Worker TfLiteModelDelete(model); 44*89c4ff92SAndroid Build Coastguard Worker } 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker /// Create Interpreter with default Arm NN Classic/Opaque Delegate applied 47*89c4ff92SAndroid Build Coastguard Worker DelegateTestInterpreter(std::vector<char>& model, 48*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BackendId>& backends, 49*89c4ff92SAndroid Build Coastguard Worker const std::string& customOp = "", 50*89c4ff92SAndroid Build Coastguard Worker bool disableFallback = true); 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker /// Create Interpreter with Arm NN Classic/Opaque Delegate applied and DelegateOptions 53*89c4ff92SAndroid Build Coastguard Worker DelegateTestInterpreter(std::vector<char>& model, 54*89c4ff92SAndroid Build Coastguard Worker const armnnDelegate::DelegateOptions& delegateOptions, 55*89c4ff92SAndroid Build Coastguard Worker const std::string& customOp = ""); 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker /// Allocate the TfLiteTensors within the graph. 58*89c4ff92SAndroid Build Coastguard Worker /// This must be called before FillInputTensor(values, index) and Invoke(). AllocateTensors()59*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus AllocateTensors() 60*89c4ff92SAndroid Build Coastguard Worker { 61*89c4ff92SAndroid Build Coastguard Worker return TfLiteInterpreterAllocateTensors(m_TfLiteInterpreter); 62*89c4ff92SAndroid Build Coastguard Worker } 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker /// Copy a buffer of values into an input tensor at a given index. 65*89c4ff92SAndroid Build Coastguard Worker template<typename T> FillInputTensor(std::vector<T> & inputValues,int index)66*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus FillInputTensor(std::vector<T>& inputValues, int index) 67*89c4ff92SAndroid Build Coastguard Worker { 68*89c4ff92SAndroid Build Coastguard Worker TfLiteTensor* inputTensor = delegateTestInterpreter::GetInputTensorFromInterpreter(m_TfLiteInterpreter, index); 69*89c4ff92SAndroid Build Coastguard Worker return delegateTestInterpreter::CopyFromBufferToTensor(inputTensor, inputValues); 70*89c4ff92SAndroid Build Coastguard Worker } 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker /// Copy a boolean buffer of values into an input tensor at a given index. 73*89c4ff92SAndroid Build Coastguard Worker /// Boolean types get converted to a bit representation in a vector. 74*89c4ff92SAndroid Build Coastguard Worker /// vector.data() returns a void pointer instead of a pointer to bool, so the tensor needs to be accessed directly. FillInputTensor(std::vector<bool> & inputValues,int index)75*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus FillInputTensor(std::vector<bool>& inputValues, int index) 76*89c4ff92SAndroid Build Coastguard Worker { 77*89c4ff92SAndroid Build Coastguard Worker TfLiteTensor* inputTensor = delegateTestInterpreter::GetInputTensorFromInterpreter(m_TfLiteInterpreter, index); 78*89c4ff92SAndroid Build Coastguard Worker if(inputTensor->type != kTfLiteBool) 79*89c4ff92SAndroid Build Coastguard Worker { 80*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Input tensor at the given index is not of bool type: " + std::to_string(index)); 81*89c4ff92SAndroid Build Coastguard Worker } 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker // Make sure there is enough bytes allocated to copy into. 84*89c4ff92SAndroid Build Coastguard Worker if(inputTensor->bytes < inputValues.size() * sizeof(bool)) 85*89c4ff92SAndroid Build Coastguard Worker { 86*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Input tensor has not been allocated to match number of input values."); 87*89c4ff92SAndroid Build Coastguard Worker } 88*89c4ff92SAndroid Build Coastguard Worker 89*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputValues.size(); ++i) 90*89c4ff92SAndroid Build Coastguard Worker { 91*89c4ff92SAndroid Build Coastguard Worker inputTensor->data.b[i] = inputValues[i]; 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk; 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker 97*89c4ff92SAndroid Build Coastguard Worker /// Run the interpreter either on TFLite Runtime or Arm NN Delegate. 98*89c4ff92SAndroid Build Coastguard Worker /// AllocateTensors() must be called before Invoke(). Invoke()99*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus Invoke() 100*89c4ff92SAndroid Build Coastguard Worker { 101*89c4ff92SAndroid Build Coastguard Worker return TfLiteInterpreterInvoke(m_TfLiteInterpreter); 102*89c4ff92SAndroid Build Coastguard Worker } 103*89c4ff92SAndroid Build Coastguard Worker 104*89c4ff92SAndroid Build Coastguard Worker /// Return a buffer of values from the output tensor at a given index. 105*89c4ff92SAndroid Build Coastguard Worker /// This must be called after Invoke(). 106*89c4ff92SAndroid Build Coastguard Worker template<typename T> GetOutputResult(int index)107*89c4ff92SAndroid Build Coastguard Worker std::vector<T> GetOutputResult(int index) 108*89c4ff92SAndroid Build Coastguard Worker { 109*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* outputTensor = 110*89c4ff92SAndroid Build Coastguard Worker delegateTestInterpreter::GetOutputTensorFromInterpreter(m_TfLiteInterpreter, index); 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker int64_t n = tflite::NumElements(outputTensor); 113*89c4ff92SAndroid Build Coastguard Worker std::vector<T> output; 114*89c4ff92SAndroid Build Coastguard Worker output.resize(n); 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus status = TfLiteTensorCopyToBuffer(outputTensor, output.data(), output.size() * sizeof(T)); 117*89c4ff92SAndroid Build Coastguard Worker if(status != kTfLiteOk) 118*89c4ff92SAndroid Build Coastguard Worker { 119*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("An error occurred when copying output buffer."); 120*89c4ff92SAndroid Build Coastguard Worker } 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker return output; 123*89c4ff92SAndroid Build Coastguard Worker } 124*89c4ff92SAndroid Build Coastguard Worker 125*89c4ff92SAndroid Build Coastguard Worker /// Return a buffer of values from the output tensor at a given index. This must be called after Invoke(). 126*89c4ff92SAndroid Build Coastguard Worker /// Boolean types get converted to a bit representation in a vector. 127*89c4ff92SAndroid Build Coastguard Worker /// vector.data() returns a void pointer instead of a pointer to bool, so the tensor needs to be accessed directly. GetOutputResult(int index)128*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> GetOutputResult(int index) 129*89c4ff92SAndroid Build Coastguard Worker { 130*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* outputTensor = 131*89c4ff92SAndroid Build Coastguard Worker delegateTestInterpreter::GetOutputTensorFromInterpreter(m_TfLiteInterpreter, index); 132*89c4ff92SAndroid Build Coastguard Worker if(outputTensor->type != kTfLiteBool) 133*89c4ff92SAndroid Build Coastguard Worker { 134*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Output tensor at the given index is not of bool type: " + std::to_string(index)); 135*89c4ff92SAndroid Build Coastguard Worker } 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker int64_t n = tflite::NumElements(outputTensor); 138*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> output(n, false); 139*89c4ff92SAndroid Build Coastguard Worker output.reserve(n); 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < output.size(); ++i) 142*89c4ff92SAndroid Build Coastguard Worker { 143*89c4ff92SAndroid Build Coastguard Worker output[i] = outputTensor->data.b[i]; 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker return output; 146*89c4ff92SAndroid Build Coastguard Worker } 147*89c4ff92SAndroid Build Coastguard Worker 148*89c4ff92SAndroid Build Coastguard Worker /// Return a buffer of dimensions from the output tensor at a given index. GetOutputShape(int index)149*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> GetOutputShape(int index) 150*89c4ff92SAndroid Build Coastguard Worker { 151*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* outputTensor = 152*89c4ff92SAndroid Build Coastguard Worker delegateTestInterpreter::GetOutputTensorFromInterpreter(m_TfLiteInterpreter, index); 153*89c4ff92SAndroid Build Coastguard Worker int32_t numDims = TfLiteTensorNumDims(outputTensor); 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> dims; 156*89c4ff92SAndroid Build Coastguard Worker dims.reserve(numDims); 157*89c4ff92SAndroid Build Coastguard Worker 158*89c4ff92SAndroid Build Coastguard Worker for (int32_t i = 0; i < numDims; ++i) 159*89c4ff92SAndroid Build Coastguard Worker { 160*89c4ff92SAndroid Build Coastguard Worker dims.push_back(TfLiteTensorDim(outputTensor, i)); 161*89c4ff92SAndroid Build Coastguard Worker } 162*89c4ff92SAndroid Build Coastguard Worker return dims; 163*89c4ff92SAndroid Build Coastguard Worker } 164*89c4ff92SAndroid Build Coastguard Worker 165*89c4ff92SAndroid Build Coastguard Worker /// Delete TfLiteInterpreter and the TfLiteDelegate/TfLiteOpaqueDelegate 166*89c4ff92SAndroid Build Coastguard Worker void Cleanup(); 167*89c4ff92SAndroid Build Coastguard Worker 168*89c4ff92SAndroid Build Coastguard Worker private: 169*89c4ff92SAndroid Build Coastguard Worker TfLiteInterpreter* m_TfLiteInterpreter; 170*89c4ff92SAndroid Build Coastguard Worker 171*89c4ff92SAndroid Build Coastguard Worker /// m_TfLiteDelegate can be TfLiteDelegate or TfLiteOpaqueDelegate 172*89c4ff92SAndroid Build Coastguard Worker void* m_TfLiteDelegate; 173*89c4ff92SAndroid Build Coastguard Worker }; 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace