xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <caffe2/torch/csrc/jit/backends/xnnpack/serialization/serializer.h>
7 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
8 
9 #include <sstream>
10 
11 namespace torch {
12 namespace jit {
13 namespace xnnpack {
14 namespace delegate {
15 
16 using namespace fb_xnnpack;
17 
serializeAddNode(uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)18 void XNNSerializer::serializeAddNode(
19     uint32_t input1_id,
20     uint32_t input2_id,
21     uint32_t output_id,
22     uint32_t flags) {
23   const auto addNode =
24       CreateXNNAdd(_builder, input1_id, input2_id, output_id, flags);
25   const auto flatbufferNode =
26       CreateXNode(_builder, XNodeUnion::XNNAdd, addNode.Union());
27   _nodes.push_back(flatbufferNode);
28 }
29 
serializeData(const uint8_t * data_ptr,size_t num_bytes)30 size_t XNNSerializer::serializeData(const uint8_t* data_ptr, size_t num_bytes) {
31   size_t constant_buffer_idx = 0;
32   // Handling the tensor _values with data
33   if (data_ptr != nullptr) {
34     // steps:
35     // 1. creating flatbuffer byte-vector for tensor data
36     auto storage = _builder.CreateVector(data_ptr, num_bytes);
37 
38     // 2. put it in the common buffer
39     constant_buffer_idx = _constantBuffer.size();
40     _constantBuffer.emplace_back(CreateBuffer(_builder, storage));
41 
42     // 3. record size into bufferSizes
43     _bufferSizes.push_back(num_bytes);
44     assert(_bufferSizes.size() == _constantBuffer.size());
45   }
46   return constant_buffer_idx;
47 }
48 
serializeTensorValue(uint32_t xnn_datatype,size_t num_dims,std::vector<size_t> dims,size_t data_buffer_idx,uint32_t external_id,uint32_t flags,uint32_t id_out)49 void XNNSerializer::serializeTensorValue(
50     uint32_t xnn_datatype,
51     size_t num_dims,
52     std::vector<size_t> dims,
53     size_t data_buffer_idx,
54     uint32_t external_id,
55     uint32_t flags,
56     uint32_t id_out) {
57   std::vector<uint32_t> serialized_dims;
58   serialized_dims.reserve(dims.size());
59   for (auto dim : dims) {
60     serialized_dims.push_back(static_cast<uint32_t>(dim));
61   }
62 
63   const auto tensorValue = CreateXNNTensorValueDirect(
64       _builder,
65       XNNDatatype(xnn_datatype),
66       num_dims,
67       &serialized_dims,
68       data_buffer_idx,
69       external_id,
70       flags,
71       id_out);
72 
73   const auto flatbufferValue =
74       CreateXValue(_builder, XValueUnion::XNNTensorValue, tensorValue.Union());
75   _values.push_back(flatbufferValue);
76 }
77 
finishAndSerialize(std::vector<uint32_t> input_ids,std::vector<uint32_t> output_ids,size_t num_extern_ids)78 std::string XNNSerializer::finishAndSerialize(
79     std::vector<uint32_t> input_ids,
80     std::vector<uint32_t> output_ids,
81     size_t num_extern_ids) {
82   auto xnnGraph = CreateXNNGraphDirect(
83       _builder,
84       _version_sha1,
85       &_nodes,
86       &_values,
87       num_extern_ids,
88       &input_ids,
89       &output_ids,
90       &_constantBuffer,
91       &_bufferSizes);
92 
93   _builder.Finish(xnnGraph);
94 
95   std::stringstream ss;
96   ss.write(
97       reinterpret_cast<char*>(_builder.GetBufferPointer()), _builder.GetSize());
98 
99   return ss.str();
100 }
101 
102 } // namespace delegate
103 } // namespace xnnpack
104 } // namespace jit
105 } // namespace torch
106