1 #include <torch/csrc/jit/mobile/interpreter.h>
2
3 #include <ATen/core/class_type.h>
4 #include <ATen/core/dynamic_type.h>
5 #include <ATen/core/function.h>
6 #include <ATen/core/jit_type.h>
7 #include <ATen/core/operator_name.h>
8 #include <ATen/record_function.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/irange.h>
11 #include <torch/csrc/jit/backends/backend_exception.h>
12 #include <torch/csrc/jit/mobile/function.h>
13 #include <torch/csrc/jit/mobile/observer.h>
14 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
15 #include <torch/csrc/jit/runtime/jit_exception.h>
16 #include <torch/csrc/jit/runtime/vararg_functions.h>
17
18 namespace torch::jit {
19 char const* toString(OpCode op);
20 std::ostream& operator<<(std::ostream& out, Instruction inst);
21 namespace mobile {
InterpreterState(const Code & code)22 InterpreterState::InterpreterState(const Code& code) {
23 enterFrame(code);
24 }
25
26 namespace {
27 static thread_local std::vector<DebugHandle> exception_debug_handles_;
createObject(Stack & stack,const at::ClassTypePtr & type)28 void createObject(Stack& stack, const at::ClassTypePtr& type) {
29 auto userObj = c10::ivalue::Object::create(
30 c10::StrongTypePtr(type->compilation_unit(), type),
31 type->numAttributes());
32 push(stack, std::move(userObj));
33 }
34
isinstance(Stack & stack,at::ArrayRef<at::TypePtr> types)35 void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
36 at::TypePtr ty = pop(stack).type<c10::DynamicType>();
37 for (const at::TypePtr& candidate : types) {
38 if (ty->isSubtypeOf(*candidate)) {
39 push(stack, true);
40 return;
41 }
42 }
43 push(stack, false);
44 }
45 } // namespace
46
47 using namespace at;
48
getInterpretersExceptionDebugHandles()49 const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() {
50 return exception_debug_handles_;
51 }
52
enterFrame(const Code & code)53 void InterpreterState::enterFrame(const Code& code) {
54 frames_.emplace_back(code);
55 registers_.resize(registers_.size() + code.register_size_);
56 }
57
leaveFrame()58 void InterpreterState::leaveFrame() {
59 registers_.resize(
60 registers_.size() - frames_.back().getCode().register_size_);
61 frames_.pop_back();
62 }
63
saveExceptionDebugHandles()64 void InterpreterState::saveExceptionDebugHandles() {
65 std::vector<DebugHandle> exception_debug_handles;
66 for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) {
67 size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0);
68 if (auto handle = frame->getDebugHandle(pc)) {
69 exception_debug_handles.push_back(*handle);
70 } else {
71 exception_debug_handles.push_back(-1);
72 }
73 }
74 exception_debug_handles_ = std::move(exception_debug_handles);
75 }
76
callFunction(torch::jit::Function & f,Stack & stack)77 void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
78 bool newFrame =
79 f.call(stack, [&](const mobile::Code& code) { enterFrame(code); });
80 (frames_.rbegin() + (newFrame ? 1 : 0))->step();
81 }
82
run(Stack & stack)83 bool InterpreterState::run(Stack& stack) {
84 while (true) {
85 try {
86 auto& frame = frames_.back();
87 const auto& code = frame.getCode();
88 const auto pc = frame.getPC();
89 auto inst = frame.getInstruction();
90 // If no valid debug handle found then just log pc.
91 // This is possible when we did not save debug handles
92
93 DebugHandle debug_handle = pc;
94 if (auto handle = frame.getDebugHandle()) {
95 debug_handle = *handle;
96 }
97
98 // std::cout << "RUNNING " << pc << " " << code.instructions_[pc];
99 // if (inst.op == OP) {
100 // std::cout << ", " << code.op_names_[inst.X].name;
101 // if (!code.op_names_[inst.X].overload_name.empty()) {
102 // std::cout << "." << code.op_names_[inst.X].overload_name;
103 // }
104 // }
105 // std::cout << std::endl;
106
107 // TODO(iliacher): remove the workaround after RecordFunction is in
108 // Dispatcher
109 // Check with iliacher if has been done.
110 // Plus this is not safe as if you throw exception record function will be
111 // left enabled. That is a TODO
112 // NOTE: this recordFunction logic takes up ~2-3% of cpu cycles in some
113 // workflows. do we need it and/or can we opt-out of
114 // isRecordFunctionEnabled with a macro? if we delete it, things appear to
115 // work just fine.
116 bool prev_value = isRecordFunctionEnabled();
117 if (!prev_value) {
118 // enable only for the RecordFunction
119 enableRecordFunction(true);
120 }
121 switch (inst.op) {
122 case OP: {
123 if (at::hasGlobalCallbacks()) {
124 if (auto* mobile_debug_info = static_cast<MobileDebugInfo*>(
125 c10::ThreadLocalDebugInfo::get(
126 c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) {
127 mobile_debug_info->setOpIdx(pc);
128 }
129 }
130 if (inst.X < 0 ||
131 static_cast<size_t>(inst.X) >= code.op_names_.size() ||
132 static_cast<size_t>(inst.X) >= code.operators_.size()) {
133 TORCH_CHECK(false, "Can't load op with index: ", inst.X);
134 }
135 RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
136 code.op_names_[inst.X].name, debug_handle, stack);
137 code.operators_[inst.X](stack);
138 frame.step();
139 } break;
140 case OPN: {
141 if (inst.X < 0 ||
142 static_cast<size_t>(inst.X) >= code.op_names_.size() ||
143 static_cast<size_t>(inst.X) >= code.operators_.size()) {
144 TORCH_CHECK(false, "Can't load op with index: ", inst.X);
145 }
146 stack.emplace_back(inst.N);
147 RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
148 code.op_names_[inst.X].name, debug_handle, stack);
149 code.operators_[inst.X](stack);
150 frame.step();
151 } break;
152 case CALL: {
153 auto& function = *frame.getCode().functions_.at(inst.X);
154 callFunction(function, stack);
155 } break;
156 case INTERFACE_CALL: {
157 if (inst.X < 0 ||
158 static_cast<size_t>(inst.X) >= code.constants_.size()) {
159 TORCH_CHECK(false, "Can't load constant with index: ", inst.X);
160 }
161 if (inst.N == 0 || inst.N > stack.size()) {
162 TORCH_CHECK(
163 false,
164 "INTERFACE_CALL N=",
165 inst.N,
166 " not in range [1, ",
167 stack.size(),
168 "]");
169 }
170 torch::jit::Function& method =
171 peek(stack, 0, inst.N)
172 .toObject()
173 ->type()
174 ->getMethod(code.constants_[inst.X].toStringRef());
175 RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
176 method.name(), debug_handle, stack);
177 callFunction(method, stack);
178 } break;
179 case LOAD:
180 stack.emplace_back(reg(inst.X));
181 frame.step();
182 break;
183 case MOVE:
184 stack.emplace_back(std::move(reg(inst.X)));
185 frame.step();
186 break;
187 case STORE:
188 reg(inst.X) = pop(stack);
189 frame.step();
190 break;
191 case STOREN:
192 for (size_t i = inst.N; i > 0; --i) {
193 reg(inst.X + i - 1) = pop(stack);
194 }
195 frame.step();
196 break;
197 case DROP:
198 pop(stack);
199 frame.step();
200 break;
201 case DROPR:
202 reg(inst.X) = IValue();
203 frame.step();
204 break;
205 case LOADC:
206 if (inst.X < 0 ||
207 static_cast<size_t>(inst.X) >= code.constants_.size()) {
208 TORCH_CHECK(false, "Can't load constant with index: ", inst.X);
209 }
210 stack.emplace_back(code.constants_[inst.X]);
211 frame.step();
212 break;
213 case GET_ATTR: {
214 auto userObj = pop(stack).toObject();
215 auto value = userObj->getSlot(inst.X);
216 push(stack, std::move(value));
217 frame.step();
218 } break;
219 case SET_ATTR: {
220 auto v = pop(stack);
221 auto userObj = pop(stack).toObject();
222 // Mobile only: since the number of slots is not known, resize the
223 // numAttributes before setSlot.
224 while (static_cast<int>(userObj->type()->numAttributes()) <= inst.X) {
225 std::stringstream ss;
226 ss << userObj->type()->numAttributes();
227 userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
228 }
229 userObj->setSlot(inst.X, std::move(v));
230 frame.step();
231 } break;
232 case JF:
233 frame.jump(pop(stack).toBool() ? 1 : inst.X);
234 break;
235 case JMP:
236 frame.jump(inst.X);
237 break;
238 case LOOP: {
239 // stack: iteration_count, max_iter, cond, loop_carried_deps...
240 auto sframe = stack.end() - (inst.N + 1);
241 int64_t trip_count = sframe[0].toInt();
242 int64_t max_trip_count = sframe[1].toInt();
243 bool cond = sframe[2].toBool();
244 if (trip_count < max_trip_count && cond) {
245 sframe[2] = trip_count;
246 sframe[0] = trip_count + 1;
247 frame.step();
248 } else {
249 size_t n_loop_carried = inst.N - 2;
250 for (const auto i : c10::irange(n_loop_carried)) {
251 sframe[i] = std::move(sframe[i + 3]);
252 }
253 drop(stack, 3); // iteration_count, max_iter, cond
254 frame.jump(inst.X);
255 }
256 } break;
257 case RET:
258 leaveFrame();
259 if (!frames_.empty()) {
260 continue;
261 }
262 return false;
263 case LIST_CONSTRUCT: {
264 listConstruct(stack, *code.types_.at(inst.X), inst.N);
265 frame.step();
266 } break;
267 case LIST_UNPACK: {
268 listUnpack(stack, inst.X);
269 frame.step();
270 } break;
271 case TUPLE_CONSTRUCT: {
272 tupleConstruct(stack, inst.X);
273 frame.step();
274 } break;
275 case TUPLE_SLICE: {
276 tupleSlice(stack, inst.X, inst.X + inst.N);
277 frame.step();
278 } break;
279 case TUPLE_INDEX: {
280 tupleIndex(stack);
281 frame.step();
282 } break;
283 case RAISE_EXCEPTION: {
284 raiseExceptionWithMessage(stack);
285 frame.step();
286 } break;
287 case __IS__: {
288 is(stack);
289 frame.step();
290 } break;
291 case UN_INITIALIZED: {
292 unInitialized(stack);
293 frame.step();
294 } break;
295 case __ISNOT__: {
296 isNot(stack);
297 frame.step();
298 } break;
299 case FORMAT: {
300 format(stack, inst.X);
301 frame.step();
302 } break;
303 case DEVICE: {
304 device(stack);
305 frame.step();
306 } break;
307 case DTYPE: {
308 dtype(stack);
309 frame.step();
310 } break;
311 case DIM: {
312 dim(stack);
313 frame.step();
314 } break;
315 case __NOT__: {
316 _not(stack);
317 frame.step();
318 } break;
319 case DICT_INDEX: {
320 dictIndex(stack);
321 frame.step();
322 } break;
323 case TO_LIST: {
324 toList(stack);
325 frame.step();
326 } break;
327 case NUM_TO_TENSOR: {
328 numToTensorScalar(stack);
329 frame.step();
330 } break;
331 case IS_CUDA: {
332 isCuda(stack);
333 frame.step();
334 } break;
335 case DICT_CONSTRUCT: {
336 dictConstruct(stack, *code.types_.at(inst.X), inst.N);
337 frame.step();
338 } break;
339 case NAMED_TUPLE_CONSTRUCT: {
340 namedTupleConstruct(stack, code.types_.at(inst.X), inst.N);
341 frame.step();
342 } break;
343 case CREATE_OBJECT: {
344 auto type = code.types_.at(inst.X)->expect<c10::ClassType>();
345 createObject(stack, type);
346 frame.step();
347 } break;
348 case ISINSTANCE: {
349 at::ArrayRef<TypePtr> types(&code.types_.at(inst.X), inst.N);
350 isinstance(stack, types);
351 frame.step();
352 } break;
353 case WARN: {
354 drop(stack, 1);
355 // Note: Please don't move the pop(stack) code below into the
356 // TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments
357 // when STRIP_ERROR_MESSAGES is defined (which happens for production
358 // mobile builds). This will cause the stack to be in an inconsistent
359 // state. It has previously resulted in a SEV (S22350).
360 TORCH_WARN(stack.back().toStringRef());
361 stack.pop_back();
362 frame.step();
363 } break;
364 default:
365 AT_ERROR(toString(inst.op), " is invalid.");
366 }
367
368 if (!prev_value) {
369 enableRecordFunction(false);
370 }
371 // This exception must be caught first as it derived from c10::Error
372 } catch (c10::BackendRuntimeException& e) {
373 saveExceptionDebugHandles();
374 TORCH_RETHROW(e);
375 } catch (c10::Error& error) {
376 // Reason for catching and rethrowing the error is so that we can
377 // set the exception pc that is queried later
378 saveExceptionDebugHandles();
379 TORCH_RETHROW(error);
380 } catch (...) {
381 saveExceptionDebugHandles();
382 throw;
383 }
384 // for (auto val : stack) {
385 // if (val.isTensor()) {
386 // std::cout << val.toTensor().sizes() << std::endl;
387 // } else {
388 // std::cout << val << std::endl;
389 // }
390 // }
391 }
392 return false;
393 }
394
reg(size_t reg)395 IValue& InterpreterState::reg(size_t reg) {
396 TORCH_CHECK(
397 reg > 0 && reg <= registers_.size(), "Invalid register index: ", reg);
398 return *(registers_.end() - reg);
399 }
400
401 } // namespace mobile
402 } // namespace torch::jit
403