xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/DynamicLayer.h>
8 #include <ATen/functorch/TensorWrapper.h>
9 #include <ATen/functorch/BatchedTensorImpl.h>
10 #include <ATen/functorch/BatchRulesHelper.h>
11 
12 #include <torch/library.h>
13 #include <c10/core/impl/LocalDispatchKeySet.h>
14 #include <ATen/core/dispatch/Dispatcher.h>
15 #include <ATen/FunctionalTensorWrapper.h>
16 #include <c10/util/irange.h>
17 #include <ATen/FuncTorchTLS.h>
18 #include <iostream>
19 
20 namespace at::functorch {
21 
setDynamicLayerFrontBackKeysIncluded(bool included)22 void setDynamicLayerFrontBackKeysIncluded(bool included) {
23   c10::impl::tls_set_dispatch_key_included(DispatchKey::FuncTorchDynamicLayerFrontMode, included);
24   c10::impl::tls_set_dispatch_key_included(DispatchKey::FuncTorchDynamicLayerBackMode, included);
25 }
26 
DynamicLayer(TransformType transform_type,int64_t layerId,std::optional<c10::SymInt> batchSize,std::optional<RandomnessType> randomness,std::optional<bool> prev_grad_mode,std::optional<bool> prev_fwd_grad_mode,std::optional<bool> functionalize_add_back_views)27 DynamicLayer::DynamicLayer(
28     TransformType transform_type,
29     int64_t layerId,
30     std::optional<c10::SymInt> batchSize,
31     std::optional<RandomnessType> randomness,
32     std::optional<bool> prev_grad_mode,
33     std::optional<bool> prev_fwd_grad_mode,
34     std::optional<bool> functionalize_add_back_views)
35 {
36   if (transform_type == TransformType::Grad) {
37     TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
38   }
39   if (transform_type == TransformType::Jvp) {
40     TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
41   }
42   switch (transform_type) {
43     case TransformType::Vmap:
44       interpreter_ = Interpreter::Vmap(layerId, std::move(batchSize.value()), randomness.value());
45       break;
46     case TransformType::Grad:
47       interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value());
48       break;
49     case TransformType::Jvp:
50       interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value());
51       break;
52     case TransformType::Functionalize:
53       interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value());
54       break;
55     default:
56       TORCH_INTERNAL_ASSERT(false);
57   }
58 }
59 
key() const60 TransformType DynamicLayer::key() const {
61   return interpreter_.key();
62 }
63 
layerId() const64 int64_t DynamicLayer::layerId() const {
65   return interpreter_.level();
66 }
67 
batchSize() const68 c10::SymInt DynamicLayer::batchSize() const {
69   return VmapInterpreterPtr(&interpreter_).batchSize();
70 }
71 
randomness() const72 RandomnessType DynamicLayer::randomness() const {
73   return VmapInterpreterPtr(&interpreter_).randomness();
74 }
75 
76 // functorch stores some TLS. Inside the TLS is the stack of transforms.
77 // Unfortunately, since functorch isn't a part of libtorch, we have
78 // a level of indirection. FuncTorchTLSBase is the interface that lives in libtorch,
79 // while FuncTorchTLS implements all the methods and stores data.
80 //
81 // TODO: after functorch C++ code is moved into PyTorch, we can get rid of
82 // this layer of indirection.
83 class FuncTorchTLS : public FuncTorchTLSBase {
84  public:
85   FuncTorchTLS() = default;
86 
deepcopy() const87   std::unique_ptr<FuncTorchTLSBase> deepcopy() const override {
88     auto result = std::make_unique<FuncTorchTLS>();
89     result->dynamicLayerStack = dynamicLayerStack;
90     return result;
91   }
92 
checkSupportsSingleLevelAutogradFunction() const93   int64_t checkSupportsSingleLevelAutogradFunction() const override {
94     TORCH_INTERNAL_ASSERT(dynamicLayerStack.empty() || getSingleLevelAutogradFunctionAllowed(),
95         "functorch functions (vmap, grad, vjp, etc.) incorrectly used with ",
96         "torch.autograd.function._SingleLevelFunction. ",
97         "This is not expected, please file a bug.");
98     return 0;
99   }
100 
checkSupportsCppAutogradFunction() const101   void checkSupportsCppAutogradFunction() const override {
102     TORCH_CHECK(
103         dynamicLayerStack.empty(),
104         "cannot use C++ torch::autograd::Function with functorch transforms (vmap, grad, vjp, etc)");
105   }
106 
checkSupportsInplaceRequiresGrad() const107   void checkSupportsInplaceRequiresGrad() const override {
108     TORCH_CHECK(dynamicLayerStack.empty() || allow_inplace_requires_grad_,
109         "You are attempting to call Tensor.requires_grad_() (or perhaps using ",
110         "torch.autograd.functional.* APIs) inside of a function being transformed ",
111         "by a functorch transform. ",
112         "This is unsupported, please attempt to use the functorch transforms ",
113         "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() "
114         "outside of a function being transformed instead.");
115   }
checkSupportsRetainGrad() const116   void checkSupportsRetainGrad() const override {
117     TORCH_CHECK(dynamicLayerStack.empty(),
118         "You are attempting to call Tensor.retain_grad() ",
119         "inside of a function being transformed ",
120         "by a functorch transform. ",
121         "This is unsupported, please attempt to use the functorch transforms ",
122         "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() "
123         "outside of a function being transformed instead.");
124   }
125 
126   std::vector<DynamicLayer> dynamicLayerStack;
127   bool allow_inplace_requires_grad_ = false;
128   bool allow_single_level_autograd_function_ = false;
129 };
130 
getRawFunctorchTLS()131 static FuncTorchTLS* getRawFunctorchTLS() {
132   auto& state = functorchTLSAccessor();
133   if (state == nullptr) {
134     state = std::make_unique<FuncTorchTLS>();
135   }
136   // Raw pointer usage OK, `state` keeps the pointer alive
137   FuncTorchTLSBase* raw_state = state.get();
138   FuncTorchTLS* result = static_cast<FuncTorchTLS*>(raw_state);
139   return result;
140 }
141 
setInplaceRequiresGradAllowed(bool allowed)142 void setInplaceRequiresGradAllowed(bool allowed) {
143   auto* functorch_tls = getRawFunctorchTLS();
144   functorch_tls->allow_inplace_requires_grad_ = allowed;
145 }
146 
getInplaceRequiresGradAllowed()147 bool getInplaceRequiresGradAllowed() {
148   auto* functorch_tls = getRawFunctorchTLS();
149   return functorch_tls->allow_inplace_requires_grad_;
150 }
151 
setSingleLevelAutogradFunctionAllowed(bool allowed)152 void setSingleLevelAutogradFunctionAllowed(bool allowed) {
153   auto* functorch_tls = getRawFunctorchTLS();
154   functorch_tls->allow_single_level_autograd_function_ = allowed;
155 }
156 
getSingleLevelAutogradFunctionAllowed()157 bool getSingleLevelAutogradFunctionAllowed() {
158   auto* functorch_tls = getRawFunctorchTLS();
159   return functorch_tls->allow_single_level_autograd_function_;
160 }
161 
dynamicLayerStackAccessor()162 static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
163   return getRawFunctorchTLS()->dynamicLayerStack;
164 }
165 
getLifeHandleForLevel(int64_t level)166 const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level) {
167   auto& dynamicLayerStack = dynamicLayerStackAccessor();
168   TORCH_INTERNAL_ASSERT(
169       (int64_t)dynamicLayerStack.size() >= level && level >= 1,
170       "If you're trying to construct a tensor with the current level (",
171       level,
172       ") then the interpreter for that level must be on the DynamicLayerStack ");
173 
174   auto& dynamic_layer = dynamicLayerStack[level - 1];
175   return dynamic_layer.interpreter().is_alive_ptr();
176 }
177 
maybeCurrentDynamicLayer()178 std::optional<DynamicLayer> maybeCurrentDynamicLayer() {
179   auto& dynamicLayerStack = dynamicLayerStackAccessor();
180   if (dynamicLayerStack.empty()) {
181     return {};
182   }
183   return dynamicLayerStack.back();
184 }
185 
186 struct SaveLocalDispatchKeySet {
187  public:
SaveLocalDispatchKeySetat::functorch::SaveLocalDispatchKeySet188   SaveLocalDispatchKeySet() {
189     auto& dynamicLayerStack = dynamicLayerStackAccessor();
190     TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
191     auto& layer = dynamicLayerStack.back();
192     auto tmp = c10::impl::tls_local_dispatch_key_set();
193     layer.interpreter().saveLocalDispatchKeySet(tmp);
194   }
~SaveLocalDispatchKeySetat::functorch::SaveLocalDispatchKeySet195   ~SaveLocalDispatchKeySet() {
196     auto& dynamicLayerStack = dynamicLayerStackAccessor();
197     TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
198     auto& layer = dynamicLayerStack.back();
199     auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
200     layer.interpreter().clearSavedLocalDispatchKeySet();
201     c10::impl::_force_tls_local_dispatch_key_set(tmp);
202   }
203   SaveLocalDispatchKeySet(const SaveLocalDispatchKeySet&) = delete;
204   SaveLocalDispatchKeySet& operator=(const SaveLocalDispatchKeySet&) = delete;
205 };
206 
getDynamicLayerStack()207 const std::vector<DynamicLayer>& getDynamicLayerStack() {
208   return dynamicLayerStackAccessor();
209 }
210 
setDynamicLayerStack(const std::vector<DynamicLayer> & stack)211 void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
212   dynamicLayerStackAccessor() = stack;
213 }
214 
popDynamicLayer()215 DynamicLayer popDynamicLayer() {
216   auto& dynamicLayerStack = dynamicLayerStackAccessor();
217   TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
218   auto result = dynamicLayerStack.back();
219   dynamicLayerStack.pop_back();
220 
221   if (dynamicLayerStack.empty()) {
222 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
223     if (c10::show_dispatch_trace_enabled()) {
224       std::cout << "DynamicLayer off" << std::endl;
225     }
226 #endif
227     setDynamicLayerFrontBackKeysIncluded(false);
228   }
229 
230   return result;
231 }
232 
pushDynamicLayer(DynamicLayer && dynamic_layer)233 int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) {
234   auto& dynamicLayerStack = dynamicLayerStackAccessor();
235   int64_t layerId = 1 + dynamicLayerStack.size();
236   TORCH_INTERNAL_ASSERT(layerId == dynamic_layer.layerId());
237   dynamicLayerStack.emplace_back(std::move(dynamic_layer));
238 
239   if (layerId == 1) {
240     setDynamicLayerFrontBackKeysIncluded(true);
241 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
242     if (c10::show_dispatch_trace_enabled()) {
243       std::cout << "DynamicLayer on" << std::endl;
244     }
245 #endif
246   }
247 
248   return layerId;
249 }
250 
initAndPushDynamicLayer(TransformType transform_type,std::optional<c10::SymInt> batch_size,std::optional<RandomnessType> randomness,std::optional<bool> prev_grad_mode,std::optional<bool> prev_fwd_grad_mode,std::optional<bool> functionalize_add_back_views)251 int64_t initAndPushDynamicLayer(
252     TransformType transform_type,
253     std::optional<c10::SymInt> batch_size,
254     std::optional<RandomnessType> randomness,
255     std::optional<bool> prev_grad_mode,
256     std::optional<bool> prev_fwd_grad_mode,
257     std::optional<bool> functionalize_add_back_views) {
258   const auto& dynamicLayerStack = dynamicLayerStackAccessor();
259   const auto layerId = 1 + dynamicLayerStack.size();
260   DynamicLayer new_layer(transform_type, layerId, std::move(batch_size), randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views);
261   // NB: this function should be called while holding the GIL to avoid races
262   new_layer.interpreter().set_is_alive(true);
263   pushDynamicLayer(std::move(new_layer));
264 
265 
266   if (transform_type == TransformType::Grad) {
267     TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
268   }
269   if (transform_type == TransformType::Jvp) {
270     TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
271   }
272   return layerId;
273 }
274 
popDynamicLayerAndDeleteMetadata()275 DynamicLayer popDynamicLayerAndDeleteMetadata() {
276   auto result = popDynamicLayer();
277 
278   // NB: this function should be called while holding the GIL to avoid races
279   result.interpreter().set_is_alive(false);
280   return result;
281 }
282 
isDeadTensorWrapper(const Tensor & tensor)283 bool isDeadTensorWrapper(const Tensor& tensor) {
284   auto* wrapped = maybeGetTensorWrapper(tensor);
285   if (!wrapped) {
286     return false;
287   }
288   return !wrapped->is_alive();
289 }
290 
unwrapIfDead(const Tensor & tensor)291 Tensor unwrapIfDead(const Tensor& tensor) {
292   auto* wrapped = maybeGetTensorWrapper(tensor);
293   if (!wrapped) {
294     return tensor;
295   }
296   if (wrapped->is_alive()) {
297     return tensor;
298   }
299   return wrapped->value();
300 }
301 
foreachTensorInplace(std::vector<IValue> & args,int64_t begin,int64_t end,std::function<Tensor (const Tensor &)> func)302 void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
303     std::function<Tensor(const Tensor&)> func) {
304    auto func_with_bool = [&](const Tensor& tensor, bool unused) { return func(tensor); };
305    foreachTensorInplaceWithFlag(args, begin, end, std::bitset<64>(), func_with_bool);
306 }
307 
foreachTensorInplaceWithFlag(std::vector<IValue> & args,int64_t begin,int64_t end,const std::bitset<64> use_flag_relative,const std::function<Tensor (const Tensor &,bool)> & func)308 void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
309     const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func){
310   TORCH_INTERNAL_ASSERT(begin >= 0);
311   TORCH_INTERNAL_ASSERT(end >= 0);
312   TORCH_INTERNAL_ASSERT(begin <= end);
313   for (int64_t relative_idx = 0; relative_idx < end - begin; relative_idx++) {
314     const bool flag = use_flag_relative[relative_idx] == 1;
315 
316     const auto idx = relative_idx + begin;
317     auto ivalue = args[idx];
318     // Tensor?[] translates to a c10::List<IValue> so we need to peek inside List
319     if (ivalue.isList()) {
320       bool modified = false;
321       // TODO: might be more efficient if we scan first then not copy? Depends.
322       auto list = ivalue.toList().copy();
323       for (const auto list_idx : c10::irange(0, list.size())) {
324         const auto& elt = list.get(list_idx);
325         if (elt.isTensor()) {
326           list.set(list_idx, func(elt.toTensor(), flag));
327           modified = true;
328         }
329       }
330       if (modified) {
331         args[idx] = list;
332       }
333       continue;
334     }
335     if (ivalue.isTensorList()) {
336       auto list = ivalue.toTensorList();
337       for (const auto list_idx : c10::irange(0, list.size())) {
338         list[list_idx] = func(list[list_idx], flag);
339       }
340       args[idx] = list;
341     }
342     TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
343     if (!ivalue.isTensor()) {
344       continue;
345     }
346     Tensor value = ivalue.toTensor();
347     Tensor replacement = func(value, flag);
348     args[idx] = std::move(replacement);
349     // sanity checks
350     if (ivalue.toTensor().defined()) {
351       TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
352     }
353   }
354 }
355 
operator <<(std::ostream & os,const DynamicLayer & layer)356 std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
357   os << layer.layerId() << ":" << layer.key();
358   return os;
359 }
operator <<(std::ostream & os,const std::vector<DynamicLayer> & dls)360 std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
361   os << "DynamicLayerStack[ ";
362   for (const auto& layer : dls) {
363     os << layer << " ";
364   }
365   os << "]";
366   return os;
367 }
368 
isInplaceOp(const FunctionSchema & schema)369 bool isInplaceOp(const FunctionSchema& schema) {
370   if (!schema.is_mutable() || schema.returns().size() != 1) {
371     return false;
372   }
373   // Check that the first argument is being written to
374   const auto& first_arg_alias_info = schema.arguments().begin()->alias_info();
375   if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) {
376     return false;
377   }
378   // Check that none of the other args are being aliased
379   for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
380     const auto& alias_info = it->alias_info();
381     if (alias_info) {
382       return false;
383     }
384   }
385   // Check that the first tensor is being returned (i.e., output has a (a!))
386   const auto& return_alias_info = schema.returns()[0].alias_info();
387   return return_alias_info && return_alias_info->isWrite();
388 }
389 
findAliasedOutput(const FunctionSchema & schema,const int64_t immutable_input_idx)390 std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input_idx) {
391   for (size_t res_idx = 0; res_idx != schema.returns().size(); ++res_idx) {
392     if (schema.may_contain_alias(SchemaArgument(SchemaArgType::input, immutable_input_idx), SchemaArgument(SchemaArgType::output, res_idx))) {
393       return res_idx; // for everything currently in native_functions, each input aliases at most one output (tensor list counts as one output)
394     }
395   }
396   return std::nullopt;
397 }
398 
399 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
dump_local_tls()400 static void dump_local_tls() {
401   auto tls = c10::impl::tls_local_dispatch_key_set();
402   std::cout << "[Local Include] " << tls.included_ << std::endl;
403   std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
404 }
405 #endif
406 
407 struct WithoutTop {
408   WithoutTop();
409   ~WithoutTop();
410   DynamicLayer layer_;
411 };
412 
WithoutTop()413 WithoutTop::WithoutTop(): layer_(popDynamicLayer()) {}
~WithoutTop()414 WithoutTop::~WithoutTop() {
415   pushDynamicLayer(std::move(layer_));
416 }
417 
418 // NOTE: [functorch front and back key fallbacks]
419 //
420 // Please read NOTE: [functorch interpreter stack] first for some context.
421 // The following doc also provides some visuals:
422 // https://docs.google.com/document/d/14qyaa3xIjmVxYiMLlIlQErunYgR_uR1WupsKMZlnGY4/edit
423 //
424 // functorch's "stack of transforms" is implemented as the following:
425 // - each transform is associated with one or more dispatch keys in the PyTorch
426 //   dispatcher. For example, vmap -> {FuncTorchBatched, FuncTorchVmapMode},
427 //   Autograd -> {Autograd{Backend}, ADInplaceOrView}
428 // - Whenever a functorch transform is active, the FuncTorchDynamicLayer{Front, Back}Mode
429 //   keys are added to the dispatcher's local dispatch key set.
430 //
431 // DynamicLayerFrontMode is responsible for:
432 // 1. selecting the transform that is at the top of the stack and grabbing its
433 //    interpreter
434 // 2. Calling interpreter.process(), which does the following:
435 // 2a. enables/disables a bunch of dispatch keys, so that the only dispatch
436 //     keys that are enabled are the ones that belong to the transform.
437 // 2b. redispatching
438 //
439 // Eventually, DynamicLayerBackMode captures the redispatch from the transforms.
440 // DynamicLayerBackMode is responsible for:
441 // - redirecting back to DynamicLayerFrontMode
442 
dynamicLayerFrontFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)443 static void dynamicLayerFrontFallback(
444     const c10::OperatorHandle& op,
445     torch::jit::Stack* stack) {
446   auto& dynamicLayerStack = dynamicLayerStackAccessor();
447   TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
448 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
449   if (c10::show_dispatch_trace_enabled()) {
450     std::cout << dynamicLayerStack << std::endl;
451     dump_local_tls();
452   }
453 #endif
454   // Save the current LocalDispatchKeySet (to the current DynamicLayer).
455   // Upon exiting the current scope, that LocalDispatchKeySet gets restored.
456   // When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
457   // it will also temporarily restore the saved LocalDispatchKeySet.
458   SaveLocalDispatchKeySet guard;
459 
460   // Unwrap escaped GradWrappers
461   auto num_args = op.schema().arguments().size();
462   foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), unwrapIfDead);
463 
464   auto& layer = dynamicLayerStack.back();
465   layer.interpreter().process(op, stack);
466 }
467 
468 static c10::impl::ForceDispatchKeyGuard
restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet & key_set)469 restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
470   return c10::impl::ForceDispatchKeyGuard(key_set);
471 }
472 
473 // right now grad_special_case as a bool is sufficient because this is the only special case for grad. If we need to add
474 // more special cases, it's more scalable to add an enum to know which op we're looking at without looking at the schema
dynamicLayerBack(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)475 static void dynamicLayerBack(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case) {
476   auto restore_guard = restoreLocalDispatchKeySetRAII(
477       dynamicLayerStackAccessor().back().interpreter().getSavedLocalDispatchKeySet());
478   WithoutTop guard;
479 
480   // WithoutTop stores the popped DynamicLayer object.
481   guard.layer_.interpreter().sendToNextInterpreter(op, stack, grad_special_case);
482 }
483 
484 // used for functions that have aliasing operations but should be treated like they're out of place (i.e. lift_fresh)
dynamicLayerBackGradSpecialCase(const c10::OperatorHandle & op,torch::jit::Stack * stack)485 static void dynamicLayerBackGradSpecialCase(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
486   return dynamicLayerBack(op, stack, true);
487 }
488 
dynamicLayerBackFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)489 static void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
490   return dynamicLayerBack(op, stack, false);
491 }
492 
TORCH_LIBRARY_IMPL(_,FuncTorchDynamicLayerFrontMode,m)493 TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerFrontMode, m) {
494   m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
495 }
496 
TORCH_LIBRARY_IMPL(_,FuncTorchDynamicLayerBackMode,m)497 TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerBackMode, m) {
498   m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
499 }
500 
501 
502 #define SPECIAL_GRAD_CASE(op) \
503   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackGradSpecialCase>());
504 
TORCH_LIBRARY_IMPL(aten,FuncTorchDynamicLayerBackMode,m)505 TORCH_LIBRARY_IMPL(aten, FuncTorchDynamicLayerBackMode, m) {
506   // lift_fresh: it's must be freshly allocated and should be wrapped. User shouldn't have access to input version
507   // alias: this is needed for the CompositeImplicit instance norm (running_mean/var get set to be a wrapped value)
508   //        It's not a user facing function, but is more prone to possible errors
509   SPECIAL_GRAD_CASE(lift_fresh);
510   SPECIAL_GRAD_CASE(alias);
511 }
512 
513 } // namespace at::functorch
514