xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/replacement_of_old_operators.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/replacement_of_old_operators.h>
2 
3 #include <c10/util/Exception.h>
4 #include <caffe2/serialize/versions.h>
5 #include <torch/csrc/jit/frontend/schema_matching.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
8 #include <torch/csrc/jit/operator_upgraders/utils.h>
9 #include <torch/csrc/jit/operator_upgraders/version_map.h>
10 #include <torch/csrc/jit/runtime/graph_iterator.h>
11 #include <limits>
12 #include <string>
13 #include <unordered_map>
14 #include <utility>
15 
16 namespace torch::jit {
17 
18 struct OldOpsReplacerWithUpgraders {
OldOpsReplacerWithUpgraderstorch::jit::OldOpsReplacerWithUpgraders19   OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
20       : graph_(std::move(graph)) {}
21 
runtorch::jit::OldOpsReplacerWithUpgraders22   void run() {
23     if (!graph_->get_op_version().has_value()) {
24       return;
25     }
26 
27     auto current_version = graph_->get_op_version().value();
28     DepthFirstGraphNodeIterator graph_it(graph_);
29     Node* node = graph_it.next();
30     while (node) {
31       // load the schema name for this op
32       std::optional<std::string> schema_name = std::nullopt;
33       if (auto op_schema = node->maybeSchema()) {
34         schema_name = getFullSchemaName(*op_schema);
35       } else {
36         schema_name = node->getHistoricSchemaName();
37       }
38 
39       if (schema_name.has_value()) {
40         // this implies there was a version bump because of this operator
41         auto version_entry =
42             get_operator_version_map().find(schema_name.value());
43         if (version_entry != get_operator_version_map().end()) {
44           const auto& entry = version_entry->second;
45           auto upgrader_entry = findUpgrader(entry, current_version);
46           if (!upgrader_entry.has_value()) {
47             if (!isOpSymbolCurrent(schema_name.value(), current_version)) {
48               TORCH_INTERNAL_ASSERT(
49                   false,
50                   "Upgrader must be present for ",
51                   schema_name.value(),
52                   ". The upgrader might have deprecated");
53             }
54             node = graph_it.next();
55             continue;
56           }
57           auto upgrader_entry_val = upgrader_entry.value();
58           auto upgrader_name = upgrader_entry_val.upgrader_name;
59           auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
60           TORCH_INTERNAL_ASSERT(
61               upgrader_graph_entry != dump_upgraders_map().end(),
62               "Corresponding upgrader graph for ",
63               upgrader_name,
64               " must exist.",
65               " This upgrader"
66               " might be deprecated.");
67 
68           auto upgrader_graph = upgrader_graph_entry->second;
69           // inline the upgrader function body
70           WithInsertPoint guard(node);
71           auto new_outputs = insertGraph(
72               *node->owningGraph(), *upgrader_graph, node->inputs());
73           const auto& old_outputs = node->outputs();
74           TORCH_INTERNAL_ASSERT(new_outputs.size() == old_outputs.size());
75           for (const auto i : c10::irange(old_outputs.size())) {
76             TORCH_INTERNAL_ASSERT(
77                 new_outputs[i]->type() == old_outputs[i]->type())
78             old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
79           }
80           node->removeAllInputs();
81           node->destroy();
82         }
83       }
84       node = graph_it.next();
85     }
86 
87     // now that we updated the graph, we want to bump the
88     // graph version too.
89     graph_->set_op_version(caffe2::serialize::kProducedFileFormatVersion);
90   }
91 
92   std::shared_ptr<Graph> graph_;
93 };
94 
ReplaceOldOperatorsWithUpgraders(std::shared_ptr<Graph> graph)95 TORCH_API void ReplaceOldOperatorsWithUpgraders(std::shared_ptr<Graph> graph) {
96   OldOpsReplacerWithUpgraders(std::move(graph)).run();
97 }
98 
99 } // namespace torch::jit
100