xref: /aosp_15_r20/external/executorch/backends/xnnpack/runtime/XNNCompiler.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/xnnpack/runtime/XNNCompiler.h>
10*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/xnnpack/runtime/XNNHeader.h>
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/xnnpack/serialization/schema_generated.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/threadpool.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14*523fa7a6SAndroid Build Coastguard Worker #include <unordered_map>
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker #pragma clang diagnostic ignored "-Wmissing-prototypes"
17*523fa7a6SAndroid Build Coastguard Worker #pragma clang diagnostic ignored "-Wglobal-constructors"
18*523fa7a6SAndroid Build Coastguard Worker 
19*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
20*523fa7a6SAndroid Build Coastguard Worker namespace backends {
21*523fa7a6SAndroid Build Coastguard Worker namespace xnnpack {
22*523fa7a6SAndroid Build Coastguard Worker namespace delegate {
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Error;
25*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::MemoryAllocator;
26*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Result;
27*523fa7a6SAndroid Build Coastguard Worker 
28*523fa7a6SAndroid Build Coastguard Worker /*
29*523fa7a6SAndroid Build Coastguard Worker  * Provide compile-time allocation.
30*523fa7a6SAndroid Build Coastguard Worker  */
31*523fa7a6SAndroid Build Coastguard Worker class CompileAllocator {
32*523fa7a6SAndroid Build Coastguard Worker  public:
33*523fa7a6SAndroid Build Coastguard Worker   /*
34*523fa7a6SAndroid Build Coastguard Worker    * Allocate memory which will be automatically freed at the end
35*523fa7a6SAndroid Build Coastguard Worker    * of the compilation process.
36*523fa7a6SAndroid Build Coastguard Worker    */
allocateTemporary(size_t size)37*523fa7a6SAndroid Build Coastguard Worker   void* allocateTemporary(size_t size) {
38*523fa7a6SAndroid Build Coastguard Worker     auto mem = new uint8_t[size];
39*523fa7a6SAndroid Build Coastguard Worker     temporaries_.emplace_back(mem);
40*523fa7a6SAndroid Build Coastguard Worker     return mem;
41*523fa7a6SAndroid Build Coastguard Worker   }
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker  private:
44*523fa7a6SAndroid Build Coastguard Worker   std::vector<std::unique_ptr<uint8_t[]>> temporaries_;
45*523fa7a6SAndroid Build Coastguard Worker };
46*523fa7a6SAndroid Build Coastguard Worker 
47*523fa7a6SAndroid Build Coastguard Worker // Flatbuffer types
48*523fa7a6SAndroid Build Coastguard Worker using ValuePtr = const fb_xnnpack::XValue*;
49*523fa7a6SAndroid Build Coastguard Worker using NodePtr = const fb_xnnpack::XNode*;
50*523fa7a6SAndroid Build Coastguard Worker using GraphPtr = const fb_xnnpack::XNNGraph*;
51*523fa7a6SAndroid Build Coastguard Worker using DataType = fb_xnnpack::XNNDatatype;
52*523fa7a6SAndroid Build Coastguard Worker 
53*523fa7a6SAndroid Build Coastguard Worker // Type for define node function. This is the function signature
54*523fa7a6SAndroid Build Coastguard Worker // for any function that takes in a flatbuffer node and defines it
55*523fa7a6SAndroid Build Coastguard Worker // into our xnn_subgraph
56*523fa7a6SAndroid Build Coastguard Worker using DefineNodeFunc = Error (*)(
57*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t,
58*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>&,
59*523fa7a6SAndroid Build Coastguard Worker     NodePtr,
60*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph*) noexcept;
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker /*
63*523fa7a6SAndroid Build Coastguard Worker Convert a tensor from fp32 to bf16.
64*523fa7a6SAndroid Build Coastguard Worker */
convertF32TensorToBF16(const float * f32_data,uint16_t * bf16_data_out,size_t numel)65*523fa7a6SAndroid Build Coastguard Worker void convertF32TensorToBF16(
66*523fa7a6SAndroid Build Coastguard Worker     const float* f32_data,
67*523fa7a6SAndroid Build Coastguard Worker     uint16_t* bf16_data_out,
68*523fa7a6SAndroid Build Coastguard Worker     size_t numel) {
69*523fa7a6SAndroid Build Coastguard Worker   for (auto i = 0u; i < numel; i++) {
70*523fa7a6SAndroid Build Coastguard Worker     // Adjust the f32 value such that it rounds properly after truncation.
71*523fa7a6SAndroid Build Coastguard Worker     // Constant factor scales 1+2^-8 to 1+2e-7.
72*523fa7a6SAndroid Build Coastguard Worker     float f32_adjusted = f32_data[i] * 1.00389105f;
73*523fa7a6SAndroid Build Coastguard Worker     uint32_t f32_bits;
74*523fa7a6SAndroid Build Coastguard Worker     memcpy(&f32_bits, &f32_adjusted, sizeof(float));
75*523fa7a6SAndroid Build Coastguard Worker     bf16_data_out[i] = static_cast<uint16_t>(f32_bits >> 16);
76*523fa7a6SAndroid Build Coastguard Worker   }
77*523fa7a6SAndroid Build Coastguard Worker }
78*523fa7a6SAndroid Build Coastguard Worker 
79*523fa7a6SAndroid Build Coastguard Worker /*
80*523fa7a6SAndroid Build Coastguard Worker Gets the output min and output max for a given node operator
81*523fa7a6SAndroid Build Coastguard Worker */
getOutputMinMax(const NodePtr node)82*523fa7a6SAndroid Build Coastguard Worker std::pair<float, float> getOutputMinMax(const NodePtr node) noexcept {
83*523fa7a6SAndroid Build Coastguard Worker   float output_min = -std::numeric_limits<float>::infinity();
84*523fa7a6SAndroid Build Coastguard Worker   float output_max = std::numeric_limits<float>::infinity();
85*523fa7a6SAndroid Build Coastguard Worker   auto output_min_max = node->output_min_max();
86*523fa7a6SAndroid Build Coastguard Worker   if (output_min_max != nullptr) {
87*523fa7a6SAndroid Build Coastguard Worker     output_min = output_min_max->output_min();
88*523fa7a6SAndroid Build Coastguard Worker     output_max = output_min_max->output_max();
89*523fa7a6SAndroid Build Coastguard Worker   }
90*523fa7a6SAndroid Build Coastguard Worker 
91*523fa7a6SAndroid Build Coastguard Worker   return {output_min, output_max};
92*523fa7a6SAndroid Build Coastguard Worker }
93*523fa7a6SAndroid Build Coastguard Worker 
94*523fa7a6SAndroid Build Coastguard Worker /*
95*523fa7a6SAndroid Build Coastguard Worker Converts flatbuffer xnn data type to xnnpack data type
96*523fa7a6SAndroid Build Coastguard Worker */
getDataType(const DataType & data_type)97*523fa7a6SAndroid Build Coastguard Worker xnn_datatype getDataType(const DataType& data_type) {
98*523fa7a6SAndroid Build Coastguard Worker   switch (data_type) {
99*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_fp32:
100*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_fp32;
101*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_fp16:
102*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_fp16;
103*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_qint8:
104*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_qint8;
105*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_quint8:
106*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_quint8;
107*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_qint32:
108*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_qint32;
109*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_qcint8:
110*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_qcint8;
111*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_qcint32:
112*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_qcint32;
113*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_qcint4:
114*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_qcint4;
115*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_qdint8:
116*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_qdint8;
117*523fa7a6SAndroid Build Coastguard Worker     case DataType::xnn_datatype_qbint4:
118*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_qbint4;
119*523fa7a6SAndroid Build Coastguard Worker     default:
120*523fa7a6SAndroid Build Coastguard Worker       return xnn_datatype::xnn_datatype_invalid;
121*523fa7a6SAndroid Build Coastguard Worker   }
122*523fa7a6SAndroid Build Coastguard Worker }
123*523fa7a6SAndroid Build Coastguard Worker 
isQuantizedDataType(const xnn_datatype data_type)124*523fa7a6SAndroid Build Coastguard Worker bool isQuantizedDataType(const xnn_datatype data_type) {
125*523fa7a6SAndroid Build Coastguard Worker   switch (data_type) {
126*523fa7a6SAndroid Build Coastguard Worker     case xnn_datatype::xnn_datatype_qint8:
127*523fa7a6SAndroid Build Coastguard Worker     case xnn_datatype::xnn_datatype_quint8:
128*523fa7a6SAndroid Build Coastguard Worker     case xnn_datatype::xnn_datatype_qint32:
129*523fa7a6SAndroid Build Coastguard Worker     case xnn_datatype::xnn_datatype_qcint8:
130*523fa7a6SAndroid Build Coastguard Worker     case xnn_datatype::xnn_datatype_qcint32:
131*523fa7a6SAndroid Build Coastguard Worker     case xnn_datatype::xnn_datatype_qcint4:
132*523fa7a6SAndroid Build Coastguard Worker     case xnn_datatype::xnn_datatype_qdint8:
133*523fa7a6SAndroid Build Coastguard Worker       return true;
134*523fa7a6SAndroid Build Coastguard Worker     default:
135*523fa7a6SAndroid Build Coastguard Worker       return false;
136*523fa7a6SAndroid Build Coastguard Worker   }
137*523fa7a6SAndroid Build Coastguard Worker }
138*523fa7a6SAndroid Build Coastguard Worker 
139*523fa7a6SAndroid Build Coastguard Worker /**
140*523fa7a6SAndroid Build Coastguard Worker Converts dims from uint32 to size_t. Takes in a flatbuffer vector
141*523fa7a6SAndroid Build Coastguard Worker of uint32_t and returns a std::vector of size_t. XNNPACK takes in
142*523fa7a6SAndroid Build Coastguard Worker dims of size_t* but tensor shape is serialized in flatbuffer as
143*523fa7a6SAndroid Build Coastguard Worker int32_t. As a result, we need to static cast the shapes to size_t
144*523fa7a6SAndroid Build Coastguard Worker */
145*523fa7a6SAndroid Build Coastguard Worker template <typename T = size_t>
flatbufferDimsToVector(const flatbuffers::Vector<uint32_t> * fb_dims)146*523fa7a6SAndroid Build Coastguard Worker std::vector<T> flatbufferDimsToVector(
147*523fa7a6SAndroid Build Coastguard Worker     const flatbuffers::Vector<uint32_t>* fb_dims) {
148*523fa7a6SAndroid Build Coastguard Worker   std::vector<T> dims_data;
149*523fa7a6SAndroid Build Coastguard Worker   dims_data.reserve(fb_dims->size());
150*523fa7a6SAndroid Build Coastguard Worker   for (auto fb_dim : *fb_dims) {
151*523fa7a6SAndroid Build Coastguard Worker     dims_data.push_back(static_cast<T>(fb_dim));
152*523fa7a6SAndroid Build Coastguard Worker   }
153*523fa7a6SAndroid Build Coastguard Worker   return dims_data;
154*523fa7a6SAndroid Build Coastguard Worker }
155*523fa7a6SAndroid Build Coastguard Worker 
156*523fa7a6SAndroid Build Coastguard Worker /**
157*523fa7a6SAndroid Build Coastguard Worker Gets the constant data pointer associated with the given tensor value.
158*523fa7a6SAndroid Build Coastguard Worker Obtaining the constant data pointer can either be from within the flatbuffer
159*523fa7a6SAndroid Build Coastguard Worker payload (deprecated) or via offsets to the constant_data_ptr. If no constant
160*523fa7a6SAndroid Build Coastguard Worker data associated with the tensor value, then returns nullptr.
161*523fa7a6SAndroid Build Coastguard Worker */
getConstantDataPtr(const fb_xnnpack::XNNTensorValue * tensor_value,GraphPtr flatbuffer_graph,const uint8_t * constant_data_ptr)162*523fa7a6SAndroid Build Coastguard Worker const uint8_t* getConstantDataPtr(
163*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNTensorValue* tensor_value,
164*523fa7a6SAndroid Build Coastguard Worker     GraphPtr flatbuffer_graph,
165*523fa7a6SAndroid Build Coastguard Worker     const uint8_t* constant_data_ptr) {
166*523fa7a6SAndroid Build Coastguard Worker   auto buffer_idx = tensor_value->constant_buffer_idx();
167*523fa7a6SAndroid Build Coastguard Worker   if (buffer_idx) {
168*523fa7a6SAndroid Build Coastguard Worker     if (!constant_data_ptr) {
169*523fa7a6SAndroid Build Coastguard Worker       // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
170*523fa7a6SAndroid Build Coastguard Worker       // window
171*523fa7a6SAndroid Build Coastguard Worker       const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
172*523fa7a6SAndroid Build Coastguard Worker       return constant_buffer[buffer_idx]->storage()->data();
173*523fa7a6SAndroid Build Coastguard Worker     } else {
174*523fa7a6SAndroid Build Coastguard Worker       const auto& constant_data_offsets = *flatbuffer_graph->constant_data();
175*523fa7a6SAndroid Build Coastguard Worker       uint64_t constant_data_offset =
176*523fa7a6SAndroid Build Coastguard Worker           constant_data_offsets[buffer_idx]->offset();
177*523fa7a6SAndroid Build Coastguard Worker       return constant_data_ptr + constant_data_offset;
178*523fa7a6SAndroid Build Coastguard Worker     }
179*523fa7a6SAndroid Build Coastguard Worker   }
180*523fa7a6SAndroid Build Coastguard Worker 
181*523fa7a6SAndroid Build Coastguard Worker   return nullptr;
182*523fa7a6SAndroid Build Coastguard Worker }
183*523fa7a6SAndroid Build Coastguard Worker 
184*523fa7a6SAndroid Build Coastguard Worker /**
185*523fa7a6SAndroid Build Coastguard Worker Define serialized tensor value into
186*523fa7a6SAndroid Build Coastguard Worker the subgraph. While also keeping track of the remapped ids from
187*523fa7a6SAndroid Build Coastguard Worker the serialized id to the newly generated id.
188*523fa7a6SAndroid Build Coastguard Worker */
defineTensor(xnn_subgraph_t subgraph_ptr,std::unordered_map<uint32_t,uint32_t> & remapped_ids,ValuePtr value,GraphPtr flatbuffer_graph,const uint8_t * constant_data_ptr,std::vector<uint32_t> & input_ids,std::vector<uint32_t> & output_ids,CompileAllocator & allocator)189*523fa7a6SAndroid Build Coastguard Worker Error defineTensor(
190*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
191*523fa7a6SAndroid Build Coastguard Worker     std::unordered_map<uint32_t, uint32_t>& remapped_ids,
192*523fa7a6SAndroid Build Coastguard Worker     ValuePtr value,
193*523fa7a6SAndroid Build Coastguard Worker     GraphPtr flatbuffer_graph,
194*523fa7a6SAndroid Build Coastguard Worker     const uint8_t* constant_data_ptr,
195*523fa7a6SAndroid Build Coastguard Worker     std::vector<uint32_t>& input_ids,
196*523fa7a6SAndroid Build Coastguard Worker     std::vector<uint32_t>& output_ids,
197*523fa7a6SAndroid Build Coastguard Worker     CompileAllocator& allocator) {
198*523fa7a6SAndroid Build Coastguard Worker   const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
199*523fa7a6SAndroid Build Coastguard Worker   const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
200*523fa7a6SAndroid Build Coastguard Worker 
201*523fa7a6SAndroid Build Coastguard Worker   switch (value->xvalue_union_type()) {
202*523fa7a6SAndroid Build Coastguard Worker     case fb_xnnpack::XValueUnion::XNNTensorValue: {
203*523fa7a6SAndroid Build Coastguard Worker       tensor_value = value->xvalue_union_as_XNNTensorValue();
204*523fa7a6SAndroid Build Coastguard Worker       break;
205*523fa7a6SAndroid Build Coastguard Worker     }
206*523fa7a6SAndroid Build Coastguard Worker     case fb_xnnpack::XValueUnion::XNNQuantizedTensorValue: {
207*523fa7a6SAndroid Build Coastguard Worker       qtensor_value = value->xvalue_union_as_XNNQuantizedTensorValue();
208*523fa7a6SAndroid Build Coastguard Worker       tensor_value = qtensor_value->tensor_value();
209*523fa7a6SAndroid Build Coastguard Worker       break;
210*523fa7a6SAndroid Build Coastguard Worker     }
211*523fa7a6SAndroid Build Coastguard Worker     default: {
212*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OR_RETURN_ERROR(
213*523fa7a6SAndroid Build Coastguard Worker           false,
214*523fa7a6SAndroid Build Coastguard Worker           NotImplemented,
215*523fa7a6SAndroid Build Coastguard Worker           "Unhandled value type: %s",
216*523fa7a6SAndroid Build Coastguard Worker           fb_xnnpack::EnumNameXValueUnion(value->xvalue_union_type()));
217*523fa7a6SAndroid Build Coastguard Worker     }
218*523fa7a6SAndroid Build Coastguard Worker   }
219*523fa7a6SAndroid Build Coastguard Worker 
220*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
221*523fa7a6SAndroid Build Coastguard Worker       tensor_value != nullptr,
222*523fa7a6SAndroid Build Coastguard Worker       Internal,
223*523fa7a6SAndroid Build Coastguard Worker       "Deserialized Tensor is Null, this should never happen");
224*523fa7a6SAndroid Build Coastguard Worker 
225*523fa7a6SAndroid Build Coastguard Worker   // Get tensor dims, here we need to use a vector in order
226*523fa7a6SAndroid Build Coastguard Worker   // to properly convert the uint32_t* to size_t*
227*523fa7a6SAndroid Build Coastguard Worker   std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
228*523fa7a6SAndroid Build Coastguard Worker 
229*523fa7a6SAndroid Build Coastguard Worker   // XNNPACK Id
230*523fa7a6SAndroid Build Coastguard Worker   uint32_t id = XNN_INVALID_VALUE_ID;
231*523fa7a6SAndroid Build Coastguard Worker 
232*523fa7a6SAndroid Build Coastguard Worker   // Get Pointer to constant data from flatbuffer, if its non-constant
233*523fa7a6SAndroid Build Coastguard Worker   // it is a nullptr
234*523fa7a6SAndroid Build Coastguard Worker   const uint8_t* buffer_ptr =
235*523fa7a6SAndroid Build Coastguard Worker       getConstantDataPtr(tensor_value, flatbuffer_graph, constant_data_ptr);
236*523fa7a6SAndroid Build Coastguard Worker 
237*523fa7a6SAndroid Build Coastguard Worker   xnn_status status;
238*523fa7a6SAndroid Build Coastguard Worker   // The type we might have to convert to
239*523fa7a6SAndroid Build Coastguard Worker   auto dq_datatype = getDataType(tensor_value->dq_datatype());
240*523fa7a6SAndroid Build Coastguard Worker 
241*523fa7a6SAndroid Build Coastguard Worker   if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
242*523fa7a6SAndroid Build Coastguard Worker     if (dq_datatype != xnn_datatype::xnn_datatype_qint8) {
243*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OR_RETURN_ERROR(
244*523fa7a6SAndroid Build Coastguard Worker           false,
245*523fa7a6SAndroid Build Coastguard Worker           Internal,
246*523fa7a6SAndroid Build Coastguard Worker           "Only int8_t is supported for dq_datatype for now, got: %d",
247*523fa7a6SAndroid Build Coastguard Worker           dq_datatype);
248*523fa7a6SAndroid Build Coastguard Worker     } else {
249*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OR_RETURN_ERROR(
250*523fa7a6SAndroid Build Coastguard Worker           (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT),
251*523fa7a6SAndroid Build Coastguard Worker           Internal,
252*523fa7a6SAndroid Build Coastguard Worker           "Dynamic quantization of tensor is only allowed for the external input tensor value for now! got flags: %u",
253*523fa7a6SAndroid Build Coastguard Worker           tensor_value->flags());
254*523fa7a6SAndroid Build Coastguard Worker     }
255*523fa7a6SAndroid Build Coastguard Worker   }
256*523fa7a6SAndroid Build Coastguard Worker 
257*523fa7a6SAndroid Build Coastguard Worker   if (qtensor_value == nullptr) {
258*523fa7a6SAndroid Build Coastguard Worker     // FP32 tensor
259*523fa7a6SAndroid Build Coastguard Worker     if (!isQuantizedDataType(dq_datatype)) {
260*523fa7a6SAndroid Build Coastguard Worker       // Define non-quantied tensor
261*523fa7a6SAndroid Build Coastguard Worker       status = xnn_define_tensor_value(
262*523fa7a6SAndroid Build Coastguard Worker           /*subgraph=*/subgraph_ptr,
263*523fa7a6SAndroid Build Coastguard Worker           /*datatype=*/getDataType(tensor_value->datatype()),
264*523fa7a6SAndroid Build Coastguard Worker           /*num_dims=*/tensor_value->num_dims(),
265*523fa7a6SAndroid Build Coastguard Worker           /*dims=*/dims_data.data(),
266*523fa7a6SAndroid Build Coastguard Worker           /*data=*/buffer_ptr,
267*523fa7a6SAndroid Build Coastguard Worker           /*external_id=*/tensor_value->external_id(),
268*523fa7a6SAndroid Build Coastguard Worker           /*flags=*/tensor_value->flags(),
269*523fa7a6SAndroid Build Coastguard Worker           /*id_out=*/&id);
270*523fa7a6SAndroid Build Coastguard Worker     } else if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
271*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OR_RETURN_ERROR(
272*523fa7a6SAndroid Build Coastguard Worker           isQuantizedDataType(dq_datatype),
273*523fa7a6SAndroid Build Coastguard Worker           Internal,
274*523fa7a6SAndroid Build Coastguard Worker           "Dynamic quantization can only produce supported quantized dtypes");
275*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OR_RETURN_ERROR(
276*523fa7a6SAndroid Build Coastguard Worker           tensor_value->external_id() != XNN_INVALID_VALUE_ID,
277*523fa7a6SAndroid Build Coastguard Worker           Internal,
278*523fa7a6SAndroid Build Coastguard Worker           "Dynamic quantization can only work with external inputs for now, got an internal ID");
279*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OR_RETURN_ERROR(
280*523fa7a6SAndroid Build Coastguard Worker           buffer_ptr == nullptr,
281*523fa7a6SAndroid Build Coastguard Worker           Internal,
282*523fa7a6SAndroid Build Coastguard Worker           "Dynamic quantization can only work with external inputs for now, got const data");
283*523fa7a6SAndroid Build Coastguard Worker 
284*523fa7a6SAndroid Build Coastguard Worker       switch (dq_datatype) {
285*523fa7a6SAndroid Build Coastguard Worker         case xnn_datatype::xnn_datatype_qint8: {
286*523fa7a6SAndroid Build Coastguard Worker           // HACK TO Maintain FC/BC for ASR this will be removed after 01/2024
287*523fa7a6SAndroid Build Coastguard Worker 
288*523fa7a6SAndroid Build Coastguard Worker           // When encountering a dynamically quantized tensor via dq_datatype,
289*523fa7a6SAndroid Build Coastguard Worker           // which is the old flow for serializing dynamically quantized linear.
290*523fa7a6SAndroid Build Coastguard Worker           // We replace the definition of a single tensor with a new dynamic
291*523fa7a6SAndroid Build Coastguard Worker           // Quantization pattern. We change the pattern from:
292*523fa7a6SAndroid Build Coastguard Worker           //     serialized_qd_input
293*523fa7a6SAndroid Build Coastguard Worker           //           to
294*523fa7a6SAndroid Build Coastguard Worker           // (fp32_input --> convert --> qdint8_input)
295*523fa7a6SAndroid Build Coastguard Worker 
296*523fa7a6SAndroid Build Coastguard Worker           status = xnn_define_dynamically_quantized_tensor_value(
297*523fa7a6SAndroid Build Coastguard Worker               /*subgraph=*/subgraph_ptr,
298*523fa7a6SAndroid Build Coastguard Worker               /*datatype=*/xnn_datatype_qdint8,
299*523fa7a6SAndroid Build Coastguard Worker               /*num_dims=*/tensor_value->num_dims(),
300*523fa7a6SAndroid Build Coastguard Worker               /*num_nonbatch_dims=*/1, // always do per token quantization
301*523fa7a6SAndroid Build Coastguard Worker               /*dims=*/dims_data.data(),
302*523fa7a6SAndroid Build Coastguard Worker               /*external_id=*/XNN_INVALID_VALUE_ID, // always internal value id
303*523fa7a6SAndroid Build Coastguard Worker               /*flags=*/0, // this is netiher external input or output
304*523fa7a6SAndroid Build Coastguard Worker               /*id_out=*/&id);
305*523fa7a6SAndroid Build Coastguard Worker 
306*523fa7a6SAndroid Build Coastguard Worker           // this is the FP16 or FP32 external value that is being dynamically
307*523fa7a6SAndroid Build Coastguard Worker           // quantized
308*523fa7a6SAndroid Build Coastguard Worker           uint32_t float_id;
309*523fa7a6SAndroid Build Coastguard Worker           enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype());
310*523fa7a6SAndroid Build Coastguard Worker           status = xnn_define_tensor_value(
311*523fa7a6SAndroid Build Coastguard Worker               /*subgraph=*/subgraph_ptr,
312*523fa7a6SAndroid Build Coastguard Worker               /*datatype=*/fp_datatype,
313*523fa7a6SAndroid Build Coastguard Worker               /*num_dims=*/tensor_value->num_dims(),
314*523fa7a6SAndroid Build Coastguard Worker               /*dims=*/dims_data.data(),
315*523fa7a6SAndroid Build Coastguard Worker               /*data=*/buffer_ptr,
316*523fa7a6SAndroid Build Coastguard Worker               /*external_id=*/tensor_value->external_id(),
317*523fa7a6SAndroid Build Coastguard Worker               /*flags=*/tensor_value->flags(),
318*523fa7a6SAndroid Build Coastguard Worker               /*id_out=*/&float_id);
319*523fa7a6SAndroid Build Coastguard Worker 
320*523fa7a6SAndroid Build Coastguard Worker           // Define dynamic conversion from float to qdint8
321*523fa7a6SAndroid Build Coastguard Worker           status = xnn_define_convert(
322*523fa7a6SAndroid Build Coastguard Worker               /*subgraph=*/subgraph_ptr,
323*523fa7a6SAndroid Build Coastguard Worker               /*input_id=*/float_id,
324*523fa7a6SAndroid Build Coastguard Worker               /*output_id=*/id,
325*523fa7a6SAndroid Build Coastguard Worker               /*flags=*/0);
326*523fa7a6SAndroid Build Coastguard Worker           break;
327*523fa7a6SAndroid Build Coastguard Worker         }
328*523fa7a6SAndroid Build Coastguard Worker         default:
329*523fa7a6SAndroid Build Coastguard Worker           ET_CHECK_OR_RETURN_ERROR(
330*523fa7a6SAndroid Build Coastguard Worker               false,
331*523fa7a6SAndroid Build Coastguard Worker               NotImplemented,
332*523fa7a6SAndroid Build Coastguard Worker               "Unhandled Dyanmic Quantization dtype: %d",
333*523fa7a6SAndroid Build Coastguard Worker               dq_datatype);
334*523fa7a6SAndroid Build Coastguard Worker       }
335*523fa7a6SAndroid Build Coastguard Worker     } else {
336*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OR_RETURN_ERROR(false, NotImplemented, "Unhandled fp32 tensor");
337*523fa7a6SAndroid Build Coastguard Worker     }
338*523fa7a6SAndroid Build Coastguard Worker   } else {
339*523fa7a6SAndroid Build Coastguard Worker     // define tensor for quantized
340*523fa7a6SAndroid Build Coastguard Worker     switch (qtensor_value->quant_params_type()) {
341*523fa7a6SAndroid Build Coastguard Worker       case fb_xnnpack::XNNQuantParams::PerTensorQuant: {
342*523fa7a6SAndroid Build Coastguard Worker         auto qparams = qtensor_value->quant_params_as_PerTensorQuant();
343*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
344*523fa7a6SAndroid Build Coastguard Worker             Debug,
345*523fa7a6SAndroid Build Coastguard Worker             "define quant tensor (per tensor): buffer_ptr: %p, scale: %f, zp: %u\n",
346*523fa7a6SAndroid Build Coastguard Worker             buffer_ptr,
347*523fa7a6SAndroid Build Coastguard Worker             qparams->scale(),
348*523fa7a6SAndroid Build Coastguard Worker             qparams->zero_point());
349*523fa7a6SAndroid Build Coastguard Worker         status = xnn_define_quantized_tensor_value(
350*523fa7a6SAndroid Build Coastguard Worker             /*subgraph=*/subgraph_ptr,
351*523fa7a6SAndroid Build Coastguard Worker             /*datatype=*/getDataType(tensor_value->datatype()),
352*523fa7a6SAndroid Build Coastguard Worker             /*zero_point=*/qparams->zero_point(),
353*523fa7a6SAndroid Build Coastguard Worker             /*scale=*/qparams->scale(),
354*523fa7a6SAndroid Build Coastguard Worker             /*num_dims=*/tensor_value->num_dims(),
355*523fa7a6SAndroid Build Coastguard Worker             /*dims=*/dims_data.data(),
356*523fa7a6SAndroid Build Coastguard Worker             /*data=*/buffer_ptr,
357*523fa7a6SAndroid Build Coastguard Worker             /*external_id=*/tensor_value->external_id(),
358*523fa7a6SAndroid Build Coastguard Worker             /*flags=*/tensor_value->flags(),
359*523fa7a6SAndroid Build Coastguard Worker             /*id_out=*/&id);
360*523fa7a6SAndroid Build Coastguard Worker         break;
361*523fa7a6SAndroid Build Coastguard Worker       }
362*523fa7a6SAndroid Build Coastguard Worker       case fb_xnnpack::XNNQuantParams::PerChannelQuant: {
363*523fa7a6SAndroid Build Coastguard Worker         auto qparams = qtensor_value->quant_params_as_PerChannelQuant();
364*523fa7a6SAndroid Build Coastguard Worker         enum xnn_datatype dtype = getDataType(tensor_value->datatype());
365*523fa7a6SAndroid Build Coastguard Worker         int32_t zero_point =
366*523fa7a6SAndroid Build Coastguard Worker             (dtype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0);
367*523fa7a6SAndroid Build Coastguard Worker 
368*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
369*523fa7a6SAndroid Build Coastguard Worker             Debug,
370*523fa7a6SAndroid Build Coastguard Worker             "define quant tensor (per channel): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, dtype: %u, zero_point: %d\n",
371*523fa7a6SAndroid Build Coastguard Worker             buffer_ptr,
372*523fa7a6SAndroid Build Coastguard Worker             qparams->scale()->size(),
373*523fa7a6SAndroid Build Coastguard Worker             qparams->channel_dim(),
374*523fa7a6SAndroid Build Coastguard Worker             dtype,
375*523fa7a6SAndroid Build Coastguard Worker             zero_point);
376*523fa7a6SAndroid Build Coastguard Worker         status = xnn_define_channelwise_quantized_tensor_value_v2(
377*523fa7a6SAndroid Build Coastguard Worker             /*subgraph=*/subgraph_ptr,
378*523fa7a6SAndroid Build Coastguard Worker             /*datatype=*/dtype,
379*523fa7a6SAndroid Build Coastguard Worker             /*zero_point=*/zero_point,
380*523fa7a6SAndroid Build Coastguard Worker             /*scale=*/qparams->scale()->data(),
381*523fa7a6SAndroid Build Coastguard Worker             /*num_dims=*/tensor_value->num_dims(),
382*523fa7a6SAndroid Build Coastguard Worker             /*channel_dim*/ qparams->channel_dim(),
383*523fa7a6SAndroid Build Coastguard Worker             /*dims=*/dims_data.data(),
384*523fa7a6SAndroid Build Coastguard Worker             /*data=*/buffer_ptr,
385*523fa7a6SAndroid Build Coastguard Worker             /*external_id=*/tensor_value->external_id(),
386*523fa7a6SAndroid Build Coastguard Worker             /*flags=*/tensor_value->flags(),
387*523fa7a6SAndroid Build Coastguard Worker             /*id_out=*/&id);
388*523fa7a6SAndroid Build Coastguard Worker         break;
389*523fa7a6SAndroid Build Coastguard Worker       }
390*523fa7a6SAndroid Build Coastguard Worker       case fb_xnnpack::XNNQuantParams::PerChannelGroupQuant: {
391*523fa7a6SAndroid Build Coastguard Worker         xnn_datatype datatype = getDataType(tensor_value->datatype());
392*523fa7a6SAndroid Build Coastguard Worker         ET_CHECK_OR_RETURN_ERROR(
393*523fa7a6SAndroid Build Coastguard Worker             datatype == xnn_datatype::xnn_datatype_qbint4,
394*523fa7a6SAndroid Build Coastguard Worker             Internal,
395*523fa7a6SAndroid Build Coastguard Worker             "Unsupported datatype for per channel group quantization: %d",
396*523fa7a6SAndroid Build Coastguard Worker             datatype);
397*523fa7a6SAndroid Build Coastguard Worker         auto qparams = qtensor_value->quant_params_as_PerChannelGroupQuant();
398*523fa7a6SAndroid Build Coastguard Worker         size_t group_size = qparams->group_size();
399*523fa7a6SAndroid Build Coastguard Worker         size_t output_channels = tensor_value->dims()->Get(0);
400*523fa7a6SAndroid Build Coastguard Worker         size_t input_channels = tensor_value->dims()->Get(1);
401*523fa7a6SAndroid Build Coastguard Worker 
402*523fa7a6SAndroid Build Coastguard Worker         const uint16_t* scale_data = nullptr;
403*523fa7a6SAndroid Build Coastguard Worker         uint32_t scale_numel = 0;
404*523fa7a6SAndroid Build Coastguard Worker 
405*523fa7a6SAndroid Build Coastguard Worker         // Block scales are preferably serialized as bf16 but can also be
406*523fa7a6SAndroid Build Coastguard Worker         // serialized as fp32 for backwards compatability.
407*523fa7a6SAndroid Build Coastguard Worker         if (qparams->scale_bf16() != nullptr) {
408*523fa7a6SAndroid Build Coastguard Worker           scale_data =
409*523fa7a6SAndroid Build Coastguard Worker               static_cast<const uint16_t*>(qparams->scale_bf16()->data());
410*523fa7a6SAndroid Build Coastguard Worker           scale_numel = qparams->scale_bf16()->size();
411*523fa7a6SAndroid Build Coastguard Worker         } else {
412*523fa7a6SAndroid Build Coastguard Worker           // Read fp32 scales, convert to bf16.
413*523fa7a6SAndroid Build Coastguard Worker           auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(
414*523fa7a6SAndroid Build Coastguard Worker               qparams->scale()->size() * sizeof(uint16_t)));
415*523fa7a6SAndroid Build Coastguard Worker           scale_numel = qparams->scale()->size();
416*523fa7a6SAndroid Build Coastguard Worker           convertF32TensorToBF16(
417*523fa7a6SAndroid Build Coastguard Worker               qparams->scale()->data(), conv_buffer, scale_numel);
418*523fa7a6SAndroid Build Coastguard Worker           scale_data = conv_buffer;
419*523fa7a6SAndroid Build Coastguard Worker         }
420*523fa7a6SAndroid Build Coastguard Worker 
421*523fa7a6SAndroid Build Coastguard Worker         ET_CHECK_OR_RETURN_ERROR(
422*523fa7a6SAndroid Build Coastguard Worker             scale_numel == output_channels * input_channels / group_size,
423*523fa7a6SAndroid Build Coastguard Worker             Internal,
424*523fa7a6SAndroid Build Coastguard Worker             "scale size %zu != output channels %zu * group size %zu",
425*523fa7a6SAndroid Build Coastguard Worker             static_cast<size_t>(scale_numel),
426*523fa7a6SAndroid Build Coastguard Worker             output_channels,
427*523fa7a6SAndroid Build Coastguard Worker             group_size);
428*523fa7a6SAndroid Build Coastguard Worker         int32_t zero_point =
429*523fa7a6SAndroid Build Coastguard Worker             (datatype == xnn_datatype::xnn_datatype_qbint4 ? 8 : 0);
430*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
431*523fa7a6SAndroid Build Coastguard Worker             Debug,
432*523fa7a6SAndroid Build Coastguard Worker             "define quant tensor (per channel group): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, grpup_size: %zu, output_channels: %zu, dtype: %u, zero_point: %d, datatype: %d\n",
433*523fa7a6SAndroid Build Coastguard Worker             buffer_ptr,
434*523fa7a6SAndroid Build Coastguard Worker             scale_numel,
435*523fa7a6SAndroid Build Coastguard Worker             qparams->channel_dim(),
436*523fa7a6SAndroid Build Coastguard Worker             group_size,
437*523fa7a6SAndroid Build Coastguard Worker             output_channels,
438*523fa7a6SAndroid Build Coastguard Worker             datatype,
439*523fa7a6SAndroid Build Coastguard Worker             zero_point,
440*523fa7a6SAndroid Build Coastguard Worker             datatype);
441*523fa7a6SAndroid Build Coastguard Worker 
442*523fa7a6SAndroid Build Coastguard Worker         status = xnn_define_blockwise_quantized_tensor_value(
443*523fa7a6SAndroid Build Coastguard Worker             /*subgraph=*/subgraph_ptr,
444*523fa7a6SAndroid Build Coastguard Worker             /*datatype=*/datatype,
445*523fa7a6SAndroid Build Coastguard Worker             /*zero_point=*/zero_point,
446*523fa7a6SAndroid Build Coastguard Worker             /*scale=*/scale_data,
447*523fa7a6SAndroid Build Coastguard Worker             /*num_dims=*/tensor_value->num_dims(),
448*523fa7a6SAndroid Build Coastguard Worker             /*channel_dim=*/qparams->channel_dim(),
449*523fa7a6SAndroid Build Coastguard Worker             /*block_size=*/qparams->group_size(),
450*523fa7a6SAndroid Build Coastguard Worker             /*dims=*/dims_data.data(),
451*523fa7a6SAndroid Build Coastguard Worker             /*data=*/buffer_ptr,
452*523fa7a6SAndroid Build Coastguard Worker             /*external_id=*/tensor_value->external_id(),
453*523fa7a6SAndroid Build Coastguard Worker             /*flags=*/tensor_value->flags(),
454*523fa7a6SAndroid Build Coastguard Worker             /*id_out=*/&id);
455*523fa7a6SAndroid Build Coastguard Worker         break;
456*523fa7a6SAndroid Build Coastguard Worker       }
457*523fa7a6SAndroid Build Coastguard Worker       case fb_xnnpack::XNNQuantParams::PerTokenDynamicQuant: {
458*523fa7a6SAndroid Build Coastguard Worker         auto qparams = qtensor_value->quant_params_as_PerTokenDynamicQuant();
459*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
460*523fa7a6SAndroid Build Coastguard Worker             Debug,
461*523fa7a6SAndroid Build Coastguard Worker             "define quant tensor (dynamic): num_dims: %i, num_nonbatch_dims: %i\n",
462*523fa7a6SAndroid Build Coastguard Worker             tensor_value->num_dims(),
463*523fa7a6SAndroid Build Coastguard Worker             qparams->num_nonbatch_dims());
464*523fa7a6SAndroid Build Coastguard Worker         ET_CHECK_OR_RETURN_ERROR(
465*523fa7a6SAndroid Build Coastguard Worker             buffer_ptr == nullptr,
466*523fa7a6SAndroid Build Coastguard Worker             Internal,
467*523fa7a6SAndroid Build Coastguard Worker             "Dynamically quantized tensor should not have constant data but found non-nullptr");
468*523fa7a6SAndroid Build Coastguard Worker         // TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
469*523fa7a6SAndroid Build Coastguard Worker         ET_CHECK_OR_RETURN_ERROR(
470*523fa7a6SAndroid Build Coastguard Worker             qparams->num_nonbatch_dims() == 1,
471*523fa7a6SAndroid Build Coastguard Worker             Internal,
472*523fa7a6SAndroid Build Coastguard Worker             "Dynamically Quantized Tensors currently only support per token quantization");
473*523fa7a6SAndroid Build Coastguard Worker         status = xnn_define_dynamically_quantized_tensor_value(
474*523fa7a6SAndroid Build Coastguard Worker             /*subgraph=*/subgraph_ptr,
475*523fa7a6SAndroid Build Coastguard Worker             /*datatype=*/getDataType(tensor_value->datatype()),
476*523fa7a6SAndroid Build Coastguard Worker             /*num_dims=*/tensor_value->num_dims(),
477*523fa7a6SAndroid Build Coastguard Worker             /*num_nonbatch_dims*/ qparams->num_nonbatch_dims(),
478*523fa7a6SAndroid Build Coastguard Worker             /*dims=*/dims_data.data(),
479*523fa7a6SAndroid Build Coastguard Worker             /*external_id=*/tensor_value->external_id(),
480*523fa7a6SAndroid Build Coastguard Worker             /*flags=*/tensor_value->flags(),
481*523fa7a6SAndroid Build Coastguard Worker             /*id_out=*/&id);
482*523fa7a6SAndroid Build Coastguard Worker         break;
483*523fa7a6SAndroid Build Coastguard Worker       }
484*523fa7a6SAndroid Build Coastguard Worker       default: {
485*523fa7a6SAndroid Build Coastguard Worker         ET_CHECK_OR_RETURN_ERROR(
486*523fa7a6SAndroid Build Coastguard Worker             false,
487*523fa7a6SAndroid Build Coastguard Worker             NotImplemented,
488*523fa7a6SAndroid Build Coastguard Worker             "Unhandled Quantization Parameters: %s",
489*523fa7a6SAndroid Build Coastguard Worker             fb_xnnpack::EnumNameXNNQuantParams(
490*523fa7a6SAndroid Build Coastguard Worker                 qtensor_value->quant_params_type()));
491*523fa7a6SAndroid Build Coastguard Worker       }
492*523fa7a6SAndroid Build Coastguard Worker     }
493*523fa7a6SAndroid Build Coastguard Worker   }
494*523fa7a6SAndroid Build Coastguard Worker 
495*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
496*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
497*523fa7a6SAndroid Build Coastguard Worker       Internal,
498*523fa7a6SAndroid Build Coastguard Worker       "Failed to define tensor %i with code: %s",
499*523fa7a6SAndroid Build Coastguard Worker       tensor_value->id_out(),
500*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
501*523fa7a6SAndroid Build Coastguard Worker 
502*523fa7a6SAndroid Build Coastguard Worker   // map serialized id to newly generated id
503*523fa7a6SAndroid Build Coastguard Worker   remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
504*523fa7a6SAndroid Build Coastguard Worker 
505*523fa7a6SAndroid Build Coastguard Worker   // Add external ids to either list of input or output ids
506*523fa7a6SAndroid Build Coastguard Worker   if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT) {
507*523fa7a6SAndroid Build Coastguard Worker     input_ids.push_back(tensor_value->external_id());
508*523fa7a6SAndroid Build Coastguard Worker   }
509*523fa7a6SAndroid Build Coastguard Worker   if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
510*523fa7a6SAndroid Build Coastguard Worker     output_ids.push_back(tensor_value->external_id());
511*523fa7a6SAndroid Build Coastguard Worker   }
512*523fa7a6SAndroid Build Coastguard Worker 
513*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
514*523fa7a6SAndroid Build Coastguard Worker };
515*523fa7a6SAndroid Build Coastguard Worker 
516*523fa7a6SAndroid Build Coastguard Worker #define MAYBE_UNUSED(x) (void)(x)
517*523fa7a6SAndroid Build Coastguard Worker 
518*523fa7a6SAndroid Build Coastguard Worker /*
519*523fa7a6SAndroid Build Coastguard Worker Define serialized add node into the subgraph, using the remapped ids
520*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining
521*523fa7a6SAndroid Build Coastguard Worker the tensor value
522*523fa7a6SAndroid Build Coastguard Worker */
defineAddNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)523*523fa7a6SAndroid Build Coastguard Worker Error defineAddNode(
524*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
525*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
526*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
527*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
528*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
529*523fa7a6SAndroid Build Coastguard Worker 
530*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
531*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNAdd();
532*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_add2(
533*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
534*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
535*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
536*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
537*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
538*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
539*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
540*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
541*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
542*523fa7a6SAndroid Build Coastguard Worker       Internal,
543*523fa7a6SAndroid Build Coastguard Worker       "Failed to create add node %i with code: %s",
544*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
545*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
546*523fa7a6SAndroid Build Coastguard Worker 
547*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
548*523fa7a6SAndroid Build Coastguard Worker };
549*523fa7a6SAndroid Build Coastguard Worker 
550*523fa7a6SAndroid Build Coastguard Worker /*
551*523fa7a6SAndroid Build Coastguard Worker Define Minimum operator Node into the subgraph
552*523fa7a6SAndroid Build Coastguard Worker */
defineMinimumNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)553*523fa7a6SAndroid Build Coastguard Worker Error defineMinimumNode(
554*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
555*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
556*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
557*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
558*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
559*523fa7a6SAndroid Build Coastguard Worker 
560*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNMinimum();
561*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_minimum2(
562*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
563*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
564*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
565*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
566*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
567*523fa7a6SAndroid Build Coastguard Worker 
568*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
569*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
570*523fa7a6SAndroid Build Coastguard Worker       Internal,
571*523fa7a6SAndroid Build Coastguard Worker       "Failed to create minumum node %i with code: %s",
572*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
573*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
574*523fa7a6SAndroid Build Coastguard Worker 
575*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
576*523fa7a6SAndroid Build Coastguard Worker };
577*523fa7a6SAndroid Build Coastguard Worker 
578*523fa7a6SAndroid Build Coastguard Worker /*
579*523fa7a6SAndroid Build Coastguard Worker Define subtract operator Node into the subgraph
580*523fa7a6SAndroid Build Coastguard Worker */
defineSubtractNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)581*523fa7a6SAndroid Build Coastguard Worker Error defineSubtractNode(
582*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
583*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
584*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
585*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
586*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
587*523fa7a6SAndroid Build Coastguard Worker 
588*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNSubtract();
589*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
590*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_subtract(
591*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
592*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
593*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
594*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
595*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
596*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
597*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
598*523fa7a6SAndroid Build Coastguard Worker 
599*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
600*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
601*523fa7a6SAndroid Build Coastguard Worker       Internal,
602*523fa7a6SAndroid Build Coastguard Worker       "Failed to create subtract node %i with code: %s",
603*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
604*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
605*523fa7a6SAndroid Build Coastguard Worker 
606*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
607*523fa7a6SAndroid Build Coastguard Worker };
608*523fa7a6SAndroid Build Coastguard Worker 
609*523fa7a6SAndroid Build Coastguard Worker /*
610*523fa7a6SAndroid Build Coastguard Worker Define Multiply operator Node into the subgraph
611*523fa7a6SAndroid Build Coastguard Worker */
defineMultiplyNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)612*523fa7a6SAndroid Build Coastguard Worker Error defineMultiplyNode(
613*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
614*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
615*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
616*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
617*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
618*523fa7a6SAndroid Build Coastguard Worker 
619*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNMultiply();
620*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
621*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_multiply2(
622*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
623*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
624*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
625*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
626*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
627*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
628*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
629*523fa7a6SAndroid Build Coastguard Worker 
630*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
631*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
632*523fa7a6SAndroid Build Coastguard Worker       Internal,
633*523fa7a6SAndroid Build Coastguard Worker       "Failed to create multiply node %i with code: %s",
634*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
635*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
636*523fa7a6SAndroid Build Coastguard Worker 
637*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
638*523fa7a6SAndroid Build Coastguard Worker };
639*523fa7a6SAndroid Build Coastguard Worker 
640*523fa7a6SAndroid Build Coastguard Worker #ifdef ENABLE_XNNPACK_KLEIDI
isQP8(const fb_xnnpack::XNNGraph * graph,const NodePtr node)641*523fa7a6SAndroid Build Coastguard Worker bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
642*523fa7a6SAndroid Build Coastguard Worker   assert(node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNConvert);
643*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNConvert();
644*523fa7a6SAndroid Build Coastguard Worker   auto cvt_output_id = graph_node->output_id();
645*523fa7a6SAndroid Build Coastguard Worker 
646*523fa7a6SAndroid Build Coastguard Worker   auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
647*523fa7a6SAndroid Build Coastguard Worker     assert(
648*523fa7a6SAndroid Build Coastguard Worker         dtype == DataType::xnn_datatype_qdint8 ||
649*523fa7a6SAndroid Build Coastguard Worker         dtype == DataType::xnn_datatype_qbint4);
650*523fa7a6SAndroid Build Coastguard Worker     for (auto value : *graph->xvalues()) {
651*523fa7a6SAndroid Build Coastguard Worker       if (value->xvalue_union_type() !=
652*523fa7a6SAndroid Build Coastguard Worker           fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
653*523fa7a6SAndroid Build Coastguard Worker         continue;
654*523fa7a6SAndroid Build Coastguard Worker       }
655*523fa7a6SAndroid Build Coastguard Worker       auto tensor =
656*523fa7a6SAndroid Build Coastguard Worker           value->xvalue_union_as_XNNQuantizedTensorValue()->tensor_value();
657*523fa7a6SAndroid Build Coastguard Worker       if (tensor->id_out() == id) {
658*523fa7a6SAndroid Build Coastguard Worker         return tensor->datatype() == dtype;
659*523fa7a6SAndroid Build Coastguard Worker       }
660*523fa7a6SAndroid Build Coastguard Worker     }
661*523fa7a6SAndroid Build Coastguard Worker     return false;
662*523fa7a6SAndroid Build Coastguard Worker   };
663*523fa7a6SAndroid Build Coastguard Worker 
664*523fa7a6SAndroid Build Coastguard Worker   // Check if the output tensor is qint8 else bail early.
665*523fa7a6SAndroid Build Coastguard Worker   if (!check_dtype(cvt_output_id, DataType::xnn_datatype_qdint8)) {
666*523fa7a6SAndroid Build Coastguard Worker     return false;
667*523fa7a6SAndroid Build Coastguard Worker   }
668*523fa7a6SAndroid Build Coastguard Worker 
669*523fa7a6SAndroid Build Coastguard Worker   // Find if the convert output is going to the right linear node.
670*523fa7a6SAndroid Build Coastguard Worker   // Assuming if we can find one valid linear node, then we can use QP8
671*523fa7a6SAndroid Build Coastguard Worker   // for all the linear nodes consuming this convert output.
672*523fa7a6SAndroid Build Coastguard Worker   for (auto node : *graph->xnodes()) {
673*523fa7a6SAndroid Build Coastguard Worker     if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
674*523fa7a6SAndroid Build Coastguard Worker       auto linear_node = node->xnode_union_as_XNNFullyConnected();
675*523fa7a6SAndroid Build Coastguard Worker       if (linear_node->input1_id() == cvt_output_id) {
676*523fa7a6SAndroid Build Coastguard Worker         if (check_dtype(
677*523fa7a6SAndroid Build Coastguard Worker                 linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
678*523fa7a6SAndroid Build Coastguard Worker           return true;
679*523fa7a6SAndroid Build Coastguard Worker         }
680*523fa7a6SAndroid Build Coastguard Worker       }
681*523fa7a6SAndroid Build Coastguard Worker     }
682*523fa7a6SAndroid Build Coastguard Worker   }
683*523fa7a6SAndroid Build Coastguard Worker   return false;
684*523fa7a6SAndroid Build Coastguard Worker }
685*523fa7a6SAndroid Build Coastguard Worker #endif // ENABLE_XNNPACK_KLEIDI
686*523fa7a6SAndroid Build Coastguard Worker 
687*523fa7a6SAndroid Build Coastguard Worker /*
688*523fa7a6SAndroid Build Coastguard Worker Define Convert operator Node into the subgraph
689*523fa7a6SAndroid Build Coastguard Worker */
defineConvertNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * flatbuffer_graph)690*523fa7a6SAndroid Build Coastguard Worker Error defineConvertNode(
691*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
692*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
693*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
694*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* flatbuffer_graph) noexcept {
695*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(flatbuffer_graph);
696*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNConvert();
697*523fa7a6SAndroid Build Coastguard Worker 
698*523fa7a6SAndroid Build Coastguard Worker   int32_t flags = graph_node->flags();
699*523fa7a6SAndroid Build Coastguard Worker #ifdef ENABLE_XNNPACK_KLEIDI
700*523fa7a6SAndroid Build Coastguard Worker // This is not currently exposed at include/xnnpack.h yet once it is
701*523fa7a6SAndroid Build Coastguard Worker // we can remove this runtime logic and do this ahead-of-time
702*523fa7a6SAndroid Build Coastguard Worker #define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100;
703*523fa7a6SAndroid Build Coastguard Worker   if (isQP8(flatbuffer_graph, node)) {
704*523fa7a6SAndroid Build Coastguard Worker     flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
705*523fa7a6SAndroid Build Coastguard Worker     ET_LOG(
706*523fa7a6SAndroid Build Coastguard Worker         Debug,
707*523fa7a6SAndroid Build Coastguard Worker         "Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i",
708*523fa7a6SAndroid Build Coastguard Worker         node->debug_handle());
709*523fa7a6SAndroid Build Coastguard Worker   }
710*523fa7a6SAndroid Build Coastguard Worker #endif
711*523fa7a6SAndroid Build Coastguard Worker 
712*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_convert(
713*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
714*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
715*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
716*523fa7a6SAndroid Build Coastguard Worker       flags);
717*523fa7a6SAndroid Build Coastguard Worker 
718*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
719*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
720*523fa7a6SAndroid Build Coastguard Worker       Internal,
721*523fa7a6SAndroid Build Coastguard Worker       "Failed to create convert node %i with code: %s",
722*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
723*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
724*523fa7a6SAndroid Build Coastguard Worker 
725*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
726*523fa7a6SAndroid Build Coastguard Worker };
727*523fa7a6SAndroid Build Coastguard Worker /*
728*523fa7a6SAndroid Build Coastguard Worker Define serialized linear(fully-connected) node into the subgraph using
729*523fa7a6SAndroid Build Coastguard Worker the remapped ids to map the serialized ids, to the new ids generated
730*523fa7a6SAndroid Build Coastguard Worker when defining the tensor values
731*523fa7a6SAndroid Build Coastguard Worker */
defineFullyConnectedNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)732*523fa7a6SAndroid Build Coastguard Worker Error defineFullyConnectedNode(
733*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
734*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
735*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
736*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
737*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
738*523fa7a6SAndroid Build Coastguard Worker 
739*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNFullyConnected();
740*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
741*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_fully_connected(
742*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
743*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
744*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
745*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
746*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->filter_id()),
747*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->bias_id()),
748*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
749*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
750*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
751*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
752*523fa7a6SAndroid Build Coastguard Worker       Internal,
753*523fa7a6SAndroid Build Coastguard Worker       "Failed to create linear node %i, with code: %s",
754*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
755*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
756*523fa7a6SAndroid Build Coastguard Worker 
757*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
758*523fa7a6SAndroid Build Coastguard Worker };
759*523fa7a6SAndroid Build Coastguard Worker 
760*523fa7a6SAndroid Build Coastguard Worker /*
761*523fa7a6SAndroid Build Coastguard Worker Define serialized clamp node into the subgraph, using the remapped ids
762*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining
763*523fa7a6SAndroid Build Coastguard Worker the tensor value
764*523fa7a6SAndroid Build Coastguard Worker */
defineClampNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)765*523fa7a6SAndroid Build Coastguard Worker Error defineClampNode(
766*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
767*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
768*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
769*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
770*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
771*523fa7a6SAndroid Build Coastguard Worker 
772*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
773*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNClamp();
774*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_clamp(
775*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
776*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
777*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
778*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
779*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
780*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
781*523fa7a6SAndroid Build Coastguard Worker 
782*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
783*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
784*523fa7a6SAndroid Build Coastguard Worker       Internal,
785*523fa7a6SAndroid Build Coastguard Worker       "Failed to create hardtanh node %i with code: %s",
786*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
787*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
788*523fa7a6SAndroid Build Coastguard Worker 
789*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
790*523fa7a6SAndroid Build Coastguard Worker }
791*523fa7a6SAndroid Build Coastguard Worker 
792*523fa7a6SAndroid Build Coastguard Worker /*
793*523fa7a6SAndroid Build Coastguard Worker Define serialized softmax node into the subgraph, using the remapped ids
794*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining
795*523fa7a6SAndroid Build Coastguard Worker the tensor value
796*523fa7a6SAndroid Build Coastguard Worker */
defineSoftmaxNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)797*523fa7a6SAndroid Build Coastguard Worker Error defineSoftmaxNode(
798*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
799*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
800*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
801*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
802*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
803*523fa7a6SAndroid Build Coastguard Worker 
804*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNSoftmax();
805*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_softmax(
806*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
807*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
808*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
809*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
810*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
811*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
812*523fa7a6SAndroid Build Coastguard Worker       Internal,
813*523fa7a6SAndroid Build Coastguard Worker       "Failed to create softmax node %i with code: %s",
814*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
815*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
816*523fa7a6SAndroid Build Coastguard Worker 
817*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
818*523fa7a6SAndroid Build Coastguard Worker }
819*523fa7a6SAndroid Build Coastguard Worker 
820*523fa7a6SAndroid Build Coastguard Worker /*
821*523fa7a6SAndroid Build Coastguard Worker Define serialized sigmoid node into the subgraph, using the remapped ids
822*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining
823*523fa7a6SAndroid Build Coastguard Worker the tensor value
824*523fa7a6SAndroid Build Coastguard Worker */
defineSigmoidNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)825*523fa7a6SAndroid Build Coastguard Worker Error defineSigmoidNode(
826*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
827*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
828*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
829*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
830*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
831*523fa7a6SAndroid Build Coastguard Worker 
832*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNSigmoid();
833*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_sigmoid(
834*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
835*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
836*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
837*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
838*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
839*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
840*523fa7a6SAndroid Build Coastguard Worker       Internal,
841*523fa7a6SAndroid Build Coastguard Worker       "Failed to create sigmoid node %i with code: %s",
842*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
843*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
844*523fa7a6SAndroid Build Coastguard Worker 
845*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
846*523fa7a6SAndroid Build Coastguard Worker }
847*523fa7a6SAndroid Build Coastguard Worker 
848*523fa7a6SAndroid Build Coastguard Worker /*
849*523fa7a6SAndroid Build Coastguard Worker Define serialized floor node into the subgraph, using the remapped ids
850*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining
851*523fa7a6SAndroid Build Coastguard Worker the tensor value
852*523fa7a6SAndroid Build Coastguard Worker */
defineFloorNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)853*523fa7a6SAndroid Build Coastguard Worker Error defineFloorNode(
854*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
855*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
856*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
857*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
858*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
859*523fa7a6SAndroid Build Coastguard Worker 
860*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNFloor();
861*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_floor(
862*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
863*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
864*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
865*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
866*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
867*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
868*523fa7a6SAndroid Build Coastguard Worker       Internal,
869*523fa7a6SAndroid Build Coastguard Worker       "Failed to create floor node %i with code: %s",
870*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
871*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
872*523fa7a6SAndroid Build Coastguard Worker 
873*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
874*523fa7a6SAndroid Build Coastguard Worker }
875*523fa7a6SAndroid Build Coastguard Worker 
defineGlobalAvgPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)876*523fa7a6SAndroid Build Coastguard Worker Error defineGlobalAvgPooling2dNode(
877*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
878*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
879*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
880*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
881*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
882*523fa7a6SAndroid Build Coastguard Worker 
883*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNGlobalAvgPooling2d();
884*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
885*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_global_average_pooling_2d(
886*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
887*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
888*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
889*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
890*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
891*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
892*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
893*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
894*523fa7a6SAndroid Build Coastguard Worker       Internal,
895*523fa7a6SAndroid Build Coastguard Worker       "Failed to create global average pooling node %i with code: %s",
896*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
897*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
898*523fa7a6SAndroid Build Coastguard Worker 
899*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
900*523fa7a6SAndroid Build Coastguard Worker }
901*523fa7a6SAndroid Build Coastguard Worker 
defineAvgPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)902*523fa7a6SAndroid Build Coastguard Worker Error defineAvgPooling2dNode(
903*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
904*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
905*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
906*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
907*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
908*523fa7a6SAndroid Build Coastguard Worker 
909*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNAvgPooling2d();
910*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
911*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_average_pooling_2d(
912*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
913*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_top(),
914*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_right(),
915*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_bottom(),
916*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_left(),
917*523fa7a6SAndroid Build Coastguard Worker       graph_node->pooling_height(),
918*523fa7a6SAndroid Build Coastguard Worker       graph_node->pooling_width(),
919*523fa7a6SAndroid Build Coastguard Worker       graph_node->stride_height(),
920*523fa7a6SAndroid Build Coastguard Worker       graph_node->stride_width(),
921*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
922*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
923*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
924*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
925*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
926*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
927*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
928*523fa7a6SAndroid Build Coastguard Worker       Internal,
929*523fa7a6SAndroid Build Coastguard Worker       "Failed to create average pooling node %i with code: %s",
930*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
931*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
932*523fa7a6SAndroid Build Coastguard Worker 
933*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
934*523fa7a6SAndroid Build Coastguard Worker }
935*523fa7a6SAndroid Build Coastguard Worker 
936*523fa7a6SAndroid Build Coastguard Worker /*
937*523fa7a6SAndroid Build Coastguard Worker Define serialized conv2d node into the subgraph, using the remapped ids
938*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
939*523fa7a6SAndroid Build Coastguard Worker tensor value
940*523fa7a6SAndroid Build Coastguard Worker */
defineConv2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)941*523fa7a6SAndroid Build Coastguard Worker Error defineConv2dNode(
942*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
943*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
944*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
945*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
946*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
947*523fa7a6SAndroid Build Coastguard Worker 
948*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNConv2d();
949*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
950*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_convolution_2d(
951*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
952*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_top(),
953*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_right(),
954*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_bottom(),
955*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_left(),
956*523fa7a6SAndroid Build Coastguard Worker       graph_node->kernel_height(),
957*523fa7a6SAndroid Build Coastguard Worker       graph_node->kernel_width(),
958*523fa7a6SAndroid Build Coastguard Worker       graph_node->subsampling_height(),
959*523fa7a6SAndroid Build Coastguard Worker       graph_node->subsampling_width(),
960*523fa7a6SAndroid Build Coastguard Worker       graph_node->dilation_height(),
961*523fa7a6SAndroid Build Coastguard Worker       graph_node->dilation_width(),
962*523fa7a6SAndroid Build Coastguard Worker       graph_node->groups(),
963*523fa7a6SAndroid Build Coastguard Worker       graph_node->group_input_channels(),
964*523fa7a6SAndroid Build Coastguard Worker       graph_node->group_output_channels(),
965*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
966*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
967*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
968*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->filter_id()),
969*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->bias_id()),
970*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
971*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
972*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
973*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
974*523fa7a6SAndroid Build Coastguard Worker       Internal,
975*523fa7a6SAndroid Build Coastguard Worker       "Failed to create convolution node %i with code: %s",
976*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
977*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
978*523fa7a6SAndroid Build Coastguard Worker 
979*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
980*523fa7a6SAndroid Build Coastguard Worker }
981*523fa7a6SAndroid Build Coastguard Worker 
982*523fa7a6SAndroid Build Coastguard Worker /*
983*523fa7a6SAndroid Build Coastguard Worker Define serialized maxpool2d node into the subgraph, using the remapped ids
984*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
985*523fa7a6SAndroid Build Coastguard Worker tensor value
986*523fa7a6SAndroid Build Coastguard Worker */
defineMaxPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)987*523fa7a6SAndroid Build Coastguard Worker Error defineMaxPooling2dNode(
988*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
989*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
990*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
991*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
992*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
993*523fa7a6SAndroid Build Coastguard Worker 
994*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNMaxPooling2d();
995*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
996*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_max_pooling_2d(
997*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
998*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_top(),
999*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_right(),
1000*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_bottom(),
1001*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_left(),
1002*523fa7a6SAndroid Build Coastguard Worker       graph_node->pooling_height(),
1003*523fa7a6SAndroid Build Coastguard Worker       graph_node->pooling_width(),
1004*523fa7a6SAndroid Build Coastguard Worker       graph_node->stride_height(),
1005*523fa7a6SAndroid Build Coastguard Worker       graph_node->stride_width(),
1006*523fa7a6SAndroid Build Coastguard Worker       graph_node->dilation_height(),
1007*523fa7a6SAndroid Build Coastguard Worker       graph_node->dilation_width(),
1008*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
1009*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
1010*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1011*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1012*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1013*523fa7a6SAndroid Build Coastguard Worker 
1014*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1015*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1016*523fa7a6SAndroid Build Coastguard Worker       Internal,
1017*523fa7a6SAndroid Build Coastguard Worker       "Failed to create maxpool2d node %i with code: %s",
1018*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1019*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1020*523fa7a6SAndroid Build Coastguard Worker 
1021*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1022*523fa7a6SAndroid Build Coastguard Worker }
1023*523fa7a6SAndroid Build Coastguard Worker 
1024*523fa7a6SAndroid Build Coastguard Worker /*
1025*523fa7a6SAndroid Build Coastguard Worker Define serialized div node into the subgraph
1026*523fa7a6SAndroid Build Coastguard Worker */
defineDivNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1027*523fa7a6SAndroid Build Coastguard Worker Error defineDivNode(
1028*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1029*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1030*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1031*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1032*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1033*523fa7a6SAndroid Build Coastguard Worker 
1034*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNDiv();
1035*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
1036*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_divide(
1037*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1038*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
1039*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
1040*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1041*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
1042*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1043*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1044*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1045*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1046*523fa7a6SAndroid Build Coastguard Worker       Internal,
1047*523fa7a6SAndroid Build Coastguard Worker       "Failed to create div node %i with code: %s",
1048*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1049*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1050*523fa7a6SAndroid Build Coastguard Worker 
1051*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1052*523fa7a6SAndroid Build Coastguard Worker }
1053*523fa7a6SAndroid Build Coastguard Worker 
1054*523fa7a6SAndroid Build Coastguard Worker /*
1055*523fa7a6SAndroid Build Coastguard Worker Define serialized static transpose node into the subgraph, using the remapped
1056*523fa7a6SAndroid Build Coastguard Worker ids to map the serialized ids, to the new ids generated when defining the
1057*523fa7a6SAndroid Build Coastguard Worker tensor value
1058*523fa7a6SAndroid Build Coastguard Worker */
defineStaticTransposeNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1059*523fa7a6SAndroid Build Coastguard Worker Error defineStaticTransposeNode(
1060*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1061*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1062*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1063*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1064*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1065*523fa7a6SAndroid Build Coastguard Worker 
1066*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNStaticTranspose();
1067*523fa7a6SAndroid Build Coastguard Worker 
1068*523fa7a6SAndroid Build Coastguard Worker   // Get tensor dims, we need to convert the uint32_t* to size_t*
1069*523fa7a6SAndroid Build Coastguard Worker   std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
1070*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_static_transpose(
1071*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1072*523fa7a6SAndroid Build Coastguard Worker       graph_node->num_dims(),
1073*523fa7a6SAndroid Build Coastguard Worker       dims_data.data(),
1074*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1075*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1076*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1077*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1078*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1079*523fa7a6SAndroid Build Coastguard Worker       Internal,
1080*523fa7a6SAndroid Build Coastguard Worker       "Failed to create sigmoid node %i with code: %s",
1081*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1082*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1083*523fa7a6SAndroid Build Coastguard Worker 
1084*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1085*523fa7a6SAndroid Build Coastguard Worker }
1086*523fa7a6SAndroid Build Coastguard Worker 
1087*523fa7a6SAndroid Build Coastguard Worker /*
1088*523fa7a6SAndroid Build Coastguard Worker Define serialized static resize bilinear 2d node into the subgraph, using the
1089*523fa7a6SAndroid Build Coastguard Worker remapped ids to map the serialized ids, to the new ids generated when defining
1090*523fa7a6SAndroid Build Coastguard Worker the tensor value
1091*523fa7a6SAndroid Build Coastguard Worker */
defineStaticResizeBilinear2DNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1092*523fa7a6SAndroid Build Coastguard Worker Error defineStaticResizeBilinear2DNode(
1093*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1094*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1095*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1096*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1097*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1098*523fa7a6SAndroid Build Coastguard Worker 
1099*523fa7a6SAndroid Build Coastguard Worker   const fb_xnnpack::XNNStaticResizeBilinear2D* graph_node =
1100*523fa7a6SAndroid Build Coastguard Worker       node->xnode_union_as_XNNStaticResizeBilinear2D();
1101*523fa7a6SAndroid Build Coastguard Worker 
1102*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_static_resize_bilinear_2d(
1103*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1104*523fa7a6SAndroid Build Coastguard Worker       graph_node->new_height(),
1105*523fa7a6SAndroid Build Coastguard Worker       graph_node->new_width(),
1106*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1107*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1108*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1109*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1110*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1111*523fa7a6SAndroid Build Coastguard Worker       Internal,
1112*523fa7a6SAndroid Build Coastguard Worker       "Failed to create StaticResizeBilinear2DNode node %i with code: %s",
1113*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1114*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1115*523fa7a6SAndroid Build Coastguard Worker 
1116*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1117*523fa7a6SAndroid Build Coastguard Worker }
1118*523fa7a6SAndroid Build Coastguard Worker 
1119*523fa7a6SAndroid Build Coastguard Worker /*
1120*523fa7a6SAndroid Build Coastguard Worker Define serialized static constant pad node into the subgraph, using the
1121*523fa7a6SAndroid Build Coastguard Worker remapped ids to map the serialized ids, to the new ids generated when defining
1122*523fa7a6SAndroid Build Coastguard Worker the tensor value
1123*523fa7a6SAndroid Build Coastguard Worker */
defineStaticConstantPadNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1124*523fa7a6SAndroid Build Coastguard Worker Error defineStaticConstantPadNode(
1125*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1126*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1127*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1128*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1129*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1130*523fa7a6SAndroid Build Coastguard Worker 
1131*523fa7a6SAndroid Build Coastguard Worker   const fb_xnnpack::XNNStaticConstantPad* graph_node =
1132*523fa7a6SAndroid Build Coastguard Worker       node->xnode_union_as_XNNStaticConstantPad();
1133*523fa7a6SAndroid Build Coastguard Worker 
1134*523fa7a6SAndroid Build Coastguard Worker   std::vector<size_t> pre_paddings_dims =
1135*523fa7a6SAndroid Build Coastguard Worker       flatbufferDimsToVector(graph_node->pre_paddings());
1136*523fa7a6SAndroid Build Coastguard Worker   std::vector<size_t> post_paddings_dims =
1137*523fa7a6SAndroid Build Coastguard Worker       flatbufferDimsToVector(graph_node->post_paddings());
1138*523fa7a6SAndroid Build Coastguard Worker 
1139*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_static_constant_pad(
1140*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1141*523fa7a6SAndroid Build Coastguard Worker       pre_paddings_dims.data(),
1142*523fa7a6SAndroid Build Coastguard Worker       post_paddings_dims.data(),
1143*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_value(),
1144*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1145*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1146*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1147*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1148*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1149*523fa7a6SAndroid Build Coastguard Worker       Internal,
1150*523fa7a6SAndroid Build Coastguard Worker       "Failed to create StaticConstantPad node %i with code: %s",
1151*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1152*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1153*523fa7a6SAndroid Build Coastguard Worker 
1154*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1155*523fa7a6SAndroid Build Coastguard Worker }
1156*523fa7a6SAndroid Build Coastguard Worker 
1157*523fa7a6SAndroid Build Coastguard Worker /*
1158*523fa7a6SAndroid Build Coastguard Worker Define serialized depthwise conv2d node into the subgraph, using the remapped
1159*523fa7a6SAndroid Build Coastguard Worker ids to map the serialized ids, to the new ids generated when defining the
1160*523fa7a6SAndroid Build Coastguard Worker tensor value
1161*523fa7a6SAndroid Build Coastguard Worker */
defineDepthwiseConv2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1162*523fa7a6SAndroid Build Coastguard Worker Error defineDepthwiseConv2dNode(
1163*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1164*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1165*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1166*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1167*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1168*523fa7a6SAndroid Build Coastguard Worker 
1169*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNDepthwiseConv2d();
1170*523fa7a6SAndroid Build Coastguard Worker   std::pair<float, float> min_max = getOutputMinMax(node);
1171*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_depthwise_convolution_2d(
1172*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1173*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_top(),
1174*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_right(),
1175*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_bottom(),
1176*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_left(),
1177*523fa7a6SAndroid Build Coastguard Worker       graph_node->kernel_height(),
1178*523fa7a6SAndroid Build Coastguard Worker       graph_node->kernel_width(),
1179*523fa7a6SAndroid Build Coastguard Worker       graph_node->subsampling_height(),
1180*523fa7a6SAndroid Build Coastguard Worker       graph_node->subsampling_width(),
1181*523fa7a6SAndroid Build Coastguard Worker       graph_node->dilation_height(),
1182*523fa7a6SAndroid Build Coastguard Worker       graph_node->dilation_width(),
1183*523fa7a6SAndroid Build Coastguard Worker       graph_node->group_output_channels() /
1184*523fa7a6SAndroid Build Coastguard Worker           graph_node->group_input_channels(), // depth_multiplier
1185*523fa7a6SAndroid Build Coastguard Worker       graph_node->groups(), // input_channels = groups for depthwise conv
1186*523fa7a6SAndroid Build Coastguard Worker       min_max.first,
1187*523fa7a6SAndroid Build Coastguard Worker       min_max.second,
1188*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1189*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->filter_id()),
1190*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->bias_id()),
1191*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1192*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1193*523fa7a6SAndroid Build Coastguard Worker 
1194*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1195*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1196*523fa7a6SAndroid Build Coastguard Worker       Internal,
1197*523fa7a6SAndroid Build Coastguard Worker       "Failed to create depthwise convolution node %i with code: %s",
1198*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1199*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1200*523fa7a6SAndroid Build Coastguard Worker 
1201*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1202*523fa7a6SAndroid Build Coastguard Worker }
1203*523fa7a6SAndroid Build Coastguard Worker 
defineStaticReshapeNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1204*523fa7a6SAndroid Build Coastguard Worker Error defineStaticReshapeNode(
1205*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1206*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1207*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1208*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1209*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1210*523fa7a6SAndroid Build Coastguard Worker 
1211*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNStaticReshape();
1212*523fa7a6SAndroid Build Coastguard Worker 
1213*523fa7a6SAndroid Build Coastguard Worker   // Get tensor dims, we need to convert the uint32_t* to size_t*
1214*523fa7a6SAndroid Build Coastguard Worker   std::vector<size_t> dims_data =
1215*523fa7a6SAndroid Build Coastguard Worker       flatbufferDimsToVector(graph_node->new_shape());
1216*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_static_reshape(
1217*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1218*523fa7a6SAndroid Build Coastguard Worker       graph_node->num_dims(),
1219*523fa7a6SAndroid Build Coastguard Worker       dims_data.data(),
1220*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1221*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1222*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1223*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1224*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1225*523fa7a6SAndroid Build Coastguard Worker       Internal,
1226*523fa7a6SAndroid Build Coastguard Worker       "Failed to create squeeze node %i with code: %s",
1227*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1228*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1229*523fa7a6SAndroid Build Coastguard Worker 
1230*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1231*523fa7a6SAndroid Build Coastguard Worker }
1232*523fa7a6SAndroid Build Coastguard Worker 
1233*523fa7a6SAndroid Build Coastguard Worker /*
1234*523fa7a6SAndroid Build Coastguard Worker Define serialized maxpool2d node into the subgraph, using the remapped ids
1235*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
1236*523fa7a6SAndroid Build Coastguard Worker tensor value
1237*523fa7a6SAndroid Build Coastguard Worker */
defineArgMaxPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1238*523fa7a6SAndroid Build Coastguard Worker Error defineArgMaxPooling2dNode(
1239*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1240*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1241*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1242*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1243*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1244*523fa7a6SAndroid Build Coastguard Worker 
1245*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNArgMaxPooling2d();
1246*523fa7a6SAndroid Build Coastguard Worker 
1247*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_argmax_pooling_2d(
1248*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1249*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_top(),
1250*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_right(),
1251*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_bottom(),
1252*523fa7a6SAndroid Build Coastguard Worker       graph_node->padding_left(),
1253*523fa7a6SAndroid Build Coastguard Worker       graph_node->pooling_height(),
1254*523fa7a6SAndroid Build Coastguard Worker       graph_node->pooling_width(),
1255*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1256*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_value_id()),
1257*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_index_id()),
1258*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1259*523fa7a6SAndroid Build Coastguard Worker 
1260*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1261*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1262*523fa7a6SAndroid Build Coastguard Worker       Internal,
1263*523fa7a6SAndroid Build Coastguard Worker       "Failed to create argmaxpool2d node %i with code: %s",
1264*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1265*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1266*523fa7a6SAndroid Build Coastguard Worker 
1267*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1268*523fa7a6SAndroid Build Coastguard Worker }
1269*523fa7a6SAndroid Build Coastguard Worker 
1270*523fa7a6SAndroid Build Coastguard Worker /*
1271*523fa7a6SAndroid Build Coastguard Worker Define serialized square root node into the subgraph, using the remapped ids
1272*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
1273*523fa7a6SAndroid Build Coastguard Worker tensor value
1274*523fa7a6SAndroid Build Coastguard Worker */
defineSquareRootNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1275*523fa7a6SAndroid Build Coastguard Worker Error defineSquareRootNode(
1276*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1277*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1278*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1279*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1280*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1281*523fa7a6SAndroid Build Coastguard Worker 
1282*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNSquareRoot();
1283*523fa7a6SAndroid Build Coastguard Worker 
1284*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_square_root(
1285*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1286*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1287*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1288*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1289*523fa7a6SAndroid Build Coastguard Worker 
1290*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1291*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1292*523fa7a6SAndroid Build Coastguard Worker       Internal,
1293*523fa7a6SAndroid Build Coastguard Worker       "Failed to create square root node %i with code: %s",
1294*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1295*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1296*523fa7a6SAndroid Build Coastguard Worker 
1297*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1298*523fa7a6SAndroid Build Coastguard Worker }
1299*523fa7a6SAndroid Build Coastguard Worker 
1300*523fa7a6SAndroid Build Coastguard Worker /*
1301*523fa7a6SAndroid Build Coastguard Worker Define serialized ceiling node into the subgraph, using the remapped ids
1302*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
1303*523fa7a6SAndroid Build Coastguard Worker tensor value
1304*523fa7a6SAndroid Build Coastguard Worker */
defineCeilingNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1305*523fa7a6SAndroid Build Coastguard Worker Error defineCeilingNode(
1306*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1307*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1308*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1309*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1310*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1311*523fa7a6SAndroid Build Coastguard Worker 
1312*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNCeiling();
1313*523fa7a6SAndroid Build Coastguard Worker 
1314*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_ceiling(
1315*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1316*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1317*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1318*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1319*523fa7a6SAndroid Build Coastguard Worker 
1320*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1321*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1322*523fa7a6SAndroid Build Coastguard Worker       Internal,
1323*523fa7a6SAndroid Build Coastguard Worker       "Failed to create ceiling node %i with code: %s",
1324*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1325*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1326*523fa7a6SAndroid Build Coastguard Worker 
1327*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1328*523fa7a6SAndroid Build Coastguard Worker }
1329*523fa7a6SAndroid Build Coastguard Worker 
1330*523fa7a6SAndroid Build Coastguard Worker /*
1331*523fa7a6SAndroid Build Coastguard Worker Define serialized hardswish node into the subgraph, using the remapped ids
1332*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
1333*523fa7a6SAndroid Build Coastguard Worker tensor value
1334*523fa7a6SAndroid Build Coastguard Worker */
defineHardswishNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1335*523fa7a6SAndroid Build Coastguard Worker Error defineHardswishNode(
1336*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1337*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1338*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1339*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1340*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1341*523fa7a6SAndroid Build Coastguard Worker 
1342*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNHardswish();
1343*523fa7a6SAndroid Build Coastguard Worker 
1344*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_hardswish(
1345*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1346*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1347*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1348*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1349*523fa7a6SAndroid Build Coastguard Worker 
1350*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1351*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1352*523fa7a6SAndroid Build Coastguard Worker       Internal,
1353*523fa7a6SAndroid Build Coastguard Worker       "Failed to create hardswish node %i with code: %s",
1354*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1355*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1356*523fa7a6SAndroid Build Coastguard Worker 
1357*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1358*523fa7a6SAndroid Build Coastguard Worker }
1359*523fa7a6SAndroid Build Coastguard Worker 
1360*523fa7a6SAndroid Build Coastguard Worker /*
1361*523fa7a6SAndroid Build Coastguard Worker Define serialized leaky relu node into the subgraph, using the remapped ids
1362*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
1363*523fa7a6SAndroid Build Coastguard Worker tensor value
1364*523fa7a6SAndroid Build Coastguard Worker */
defineLeakyReLUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1365*523fa7a6SAndroid Build Coastguard Worker Error defineLeakyReLUNode(
1366*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1367*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1368*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1369*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1370*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1371*523fa7a6SAndroid Build Coastguard Worker 
1372*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNLeakyReLU();
1373*523fa7a6SAndroid Build Coastguard Worker 
1374*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_leaky_relu(
1375*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1376*523fa7a6SAndroid Build Coastguard Worker       graph_node->negative_slope(),
1377*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1378*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1379*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1380*523fa7a6SAndroid Build Coastguard Worker 
1381*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1382*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1383*523fa7a6SAndroid Build Coastguard Worker       Internal,
1384*523fa7a6SAndroid Build Coastguard Worker       "Failed to create leaky relu node %i with code: %s",
1385*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1386*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1387*523fa7a6SAndroid Build Coastguard Worker 
1388*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1389*523fa7a6SAndroid Build Coastguard Worker }
1390*523fa7a6SAndroid Build Coastguard Worker 
1391*523fa7a6SAndroid Build Coastguard Worker /*
1392*523fa7a6SAndroid Build Coastguard Worker Define serialized maximum node into the subgraph, using the remapped ids
1393*523fa7a6SAndroid Build Coastguard Worker to map the serialized ids, to the new ids generated when defining the
1394*523fa7a6SAndroid Build Coastguard Worker tensor value
1395*523fa7a6SAndroid Build Coastguard Worker */
defineMaximumNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1396*523fa7a6SAndroid Build Coastguard Worker Error defineMaximumNode(
1397*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1398*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1399*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1400*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1401*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1402*523fa7a6SAndroid Build Coastguard Worker 
1403*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNMaximum();
1404*523fa7a6SAndroid Build Coastguard Worker 
1405*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_maximum2(
1406*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1407*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1408*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
1409*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1410*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1411*523fa7a6SAndroid Build Coastguard Worker 
1412*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1413*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1414*523fa7a6SAndroid Build Coastguard Worker       Internal,
1415*523fa7a6SAndroid Build Coastguard Worker       "Failed to create maximum node %i with code: %s",
1416*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1417*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1418*523fa7a6SAndroid Build Coastguard Worker 
1419*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1420*523fa7a6SAndroid Build Coastguard Worker }
1421*523fa7a6SAndroid Build Coastguard Worker 
1422*523fa7a6SAndroid Build Coastguard Worker /*
1423*523fa7a6SAndroid Build Coastguard Worker Define Negate node into subgraph, using the remapped ids to map the
1424*523fa7a6SAndroid Build Coastguard Worker serialized ids, to the new ids generated when defining the tensor value
1425*523fa7a6SAndroid Build Coastguard Worker */
defineNegateNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1426*523fa7a6SAndroid Build Coastguard Worker Error defineNegateNode(
1427*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1428*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1429*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1430*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1431*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1432*523fa7a6SAndroid Build Coastguard Worker 
1433*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNNegate();
1434*523fa7a6SAndroid Build Coastguard Worker 
1435*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_negate(
1436*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1437*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1438*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1439*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1440*523fa7a6SAndroid Build Coastguard Worker 
1441*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1442*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1443*523fa7a6SAndroid Build Coastguard Worker       Internal,
1444*523fa7a6SAndroid Build Coastguard Worker       "Failed to create negate node %i with code: %s",
1445*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1446*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1447*523fa7a6SAndroid Build Coastguard Worker 
1448*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1449*523fa7a6SAndroid Build Coastguard Worker }
1450*523fa7a6SAndroid Build Coastguard Worker 
1451*523fa7a6SAndroid Build Coastguard Worker /*
1452*523fa7a6SAndroid Build Coastguard Worker Defines square node into subgraph using the remapped ids to map the
1453*523fa7a6SAndroid Build Coastguard Worker serialized ids to the new ids generated when defining the tensor value
1454*523fa7a6SAndroid Build Coastguard Worker */
defineSquareNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1455*523fa7a6SAndroid Build Coastguard Worker Error defineSquareNode(
1456*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1457*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1458*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1459*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1460*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1461*523fa7a6SAndroid Build Coastguard Worker 
1462*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNSquare();
1463*523fa7a6SAndroid Build Coastguard Worker 
1464*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_square(
1465*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1466*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1467*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1468*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1469*523fa7a6SAndroid Build Coastguard Worker 
1470*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1471*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1472*523fa7a6SAndroid Build Coastguard Worker       Internal,
1473*523fa7a6SAndroid Build Coastguard Worker       "Failed to create square node %i with code: %s",
1474*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1475*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1476*523fa7a6SAndroid Build Coastguard Worker 
1477*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1478*523fa7a6SAndroid Build Coastguard Worker }
1479*523fa7a6SAndroid Build Coastguard Worker 
1480*523fa7a6SAndroid Build Coastguard Worker /*
1481*523fa7a6SAndroid Build Coastguard Worker Defines square node into subgraph using the remapped ids to map the
1482*523fa7a6SAndroid Build Coastguard Worker serialized ids to the new ids generated when defining the tensor value
1483*523fa7a6SAndroid Build Coastguard Worker */
defineELUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1484*523fa7a6SAndroid Build Coastguard Worker Error defineELUNode(
1485*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1486*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1487*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1488*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1489*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1490*523fa7a6SAndroid Build Coastguard Worker 
1491*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNELU();
1492*523fa7a6SAndroid Build Coastguard Worker 
1493*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_elu(
1494*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1495*523fa7a6SAndroid Build Coastguard Worker       graph_node->alpha(),
1496*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1497*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1498*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1499*523fa7a6SAndroid Build Coastguard Worker 
1500*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1501*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1502*523fa7a6SAndroid Build Coastguard Worker       Internal,
1503*523fa7a6SAndroid Build Coastguard Worker       "Failed to create ELU node %i with code: %s",
1504*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1505*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1506*523fa7a6SAndroid Build Coastguard Worker 
1507*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1508*523fa7a6SAndroid Build Coastguard Worker }
1509*523fa7a6SAndroid Build Coastguard Worker 
1510*523fa7a6SAndroid Build Coastguard Worker /*
1511*523fa7a6SAndroid Build Coastguard Worker Defines absolute value node into subgraph using the remapped ids to map the
1512*523fa7a6SAndroid Build Coastguard Worker serialized ids to the new ids generated when defining the tensor value
1513*523fa7a6SAndroid Build Coastguard Worker */
defineAbsNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1514*523fa7a6SAndroid Build Coastguard Worker Error defineAbsNode(
1515*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1516*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1517*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1518*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1519*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1520*523fa7a6SAndroid Build Coastguard Worker 
1521*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNAbs();
1522*523fa7a6SAndroid Build Coastguard Worker 
1523*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_abs(
1524*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1525*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1526*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1527*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1528*523fa7a6SAndroid Build Coastguard Worker 
1529*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1530*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1531*523fa7a6SAndroid Build Coastguard Worker       Internal,
1532*523fa7a6SAndroid Build Coastguard Worker       "Failed to create abs node %i with code: %s",
1533*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1534*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1535*523fa7a6SAndroid Build Coastguard Worker 
1536*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1537*523fa7a6SAndroid Build Coastguard Worker }
1538*523fa7a6SAndroid Build Coastguard Worker 
1539*523fa7a6SAndroid Build Coastguard Worker /*
1540*523fa7a6SAndroid Build Coastguard Worker Defines serialized prelu node into the subgraph,
1541*523fa7a6SAndroid Build Coastguard Worker using the remapped ids to map the serialized ids,
1542*523fa7a6SAndroid Build Coastguard Worker to the new ids generated when defining the tensor value
1543*523fa7a6SAndroid Build Coastguard Worker */
definePReLUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1544*523fa7a6SAndroid Build Coastguard Worker Error definePReLUNode(
1545*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1546*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1547*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1548*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1549*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1550*523fa7a6SAndroid Build Coastguard Worker 
1551*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNPReLU();
1552*523fa7a6SAndroid Build Coastguard Worker 
1553*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_prelu(
1554*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1555*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1556*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
1557*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1558*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1559*523fa7a6SAndroid Build Coastguard Worker 
1560*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1561*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1562*523fa7a6SAndroid Build Coastguard Worker       Internal,
1563*523fa7a6SAndroid Build Coastguard Worker       "Failed to create prelu node %i with code: %s",
1564*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1565*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1566*523fa7a6SAndroid Build Coastguard Worker 
1567*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1568*523fa7a6SAndroid Build Coastguard Worker }
1569*523fa7a6SAndroid Build Coastguard Worker 
1570*523fa7a6SAndroid Build Coastguard Worker /*
1571*523fa7a6SAndroid Build Coastguard Worker Defines serialized concatenate2 node into the subgraph,
1572*523fa7a6SAndroid Build Coastguard Worker using the remapped ids to map the serialized ids,
1573*523fa7a6SAndroid Build Coastguard Worker to the new ids generated when defining the tensor value
1574*523fa7a6SAndroid Build Coastguard Worker */
defineConcatenate2Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1575*523fa7a6SAndroid Build Coastguard Worker Error defineConcatenate2Node(
1576*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1577*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1578*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1579*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1580*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1581*523fa7a6SAndroid Build Coastguard Worker 
1582*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNConcatenate2();
1583*523fa7a6SAndroid Build Coastguard Worker 
1584*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_concatenate2(
1585*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1586*523fa7a6SAndroid Build Coastguard Worker       graph_node->axis(),
1587*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1588*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
1589*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1590*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1591*523fa7a6SAndroid Build Coastguard Worker 
1592*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1593*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1594*523fa7a6SAndroid Build Coastguard Worker       Internal,
1595*523fa7a6SAndroid Build Coastguard Worker       "Failed to create cat2 node %i with code: %s",
1596*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1597*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1598*523fa7a6SAndroid Build Coastguard Worker 
1599*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1600*523fa7a6SAndroid Build Coastguard Worker }
1601*523fa7a6SAndroid Build Coastguard Worker 
1602*523fa7a6SAndroid Build Coastguard Worker /*
1603*523fa7a6SAndroid Build Coastguard Worker Defines serialized concatenate2 node into the subgraph,
1604*523fa7a6SAndroid Build Coastguard Worker using the remapped ids to map the serialized ids,
1605*523fa7a6SAndroid Build Coastguard Worker to the new ids generated when defining the tensor value
1606*523fa7a6SAndroid Build Coastguard Worker */
defineConcatenate3Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1607*523fa7a6SAndroid Build Coastguard Worker Error defineConcatenate3Node(
1608*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1609*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1610*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1611*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1612*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1613*523fa7a6SAndroid Build Coastguard Worker 
1614*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNConcatenate3();
1615*523fa7a6SAndroid Build Coastguard Worker 
1616*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_concatenate3(
1617*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1618*523fa7a6SAndroid Build Coastguard Worker       graph_node->axis(),
1619*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1620*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
1621*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input3_id()),
1622*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1623*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1624*523fa7a6SAndroid Build Coastguard Worker 
1625*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1626*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1627*523fa7a6SAndroid Build Coastguard Worker       Internal,
1628*523fa7a6SAndroid Build Coastguard Worker       "Failed to create cat3 node %i with code: %s",
1629*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1630*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1631*523fa7a6SAndroid Build Coastguard Worker 
1632*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1633*523fa7a6SAndroid Build Coastguard Worker }
1634*523fa7a6SAndroid Build Coastguard Worker 
1635*523fa7a6SAndroid Build Coastguard Worker /*
1636*523fa7a6SAndroid Build Coastguard Worker Defines serialized concatenate2 node into the subgraph,
1637*523fa7a6SAndroid Build Coastguard Worker using the remapped ids to map the serialized ids,
1638*523fa7a6SAndroid Build Coastguard Worker to the new ids generated when defining the tensor value
1639*523fa7a6SAndroid Build Coastguard Worker */
defineConcatenate4Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1640*523fa7a6SAndroid Build Coastguard Worker Error defineConcatenate4Node(
1641*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1642*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1643*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1644*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1645*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1646*523fa7a6SAndroid Build Coastguard Worker 
1647*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNConcatenate4();
1648*523fa7a6SAndroid Build Coastguard Worker 
1649*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_concatenate4(
1650*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1651*523fa7a6SAndroid Build Coastguard Worker       graph_node->axis(),
1652*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1653*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
1654*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input3_id()),
1655*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input4_id()),
1656*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1657*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1658*523fa7a6SAndroid Build Coastguard Worker 
1659*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1660*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1661*523fa7a6SAndroid Build Coastguard Worker       Internal,
1662*523fa7a6SAndroid Build Coastguard Worker       "Failed to create cat4 node %i with code: %s",
1663*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1664*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1665*523fa7a6SAndroid Build Coastguard Worker 
1666*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1667*523fa7a6SAndroid Build Coastguard Worker }
1668*523fa7a6SAndroid Build Coastguard Worker 
1669*523fa7a6SAndroid Build Coastguard Worker /*
1670*523fa7a6SAndroid Build Coastguard Worker Defines serialized static_slice node into the subgraph,
1671*523fa7a6SAndroid Build Coastguard Worker using the remapped ids to map the serialized ids,
1672*523fa7a6SAndroid Build Coastguard Worker to the new ids generated when defining the tensor value
1673*523fa7a6SAndroid Build Coastguard Worker */
defineStaticSliceNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1674*523fa7a6SAndroid Build Coastguard Worker Error defineStaticSliceNode(
1675*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1676*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1677*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1678*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1679*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1680*523fa7a6SAndroid Build Coastguard Worker 
1681*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNStaticSlice();
1682*523fa7a6SAndroid Build Coastguard Worker 
1683*523fa7a6SAndroid Build Coastguard Worker   std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
1684*523fa7a6SAndroid Build Coastguard Worker   std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());
1685*523fa7a6SAndroid Build Coastguard Worker 
1686*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_static_slice(
1687*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1688*523fa7a6SAndroid Build Coastguard Worker       graph_node->num_dims(),
1689*523fa7a6SAndroid Build Coastguard Worker       offsets.data(),
1690*523fa7a6SAndroid Build Coastguard Worker       sizes.data(),
1691*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input_id()),
1692*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1693*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1694*523fa7a6SAndroid Build Coastguard Worker 
1695*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1696*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1697*523fa7a6SAndroid Build Coastguard Worker       Internal,
1698*523fa7a6SAndroid Build Coastguard Worker       "Failed to create static slice node %i with code: %s",
1699*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1700*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1701*523fa7a6SAndroid Build Coastguard Worker 
1702*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1703*523fa7a6SAndroid Build Coastguard Worker }
1704*523fa7a6SAndroid Build Coastguard Worker 
1705*523fa7a6SAndroid Build Coastguard Worker /*
1706*523fa7a6SAndroid Build Coastguard Worker Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
1707*523fa7a6SAndroid Build Coastguard Worker using the remapped ids to map the serialized ids,
1708*523fa7a6SAndroid Build Coastguard Worker to the new ids generated when defining the tensor value
1709*523fa7a6SAndroid Build Coastguard Worker */
defineScaledDotProductAttentionNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1710*523fa7a6SAndroid Build Coastguard Worker Error defineScaledDotProductAttentionNode(
1711*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1712*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1713*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1714*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1715*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1716*523fa7a6SAndroid Build Coastguard Worker 
1717*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();
1718*523fa7a6SAndroid Build Coastguard Worker 
1719*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_scaled_dot_product_attention(
1720*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1721*523fa7a6SAndroid Build Coastguard Worker       xnn_attention_logits_cap_type_none, // cap_type
1722*523fa7a6SAndroid Build Coastguard Worker       nullptr, // cap_value - not used
1723*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->query_id()),
1724*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->key_id()),
1725*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->value_id()),
1726*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->scale_id()),
1727*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->mask_id()),
1728*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1729*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1730*523fa7a6SAndroid Build Coastguard Worker 
1731*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1732*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1733*523fa7a6SAndroid Build Coastguard Worker       Internal,
1734*523fa7a6SAndroid Build Coastguard Worker       "Failed to create SDPA node %i with code: %s",
1735*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1736*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1737*523fa7a6SAndroid Build Coastguard Worker 
1738*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1739*523fa7a6SAndroid Build Coastguard Worker }
1740*523fa7a6SAndroid Build Coastguard Worker 
1741*523fa7a6SAndroid Build Coastguard Worker /*
1742*523fa7a6SAndroid Build Coastguard Worker Defines batch matrix multiply node into the subgraph,
1743*523fa7a6SAndroid Build Coastguard Worker using the remapped ids to map the serialized ids,
1744*523fa7a6SAndroid Build Coastguard Worker to the new ids generated when defining the tensor value
1745*523fa7a6SAndroid Build Coastguard Worker */
defineBatchMatrixMultiplyNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1746*523fa7a6SAndroid Build Coastguard Worker Error defineBatchMatrixMultiplyNode(
1747*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1748*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1749*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1750*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1751*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1752*523fa7a6SAndroid Build Coastguard Worker 
1753*523fa7a6SAndroid Build Coastguard Worker   auto graph_node = node->xnode_union_as_XNNBatchMatrixMultiply();
1754*523fa7a6SAndroid Build Coastguard Worker 
1755*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_define_batch_matrix_multiply(
1756*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr,
1757*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input1_id()),
1758*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->input2_id()),
1759*523fa7a6SAndroid Build Coastguard Worker       remapped_ids.at(graph_node->output_id()),
1760*523fa7a6SAndroid Build Coastguard Worker       graph_node->flags());
1761*523fa7a6SAndroid Build Coastguard Worker 
1762*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1763*523fa7a6SAndroid Build Coastguard Worker       status == xnn_status_success,
1764*523fa7a6SAndroid Build Coastguard Worker       Internal,
1765*523fa7a6SAndroid Build Coastguard Worker       "Failed to create BMM node %i with code: %s",
1766*523fa7a6SAndroid Build Coastguard Worker       node->debug_handle(),
1767*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1768*523fa7a6SAndroid Build Coastguard Worker 
1769*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
1770*523fa7a6SAndroid Build Coastguard Worker }
1771*523fa7a6SAndroid Build Coastguard Worker 
1772*523fa7a6SAndroid Build Coastguard Worker /*
1773*523fa7a6SAndroid Build Coastguard Worker Returns not Implemented Error code. This function is meant to be
1774*523fa7a6SAndroid Build Coastguard Worker called when the compiler encountes a XNodeType from the flatbuffer
1775*523fa7a6SAndroid Build Coastguard Worker that has not yet been implemented
1776*523fa7a6SAndroid Build Coastguard Worker */
defineNotImplementedNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1777*523fa7a6SAndroid Build Coastguard Worker Error defineNotImplementedNode(
1778*523fa7a6SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr,
1779*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1780*523fa7a6SAndroid Build Coastguard Worker     const NodePtr node,
1781*523fa7a6SAndroid Build Coastguard Worker     const fb_xnnpack::XNNGraph* graph) noexcept {
1782*523fa7a6SAndroid Build Coastguard Worker   MAYBE_UNUSED(graph);
1783*523fa7a6SAndroid Build Coastguard Worker 
1784*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1785*523fa7a6SAndroid Build Coastguard Worker       false,
1786*523fa7a6SAndroid Build Coastguard Worker       NotImplemented,
1787*523fa7a6SAndroid Build Coastguard Worker       "Unhandled node type: %s",
1788*523fa7a6SAndroid Build Coastguard Worker       fb_xnnpack::EnumNameXNodeUnion(node->xnode_union_type()));
1789*523fa7a6SAndroid Build Coastguard Worker }
1790*523fa7a6SAndroid Build Coastguard Worker 
1791*523fa7a6SAndroid Build Coastguard Worker /*
1792*523fa7a6SAndroid Build Coastguard Worker Returns the pointer to the defineNode function that handles the given
1793*523fa7a6SAndroid Build Coastguard Worker XNode type
1794*523fa7a6SAndroid Build Coastguard Worker */
1795*523fa7a6SAndroid Build Coastguard Worker #define _DEFINE(name)                     \
1796*523fa7a6SAndroid Build Coastguard Worker   case fb_xnnpack::XNodeUnion::XNN##name: \
1797*523fa7a6SAndroid Build Coastguard Worker     return &define##name##Node;
1798*523fa7a6SAndroid Build Coastguard Worker 
getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType)1799*523fa7a6SAndroid Build Coastguard Worker DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
1800*523fa7a6SAndroid Build Coastguard Worker   switch (nodeType) {
1801*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Add)
1802*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(FullyConnected)
1803*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Softmax)
1804*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Sigmoid)
1805*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(StaticTranspose)
1806*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Clamp)
1807*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Conv2d)
1808*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Div)
1809*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(StaticResizeBilinear2D)
1810*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(StaticConstantPad)
1811*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(AvgPooling2d)
1812*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Minimum)
1813*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(DepthwiseConv2d)
1814*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(MaxPooling2d)
1815*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Multiply)
1816*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Subtract)
1817*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Floor)
1818*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Convert)
1819*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(GlobalAvgPooling2d)
1820*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(StaticReshape)
1821*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(ArgMaxPooling2d)
1822*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(SquareRoot)
1823*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Ceiling)
1824*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Hardswish)
1825*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(LeakyReLU)
1826*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Maximum)
1827*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Negate)
1828*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Square)
1829*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(ELU)
1830*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Abs)
1831*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(PReLU)
1832*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Concatenate2)
1833*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Concatenate3)
1834*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(Concatenate4)
1835*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(StaticSlice)
1836*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(ScaledDotProductAttention)
1837*523fa7a6SAndroid Build Coastguard Worker     _DEFINE(BatchMatrixMultiply)
1838*523fa7a6SAndroid Build Coastguard Worker     case fb_xnnpack::XNodeUnion::NONE:
1839*523fa7a6SAndroid Build Coastguard Worker     default: // Adding here as a catch all, just in case
1840*523fa7a6SAndroid Build Coastguard Worker       return &defineNotImplementedNode;
1841*523fa7a6SAndroid Build Coastguard Worker   }
1842*523fa7a6SAndroid Build Coastguard Worker }
1843*523fa7a6SAndroid Build Coastguard Worker #undef _DEFINE
1844*523fa7a6SAndroid Build Coastguard Worker 
1845*523fa7a6SAndroid Build Coastguard Worker /*
1846*523fa7a6SAndroid Build Coastguard Worker Builds the xnnpack runtime object using the buffer pointer. The buffer pointer
1847*523fa7a6SAndroid Build Coastguard Worker must be a valid pointer to the serialized xnnpack object. It also fills the
1848*523fa7a6SAndroid Build Coastguard Worker XNNExecutor object with the built xnn_runtime and the input/output ids.
1849*523fa7a6SAndroid Build Coastguard Worker */
compileModel(const void * buffer_pointer,size_t num_bytes,XNNExecutor * executor,MemoryAllocator * runtime_allocator,xnn_workspace_t workspace)1850*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD Error XNNCompiler::compileModel(
1851*523fa7a6SAndroid Build Coastguard Worker     const void* buffer_pointer,
1852*523fa7a6SAndroid Build Coastguard Worker     size_t num_bytes,
1853*523fa7a6SAndroid Build Coastguard Worker     XNNExecutor* executor,
1854*523fa7a6SAndroid Build Coastguard Worker     MemoryAllocator* runtime_allocator,
1855*523fa7a6SAndroid Build Coastguard Worker     xnn_workspace_t workspace) {
1856*523fa7a6SAndroid Build Coastguard Worker   Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
1857*523fa7a6SAndroid Build Coastguard Worker   const uint8_t* flatbuffer_data = nullptr;
1858*523fa7a6SAndroid Build Coastguard Worker   const uint8_t* constant_data = nullptr;
1859*523fa7a6SAndroid Build Coastguard Worker   CompileAllocator compile_allocator;
1860*523fa7a6SAndroid Build Coastguard Worker 
1861*523fa7a6SAndroid Build Coastguard Worker   // Header status can only either be Error::Ok or Error::NotFound
1862*523fa7a6SAndroid Build Coastguard Worker   if (header.ok()) {
1863*523fa7a6SAndroid Build Coastguard Worker     flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1864*523fa7a6SAndroid Build Coastguard Worker         header->flatbuffer_offset;
1865*523fa7a6SAndroid Build Coastguard Worker     constant_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1866*523fa7a6SAndroid Build Coastguard Worker         header->constant_data_offset;
1867*523fa7a6SAndroid Build Coastguard Worker   } else if (header.error() == Error::NotFound) {
1868*523fa7a6SAndroid Build Coastguard Worker     flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer);
1869*523fa7a6SAndroid Build Coastguard Worker   } else {
1870*523fa7a6SAndroid Build Coastguard Worker     ET_LOG(Error, "XNNHeader may be corrupt");
1871*523fa7a6SAndroid Build Coastguard Worker     return header.error();
1872*523fa7a6SAndroid Build Coastguard Worker   }
1873*523fa7a6SAndroid Build Coastguard Worker 
1874*523fa7a6SAndroid Build Coastguard Worker   // Temporarily support identifier XN00 and XN01
1875*523fa7a6SAndroid Build Coastguard Worker   bool is_supported_version =
1876*523fa7a6SAndroid Build Coastguard Worker       strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN00", 4) ==
1877*523fa7a6SAndroid Build Coastguard Worker           0 ||
1878*523fa7a6SAndroid Build Coastguard Worker       strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN01", 4) ==
1879*523fa7a6SAndroid Build Coastguard Worker           0;
1880*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1881*523fa7a6SAndroid Build Coastguard Worker       is_supported_version,
1882*523fa7a6SAndroid Build Coastguard Worker       DelegateInvalidCompatibility,
1883*523fa7a6SAndroid Build Coastguard Worker       "XNNPACK Delegate Serialization Format version identifier '%.4s' != expected XN00 or XN01'",
1884*523fa7a6SAndroid Build Coastguard Worker       flatbuffers::GetBufferIdentifier(flatbuffer_data));
1885*523fa7a6SAndroid Build Coastguard Worker 
1886*523fa7a6SAndroid Build Coastguard Worker   auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(flatbuffer_data);
1887*523fa7a6SAndroid Build Coastguard Worker   // initialize xnnpack
1888*523fa7a6SAndroid Build Coastguard Worker   xnn_status status = xnn_initialize(/*allocator =*/nullptr);
1889*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1890*523fa7a6SAndroid Build Coastguard Worker       xnn_status_success == status,
1891*523fa7a6SAndroid Build Coastguard Worker       Internal,
1892*523fa7a6SAndroid Build Coastguard Worker       "XNN Initialize failed with code: %s",
1893*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1894*523fa7a6SAndroid Build Coastguard Worker 
1895*523fa7a6SAndroid Build Coastguard Worker   // create xnnpack subgraph
1896*523fa7a6SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph_ptr = nullptr;
1897*523fa7a6SAndroid Build Coastguard Worker   status = xnn_create_subgraph(
1898*523fa7a6SAndroid Build Coastguard Worker       /*external_value_ids=*/flatbuffer_graph->num_externs(),
1899*523fa7a6SAndroid Build Coastguard Worker       /*flags=*/0,
1900*523fa7a6SAndroid Build Coastguard Worker       &subgraph_ptr);
1901*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1902*523fa7a6SAndroid Build Coastguard Worker       xnn_status_success == status,
1903*523fa7a6SAndroid Build Coastguard Worker       Internal,
1904*523fa7a6SAndroid Build Coastguard Worker       "XNN Subgraph creation failed with code: %s",
1905*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1906*523fa7a6SAndroid Build Coastguard Worker 
1907*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph(
1908*523fa7a6SAndroid Build Coastguard Worker       subgraph_ptr, &xnn_delete_subgraph);
1909*523fa7a6SAndroid Build Coastguard Worker 
1910*523fa7a6SAndroid Build Coastguard Worker   // mapping from old ids to new created value ids
1911*523fa7a6SAndroid Build Coastguard Worker   // The old ids that were serialied were generated AoT, since
1912*523fa7a6SAndroid Build Coastguard Worker   // we are re-defining tensor values, the defined IDs could be
1913*523fa7a6SAndroid Build Coastguard Worker   // different from the ones generated AoT, as a result, we need
1914*523fa7a6SAndroid Build Coastguard Worker   // a new mapping from the old ids to the newly created ones
1915*523fa7a6SAndroid Build Coastguard Worker   std::unordered_map<uint32_t, uint32_t> remapped_ids;
1916*523fa7a6SAndroid Build Coastguard Worker   // Invalid ids do not need to be remapped
1917*523fa7a6SAndroid Build Coastguard Worker   remapped_ids.emplace(XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID);
1918*523fa7a6SAndroid Build Coastguard Worker 
1919*523fa7a6SAndroid Build Coastguard Worker   // External Ids for inputs and outputs
1920*523fa7a6SAndroid Build Coastguard Worker   std::vector<uint32_t> input_ids;
1921*523fa7a6SAndroid Build Coastguard Worker   std::vector<uint32_t> output_ids;
1922*523fa7a6SAndroid Build Coastguard Worker   Error err = Error::Ok;
1923*523fa7a6SAndroid Build Coastguard Worker   for (auto value : *flatbuffer_graph->xvalues()) {
1924*523fa7a6SAndroid Build Coastguard Worker     err = defineTensor(
1925*523fa7a6SAndroid Build Coastguard Worker         subgraph.get(),
1926*523fa7a6SAndroid Build Coastguard Worker         remapped_ids,
1927*523fa7a6SAndroid Build Coastguard Worker         value,
1928*523fa7a6SAndroid Build Coastguard Worker         flatbuffer_graph,
1929*523fa7a6SAndroid Build Coastguard Worker         constant_data,
1930*523fa7a6SAndroid Build Coastguard Worker         input_ids,
1931*523fa7a6SAndroid Build Coastguard Worker         output_ids,
1932*523fa7a6SAndroid Build Coastguard Worker         compile_allocator);
1933*523fa7a6SAndroid Build Coastguard Worker 
1934*523fa7a6SAndroid Build Coastguard Worker     if (err != Error::Ok) {
1935*523fa7a6SAndroid Build Coastguard Worker       return err;
1936*523fa7a6SAndroid Build Coastguard Worker     }
1937*523fa7a6SAndroid Build Coastguard Worker   }
1938*523fa7a6SAndroid Build Coastguard Worker 
1939*523fa7a6SAndroid Build Coastguard Worker   for (auto node : *flatbuffer_graph->xnodes()) {
1940*523fa7a6SAndroid Build Coastguard Worker     err = getDefineNodeFunc(node->xnode_union_type())(
1941*523fa7a6SAndroid Build Coastguard Worker         subgraph.get(), remapped_ids, node, flatbuffer_graph);
1942*523fa7a6SAndroid Build Coastguard Worker     if (err != Error::Ok) {
1943*523fa7a6SAndroid Build Coastguard Worker       return err;
1944*523fa7a6SAndroid Build Coastguard Worker     }
1945*523fa7a6SAndroid Build Coastguard Worker   }
1946*523fa7a6SAndroid Build Coastguard Worker   uint32_t runtime_flags = 0;
1947*523fa7a6SAndroid Build Coastguard Worker 
1948*523fa7a6SAndroid Build Coastguard Worker #if defined(ENABLE_XNNPACK_PROFILING) || defined(ET_EVENT_TRACER_ENABLED)
1949*523fa7a6SAndroid Build Coastguard Worker   runtime_flags |= XNN_FLAG_BASIC_PROFILING;
1950*523fa7a6SAndroid Build Coastguard Worker #endif
1951*523fa7a6SAndroid Build Coastguard Worker 
1952*523fa7a6SAndroid Build Coastguard Worker   xnn_runtime_t runtime_ptr = nullptr;
1953*523fa7a6SAndroid Build Coastguard Worker 
1954*523fa7a6SAndroid Build Coastguard Worker #ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
1955*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1956*523fa7a6SAndroid Build Coastguard Worker       workspace != nullptr, Internal, "Failed to initialize XNNPACK workspace");
1957*523fa7a6SAndroid Build Coastguard Worker   status = xnn_create_runtime_v4(
1958*523fa7a6SAndroid Build Coastguard Worker       subgraph.get(),
1959*523fa7a6SAndroid Build Coastguard Worker       /*weight_cache=*/nullptr, // TODO - support weight cache
1960*523fa7a6SAndroid Build Coastguard Worker       workspace,
1961*523fa7a6SAndroid Build Coastguard Worker       ::executorch::extension::threadpool::get_pthreadpool(),
1962*523fa7a6SAndroid Build Coastguard Worker       runtime_flags,
1963*523fa7a6SAndroid Build Coastguard Worker       &runtime_ptr);
1964*523fa7a6SAndroid Build Coastguard Worker #else
1965*523fa7a6SAndroid Build Coastguard Worker   status = xnn_create_runtime_v3(
1966*523fa7a6SAndroid Build Coastguard Worker       subgraph.get(),
1967*523fa7a6SAndroid Build Coastguard Worker       /*weight_cache=*/nullptr, // TODO - support weight cache
1968*523fa7a6SAndroid Build Coastguard Worker       ::executorch::extension::threadpool::get_pthreadpool(),
1969*523fa7a6SAndroid Build Coastguard Worker       runtime_flags,
1970*523fa7a6SAndroid Build Coastguard Worker       &runtime_ptr);
1971*523fa7a6SAndroid Build Coastguard Worker #endif
1972*523fa7a6SAndroid Build Coastguard Worker 
1973*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
1974*523fa7a6SAndroid Build Coastguard Worker       xnn_status_success == status,
1975*523fa7a6SAndroid Build Coastguard Worker       Internal,
1976*523fa7a6SAndroid Build Coastguard Worker       "XNN Runtime creation failed with code: %s",
1977*523fa7a6SAndroid Build Coastguard Worker       xnn_status_to_string(status));
1978*523fa7a6SAndroid Build Coastguard Worker 
1979*523fa7a6SAndroid Build Coastguard Worker   err = executor->initialize( // NOLINT: runtime_ptr is non-null
1980*523fa7a6SAndroid Build Coastguard Worker       runtime_ptr,
1981*523fa7a6SAndroid Build Coastguard Worker       std::move(input_ids),
1982*523fa7a6SAndroid Build Coastguard Worker       std::move(output_ids));
1983*523fa7a6SAndroid Build Coastguard Worker 
1984*523fa7a6SAndroid Build Coastguard Worker   return err;
1985*523fa7a6SAndroid Build Coastguard Worker };
1986*523fa7a6SAndroid Build Coastguard Worker 
1987*523fa7a6SAndroid Build Coastguard Worker } // namespace delegate
1988*523fa7a6SAndroid Build Coastguard Worker } // namespace xnnpack
1989*523fa7a6SAndroid Build Coastguard Worker } // namespace backends
1990*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
1991