1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/TypesUtils.hpp> 9 10 #include "RefBaseWorkload.hpp" 11 12 namespace armnn 13 { 14 15 template <armnn::DataType DataType> 16 class RefDebugWorkload : public TypedWorkload<DebugQueueDescriptor, DataType> 17 { 18 public: RefDebugWorkload(const DebugQueueDescriptor & descriptor,const WorkloadInfo & info)19 RefDebugWorkload(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info) 20 : TypedWorkload<DebugQueueDescriptor, DataType>(descriptor, info) 21 , m_Callback(nullptr) {} 22 GetName()23 static const std::string& GetName() 24 { 25 static const std::string name = std::string("RefDebug") + GetDataTypeName(DataType) + "Workload"; 26 return name; 27 } 28 29 using TypedWorkload<DebugQueueDescriptor, DataType>::m_Data; 30 using TypedWorkload<DebugQueueDescriptor, DataType>::TypedWorkload; 31 32 void Execute() const override; 33 void ExecuteAsync(ExecutionData& executionData) override; 34 35 void RegisterDebugCallback(const DebugCallbackFunction& func) override; 36 37 private: 38 void Execute(std::vector<ITensorHandle*> inputs) const; 39 DebugCallbackFunction m_Callback; 40 }; 41 42 using RefDebugBFloat16Workload = RefDebugWorkload<DataType::BFloat16>; 43 using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>; 44 using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>; 45 using RefDebugQAsymmU8Workload = RefDebugWorkload<DataType::QAsymmU8>; 46 using RefDebugQAsymmS8Workload = RefDebugWorkload<DataType::QAsymmS8>; 47 using RefDebugQSymmS16Workload = RefDebugWorkload<DataType::QSymmS16>; 48 using RefDebugQSymmS8Workload = RefDebugWorkload<DataType::QSymmS8>; 49 using RefDebugSigned32Workload = RefDebugWorkload<DataType::Signed32>; 50 51 } // namespace armnn 52