1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "RefWorkloadUtils.hpp" 9 #include <armnn/backends/WorkloadData.hpp> 10 #include <armnn/Tensor.hpp> 11 #include <armnn/utility/Assert.hpp> 12 13 namespace armnn 14 { 15 16 template <typename DataType> Splitter(const SplitterQueueDescriptor & data,std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs)17void Splitter(const SplitterQueueDescriptor& data, 18 std::vector<ITensorHandle*> inputs, 19 std::vector<ITensorHandle*> outputs) 20 { 21 const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]); 22 23 for (unsigned int index = 0; index < inputInfo0.GetNumElements(); ++index) 24 { 25 unsigned int indices[MaxNumOfTensorDimensions] = { 0 }; 26 27 unsigned int indexRemainder = index; 28 unsigned int dimensionStride = inputInfo0.GetNumElements(); 29 30 for (unsigned int i = 0; i<inputInfo0.GetNumDimensions(); i++) 31 { 32 dimensionStride /= inputInfo0.GetShape()[i]; 33 indices[i] = indexRemainder / dimensionStride; // Use integer division to round down. 34 indexRemainder -= indices[i] * dimensionStride; 35 } 36 37 for (unsigned int viewIdx = 0; viewIdx < data.m_ViewOrigins.size(); ++viewIdx) 38 { 39 SplitterQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx]; 40 41 //Split view extents are defined by the size of (the corresponding) input tensor. 42 const TensorInfo& outputInfo = GetTensorInfo(outputs[viewIdx]); 43 ARMNN_ASSERT(outputInfo.GetNumDimensions() == inputInfo0.GetNumDimensions()); 44 45 // Check all dimensions to see if this element is inside the given input view. 46 bool insideView = true; 47 for (unsigned int i = 0; i<outputInfo.GetNumDimensions(); i++) 48 { 49 if (indices[i] < view.m_Origin[i]) 50 { 51 insideView = false; 52 } 53 if (indices[i] >= view.m_Origin[i] + outputInfo.GetShape()[i]) 54 { 55 insideView = false; 56 } 57 } 58 59 if (insideView) 60 { 61 unsigned int outIndex = 0; 62 unsigned int dimensionStride = 1; 63 64 for (unsigned int i = outputInfo.GetNumDimensions(); i-- > 0;) 65 { 66 outIndex += dimensionStride * (indices[i] - view.m_Origin[i]); 67 dimensionStride *= outputInfo.GetShape()[i]; 68 } 69 70 //We are within the view, to copy input data to the output corresponding to this view. 71 DataType* outputData = GetOutputTensorData<DataType>(viewIdx, data); 72 ARMNN_ASSERT(outputData); 73 74 const DataType* inputData = GetInputTensorData<DataType>(0, data); 75 ARMNN_ASSERT(inputData); 76 77 outputData[outIndex] = inputData[index]; 78 } 79 } 80 } 81 } 82 83 void Split(const SplitterQueueDescriptor& data, 84 std::vector<ITensorHandle*> inputs, 85 std::vector<ITensorHandle*> outputs); 86 } //namespace armnn 87