1 #include <torch/csrc/python_headers.h>
2
3 #include <torch/csrc/jit/frontend/tracer.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/inliner.h>
6 #include <torch/csrc/jit/passes/lower_tuples.h>
7 #include <torch/csrc/jit/python/pybind.h>
8 #include <torch/csrc/jit/python/python_tracer.h>
9 #include <torch/csrc/jit/serialization/export.h>
10 #include <torch/csrc/utils/python_strings.h>
11
12 #include <c10/util/Exception.h>
13 #include <c10/util/irange.h>
14
15 #include <sstream>
16
17 using namespace torch::autograd;
18 using namespace torch::jit;
19 using namespace torch::jit::tracer;
20
21 namespace torch::jit::tracer {
22
23 // Python interpreter retrieval routine adapted from
24 // https://stackoverflow.com/a/8706144
_pythonCallstack()25 std::vector<StackEntry> _pythonCallstack() {
26 pybind11::gil_scoped_acquire gil;
27 PyFrameObject* frame = PyEval_GetFrame();
28 Py_XINCREF(frame);
29 std::vector<StackEntry> entries;
30
31 while (nullptr != frame) {
32 auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
33 size_t line = PyCode_Addr2Line(code.get(), PyFrame_GetLasti(frame));
34 std::string filename = THPUtils_unpackString(code->co_filename);
35 std::string funcname = THPUtils_unpackString(code->co_name);
36 auto source = std::make_shared<Source>(funcname, filename, line);
37 entries.emplace_back(
38 StackEntry{funcname, SourceRange(source, 0, funcname.size())});
39 auto new_frame = PyFrame_GetBack(frame);
40 Py_DECREF(frame);
41 frame = new_frame;
42 }
43 return entries;
44 }
45
getPythonInterpreterSourceRange()46 SourceRange getPythonInterpreterSourceRange() {
47 auto cs = pythonCallstack();
48 std::optional<std::string> source_filename;
49 size_t source_line = 0;
50 std::stringstream stack_trace;
51 for (const auto& entry : cs) {
52 auto& range = entry.range;
53 if (range.source()) {
54 auto& src = range.source();
55 if (src && src->filename()) {
56 auto line =
57 src->starting_line_no() + src->lineno_for_offset(range.start());
58 stack_trace << *(src->filename()) << "(" << line
59 << "): " << entry.filename << "\n";
60 if (!source_filename) {
61 source_filename = *(src->filename());
62 source_line = line;
63 }
64 }
65 }
66 }
67
68 auto stack_trace_text = stack_trace.str();
69 auto source =
70 std::make_shared<Source>(stack_trace_text, source_filename, source_line);
71 return SourceRange(source, 0, stack_trace_text.size());
72 }
73
createGraphByTracingWithDict(const py::function & func,const py::dict & inputs_dict,const Stack & trace_inputs,const py::function & var_name_lookup_fn,bool strict,bool force_outplace,Module * self,const std::vector<std::string> & argument_names)74 std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
75 const py::function& func,
76 const py::dict& inputs_dict,
77 const Stack& trace_inputs,
78 const py::function& var_name_lookup_fn,
79 bool strict,
80 bool force_outplace,
81 Module* self,
82 const std::vector<std::string>& argument_names) {
83 C10_LOG_API_USAGE_ONCE("torch.tracer");
84
85 auto lookup_fn_adapter =
86 [var_name_lookup_fn](const Variable& var) -> std::string {
87 pybind11::gil_scoped_acquire ag;
88 return py::cast<std::string>(var_name_lookup_fn(var));
89 };
90
91 // The argument_names parameter is parsed in python and its order
92 // is the same as the arguments' decalaration order in forward() method.
93 // These name shall be added to the graph as debug name and the order
94 // should align with the traceable stack we generated by the python dict.
95 std::vector<std::string> compact_argument_names;
96 Stack compact_trace_inputs;
97 for (const auto& argument_name : argument_names) {
98 if (inputs_dict.contains(argument_name)) {
99 compact_argument_names.push_back(argument_name);
100 }
101 }
102 for (const auto& compact_argument_name : compact_argument_names) {
103 for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) {
104 if (py::cast<std::string>(it->first) == compact_argument_name) {
105 compact_trace_inputs.push_back(
106 toIValue(it->second, tryToInferType(it->second).type()));
107 }
108 }
109 }
110
111 auto outs = tracer::trace(
112 std::move(compact_trace_inputs),
113 [&](const Stack& inputs) -> Stack {
114 // We just leave the inputs_dict as it was and pass it to forward
115 // method.
116 auto out = func(**inputs_dict);
117 if (out.ptr() == Py_None) {
118 AT_ERROR(
119 "The traced function didn't return any values! Side-effects are not "
120 "captured in traces, so it would be a no-op.");
121 }
122 return {toTypeInferredIValue(out)};
123 },
124 lookup_fn_adapter,
125 strict,
126 force_outplace,
127 self,
128 compact_argument_names);
129 return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs));
130 }
131
createGraphByTracing(const py::function & func,Stack trace_inputs,const py::function & var_name_lookup_fn,bool strict,bool force_outplace,Module * self,const std::vector<std::string> & argument_names)132 std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
133 const py::function& func,
134 Stack trace_inputs,
135 const py::function& var_name_lookup_fn,
136 bool strict,
137 bool force_outplace,
138 Module* self,
139 const std::vector<std::string>& argument_names) {
140 C10_LOG_API_USAGE_ONCE("torch.tracer");
141
142 auto lookup_fn_adapter =
143 [var_name_lookup_fn](const Variable& var) -> std::string {
144 pybind11::gil_scoped_acquire ag;
145 return py::cast<std::string>(var_name_lookup_fn(var));
146 };
147
148 auto outs = tracer::trace(
149 std::move(trace_inputs),
150 [&func](Stack inputs) -> Stack {
151 size_t num_func_inputs = inputs.size();
152 py::tuple py_inputs(num_func_inputs);
153 for (const auto i : c10::irange(num_func_inputs)) {
154 py_inputs[i] = py::cast(inputs[i]);
155 }
156 auto out = func(*py_inputs);
157 if (out.ptr() == Py_None) {
158 AT_ERROR(
159 "The traced function didn't return any values! Side-effects are not "
160 "captured in traces, so it would be a no-op.");
161 }
162 return {toTypeInferredIValue(out)};
163 },
164 lookup_fn_adapter,
165 strict,
166 force_outplace,
167 self,
168 argument_names);
169 return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs));
170 }
171
preRecordPythonTrace(THPObjectPtr pyobj,const std::string & arg_types,at::ArrayRef<Variable> inputs,pyobj_list scalar_args)172 Node* preRecordPythonTrace(
173 THPObjectPtr pyobj,
174 const std::string& arg_types,
175 at::ArrayRef<Variable> inputs,
176 pyobj_list scalar_args) {
177 THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
178 if (!apply) {
179 throw python_error();
180 }
181
182 auto& graph = getTracingState()->graph;
183
184 Node* n = graph->createPythonOp(
185 std::move(apply), arg_types, std::move(scalar_args));
186 recordSourceLocation(n);
187
188 for (const Variable& input : inputs) {
189 n->addInput(getValueTrace(input));
190 }
191
192 graph->insertNode(n);
193
194 return n;
195 }
196
pythonRecordSourceLocation(Node * n)197 void pythonRecordSourceLocation(Node* n) {
198 n->setSourceRange(getPythonInterpreterSourceRange());
199 }
200
pythonWarn(const std::string & reason)201 void pythonWarn(const std::string& reason) {
202 pybind11::gil_scoped_acquire gil;
203 auto warn_class = py::module::import("torch.jit").attr("TracerWarning");
204 PyErr_WarnEx(warn_class.ptr(), reason.c_str(), 1);
205 }
206
initPythonTracerBindings(PyObject * module)207 void initPythonTracerBindings(PyObject* module) {
208 setPythonCallstack(_pythonCallstack);
209 setRecordSourceLocation(pythonRecordSourceLocation);
210
211 auto m = py::handle(module).cast<py::module>();
212 py::class_<TracingState, std::shared_ptr<TracingState>>(
213 m, "TracingState", py::dynamic_attr())
214 // NB: no constructor; you have to get it from C++ code
215 .def(
216 "__repr__",
217 [](const TracingState& s) {
218 std::ostringstream ss;
219 ss << "<TracingState " << (const void*)&s << ">";
220 return ss.str();
221 })
222 .def(
223 "__str__",
224 [](const TracingState& s) -> std::string {
225 std::ostringstream ss;
226 ss << *s.graph;
227 return ss.str();
228 })
229 .def(
230 "push_scope",
231 [](TracingState& s, const std::string& scope_name) {
232 s.graph->push_scope(scope_name);
233 })
234 .def("pop_scope", [](TracingState& s) { s.graph->pop_scope(); })
235 .def(
236 "current_scope",
237 [](TracingState& s) {
238 return s.graph->current_scope()->name().toUnqualString();
239 })
240 .def(
241 "set_graph",
242 [](TracingState& s, std::shared_ptr<Graph> g) {
243 s.graph = std::move(g);
244 })
245 .def("graph", [](TracingState& s) { return s.graph; });
246
247 m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
248 m.def(
249 "_create_graph_by_tracing",
250 createGraphByTracing,
251 py::arg("func"),
252 py::arg("inputs"),
253 py::arg("var_name_lookup_fn"),
254 py::arg("strict"),
255 py::arg("force_outplace"),
256 py::arg("self") = nullptr,
257 py::arg("argument_names") = std::vector<std::string>());
258 m.def("_get_tracing_state", []() { return getTracingState(); });
259 m.def("_set_tracing_state", [](std::shared_ptr<TracingState> state) {
260 return setTracingState(std::move(state));
261 });
262 m.def("_get_value_trace", [](const Variable& var) {
263 return getValueTrace(var);
264 });
265 m.def("_set_value_trace", [](const Variable& var, Value* value) {
266 return setValueTrace(var, value);
267 });
268 m.def("_tracer_set_get_unique_name_fn", [](const py::function& func) {
269 const auto& tracing_state = getTracingState();
270 AT_ASSERT(tracing_state);
271 tracing_state->lookup_var_name_fn =
272 [func](const Variable& var) -> std::string {
273 pybind11::gil_scoped_acquire ag;
274 return py::cast<std::string>(func(var));
275 };
276 });
277 m.def("_tracer_set_force_outplace", [](bool force_outplace) {
278 const auto& tracing_state = getTracingState();
279 AT_ASSERT(tracing_state);
280 tracing_state->force_outplace = force_outplace;
281 });
282 }
283
284 } // namespace torch::jit::tracer
285