xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/BatchMatMulImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 
11 #include <armnn/backends/WorkloadData.hpp>
12 
13 namespace armnn
14 {
15 
16 class BatchMatMul {
17 public:
18     BatchMatMul(const BatchMatMulDescriptor& params,
19                 const TensorInfo& inputXInfo,
20                 const TensorInfo& inputYInfo,
21                 const TensorInfo& outputInfo,
22                 Decoder<float>& inputXDecoder,
23                 Decoder<float>& inputYDecoder,
24                 Encoder<float>& outputEncoder);
25 
26 private:
27     enum DataSlot
28     {
29         InputX = 0,
30         InputY = 1,
31         Output = 2
32     };
33 
34     const BatchMatMulDescriptor& params;
35     TensorInfo inputXInfo;
36     TensorInfo inputYInfo;
37     TensorInfo outputInfo;
38     Decoder<float>& inputXDecoder;
39     Decoder<float>& inputYDecoder;
40     Encoder<float>& outputEncoder;
41 
42     std::vector<float> inputXData;
43     std::vector<float> inputYData;
44 
45     void ApplyBatchMatMul();
46 
47     void ApplyParams();
48 
49     void Transpose(DataSlot type);
50 
51     void Adjoint(DataSlot type);
52 
53     void RecurseTensor(const TensorInfo& tensorInfo,
54                        std::function<void(const std::vector<unsigned int>&)> const& operation,
55                        std::vector<unsigned int>& curIdx,
56                        unsigned int curDim);
57 
58     // Adjusts it for when input tensors are of unequal rank
59     void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
60                                         std::pair<unsigned int, unsigned int>& axesYToMul);
61 
62     float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
63 
64     void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
65 
66     // Takes into account broadcasting
67     void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
68 
69     unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
70 };
71 
72 } // namespace armnn