xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/interpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/interpreter.h>
2 
3 #include <ATen/core/class_type.h>
4 #include <ATen/core/dynamic_type.h>
5 #include <ATen/core/function.h>
6 #include <ATen/core/jit_type.h>
7 #include <ATen/core/operator_name.h>
8 #include <ATen/record_function.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/irange.h>
11 #include <torch/csrc/jit/backends/backend_exception.h>
12 #include <torch/csrc/jit/mobile/function.h>
13 #include <torch/csrc/jit/mobile/observer.h>
14 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
15 #include <torch/csrc/jit/runtime/jit_exception.h>
16 #include <torch/csrc/jit/runtime/vararg_functions.h>
17 
18 namespace torch::jit {
19 char const* toString(OpCode op);
20 std::ostream& operator<<(std::ostream& out, Instruction inst);
21 namespace mobile {
InterpreterState(const Code & code)22 InterpreterState::InterpreterState(const Code& code) {
23   enterFrame(code);
24 }
25 
26 namespace {
27 static thread_local std::vector<DebugHandle> exception_debug_handles_;
createObject(Stack & stack,const at::ClassTypePtr & type)28 void createObject(Stack& stack, const at::ClassTypePtr& type) {
29   auto userObj = c10::ivalue::Object::create(
30       c10::StrongTypePtr(type->compilation_unit(), type),
31       type->numAttributes());
32   push(stack, std::move(userObj));
33 }
34 
isinstance(Stack & stack,at::ArrayRef<at::TypePtr> types)35 void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
36   at::TypePtr ty = pop(stack).type<c10::DynamicType>();
37   for (const at::TypePtr& candidate : types) {
38     if (ty->isSubtypeOf(*candidate)) {
39       push(stack, true);
40       return;
41     }
42   }
43   push(stack, false);
44 }
45 } // namespace
46 
47 using namespace at;
48 
getInterpretersExceptionDebugHandles()49 const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() {
50   return exception_debug_handles_;
51 }
52 
enterFrame(const Code & code)53 void InterpreterState::enterFrame(const Code& code) {
54   frames_.emplace_back(code);
55   registers_.resize(registers_.size() + code.register_size_);
56 }
57 
leaveFrame()58 void InterpreterState::leaveFrame() {
59   registers_.resize(
60       registers_.size() - frames_.back().getCode().register_size_);
61   frames_.pop_back();
62 }
63 
saveExceptionDebugHandles()64 void InterpreterState::saveExceptionDebugHandles() {
65   std::vector<DebugHandle> exception_debug_handles;
66   for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) {
67     size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0);
68     if (auto handle = frame->getDebugHandle(pc)) {
69       exception_debug_handles.push_back(*handle);
70     } else {
71       exception_debug_handles.push_back(-1);
72     }
73   }
74   exception_debug_handles_ = std::move(exception_debug_handles);
75 }
76 
callFunction(torch::jit::Function & f,Stack & stack)77 void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
78   bool newFrame =
79       f.call(stack, [&](const mobile::Code& code) { enterFrame(code); });
80   (frames_.rbegin() + (newFrame ? 1 : 0))->step();
81 }
82 
run(Stack & stack)83 bool InterpreterState::run(Stack& stack) {
84   while (true) {
85     try {
86       auto& frame = frames_.back();
87       const auto& code = frame.getCode();
88       const auto pc = frame.getPC();
89       auto inst = frame.getInstruction();
90       // If no valid debug handle found then just log pc.
91       // This is possible when we did not save debug handles
92 
93       DebugHandle debug_handle = pc;
94       if (auto handle = frame.getDebugHandle()) {
95         debug_handle = *handle;
96       }
97 
98       // std::cout << "RUNNING " << pc << " " << code.instructions_[pc];
99       // if (inst.op == OP) {
100       //   std::cout << ", " << code.op_names_[inst.X].name;
101       //   if (!code.op_names_[inst.X].overload_name.empty()) {
102       //     std::cout << "." << code.op_names_[inst.X].overload_name;
103       //   }
104       // }
105       // std::cout << std::endl;
106 
107       // TODO(iliacher): remove the workaround after RecordFunction is in
108       // Dispatcher
109       // Check with iliacher if has been done.
110       // Plus this is not safe as if you throw exception record function will be
111       // left enabled. That is a TODO
112       // NOTE: this recordFunction logic takes up ~2-3% of cpu cycles in some
113       // workflows. do we need it and/or can we opt-out of
114       // isRecordFunctionEnabled with a macro? if we delete it, things appear to
115       // work just fine.
116       bool prev_value = isRecordFunctionEnabled();
117       if (!prev_value) {
118         // enable only for the RecordFunction
119         enableRecordFunction(true);
120       }
121       switch (inst.op) {
122         case OP: {
123           if (at::hasGlobalCallbacks()) {
124             if (auto* mobile_debug_info = static_cast<MobileDebugInfo*>(
125                     c10::ThreadLocalDebugInfo::get(
126                         c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) {
127               mobile_debug_info->setOpIdx(pc);
128             }
129           }
130           if (inst.X < 0 ||
131               static_cast<size_t>(inst.X) >= code.op_names_.size() ||
132               static_cast<size_t>(inst.X) >= code.operators_.size()) {
133             TORCH_CHECK(false, "Can't load op with index: ", inst.X);
134           }
135           RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
136               code.op_names_[inst.X].name, debug_handle, stack);
137           code.operators_[inst.X](stack);
138           frame.step();
139         } break;
140         case OPN: {
141           if (inst.X < 0 ||
142               static_cast<size_t>(inst.X) >= code.op_names_.size() ||
143               static_cast<size_t>(inst.X) >= code.operators_.size()) {
144             TORCH_CHECK(false, "Can't load op with index: ", inst.X);
145           }
146           stack.emplace_back(inst.N);
147           RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
148               code.op_names_[inst.X].name, debug_handle, stack);
149           code.operators_[inst.X](stack);
150           frame.step();
151         } break;
152         case CALL: {
153           auto& function = *frame.getCode().functions_.at(inst.X);
154           callFunction(function, stack);
155         } break;
156         case INTERFACE_CALL: {
157           if (inst.X < 0 ||
158               static_cast<size_t>(inst.X) >= code.constants_.size()) {
159             TORCH_CHECK(false, "Can't load constant with index: ", inst.X);
160           }
161           if (inst.N == 0 || inst.N > stack.size()) {
162             TORCH_CHECK(
163                 false,
164                 "INTERFACE_CALL N=",
165                 inst.N,
166                 " not in range [1, ",
167                 stack.size(),
168                 "]");
169           }
170           torch::jit::Function& method =
171               peek(stack, 0, inst.N)
172                   .toObject()
173                   ->type()
174                   ->getMethod(code.constants_[inst.X].toStringRef());
175           RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
176               method.name(), debug_handle, stack);
177           callFunction(method, stack);
178         } break;
179         case LOAD:
180           stack.emplace_back(reg(inst.X));
181           frame.step();
182           break;
183         case MOVE:
184           stack.emplace_back(std::move(reg(inst.X)));
185           frame.step();
186           break;
187         case STORE:
188           reg(inst.X) = pop(stack);
189           frame.step();
190           break;
191         case STOREN:
192           for (size_t i = inst.N; i > 0; --i) {
193             reg(inst.X + i - 1) = pop(stack);
194           }
195           frame.step();
196           break;
197         case DROP:
198           pop(stack);
199           frame.step();
200           break;
201         case DROPR:
202           reg(inst.X) = IValue();
203           frame.step();
204           break;
205         case LOADC:
206           if (inst.X < 0 ||
207               static_cast<size_t>(inst.X) >= code.constants_.size()) {
208             TORCH_CHECK(false, "Can't load constant with index: ", inst.X);
209           }
210           stack.emplace_back(code.constants_[inst.X]);
211           frame.step();
212           break;
213         case GET_ATTR: {
214           auto userObj = pop(stack).toObject();
215           auto value = userObj->getSlot(inst.X);
216           push(stack, std::move(value));
217           frame.step();
218         } break;
219         case SET_ATTR: {
220           auto v = pop(stack);
221           auto userObj = pop(stack).toObject();
222           // Mobile only: since the number of slots is not known, resize the
223           // numAttributes before setSlot.
224           while (static_cast<int>(userObj->type()->numAttributes()) <= inst.X) {
225             std::stringstream ss;
226             ss << userObj->type()->numAttributes();
227             userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
228           }
229           userObj->setSlot(inst.X, std::move(v));
230           frame.step();
231         } break;
232         case JF:
233           frame.jump(pop(stack).toBool() ? 1 : inst.X);
234           break;
235         case JMP:
236           frame.jump(inst.X);
237           break;
238         case LOOP: {
239           // stack: iteration_count, max_iter, cond, loop_carried_deps...
240           auto sframe = stack.end() - (inst.N + 1);
241           int64_t trip_count = sframe[0].toInt();
242           int64_t max_trip_count = sframe[1].toInt();
243           bool cond = sframe[2].toBool();
244           if (trip_count < max_trip_count && cond) {
245             sframe[2] = trip_count;
246             sframe[0] = trip_count + 1;
247             frame.step();
248           } else {
249             size_t n_loop_carried = inst.N - 2;
250             for (const auto i : c10::irange(n_loop_carried)) {
251               sframe[i] = std::move(sframe[i + 3]);
252             }
253             drop(stack, 3); // iteration_count, max_iter, cond
254             frame.jump(inst.X);
255           }
256         } break;
257         case RET:
258           leaveFrame();
259           if (!frames_.empty()) {
260             continue;
261           }
262           return false;
263         case LIST_CONSTRUCT: {
264           listConstruct(stack, *code.types_.at(inst.X), inst.N);
265           frame.step();
266         } break;
267         case LIST_UNPACK: {
268           listUnpack(stack, inst.X);
269           frame.step();
270         } break;
271         case TUPLE_CONSTRUCT: {
272           tupleConstruct(stack, inst.X);
273           frame.step();
274         } break;
275         case TUPLE_SLICE: {
276           tupleSlice(stack, inst.X, inst.X + inst.N);
277           frame.step();
278         } break;
279         case TUPLE_INDEX: {
280           tupleIndex(stack);
281           frame.step();
282         } break;
283         case RAISE_EXCEPTION: {
284           raiseExceptionWithMessage(stack);
285           frame.step();
286         } break;
287         case __IS__: {
288           is(stack);
289           frame.step();
290         } break;
291         case UN_INITIALIZED: {
292           unInitialized(stack);
293           frame.step();
294         } break;
295         case __ISNOT__: {
296           isNot(stack);
297           frame.step();
298         } break;
299         case FORMAT: {
300           format(stack, inst.X);
301           frame.step();
302         } break;
303         case DEVICE: {
304           device(stack);
305           frame.step();
306         } break;
307         case DTYPE: {
308           dtype(stack);
309           frame.step();
310         } break;
311         case DIM: {
312           dim(stack);
313           frame.step();
314         } break;
315         case __NOT__: {
316           _not(stack);
317           frame.step();
318         } break;
319         case DICT_INDEX: {
320           dictIndex(stack);
321           frame.step();
322         } break;
323         case TO_LIST: {
324           toList(stack);
325           frame.step();
326         } break;
327         case NUM_TO_TENSOR: {
328           numToTensorScalar(stack);
329           frame.step();
330         } break;
331         case IS_CUDA: {
332           isCuda(stack);
333           frame.step();
334         } break;
335         case DICT_CONSTRUCT: {
336           dictConstruct(stack, *code.types_.at(inst.X), inst.N);
337           frame.step();
338         } break;
339         case NAMED_TUPLE_CONSTRUCT: {
340           namedTupleConstruct(stack, code.types_.at(inst.X), inst.N);
341           frame.step();
342         } break;
343         case CREATE_OBJECT: {
344           auto type = code.types_.at(inst.X)->expect<c10::ClassType>();
345           createObject(stack, type);
346           frame.step();
347         } break;
348         case ISINSTANCE: {
349           at::ArrayRef<TypePtr> types(&code.types_.at(inst.X), inst.N);
350           isinstance(stack, types);
351           frame.step();
352         } break;
353         case WARN: {
354           drop(stack, 1);
355           // Note: Please don't move the pop(stack) code below into the
356           // TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments
357           // when STRIP_ERROR_MESSAGES is defined (which happens for production
358           // mobile builds). This will cause the stack to be in an inconsistent
359           // state. It has previously resulted in a SEV (S22350).
360           TORCH_WARN(stack.back().toStringRef());
361           stack.pop_back();
362           frame.step();
363         } break;
364         default:
365           AT_ERROR(toString(inst.op), " is invalid.");
366       }
367 
368       if (!prev_value) {
369         enableRecordFunction(false);
370       }
371       // This exception must be caught first as it derived from c10::Error
372     } catch (c10::BackendRuntimeException& e) {
373       saveExceptionDebugHandles();
374       TORCH_RETHROW(e);
375     } catch (c10::Error& error) {
376       // Reason for catching and rethrowing the error is so that we can
377       // set the exception pc that is queried later
378       saveExceptionDebugHandles();
379       TORCH_RETHROW(error);
380     } catch (...) {
381       saveExceptionDebugHandles();
382       throw;
383     }
384     //  for (auto val : stack) {
385     //    if (val.isTensor()) {
386     //      std::cout << val.toTensor().sizes() << std::endl;
387     //    } else {
388     //      std::cout << val << std::endl;
389     //    }
390     //  }
391   }
392   return false;
393 }
394 
reg(size_t reg)395 IValue& InterpreterState::reg(size_t reg) {
396   TORCH_CHECK(
397       reg > 0 && reg <= registers_.size(), "Invalid register index: ", reg);
398   return *(registers_.end() - reg);
399 }
400 
401 } // namespace mobile
402 } // namespace torch::jit
403