1 #include <ATen/functorch/FunctionalizeInterpreter.h>
2 #include <ATen/functorch/DynamicLayer.h>
3 #include <ATen/FunctionalTensorWrapper.h>
4
5 namespace at::functorch {
6
sanityCheckNotFunctional(const c10::OperatorHandle & op,torch::jit::Stack * stack,size_t num_args)7 static void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) {
8 foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
9 [](const Tensor& tensor) {
10 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
11 return tensor;
12 });
13 }
14
processImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack)15 void FunctionalizeInterpreterPtr::processImpl(
16 const c10::OperatorHandle& op,
17 torch::jit::Stack* stack) {
18 // We always want to call the functionalization kernels if functionalize() is on the layer stack.
19 // It's the responsibility of the functionalization kernel to no-op and redispatch
20 // if none of the input tensors are functional.
21 setup_dispatch_key_tls(TransformType::Functionalize, DispatchKeySet(DispatchKey::Functionalize));
22 auto functionalization_add_back_views = functionalizeAddBackViews();
23 // We have some side-car TLS that we can set to toggle the functionaliation behavior.
24 // If set, then we functionalization will only remove mutations, instead of
25 // removing both mutations AND view operators.
26 at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views);
27
28 op.callBoxed(stack);
29
30 auto ret_size = op.schema().returns().size();
31 foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(),
32 [&](const Tensor& tensor) {
33 if (at::functionalization::impl::isFunctionalTensor(tensor)) {
34 auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
35 // Functorch is responsible for setting the level on the wrapper, since we don't
36 // have that info available in core (for now).
37 // We could just "propagate" the level from the input tensors inside of the functionalize kernels,
38 // but unfortunately we can't do that for factory operators.
39 wrapper->set_level(level());
40 }
41 return tensor;
42 }
43 );
44 }
45
sendToNextInterpreterImpl(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)46 void FunctionalizeInterpreterPtr::sendToNextInterpreterImpl(
47 const c10::OperatorHandle& op,
48 torch::jit::Stack* stack,
49 bool grad_special_case) {
50 // For now, we don't support nested functionalization calls.
51 // This check just enforces that - after the functionalize kernel runs
52 // and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors
53 // so we can check that the unwrapped thing is not another (nested) FunctionalTensor.
54 auto args_size = op.schema().arguments().size();
55 sanityCheckNotFunctional(op, stack, args_size);
56
57 // Re-dispatch
58 if (getDynamicLayerStack().empty()) {
59 sanityCheckStack(op, stack);
60 }
61 op.callBoxed(stack);
62
63 auto ret_size = op.schema().returns().size();
64 sanityCheckNotFunctional(op, stack, ret_size);
65 }
66
67 } // namespace at::functorch
68