1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "RefBaseWorkload.hpp" 9 #include <armnn/backends/WorkloadData.hpp> 10 11 #include "RefWorkloadUtils.hpp" 12 13 namespace armnn 14 { 15 16 struct RefShapeWorkload : public RefBaseWorkload<ShapeQueueDescriptor> 17 { 18 public: 19 using RefBaseWorkload<ShapeQueueDescriptor>::RefBaseWorkload; Executearmnn::RefShapeWorkload20 virtual void Execute() const override 21 { 22 Execute(m_Data.m_Inputs, m_Data.m_Outputs); 23 } ExecuteAsyncarmnn::RefShapeWorkload24 void ExecuteAsync(ExecutionData& executionData) override 25 { 26 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); 27 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); 28 } 29 30 private: Executearmnn::RefShapeWorkload31 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const 32 { 33 const TensorShape Shape = GetTensorInfo(inputs[0]).GetShape(); 34 35 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); 36 37 unsigned int numBytes = 38 GetTensorInfo(inputs[0]).GetNumDimensions() * GetDataTypeSize(outputInfo.GetDataType()); 39 40 std::memcpy(outputs[0]->Map(), &Shape, numBytes); 41 outputs[0]->Unmap(); 42 } 43 }; 44 45 } //namespace armnn 46 47 48 49 50