1 //
2 // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefElementwiseBinaryWorkload.hpp"
7
8 #include "Decoders.hpp"
9 #include "ElementwiseFunction.hpp"
10 #include "Encoders.hpp"
11 #include "RefWorkloadUtils.hpp"
12 #include "Maximum.hpp"
13 #include "Minimum.hpp"
14
15 #include <Profiling.hpp>
16
17 #include <armnn/TypesUtils.hpp>
18
19 #include <functional>
20
21 namespace armnn
22 {
23
24 template<typename DataType>
ExecuteFunction(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs,BinaryOperation operation)25 void ExecuteFunction(std::vector<ITensorHandle*> inputs,
26 std::vector<ITensorHandle*> outputs,
27 BinaryOperation operation)
28 {
29 const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
30 const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
31 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
32
33 const TensorShape& inShape0 = inputInfo0.GetShape();
34 const TensorShape& inShape1 = inputInfo1.GetShape();
35 const TensorShape& outShape = outputInfo.GetShape();
36
37 std::unique_ptr<Decoder<DataType>> input0 = MakeDecoder<DataType>(inputInfo0, inputs[0]->Map());
38 std::unique_ptr<Decoder<DataType>> input1 = MakeDecoder<DataType>(inputInfo1, inputs[1]->Map());
39 std::unique_ptr<Encoder<DataType>> output = MakeEncoder<DataType>(outputInfo, outputs[0]->Map());
40
41 using AddFunction = ElementwiseBinaryFunction<std::plus<DataType>>;
42 using DivFunction = ElementwiseBinaryFunction<std::divides<DataType>>;
43 using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>;
44 using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
45 using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>;
46 using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>;
47
48 switch (operation)
49 {
50 case BinaryOperation::Add:
51 {
52 AddFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
53 break;
54 }
55 case BinaryOperation::Div:
56 {
57 DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
58 break;
59 }
60 case BinaryOperation::Maximum:
61 {
62 MaximumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
63 break;
64 }
65 case BinaryOperation::Minimum:
66 {
67 MinimumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
68 break;
69 }
70 case BinaryOperation::Mul:
71 {
72 MulFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
73 break;
74 }
75 case BinaryOperation::Sub:
76 {
77 SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
78 break;
79 }
80 default:
81 {
82 throw InvalidArgumentException(std::string("Unsupported binary operation ") +
83 GetBinaryOperationAsCString(operation), CHECK_LOCATION());
84 }
85 }
86 }
87
RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor & desc,const WorkloadInfo & info)88 RefElementwiseBinaryWorkload::RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& desc,
89 const WorkloadInfo& info)
90 : RefBaseWorkload<ElementwiseBinaryQueueDescriptor>(desc, info)
91 {}
92
Execute() const93 void RefElementwiseBinaryWorkload::Execute() const
94 {
95 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
96 }
97
ExecuteAsync(ExecutionData & executionData)98 void RefElementwiseBinaryWorkload::ExecuteAsync(ExecutionData& executionData)
99 {
100
101 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
102 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
103 }
104
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const105 void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs,
106 std::vector<ITensorHandle*> outputs) const
107 {
108 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseBinaryWorkload_Execute");
109
110 if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32)
111 {
112 ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation);
113 }
114 else
115 {
116 ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation);
117 }
118 }
119
120 } // namespace armnn
121