xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefElementwiseBinaryWorkload.hpp"
7 
8 #include "Decoders.hpp"
9 #include "ElementwiseFunction.hpp"
10 #include "Encoders.hpp"
11 #include "RefWorkloadUtils.hpp"
12 #include "Maximum.hpp"
13 #include "Minimum.hpp"
14 
15 #include <Profiling.hpp>
16 
17 #include <armnn/TypesUtils.hpp>
18 
19 #include <functional>
20 
21 namespace armnn
22 {
23 
24 template<typename DataType>
ExecuteFunction(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs,BinaryOperation operation)25 void ExecuteFunction(std::vector<ITensorHandle*> inputs,
26                      std::vector<ITensorHandle*> outputs,
27                      BinaryOperation operation)
28 {
29     const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
30     const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
31     const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
32 
33     const TensorShape& inShape0 = inputInfo0.GetShape();
34     const TensorShape& inShape1 = inputInfo1.GetShape();
35     const TensorShape& outShape = outputInfo.GetShape();
36 
37     std::unique_ptr<Decoder<DataType>> input0 = MakeDecoder<DataType>(inputInfo0, inputs[0]->Map());
38     std::unique_ptr<Decoder<DataType>> input1 = MakeDecoder<DataType>(inputInfo1, inputs[1]->Map());
39     std::unique_ptr<Encoder<DataType>> output = MakeEncoder<DataType>(outputInfo, outputs[0]->Map());
40 
41     using AddFunction     = ElementwiseBinaryFunction<std::plus<DataType>>;
42     using DivFunction     = ElementwiseBinaryFunction<std::divides<DataType>>;
43     using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>;
44     using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
45     using MulFunction     = ElementwiseBinaryFunction<std::multiplies<DataType>>;
46     using SubFunction     = ElementwiseBinaryFunction<std::minus<DataType>>;
47 
48     switch (operation)
49     {
50         case BinaryOperation::Add:
51         {
52             AddFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
53             break;
54         }
55         case BinaryOperation::Div:
56         {
57             DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
58             break;
59         }
60         case BinaryOperation::Maximum:
61         {
62             MaximumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
63             break;
64         }
65         case BinaryOperation::Minimum:
66         {
67             MinimumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
68             break;
69         }
70         case BinaryOperation::Mul:
71         {
72             MulFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
73             break;
74         }
75         case BinaryOperation::Sub:
76         {
77             SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
78             break;
79         }
80         default:
81         {
82             throw InvalidArgumentException(std::string("Unsupported binary operation ") +
83                                            GetBinaryOperationAsCString(operation), CHECK_LOCATION());
84         }
85     }
86 }
87 
RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor & desc,const WorkloadInfo & info)88 RefElementwiseBinaryWorkload::RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& desc,
89                                                          const WorkloadInfo& info)
90     : RefBaseWorkload<ElementwiseBinaryQueueDescriptor>(desc, info)
91 {}
92 
Execute() const93 void RefElementwiseBinaryWorkload::Execute() const
94 {
95     Execute(m_Data.m_Inputs, m_Data.m_Outputs);
96 }
97 
ExecuteAsync(ExecutionData & executionData)98 void RefElementwiseBinaryWorkload::ExecuteAsync(ExecutionData& executionData)
99 {
100 
101     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
102     Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
103 }
104 
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const105 void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs,
106                                            std::vector<ITensorHandle*> outputs) const
107 {
108     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseBinaryWorkload_Execute");
109 
110     if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32)
111     {
112         ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation);
113     }
114     else
115     {
116         ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation);
117     }
118 }
119 
120 } // namespace armnn
121