xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/VmapInterpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/functorch/VmapInterpreter.h>
2 #include <ATen/functorch/DynamicLayer.h>
3 
4 namespace at::functorch {
5 
processImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack)6 void VmapInterpreterPtr::processImpl(
7     const c10::OperatorHandle& op,
8     torch::jit::Stack* stack) {
9   setup_dispatch_key_tls(TransformType::Vmap, DispatchKeySet(DispatchKey::FuncTorchVmapMode));
10   op.callBoxed(stack);
11 }
12 
sendToNextInterpreterImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)13 void VmapInterpreterPtr::sendToNextInterpreterImpl(
14     const c10::OperatorHandle& op,
15     torch::jit::Stack* stack,
16     bool grad_special_case) {
17   // Re-dispatch
18   if (getDynamicLayerStack().empty()) {
19     sanityCheckStack(op, stack);
20   }
21   op.callBoxed(stack);
22 }
23 
24 } // namespace at::functorch
25