xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h>
2 
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/quantization/helper.h>
6 #include <torch/csrc/jit/runtime/graph_iterator.h>
7 
8 namespace torch {
9 namespace jit {
10 
11 namespace {
12 
DBRQuantRemoveRedundantAliasesImpl(const Method & method)13 void DBRQuantRemoveRedundantAliasesImpl(const Method& method) {
14   auto g = method.graph();
15   const bool is_frozen = false;
16   const bool descend_function_calls = true;
17   AliasDb alias_db(g, is_frozen, descend_function_calls);
18   // find the alias nodes
19   std::vector<Node*> alias_nodes;
20   DepthFirstGraphNodeIterator it(g);
21   Node* node = nullptr;
22   while ((node = it.next()) != nullptr) {
23     if (node->kind() == Symbol::aten("alias")) {
24       alias_nodes.push_back(node);
25     }
26   }
27 
28   // remove the alias nodes, if it is safe to do so
29   for (auto* node : alias_nodes) {
30     GRAPH_DEBUG(*node);
31 
32     Value* input_value = node->input();
33     Value* output_value = node->output();
34 
35     bool always_safe_to_mutate = alias_db.safeToChangeAliasingRelationship(
36         node->inputs(), node->outputs());
37 
38     const auto g_in = g->inputs();
39     const auto g_out = g->outputs();
40     bool is_input =
41         std::find(g_in.begin(), g_in.end(), input_value) != g_in.end();
42     bool is_output =
43         std::find(g_out.begin(), g_out.end(), output_value) != g_out.end();
44     // We assume that aliasing is safe to update on inputs and outputs if they
45     // do not have writers.
46     bool input_safe_to_mutate =
47         (is_input && !alias_db.hasWriters(input_value) &&
48          !alias_db.hasWriters(output_value));
49     bool output_safe_to_mutate =
50         (is_output && !alias_db.hasWriters(input_value) &&
51          !alias_db.hasWriters(output_value));
52 
53     if (always_safe_to_mutate || input_safe_to_mutate ||
54         output_safe_to_mutate) {
55       output_value->replaceAllUsesWith(input_value);
56       node->destroy();
57     }
58   }
59 }
60 
61 } // namespace
62 
DBRQuantRemoveRedundantAliases(Module & module)63 Module DBRQuantRemoveRedundantAliases(Module& module) {
64   for (const auto& child : module.modules()) {
65     for (const auto& method : child.get_methods()) {
66       DBRQuantRemoveRedundantAliasesImpl(method);
67     }
68   }
69 
70   return module;
71 }
72 
73 } // namespace jit
74 } // namespace torch
75