xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/serialization/serializer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
7 #include <cstddef>
8 #include <cstdint>
9 #include <string>
10 #include <vector>
11 
12 namespace torch {
13 namespace jit {
14 namespace xnnpack {
15 namespace delegate {
16 
17 using namespace fb_xnnpack; // Specified in the schema
18 
19 class XNNSerializer {
20  public:
21   // Constructors
22   // initial buffersize of 1024 which will grow
23   // automatically, constant buffer and buffer sizes initialized with dummy
24   // values as 0 index is reserved for non-constant tensors
XNNSerializer()25   XNNSerializer() : XNNSerializer(1024) {}
26 
XNNSerializer(size_t bufferSize)27   explicit XNNSerializer(size_t bufferSize)
28       : _builder(bufferSize),
29         _nodes(),
30         _values(),
31         _constantBuffer({CreateBuffer(
32             _builder,
33             {})}), // index 0 is reserved for non-const data
34         _bufferSizes({0}) {}
35 
36   // Serializing Nodes
37 
38   // Serialize add node, we are serializing the argument needed to call
39   // xnn_define_add2. Serializing these values, and at run time we build
40   // teh graph by re running xnn_define_add2
41   void serializeAddNode(
42       uint32_t input1_id,
43       uint32_t input2_id,
44       uint32_t output_id,
45       uint32_t flags);
46 
47   // Serializing Values
48   void serializeTensorValue(
49       uint32_t xnn_datatype,
50       size_t num_dims,
51       std::vector<size_t> dims,
52       size_t buffer_data_idx,
53       uint32_t external_id,
54       uint32_t flags,
55       uint32_t id_out);
56 
57   // finish and serialize xnngraph returning serialized data
58   std::string finishAndSerialize(
59       std::vector<uint32_t> input_ids,
60       std::vector<uint32_t> output_ids,
61       size_t num_extern_ids);
62 
63   // decoupled data serialization with tensor values. This way constant tensor
64   // data can be referenced by multiple intermediate tensors. This call
65   // serializes the num_bytes of the data_ptr and returns the index it was
66   // placed in.
67   size_t serializeData(const uint8_t* data_ptr, size_t num_bytes);
68 
69  private:
70   // xnnpack version we are serializing
71   const char* _version_sha1 = "ae108ef49aa5623b896fc93d4298c49d1750d9ba";
72 
73   // flatbuffer objects we will create and serialize together to create xnngraph
74   flatbuffers_fbsource::FlatBufferBuilder _builder;
75 
76   // Vector of the serialized xnnpack nodes
77   std::vector<flatbuffers_fbsource::Offset<XNode>> _nodes;
78 
79   // Vector of the serialized xnnpack values
80   std::vector<flatbuffers_fbsource::Offset<XValue>> _values;
81 
82   std::vector<flatbuffers_fbsource::Offset<Buffer>> _constantBuffer;
83   std::vector<uint32_t> _bufferSizes;
84 };
85 
86 } // namespace delegate
87 } // namespace xnnpack
88 } // namespace jit
89 } // namespace torch
90