1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "armnn/backends/Workload.hpp" 9 10 #include <graph_status.h> 11 #include <model_runner.h> 12 13 #include <memory> 14 #include <string> 15 #include <vector> 16 17 namespace armnn 18 { 19 20 bool TosaRefPreCompiledWorkloadValidate(std::string* reasonIfUnsupported); 21 22 class TosaRefPreCompiledWorkload : public BaseWorkload<PreCompiledQueueDescriptor> 23 { 24 public: 25 TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor, 26 const WorkloadInfo& info); 27 void Execute() const override; 28 29 private: SupportsTensorHandleReplacement() const30 bool SupportsTensorHandleReplacement() const override 31 { 32 return true; 33 } 34 ReplaceInputTensorHandle(ITensorHandle * tensorHandle,unsigned int slot)35 void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override 36 { 37 this->m_Data.m_Inputs[slot] = tensorHandle; 38 } 39 ReplaceOutputTensorHandle(ITensorHandle * tensorHandle,unsigned int slot)40 void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override 41 { 42 this->m_Data.m_Outputs[slot] = tensorHandle; 43 } 44 45 template <typename T> 46 void SetInput(TosaReference::IModelRunner& runner, std::string inputName, uint32_t inputIndex) const; 47 48 template <typename T> 49 void GetOutput(TosaReference::IModelRunner& runner, std::string outputName, uint32_t outputIndex) const; 50 51 WorkloadInfo m_workloadInfo; 52 }; 53 54 } //namespace armnn 55