xref: /aosp_15_r20/external/executorch/backends/qualcomm/aot/wrappers/TensorParamWrapper.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <executorch/backends/qualcomm/aot/wrappers/ParamWrapper.h>
11 #include <executorch/backends/qualcomm/aot/wrappers/TensorWrapper.h>
12 #include <executorch/runtime/core/error.h>
13 
14 #include <memory>
15 
16 namespace executorch {
17 namespace backends {
18 namespace qnn {
19 class TensorParamWrapper final : public ParamWrapper {
20  public:
TensorParamWrapper(std::string name,std::shared_ptr<TensorWrapper> static_tensor)21   explicit TensorParamWrapper(
22       std::string name,
23       std::shared_ptr<TensorWrapper> static_tensor)
24       : ParamWrapper(QNN_PARAMTYPE_TENSOR, std::move(name)),
25         static_tensor_wrapper_(std::move(static_tensor)) {}
26   // Populate Qnn tensorParam with tensor wrapper
PopulateQnnParam()27   executorch::runtime::Error PopulateQnnParam() override {
28     // executorch::runtime::Error out if underlying tensor is not static:
29     if (!static_tensor_wrapper_->IsTensorStatic())
30       return executorch::runtime::Error::Internal;
31     qnn_param_.tensorParam = static_tensor_wrapper_->CloneTensorStruct();
32     return executorch::runtime::Error::Ok;
33   }
34 
35   // Accessor functions:
GetData()36   const void* GetData() const {
37     return static_tensor_wrapper_->GetStaticTensorData();
38   }
39 
GetTensorWrapper()40   std::shared_ptr<TensorWrapper> GetTensorWrapper() {
41     return static_tensor_wrapper_;
42   }
43 
44  private:
45   std::shared_ptr<TensorWrapper> static_tensor_wrapper_;
46 };
47 } // namespace qnn
48 } // namespace backends
49 } // namespace executorch
50