1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "Encoders.hpp" 9 #include "Decoders.hpp" 10 11 #include <armnn/backends/WorkloadData.hpp> 12 13 namespace armnn 14 { 15 16 class BatchMatMul { 17 public: 18 BatchMatMul(const BatchMatMulDescriptor& params, 19 const TensorInfo& inputXInfo, 20 const TensorInfo& inputYInfo, 21 const TensorInfo& outputInfo, 22 Decoder<float>& inputXDecoder, 23 Decoder<float>& inputYDecoder, 24 Encoder<float>& outputEncoder); 25 26 private: 27 enum DataSlot 28 { 29 InputX = 0, 30 InputY = 1, 31 Output = 2 32 }; 33 34 const BatchMatMulDescriptor& params; 35 TensorInfo inputXInfo; 36 TensorInfo inputYInfo; 37 TensorInfo outputInfo; 38 Decoder<float>& inputXDecoder; 39 Decoder<float>& inputYDecoder; 40 Encoder<float>& outputEncoder; 41 42 std::vector<float> inputXData; 43 std::vector<float> inputYData; 44 45 void ApplyBatchMatMul(); 46 47 void ApplyParams(); 48 49 void Transpose(DataSlot type); 50 51 void Adjoint(DataSlot type); 52 53 void RecurseTensor(const TensorInfo& tensorInfo, 54 std::function<void(const std::vector<unsigned int>&)> const& operation, 55 std::vector<unsigned int>& curIdx, 56 unsigned int curDim); 57 58 // Adjusts it for when input tensors are of unequal rank 59 void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul, 60 std::pair<unsigned int, unsigned int>& axesYToMul); 61 62 float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {}); 63 64 void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx); 65 66 // Takes into account broadcasting 67 void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx); 68 69 unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx); 70 }; 71 72 } // namespace armnn