xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/FunctionalizeInterpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/functorch/FunctionalizeInterpreter.h>
2 #include <ATen/functorch/DynamicLayer.h>
3 #include <ATen/FunctionalTensorWrapper.h>
4 
5 namespace at::functorch {
6 
sanityCheckNotFunctional(const c10::OperatorHandle & op,torch::jit::Stack * stack,size_t num_args)7 static void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) {
8   foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
9       [](const Tensor& tensor) {
10         TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
11         return tensor;
12       });
13 }
14 
processImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack)15 void FunctionalizeInterpreterPtr::processImpl(
16     const c10::OperatorHandle& op,
17     torch::jit::Stack* stack) {
18   // We always want to call the functionalization kernels if functionalize() is on the layer stack.
19   // It's the responsibility of the functionalization kernel to no-op and redispatch
20   // if none of the input tensors are functional.
21   setup_dispatch_key_tls(TransformType::Functionalize, DispatchKeySet(DispatchKey::Functionalize));
22   auto functionalization_add_back_views = functionalizeAddBackViews();
23   // We have some side-car TLS that we can set to toggle the functionaliation behavior.
24   // If set, then we functionalization will only remove mutations, instead of
25   // removing both mutations AND view operators.
26   at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views);
27 
28   op.callBoxed(stack);
29 
30   auto ret_size = op.schema().returns().size();
31   foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(),
32     [&](const Tensor& tensor) {
33       if (at::functionalization::impl::isFunctionalTensor(tensor)) {
34         auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
35         // Functorch is responsible for setting the level on the wrapper, since we don't
36         // have that info available in core (for now).
37         // We could just "propagate" the level from the input tensors inside of the functionalize kernels,
38         // but unfortunately we can't do that for factory operators.
39         wrapper->set_level(level());
40       }
41       return tensor;
42     }
43   );
44 }
45 
sendToNextInterpreterImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)46 void FunctionalizeInterpreterPtr::sendToNextInterpreterImpl(
47     const c10::OperatorHandle& op,
48     torch::jit::Stack* stack,
49     bool grad_special_case) {
50   // For now, we don't support nested functionalization calls.
51   // This check just enforces that - after the functionalize kernel runs
52   // and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors
53   // so we can check that the unwrapped thing is not another (nested) FunctionalTensor.
54   auto args_size = op.schema().arguments().size();
55   sanityCheckNotFunctional(op, stack, args_size);
56 
57   // Re-dispatch
58   if (getDynamicLayerStack().empty()) {
59     sanityCheckStack(op, stack);
60   }
61   op.callBoxed(stack);
62 
63   auto ret_size = op.schema().returns().size();
64   sanityCheckNotFunctional(op, stack, ret_size);
65 }
66 
67 } // namespace at::functorch
68