1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h>
7 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
8
9 #include <ATen/Utils.h>
10
11 namespace torch {
12 namespace jit {
13 namespace xnnpack {
14 namespace delegate {
15
compileModel(const void * buffer_pointer,size_t num_bytes,XNNExecutor * executor)16 void XNNCompiler::compileModel(
17 const void* buffer_pointer,
18 size_t num_bytes,
19 XNNExecutor* executor) {
20 auto output_min = -std::numeric_limits<float>::infinity();
21 auto output_max = std::numeric_limits<float>::infinity();
22
23 auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(buffer_pointer);
24 // initialize xnnpack
25 xnn_status status = xnn_initialize(/*allocator =*/nullptr);
26 TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack");
27
28 // create xnnpack subgraph
29 xnn_subgraph_t subgraph_ptr = nullptr;
30 status = xnn_create_subgraph(
31 /*external_value_ids=*/flatbuffer_graph->num_externs(),
32 /*flags=*/0,
33 &subgraph_ptr);
34 TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph");
35
36 // mapping from old ids to new created value ids
37 // The old ids that were serialied were generated AoT, since
38 // we are re-defining tensor values, the defined IDs could be
39 // different from the ones generated AoT, as a result, we need
40 // a new mapping from the old ids to the newly created ones
41 std::unordered_map<uint32_t, uint32_t> remapped_ids;
42
43 for (auto value : *flatbuffer_graph->xvalues()) {
44 switch (value->xvalue_type()) {
45 case fb_xnnpack::XValueUnion::XNNTensorValue: {
46 auto tensor_value = value->xvalue_as_XNNTensorValue();
47
48 std::vector<size_t> dims_data;
49 for (auto dim : *tensor_value->dims()) {
50 dims_data.push_back(static_cast<size_t>(dim));
51 }
52
53 uint32_t id = XNN_INVALID_VALUE_ID;
54 const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
55 auto buffer_idx = tensor_value->constant_buffer_idx();
56 const auto buffer_ptr = buffer_idx == 0
57 ? nullptr
58 : constant_buffer[buffer_idx]->storage()->data();
59 status = xnn_define_tensor_value(
60 /*subgraph=*/subgraph_ptr,
61 /*datatype=*/xnn_datatype_fp32,
62 /*num_dims=*/tensor_value->num_dims(),
63 /*dims=*/dims_data.data(),
64 /*data=*/buffer_ptr,
65 /*external_id=*/tensor_value->external_id(),
66 /*flags=*/tensor_value->flags(),
67 /*id_out=*/&id);
68 TORCH_CHECK(
69 status == xnn_status_success,
70 "Failed to define tensor values in graph")
71 // map serialized id to newly generated id
72 remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
73 break;
74 }
75 default: {
76 TORCH_CHECK(false, "Unhandled value type found in deserialization");
77 }
78 }
79 }
80
81 for (auto node : *flatbuffer_graph->xnodes()) {
82 switch (node->xnode_type()) {
83 case fb_xnnpack::XNodeUnion::XNNAdd: {
84 auto graph_node = node->xnode_as_XNNAdd();
85 status = xnn_define_add2(
86 subgraph_ptr,
87 output_min,
88 output_max,
89 remapped_ids.at(graph_node->input1_id()),
90 remapped_ids.at(graph_node->input2_id()),
91 remapped_ids.at(graph_node->output_id()),
92 graph_node->flags());
93 TORCH_CHECK(status == xnn_status_success, "Failed to create add node")
94 break;
95 }
96 default:
97 TORCH_CHECK(false, "Unhandled node type found in deserialization");
98 }
99 }
100
101 xnn_runtime_t runtime_ptr = nullptr;
102 status = xnn_create_runtime_v2(subgraph_ptr, nullptr, 0, &runtime_ptr);
103 TORCH_CHECK(xnn_status_success == status);
104
105 executor->runtime_ =
106 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>(
107 runtime_ptr, xnn_delete_runtime);
108
109 for (auto old_id : *flatbuffer_graph->input_ids()) {
110 executor->input_ids_.emplace_back(remapped_ids.at(old_id));
111 }
112
113 for (auto old_id : *flatbuffer_graph->output_ids()) {
114 executor->output_ids_.emplace_back(remapped_ids.at(old_id));
115 }
116 };
117
118 } // namespace delegate
119 } // namespace xnnpack
120 } // namespace jit
121 } // namespace torch
122