1 #include <torch/csrc/jit/passes/lower_graph.h>
2
3 #include <torch/csrc/jit/api/object.h>
4 #include <torch/csrc/jit/frontend/error_report.h>
5 #include <torch/csrc/jit/passes/inliner.h>
6 #include <torch/custom_class.h>
7 #include <unordered_map>
8
9 namespace torch::jit {
10
11 struct Slot {
12 c10::intrusive_ptr<c10::ivalue::Object> obj;
13 size_t offset;
operator ==torch::jit::Slot14 bool operator==(const Slot& other) const {
15 return (this->obj == other.obj && this->offset == other.offset);
16 }
17 };
18
19 // remove the first module argument, replacing any access of its
20 // parameters/attributes with extra_ivalue input Slots that hold what value to
21 // pass into the graph. Used for ONNX export to remove first-class modules
22 // so it can deal purely with parameters and inputs
lower_graph(const ModulePtr & self,Graph & g_,size_t self_offset=0)23 static std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
24 const ModulePtr& self,
25 Graph& g_,
26 size_t self_offset = 0) {
27 std::shared_ptr<Graph> g = g_.copy();
28 // Inline to remove method/function calls
29 Inline(*g);
30
31 std::vector<Slot> extra_ivalues;
32
33 struct SlotHash {
34 std::size_t operator()(const Slot& slot) const {
35 auto obj_hash = std::hash<c10::ivalue::Object*>{}(slot.obj.get());
36 auto offset_hash = std::hash<size_t>{}(slot.offset);
37 return c10::hash_combine(obj_hash, offset_hash);
38 }
39 };
40 std::unordered_map<Slot, size_t, SlotHash> slot_to_offset;
41 struct ToScan {
42 ModulePtr mod;
43 Node* n;
44 size_t offset;
45 };
46 std::vector<ToScan> to_scan;
47 std::vector<Node*> to_clean; // nodes that should be dead at the end
48
49 auto getOrAddSlot = [&](const Slot& slot) -> Value* {
50 auto it = slot_to_offset.find(slot);
51 if (it != slot_to_offset.end()) {
52 size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
53 return g->inputs().at(ivalues_start + it->second);
54 }
55 extra_ivalues.emplace_back(slot);
56 slot_to_offset[slot] = extra_ivalues.size() - 1;
57 return g->addInput()->setType(slot.obj->getSlot(slot.offset).type());
58 };
59
60 auto self_value = g->inputs().at(self_offset);
61
62 for (Use use : self_value->uses()) {
63 to_scan.emplace_back(ToScan{self, use.user, use.offset});
64 }
65 while (!to_scan.empty()) {
66 auto e = to_scan.back();
67 to_scan.pop_back();
68
69 // when we lambda lift forks, first-class modules may be passed across
70 // forks. This code recursively lowers the module in the fork call.
71 if (e.n->kind() == prim::fork) {
72 auto subgraph = e.n->g(attr::Subgraph);
73 std::vector<Slot> new_slots;
74 std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
75 e.n->g_(attr::Subgraph, subgraph);
76 for (const Slot& slot : new_slots) {
77 e.n->addInput(getOrAddSlot(slot));
78 }
79 e.n->removeInput(e.offset);
80 continue;
81 }
82 if (e.n->kind() == prim::PythonOp) {
83 throw ErrorReport(e.n->sourceRange()) << "Couldn't export Python method.";
84 }
85 if (e.n->kind() != prim::GetAttr) {
86 throw ErrorReport(e.n->sourceRange())
87 << "temporary: the only valid use of a module is looking up an "
88 "attribute but found "
89 << *e.n;
90 }
91 size_t slot_idx = e.mod->type()->getAttributeSlot(e.n->s(attr::name));
92 auto iv = e.mod->getSlot(slot_idx);
93 if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
94 if (c->is_module()) {
95 for (Use use : e.n->output()->uses()) {
96 to_scan.emplace_back(ToScan{iv.toObject(), use.user, use.offset});
97 }
98 to_clean.emplace_back(e.n);
99 continue;
100 }
101 }
102 e.n->output()->replaceAllUsesWith(getOrAddSlot({e.mod, slot_idx}));
103 e.n->destroy();
104 }
105
106 while (!to_clean.empty()) {
107 Node* n = to_clean.back();
108 AT_ASSERT(!n->hasUses());
109 n->destroy();
110 to_clean.pop_back();
111 }
112 AT_ASSERT(!self_value->hasUses());
113 g->eraseInput(self_offset);
114
115 return std::make_pair(std::move(g), std::move(extra_ivalues));
116 }
117
loadTensors(const std::vector<Slot> & slots)118 static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
119 std::vector<IValue> result;
120 result.reserve(slots.size());
121 for (const Slot& slot : slots) {
122 auto obj = slot.obj->getSlot(slot.offset);
123 if (obj.isTensor()) {
124 result.emplace_back(obj.toTensor());
125 } else {
126 // Unpack quantization packed tensor
127 auto type = obj.type();
128 TORCH_CHECK(
129 (type ==
130 getCustomClass(
131 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
132 (type ==
133 getCustomClass(
134 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
135 (type ==
136 getCustomClass(
137 "__torch__.torch.classes.quantized.LinearPackedParamsBase")),
138 "Unknown type ",
139 type->repr_str(),
140 " encountered in graph lowering. This type is not supported in ONNX export.");
141 result.emplace_back(
142 script::Object(obj.toObject()).run_method("__getstate__"));
143 }
144 }
145 return result;
146 }
147
LowerGraph(Graph & graph,const ModulePtr & self)148 std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph(
149 Graph& graph,
150 const ModulePtr& self) {
151 auto result = lower_graph(self, graph);
152 return std::make_pair(result.first, loadTensors(result.second));
153 }
154
155 } // namespace torch::jit
156