1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ClBaseWorkload.hpp" 9 10 #include "arm_compute/runtime/Tensor.h" 11 #include "arm_compute/runtime/CL/functions/CLGather.h" 12 #include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h" 13 #include "arm_compute/runtime/CL/functions/CLReductionOperation.h" 14 #include "arm_compute/runtime/CL/functions/CLReshapeLayer.h" 15 16 namespace armnn 17 { 18 arm_compute::Status ClGatherNdWorkloadValidate(const TensorInfo& params, 19 const TensorInfo& indices, 20 const TensorInfo& output); 21 22 class ClGatherNdWorkload : public ClBaseWorkload<GatherNdQueueDescriptor> 23 { 24 public: 25 ClGatherNdWorkload(const GatherNdQueueDescriptor& descriptor, 26 const WorkloadInfo& info, 27 const arm_compute::CLCompileContext& clCompileContext); 28 virtual void Execute() const override; 29 30 private: 31 arm_compute::CLTensor m_FlattenedCoeff; 32 arm_compute::CLTensor m_OutputMul; 33 arm_compute::CLTensor m_FlattenedIndices; 34 arm_compute::CLTensor m_OutputGather; 35 36 mutable arm_compute::CLPixelWiseMultiplication m_MulLayer; 37 mutable arm_compute::CLReductionOperation m_ReduceSumLayer; 38 mutable arm_compute::CLGather m_GatherLayer; 39 mutable arm_compute::CLReshapeLayer m_ReshapeLayer; 40 }; 41 42 } //namespace armnn