1 #include <torch/csrc/jit/passes/replacement_of_old_operators.h>
2
3 #include <c10/util/Exception.h>
4 #include <caffe2/serialize/versions.h>
5 #include <torch/csrc/jit/frontend/schema_matching.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
8 #include <torch/csrc/jit/operator_upgraders/utils.h>
9 #include <torch/csrc/jit/operator_upgraders/version_map.h>
10 #include <torch/csrc/jit/runtime/graph_iterator.h>
11 #include <limits>
12 #include <string>
13 #include <unordered_map>
14 #include <utility>
15
16 namespace torch::jit {
17
18 struct OldOpsReplacerWithUpgraders {
OldOpsReplacerWithUpgraderstorch::jit::OldOpsReplacerWithUpgraders19 OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
20 : graph_(std::move(graph)) {}
21
runtorch::jit::OldOpsReplacerWithUpgraders22 void run() {
23 if (!graph_->get_op_version().has_value()) {
24 return;
25 }
26
27 auto current_version = graph_->get_op_version().value();
28 DepthFirstGraphNodeIterator graph_it(graph_);
29 Node* node = graph_it.next();
30 while (node) {
31 // load the schema name for this op
32 std::optional<std::string> schema_name = std::nullopt;
33 if (auto op_schema = node->maybeSchema()) {
34 schema_name = getFullSchemaName(*op_schema);
35 } else {
36 schema_name = node->getHistoricSchemaName();
37 }
38
39 if (schema_name.has_value()) {
40 // this implies there was a version bump because of this operator
41 auto version_entry =
42 get_operator_version_map().find(schema_name.value());
43 if (version_entry != get_operator_version_map().end()) {
44 const auto& entry = version_entry->second;
45 auto upgrader_entry = findUpgrader(entry, current_version);
46 if (!upgrader_entry.has_value()) {
47 if (!isOpSymbolCurrent(schema_name.value(), current_version)) {
48 TORCH_INTERNAL_ASSERT(
49 false,
50 "Upgrader must be present for ",
51 schema_name.value(),
52 ". The upgrader might have deprecated");
53 }
54 node = graph_it.next();
55 continue;
56 }
57 auto upgrader_entry_val = upgrader_entry.value();
58 auto upgrader_name = upgrader_entry_val.upgrader_name;
59 auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
60 TORCH_INTERNAL_ASSERT(
61 upgrader_graph_entry != dump_upgraders_map().end(),
62 "Corresponding upgrader graph for ",
63 upgrader_name,
64 " must exist.",
65 " This upgrader"
66 " might be deprecated.");
67
68 auto upgrader_graph = upgrader_graph_entry->second;
69 // inline the upgrader function body
70 WithInsertPoint guard(node);
71 auto new_outputs = insertGraph(
72 *node->owningGraph(), *upgrader_graph, node->inputs());
73 const auto& old_outputs = node->outputs();
74 TORCH_INTERNAL_ASSERT(new_outputs.size() == old_outputs.size());
75 for (const auto i : c10::irange(old_outputs.size())) {
76 TORCH_INTERNAL_ASSERT(
77 new_outputs[i]->type() == old_outputs[i]->type())
78 old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
79 }
80 node->removeAllInputs();
81 node->destroy();
82 }
83 }
84 node = graph_it.next();
85 }
86
87 // now that we updated the graph, we want to bump the
88 // graph version too.
89 graph_->set_op_version(caffe2::serialize::kProducedFileFormatVersion);
90 }
91
92 std::shared_ptr<Graph> graph_;
93 };
94
ReplaceOldOperatorsWithUpgraders(std::shared_ptr<Graph> graph)95 TORCH_API void ReplaceOldOperatorsWithUpgraders(std::shared_ptr<Graph> graph) {
96 OldOpsReplacerWithUpgraders(std::move(graph)).run();
97 }
98
99 } // namespace torch::jit
100