xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefDebugWorkload.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/TypesUtils.hpp>
9 
10 #include "RefBaseWorkload.hpp"
11 
12 namespace armnn
13 {
14 
15 template <armnn::DataType DataType>
16 class RefDebugWorkload : public TypedWorkload<DebugQueueDescriptor, DataType>
17 {
18 public:
RefDebugWorkload(const DebugQueueDescriptor & descriptor,const WorkloadInfo & info)19     RefDebugWorkload(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info)
20     : TypedWorkload<DebugQueueDescriptor, DataType>(descriptor, info)
21     , m_Callback(nullptr) {}
22 
GetName()23     static const std::string& GetName()
24     {
25         static const std::string name = std::string("RefDebug") + GetDataTypeName(DataType) + "Workload";
26         return name;
27     }
28 
29     using TypedWorkload<DebugQueueDescriptor, DataType>::m_Data;
30     using TypedWorkload<DebugQueueDescriptor, DataType>::TypedWorkload;
31 
32     void Execute() const override;
33     void ExecuteAsync(ExecutionData& executionData)  override;
34 
35     void RegisterDebugCallback(const DebugCallbackFunction& func) override;
36 
37 private:
38     void Execute(std::vector<ITensorHandle*> inputs) const;
39     DebugCallbackFunction m_Callback;
40 };
41 
42 using RefDebugBFloat16Workload   = RefDebugWorkload<DataType::BFloat16>;
43 using RefDebugFloat16Workload   = RefDebugWorkload<DataType::Float16>;
44 using RefDebugFloat32Workload   = RefDebugWorkload<DataType::Float32>;
45 using RefDebugQAsymmU8Workload  = RefDebugWorkload<DataType::QAsymmU8>;
46 using RefDebugQAsymmS8Workload  = RefDebugWorkload<DataType::QAsymmS8>;
47 using RefDebugQSymmS16Workload  = RefDebugWorkload<DataType::QSymmS16>;
48 using RefDebugQSymmS8Workload   = RefDebugWorkload<DataType::QSymmS8>;
49 using RefDebugSigned32Workload  = RefDebugWorkload<DataType::Signed32>;
50 
51 } // namespace armnn
52