xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/lower_graph.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/lower_graph.h>
2 
3 #include <torch/csrc/jit/api/object.h>
4 #include <torch/csrc/jit/frontend/error_report.h>
5 #include <torch/csrc/jit/passes/inliner.h>
6 #include <torch/custom_class.h>
7 #include <unordered_map>
8 
9 namespace torch::jit {
10 
11 struct Slot {
12   c10::intrusive_ptr<c10::ivalue::Object> obj;
13   size_t offset;
operator ==torch::jit::Slot14   bool operator==(const Slot& other) const {
15     return (this->obj == other.obj && this->offset == other.offset);
16   }
17 };
18 
19 // remove the first module argument, replacing any access of its
20 // parameters/attributes with extra_ivalue input Slots that hold what value to
21 // pass into the graph. Used for ONNX export to remove first-class modules
22 // so it can deal purely with parameters and inputs
lower_graph(const ModulePtr & self,Graph & g_,size_t self_offset=0)23 static std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
24     const ModulePtr& self,
25     Graph& g_,
26     size_t self_offset = 0) {
27   std::shared_ptr<Graph> g = g_.copy();
28   // Inline to remove method/function calls
29   Inline(*g);
30 
31   std::vector<Slot> extra_ivalues;
32 
33   struct SlotHash {
34     std::size_t operator()(const Slot& slot) const {
35       auto obj_hash = std::hash<c10::ivalue::Object*>{}(slot.obj.get());
36       auto offset_hash = std::hash<size_t>{}(slot.offset);
37       return c10::hash_combine(obj_hash, offset_hash);
38     }
39   };
40   std::unordered_map<Slot, size_t, SlotHash> slot_to_offset;
41   struct ToScan {
42     ModulePtr mod;
43     Node* n;
44     size_t offset;
45   };
46   std::vector<ToScan> to_scan;
47   std::vector<Node*> to_clean; // nodes that should be dead at the end
48 
49   auto getOrAddSlot = [&](const Slot& slot) -> Value* {
50     auto it = slot_to_offset.find(slot);
51     if (it != slot_to_offset.end()) {
52       size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
53       return g->inputs().at(ivalues_start + it->second);
54     }
55     extra_ivalues.emplace_back(slot);
56     slot_to_offset[slot] = extra_ivalues.size() - 1;
57     return g->addInput()->setType(slot.obj->getSlot(slot.offset).type());
58   };
59 
60   auto self_value = g->inputs().at(self_offset);
61 
62   for (Use use : self_value->uses()) {
63     to_scan.emplace_back(ToScan{self, use.user, use.offset});
64   }
65   while (!to_scan.empty()) {
66     auto e = to_scan.back();
67     to_scan.pop_back();
68 
69     // when we lambda lift forks, first-class modules may be passed across
70     // forks. This code recursively lowers the module in the fork call.
71     if (e.n->kind() == prim::fork) {
72       auto subgraph = e.n->g(attr::Subgraph);
73       std::vector<Slot> new_slots;
74       std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
75       e.n->g_(attr::Subgraph, subgraph);
76       for (const Slot& slot : new_slots) {
77         e.n->addInput(getOrAddSlot(slot));
78       }
79       e.n->removeInput(e.offset);
80       continue;
81     }
82     if (e.n->kind() == prim::PythonOp) {
83       throw ErrorReport(e.n->sourceRange()) << "Couldn't export Python method.";
84     }
85     if (e.n->kind() != prim::GetAttr) {
86       throw ErrorReport(e.n->sourceRange())
87           << "temporary: the only valid use of a module is looking up an "
88              "attribute but found "
89           << *e.n;
90     }
91     size_t slot_idx = e.mod->type()->getAttributeSlot(e.n->s(attr::name));
92     auto iv = e.mod->getSlot(slot_idx);
93     if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
94       if (c->is_module()) {
95         for (Use use : e.n->output()->uses()) {
96           to_scan.emplace_back(ToScan{iv.toObject(), use.user, use.offset});
97         }
98         to_clean.emplace_back(e.n);
99         continue;
100       }
101     }
102     e.n->output()->replaceAllUsesWith(getOrAddSlot({e.mod, slot_idx}));
103     e.n->destroy();
104   }
105 
106   while (!to_clean.empty()) {
107     Node* n = to_clean.back();
108     AT_ASSERT(!n->hasUses());
109     n->destroy();
110     to_clean.pop_back();
111   }
112   AT_ASSERT(!self_value->hasUses());
113   g->eraseInput(self_offset);
114 
115   return std::make_pair(std::move(g), std::move(extra_ivalues));
116 }
117 
loadTensors(const std::vector<Slot> & slots)118 static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
119   std::vector<IValue> result;
120   result.reserve(slots.size());
121   for (const Slot& slot : slots) {
122     auto obj = slot.obj->getSlot(slot.offset);
123     if (obj.isTensor()) {
124       result.emplace_back(obj.toTensor());
125     } else {
126       // Unpack quantization packed tensor
127       auto type = obj.type();
128       TORCH_CHECK(
129           (type ==
130            getCustomClass(
131                "__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
132               (type ==
133                getCustomClass(
134                    "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
135               (type ==
136                getCustomClass(
137                    "__torch__.torch.classes.quantized.LinearPackedParamsBase")),
138           "Unknown type ",
139           type->repr_str(),
140           " encountered in graph lowering. This type is not supported in ONNX export.");
141       result.emplace_back(
142           script::Object(obj.toObject()).run_method("__getstate__"));
143     }
144   }
145   return result;
146 }
147 
LowerGraph(Graph & graph,const ModulePtr & self)148 std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph(
149     Graph& graph,
150     const ModulePtr& self) {
151   auto result = lower_graph(self, graph);
152   return std::make_pair(result.first, loadTensors(result.second));
153 }
154 
155 } // namespace torch::jit
156