1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <caffe2/torch/csrc/jit/backends/xnnpack/serialization/serializer.h>
7 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
8
9 #include <sstream>
10
11 namespace torch {
12 namespace jit {
13 namespace xnnpack {
14 namespace delegate {
15
16 using namespace fb_xnnpack;
17
serializeAddNode(uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)18 void XNNSerializer::serializeAddNode(
19 uint32_t input1_id,
20 uint32_t input2_id,
21 uint32_t output_id,
22 uint32_t flags) {
23 const auto addNode =
24 CreateXNNAdd(_builder, input1_id, input2_id, output_id, flags);
25 const auto flatbufferNode =
26 CreateXNode(_builder, XNodeUnion::XNNAdd, addNode.Union());
27 _nodes.push_back(flatbufferNode);
28 }
29
serializeData(const uint8_t * data_ptr,size_t num_bytes)30 size_t XNNSerializer::serializeData(const uint8_t* data_ptr, size_t num_bytes) {
31 size_t constant_buffer_idx = 0;
32 // Handling the tensor _values with data
33 if (data_ptr != nullptr) {
34 // steps:
35 // 1. creating flatbuffer byte-vector for tensor data
36 auto storage = _builder.CreateVector(data_ptr, num_bytes);
37
38 // 2. put it in the common buffer
39 constant_buffer_idx = _constantBuffer.size();
40 _constantBuffer.emplace_back(CreateBuffer(_builder, storage));
41
42 // 3. record size into bufferSizes
43 _bufferSizes.push_back(num_bytes);
44 assert(_bufferSizes.size() == _constantBuffer.size());
45 }
46 return constant_buffer_idx;
47 }
48
serializeTensorValue(uint32_t xnn_datatype,size_t num_dims,std::vector<size_t> dims,size_t data_buffer_idx,uint32_t external_id,uint32_t flags,uint32_t id_out)49 void XNNSerializer::serializeTensorValue(
50 uint32_t xnn_datatype,
51 size_t num_dims,
52 std::vector<size_t> dims,
53 size_t data_buffer_idx,
54 uint32_t external_id,
55 uint32_t flags,
56 uint32_t id_out) {
57 std::vector<uint32_t> serialized_dims;
58 serialized_dims.reserve(dims.size());
59 for (auto dim : dims) {
60 serialized_dims.push_back(static_cast<uint32_t>(dim));
61 }
62
63 const auto tensorValue = CreateXNNTensorValueDirect(
64 _builder,
65 XNNDatatype(xnn_datatype),
66 num_dims,
67 &serialized_dims,
68 data_buffer_idx,
69 external_id,
70 flags,
71 id_out);
72
73 const auto flatbufferValue =
74 CreateXValue(_builder, XValueUnion::XNNTensorValue, tensorValue.Union());
75 _values.push_back(flatbufferValue);
76 }
77
finishAndSerialize(std::vector<uint32_t> input_ids,std::vector<uint32_t> output_ids,size_t num_extern_ids)78 std::string XNNSerializer::finishAndSerialize(
79 std::vector<uint32_t> input_ids,
80 std::vector<uint32_t> output_ids,
81 size_t num_extern_ids) {
82 auto xnnGraph = CreateXNNGraphDirect(
83 _builder,
84 _version_sha1,
85 &_nodes,
86 &_values,
87 num_extern_ids,
88 &input_ids,
89 &output_ids,
90 &_constantBuffer,
91 &_bufferSizes);
92
93 _builder.Finish(xnnGraph);
94
95 std::stringstream ss;
96 ss.write(
97 reinterpret_cast<char*>(_builder.GetBufferPointer()), _builder.GetSize());
98
99 return ss.str();
100 }
101
102 } // namespace delegate
103 } // namespace xnnpack
104 } // namespace jit
105 } // namespace torch
106