xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h>
7 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
8 
9 #include <ATen/Utils.h>
10 
11 namespace torch {
12 namespace jit {
13 namespace xnnpack {
14 namespace delegate {
15 
compileModel(const void * buffer_pointer,size_t num_bytes,XNNExecutor * executor)16 void XNNCompiler::compileModel(
17     const void* buffer_pointer,
18     size_t num_bytes,
19     XNNExecutor* executor) {
20   auto output_min = -std::numeric_limits<float>::infinity();
21   auto output_max = std::numeric_limits<float>::infinity();
22 
23   auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(buffer_pointer);
24   // initialize xnnpack
25   xnn_status status = xnn_initialize(/*allocator =*/nullptr);
26   TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack");
27 
28   // create xnnpack subgraph
29   xnn_subgraph_t subgraph_ptr = nullptr;
30   status = xnn_create_subgraph(
31       /*external_value_ids=*/flatbuffer_graph->num_externs(),
32       /*flags=*/0,
33       &subgraph_ptr);
34   TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph");
35 
36   // mapping from old ids to new created value ids
37   // The old ids that were serialied were generated AoT, since
38   // we are re-defining tensor values, the defined IDs could be
39   // different from the ones generated AoT, as a result, we need
40   // a new mapping from the old ids to the newly created ones
41   std::unordered_map<uint32_t, uint32_t> remapped_ids;
42 
43   for (auto value : *flatbuffer_graph->xvalues()) {
44     switch (value->xvalue_type()) {
45       case fb_xnnpack::XValueUnion::XNNTensorValue: {
46         auto tensor_value = value->xvalue_as_XNNTensorValue();
47 
48         std::vector<size_t> dims_data;
49         for (auto dim : *tensor_value->dims()) {
50           dims_data.push_back(static_cast<size_t>(dim));
51         }
52 
53         uint32_t id = XNN_INVALID_VALUE_ID;
54         const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
55         auto buffer_idx = tensor_value->constant_buffer_idx();
56         const auto buffer_ptr = buffer_idx == 0
57             ? nullptr
58             : constant_buffer[buffer_idx]->storage()->data();
59         status = xnn_define_tensor_value(
60             /*subgraph=*/subgraph_ptr,
61             /*datatype=*/xnn_datatype_fp32,
62             /*num_dims=*/tensor_value->num_dims(),
63             /*dims=*/dims_data.data(),
64             /*data=*/buffer_ptr,
65             /*external_id=*/tensor_value->external_id(),
66             /*flags=*/tensor_value->flags(),
67             /*id_out=*/&id);
68         TORCH_CHECK(
69             status == xnn_status_success,
70             "Failed to define tensor values in graph")
71         // map serialized id to newly generated id
72         remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
73         break;
74       }
75       default: {
76         TORCH_CHECK(false, "Unhandled value type found in deserialization");
77       }
78     }
79   }
80 
81   for (auto node : *flatbuffer_graph->xnodes()) {
82     switch (node->xnode_type()) {
83       case fb_xnnpack::XNodeUnion::XNNAdd: {
84         auto graph_node = node->xnode_as_XNNAdd();
85         status = xnn_define_add2(
86             subgraph_ptr,
87             output_min,
88             output_max,
89             remapped_ids.at(graph_node->input1_id()),
90             remapped_ids.at(graph_node->input2_id()),
91             remapped_ids.at(graph_node->output_id()),
92             graph_node->flags());
93         TORCH_CHECK(status == xnn_status_success, "Failed to create add node")
94         break;
95       }
96       default:
97         TORCH_CHECK(false, "Unhandled node type found in deserialization");
98     }
99   }
100 
101   xnn_runtime_t runtime_ptr = nullptr;
102   status = xnn_create_runtime_v2(subgraph_ptr, nullptr, 0, &runtime_ptr);
103   TORCH_CHECK(xnn_status_success == status);
104 
105   executor->runtime_ =
106       std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>(
107           runtime_ptr, xnn_delete_runtime);
108 
109   for (auto old_id : *flatbuffer_graph->input_ids()) {
110     executor->input_ids_.emplace_back(remapped_ids.at(old_id));
111   }
112 
113   for (auto old_id : *flatbuffer_graph->output_ids()) {
114     executor->output_ids_.emplace_back(remapped_ids.at(old_id));
115   }
116 };
117 
118 } // namespace delegate
119 } // namespace xnnpack
120 } // namespace jit
121 } // namespace torch
122