xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/ElementwiseFunction.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "BaseIterator.hpp"
9 #include <armnn/Tensor.hpp>
10 
11 namespace armnn
12 {
13 
14 template <typename Functor>
15 struct ElementwiseBinaryFunction
16 {
17     using OutType = typename Functor::result_type;
18     using InType = typename Functor::first_argument_type;
19 
20     ElementwiseBinaryFunction(const TensorShape& inShape0,
21                               const TensorShape& inShape1,
22                               const TensorShape& outShape,
23                               Decoder<InType>& inData0,
24                               Decoder<InType>& inData1,
25                               Encoder<OutType>& outData);
26 };
27 
28 template <typename Functor>
29 struct ElementwiseUnaryFunction
30 {
31     using OutType = typename Functor::result_type;
32     using InType = typename Functor::argument_type;
33 
34     ElementwiseUnaryFunction(const TensorShape& inShape,
35                              const TensorShape& outShape,
36                              Decoder<InType>& inData,
37                              Encoder<OutType>& outData);
38 };
39 
40 template <typename Functor>
41 struct LogicalBinaryFunction
42 {
43     using OutType = bool;
44     using InType = bool;
45 
46     LogicalBinaryFunction(const TensorShape& inShape0,
47                           const TensorShape& inShape1,
48                           const TensorShape& outShape,
49                           Decoder<InType>& inData0,
50                           Decoder<InType>& inData1,
51                           Encoder<OutType>& outData);
52 };
53 
54 template <typename Functor>
55 struct LogicalUnaryFunction
56 {
57     using OutType = bool;
58     using InType = bool;
59 
60     LogicalUnaryFunction(const TensorShape& inShape,
61                          const TensorShape& outShape,
62                          Decoder<InType>& inData,
63                          Encoder<OutType>& outData);
64 };
65 
66 } //namespace armnn
67