1 // Copyright (c) Meta Platforms, Inc. and affiliates. 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h> 7 #include <cstddef> 8 #include <cstdint> 9 #include <string> 10 #include <vector> 11 12 namespace torch { 13 namespace jit { 14 namespace xnnpack { 15 namespace delegate { 16 17 using namespace fb_xnnpack; // Specified in the schema 18 19 class XNNSerializer { 20 public: 21 // Constructors 22 // initial buffersize of 1024 which will grow 23 // automatically, constant buffer and buffer sizes initialized with dummy 24 // values as 0 index is reserved for non-constant tensors XNNSerializer()25 XNNSerializer() : XNNSerializer(1024) {} 26 XNNSerializer(size_t bufferSize)27 explicit XNNSerializer(size_t bufferSize) 28 : _builder(bufferSize), 29 _nodes(), 30 _values(), 31 _constantBuffer({CreateBuffer( 32 _builder, 33 {})}), // index 0 is reserved for non-const data 34 _bufferSizes({0}) {} 35 36 // Serializing Nodes 37 38 // Serialize add node, we are serializing the argument needed to call 39 // xnn_define_add2. Serializing these values, and at run time we build 40 // teh graph by re running xnn_define_add2 41 void serializeAddNode( 42 uint32_t input1_id, 43 uint32_t input2_id, 44 uint32_t output_id, 45 uint32_t flags); 46 47 // Serializing Values 48 void serializeTensorValue( 49 uint32_t xnn_datatype, 50 size_t num_dims, 51 std::vector<size_t> dims, 52 size_t buffer_data_idx, 53 uint32_t external_id, 54 uint32_t flags, 55 uint32_t id_out); 56 57 // finish and serialize xnngraph returning serialized data 58 std::string finishAndSerialize( 59 std::vector<uint32_t> input_ids, 60 std::vector<uint32_t> output_ids, 61 size_t num_extern_ids); 62 63 // decoupled data serialization with tensor values. This way constant tensor 64 // data can be referenced by multiple intermediate tensors. This call 65 // serializes the num_bytes of the data_ptr and returns the index it was 66 // placed in. 67 size_t serializeData(const uint8_t* data_ptr, size_t num_bytes); 68 69 private: 70 // xnnpack version we are serializing 71 const char* _version_sha1 = "ae108ef49aa5623b896fc93d4298c49d1750d9ba"; 72 73 // flatbuffer objects we will create and serialize together to create xnngraph 74 flatbuffers_fbsource::FlatBufferBuilder _builder; 75 76 // Vector of the serialized xnnpack nodes 77 std::vector<flatbuffers_fbsource::Offset<XNode>> _nodes; 78 79 // Vector of the serialized xnnpack values 80 std::vector<flatbuffers_fbsource::Offset<XValue>> _values; 81 82 std::vector<flatbuffers_fbsource::Offset<Buffer>> _constantBuffer; 83 std::vector<uint32_t> _bufferSizes; 84 }; 85 86 } // namespace delegate 87 } // namespace xnnpack 88 } // namespace jit 89 } // namespace torch 90