1 /* 2 * Copyright (c) Qualcomm Innovation Center, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <executorch/backends/qualcomm/aot/wrappers/ParamWrapper.h> 11 #include <executorch/backends/qualcomm/aot/wrappers/TensorWrapper.h> 12 #include <executorch/runtime/core/error.h> 13 14 #include <memory> 15 16 namespace executorch { 17 namespace backends { 18 namespace qnn { 19 class TensorParamWrapper final : public ParamWrapper { 20 public: TensorParamWrapper(std::string name,std::shared_ptr<TensorWrapper> static_tensor)21 explicit TensorParamWrapper( 22 std::string name, 23 std::shared_ptr<TensorWrapper> static_tensor) 24 : ParamWrapper(QNN_PARAMTYPE_TENSOR, std::move(name)), 25 static_tensor_wrapper_(std::move(static_tensor)) {} 26 // Populate Qnn tensorParam with tensor wrapper PopulateQnnParam()27 executorch::runtime::Error PopulateQnnParam() override { 28 // executorch::runtime::Error out if underlying tensor is not static: 29 if (!static_tensor_wrapper_->IsTensorStatic()) 30 return executorch::runtime::Error::Internal; 31 qnn_param_.tensorParam = static_tensor_wrapper_->CloneTensorStruct(); 32 return executorch::runtime::Error::Ok; 33 } 34 35 // Accessor functions: GetData()36 const void* GetData() const { 37 return static_tensor_wrapper_->GetStaticTensorData(); 38 } 39 GetTensorWrapper()40 std::shared_ptr<TensorWrapper> GetTensorWrapper() { 41 return static_tensor_wrapper_; 42 } 43 44 private: 45 std::shared_ptr<TensorWrapper> static_tensor_wrapper_; 46 }; 47 } // namespace qnn 48 } // namespace backends 49 } // namespace executorch 50