1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/DynamicLayer.h>
8 #include <ATen/functorch/TensorWrapper.h>
9 #include <ATen/functorch/BatchedTensorImpl.h>
10 #include <ATen/functorch/BatchRulesHelper.h>
11
12 #include <torch/library.h>
13 #include <c10/core/impl/LocalDispatchKeySet.h>
14 #include <ATen/core/dispatch/Dispatcher.h>
15 #include <ATen/FunctionalTensorWrapper.h>
16 #include <c10/util/irange.h>
17 #include <ATen/FuncTorchTLS.h>
18 #include <iostream>
19
20 namespace at::functorch {
21
setDynamicLayerFrontBackKeysIncluded(bool included)22 void setDynamicLayerFrontBackKeysIncluded(bool included) {
23 c10::impl::tls_set_dispatch_key_included(DispatchKey::FuncTorchDynamicLayerFrontMode, included);
24 c10::impl::tls_set_dispatch_key_included(DispatchKey::FuncTorchDynamicLayerBackMode, included);
25 }
26
DynamicLayer(TransformType transform_type,int64_t layerId,std::optional<c10::SymInt> batchSize,std::optional<RandomnessType> randomness,std::optional<bool> prev_grad_mode,std::optional<bool> prev_fwd_grad_mode,std::optional<bool> functionalize_add_back_views)27 DynamicLayer::DynamicLayer(
28 TransformType transform_type,
29 int64_t layerId,
30 std::optional<c10::SymInt> batchSize,
31 std::optional<RandomnessType> randomness,
32 std::optional<bool> prev_grad_mode,
33 std::optional<bool> prev_fwd_grad_mode,
34 std::optional<bool> functionalize_add_back_views)
35 {
36 if (transform_type == TransformType::Grad) {
37 TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
38 }
39 if (transform_type == TransformType::Jvp) {
40 TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
41 }
42 switch (transform_type) {
43 case TransformType::Vmap:
44 interpreter_ = Interpreter::Vmap(layerId, std::move(batchSize.value()), randomness.value());
45 break;
46 case TransformType::Grad:
47 interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value());
48 break;
49 case TransformType::Jvp:
50 interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value());
51 break;
52 case TransformType::Functionalize:
53 interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value());
54 break;
55 default:
56 TORCH_INTERNAL_ASSERT(false);
57 }
58 }
59
key() const60 TransformType DynamicLayer::key() const {
61 return interpreter_.key();
62 }
63
layerId() const64 int64_t DynamicLayer::layerId() const {
65 return interpreter_.level();
66 }
67
batchSize() const68 c10::SymInt DynamicLayer::batchSize() const {
69 return VmapInterpreterPtr(&interpreter_).batchSize();
70 }
71
randomness() const72 RandomnessType DynamicLayer::randomness() const {
73 return VmapInterpreterPtr(&interpreter_).randomness();
74 }
75
76 // functorch stores some TLS. Inside the TLS is the stack of transforms.
77 // Unfortunately, since functorch isn't a part of libtorch, we have
78 // a level of indirection. FuncTorchTLSBase is the interface that lives in libtorch,
79 // while FuncTorchTLS implements all the methods and stores data.
80 //
81 // TODO: after functorch C++ code is moved into PyTorch, we can get rid of
82 // this layer of indirection.
83 class FuncTorchTLS : public FuncTorchTLSBase {
84 public:
85 FuncTorchTLS() = default;
86
deepcopy() const87 std::unique_ptr<FuncTorchTLSBase> deepcopy() const override {
88 auto result = std::make_unique<FuncTorchTLS>();
89 result->dynamicLayerStack = dynamicLayerStack;
90 return result;
91 }
92
checkSupportsSingleLevelAutogradFunction() const93 int64_t checkSupportsSingleLevelAutogradFunction() const override {
94 TORCH_INTERNAL_ASSERT(dynamicLayerStack.empty() || getSingleLevelAutogradFunctionAllowed(),
95 "functorch functions (vmap, grad, vjp, etc.) incorrectly used with ",
96 "torch.autograd.function._SingleLevelFunction. ",
97 "This is not expected, please file a bug.");
98 return 0;
99 }
100
checkSupportsCppAutogradFunction() const101 void checkSupportsCppAutogradFunction() const override {
102 TORCH_CHECK(
103 dynamicLayerStack.empty(),
104 "cannot use C++ torch::autograd::Function with functorch transforms (vmap, grad, vjp, etc)");
105 }
106
checkSupportsInplaceRequiresGrad() const107 void checkSupportsInplaceRequiresGrad() const override {
108 TORCH_CHECK(dynamicLayerStack.empty() || allow_inplace_requires_grad_,
109 "You are attempting to call Tensor.requires_grad_() (or perhaps using ",
110 "torch.autograd.functional.* APIs) inside of a function being transformed ",
111 "by a functorch transform. ",
112 "This is unsupported, please attempt to use the functorch transforms ",
113 "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() "
114 "outside of a function being transformed instead.");
115 }
checkSupportsRetainGrad() const116 void checkSupportsRetainGrad() const override {
117 TORCH_CHECK(dynamicLayerStack.empty(),
118 "You are attempting to call Tensor.retain_grad() ",
119 "inside of a function being transformed ",
120 "by a functorch transform. ",
121 "This is unsupported, please attempt to use the functorch transforms ",
122 "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() "
123 "outside of a function being transformed instead.");
124 }
125
126 std::vector<DynamicLayer> dynamicLayerStack;
127 bool allow_inplace_requires_grad_ = false;
128 bool allow_single_level_autograd_function_ = false;
129 };
130
getRawFunctorchTLS()131 static FuncTorchTLS* getRawFunctorchTLS() {
132 auto& state = functorchTLSAccessor();
133 if (state == nullptr) {
134 state = std::make_unique<FuncTorchTLS>();
135 }
136 // Raw pointer usage OK, `state` keeps the pointer alive
137 FuncTorchTLSBase* raw_state = state.get();
138 FuncTorchTLS* result = static_cast<FuncTorchTLS*>(raw_state);
139 return result;
140 }
141
setInplaceRequiresGradAllowed(bool allowed)142 void setInplaceRequiresGradAllowed(bool allowed) {
143 auto* functorch_tls = getRawFunctorchTLS();
144 functorch_tls->allow_inplace_requires_grad_ = allowed;
145 }
146
getInplaceRequiresGradAllowed()147 bool getInplaceRequiresGradAllowed() {
148 auto* functorch_tls = getRawFunctorchTLS();
149 return functorch_tls->allow_inplace_requires_grad_;
150 }
151
setSingleLevelAutogradFunctionAllowed(bool allowed)152 void setSingleLevelAutogradFunctionAllowed(bool allowed) {
153 auto* functorch_tls = getRawFunctorchTLS();
154 functorch_tls->allow_single_level_autograd_function_ = allowed;
155 }
156
getSingleLevelAutogradFunctionAllowed()157 bool getSingleLevelAutogradFunctionAllowed() {
158 auto* functorch_tls = getRawFunctorchTLS();
159 return functorch_tls->allow_single_level_autograd_function_;
160 }
161
dynamicLayerStackAccessor()162 static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
163 return getRawFunctorchTLS()->dynamicLayerStack;
164 }
165
getLifeHandleForLevel(int64_t level)166 const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level) {
167 auto& dynamicLayerStack = dynamicLayerStackAccessor();
168 TORCH_INTERNAL_ASSERT(
169 (int64_t)dynamicLayerStack.size() >= level && level >= 1,
170 "If you're trying to construct a tensor with the current level (",
171 level,
172 ") then the interpreter for that level must be on the DynamicLayerStack ");
173
174 auto& dynamic_layer = dynamicLayerStack[level - 1];
175 return dynamic_layer.interpreter().is_alive_ptr();
176 }
177
maybeCurrentDynamicLayer()178 std::optional<DynamicLayer> maybeCurrentDynamicLayer() {
179 auto& dynamicLayerStack = dynamicLayerStackAccessor();
180 if (dynamicLayerStack.empty()) {
181 return {};
182 }
183 return dynamicLayerStack.back();
184 }
185
186 struct SaveLocalDispatchKeySet {
187 public:
SaveLocalDispatchKeySetat::functorch::SaveLocalDispatchKeySet188 SaveLocalDispatchKeySet() {
189 auto& dynamicLayerStack = dynamicLayerStackAccessor();
190 TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
191 auto& layer = dynamicLayerStack.back();
192 auto tmp = c10::impl::tls_local_dispatch_key_set();
193 layer.interpreter().saveLocalDispatchKeySet(tmp);
194 }
~SaveLocalDispatchKeySetat::functorch::SaveLocalDispatchKeySet195 ~SaveLocalDispatchKeySet() {
196 auto& dynamicLayerStack = dynamicLayerStackAccessor();
197 TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
198 auto& layer = dynamicLayerStack.back();
199 auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
200 layer.interpreter().clearSavedLocalDispatchKeySet();
201 c10::impl::_force_tls_local_dispatch_key_set(tmp);
202 }
203 SaveLocalDispatchKeySet(const SaveLocalDispatchKeySet&) = delete;
204 SaveLocalDispatchKeySet& operator=(const SaveLocalDispatchKeySet&) = delete;
205 };
206
getDynamicLayerStack()207 const std::vector<DynamicLayer>& getDynamicLayerStack() {
208 return dynamicLayerStackAccessor();
209 }
210
setDynamicLayerStack(const std::vector<DynamicLayer> & stack)211 void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
212 dynamicLayerStackAccessor() = stack;
213 }
214
popDynamicLayer()215 DynamicLayer popDynamicLayer() {
216 auto& dynamicLayerStack = dynamicLayerStackAccessor();
217 TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
218 auto result = dynamicLayerStack.back();
219 dynamicLayerStack.pop_back();
220
221 if (dynamicLayerStack.empty()) {
222 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
223 if (c10::show_dispatch_trace_enabled()) {
224 std::cout << "DynamicLayer off" << std::endl;
225 }
226 #endif
227 setDynamicLayerFrontBackKeysIncluded(false);
228 }
229
230 return result;
231 }
232
pushDynamicLayer(DynamicLayer && dynamic_layer)233 int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) {
234 auto& dynamicLayerStack = dynamicLayerStackAccessor();
235 int64_t layerId = 1 + dynamicLayerStack.size();
236 TORCH_INTERNAL_ASSERT(layerId == dynamic_layer.layerId());
237 dynamicLayerStack.emplace_back(std::move(dynamic_layer));
238
239 if (layerId == 1) {
240 setDynamicLayerFrontBackKeysIncluded(true);
241 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
242 if (c10::show_dispatch_trace_enabled()) {
243 std::cout << "DynamicLayer on" << std::endl;
244 }
245 #endif
246 }
247
248 return layerId;
249 }
250
initAndPushDynamicLayer(TransformType transform_type,std::optional<c10::SymInt> batch_size,std::optional<RandomnessType> randomness,std::optional<bool> prev_grad_mode,std::optional<bool> prev_fwd_grad_mode,std::optional<bool> functionalize_add_back_views)251 int64_t initAndPushDynamicLayer(
252 TransformType transform_type,
253 std::optional<c10::SymInt> batch_size,
254 std::optional<RandomnessType> randomness,
255 std::optional<bool> prev_grad_mode,
256 std::optional<bool> prev_fwd_grad_mode,
257 std::optional<bool> functionalize_add_back_views) {
258 const auto& dynamicLayerStack = dynamicLayerStackAccessor();
259 const auto layerId = 1 + dynamicLayerStack.size();
260 DynamicLayer new_layer(transform_type, layerId, std::move(batch_size), randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views);
261 // NB: this function should be called while holding the GIL to avoid races
262 new_layer.interpreter().set_is_alive(true);
263 pushDynamicLayer(std::move(new_layer));
264
265
266 if (transform_type == TransformType::Grad) {
267 TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
268 }
269 if (transform_type == TransformType::Jvp) {
270 TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
271 }
272 return layerId;
273 }
274
popDynamicLayerAndDeleteMetadata()275 DynamicLayer popDynamicLayerAndDeleteMetadata() {
276 auto result = popDynamicLayer();
277
278 // NB: this function should be called while holding the GIL to avoid races
279 result.interpreter().set_is_alive(false);
280 return result;
281 }
282
isDeadTensorWrapper(const Tensor & tensor)283 bool isDeadTensorWrapper(const Tensor& tensor) {
284 auto* wrapped = maybeGetTensorWrapper(tensor);
285 if (!wrapped) {
286 return false;
287 }
288 return !wrapped->is_alive();
289 }
290
unwrapIfDead(const Tensor & tensor)291 Tensor unwrapIfDead(const Tensor& tensor) {
292 auto* wrapped = maybeGetTensorWrapper(tensor);
293 if (!wrapped) {
294 return tensor;
295 }
296 if (wrapped->is_alive()) {
297 return tensor;
298 }
299 return wrapped->value();
300 }
301
foreachTensorInplace(std::vector<IValue> & args,int64_t begin,int64_t end,std::function<Tensor (const Tensor &)> func)302 void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
303 std::function<Tensor(const Tensor&)> func) {
304 auto func_with_bool = [&](const Tensor& tensor, bool unused) { return func(tensor); };
305 foreachTensorInplaceWithFlag(args, begin, end, std::bitset<64>(), func_with_bool);
306 }
307
foreachTensorInplaceWithFlag(std::vector<IValue> & args,int64_t begin,int64_t end,const std::bitset<64> use_flag_relative,const std::function<Tensor (const Tensor &,bool)> & func)308 void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
309 const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func){
310 TORCH_INTERNAL_ASSERT(begin >= 0);
311 TORCH_INTERNAL_ASSERT(end >= 0);
312 TORCH_INTERNAL_ASSERT(begin <= end);
313 for (int64_t relative_idx = 0; relative_idx < end - begin; relative_idx++) {
314 const bool flag = use_flag_relative[relative_idx] == 1;
315
316 const auto idx = relative_idx + begin;
317 auto ivalue = args[idx];
318 // Tensor?[] translates to a c10::List<IValue> so we need to peek inside List
319 if (ivalue.isList()) {
320 bool modified = false;
321 // TODO: might be more efficient if we scan first then not copy? Depends.
322 auto list = ivalue.toList().copy();
323 for (const auto list_idx : c10::irange(0, list.size())) {
324 const auto& elt = list.get(list_idx);
325 if (elt.isTensor()) {
326 list.set(list_idx, func(elt.toTensor(), flag));
327 modified = true;
328 }
329 }
330 if (modified) {
331 args[idx] = list;
332 }
333 continue;
334 }
335 if (ivalue.isTensorList()) {
336 auto list = ivalue.toTensorList();
337 for (const auto list_idx : c10::irange(0, list.size())) {
338 list[list_idx] = func(list[list_idx], flag);
339 }
340 args[idx] = list;
341 }
342 TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
343 if (!ivalue.isTensor()) {
344 continue;
345 }
346 Tensor value = ivalue.toTensor();
347 Tensor replacement = func(value, flag);
348 args[idx] = std::move(replacement);
349 // sanity checks
350 if (ivalue.toTensor().defined()) {
351 TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
352 }
353 }
354 }
355
operator <<(std::ostream & os,const DynamicLayer & layer)356 std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
357 os << layer.layerId() << ":" << layer.key();
358 return os;
359 }
operator <<(std::ostream & os,const std::vector<DynamicLayer> & dls)360 std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
361 os << "DynamicLayerStack[ ";
362 for (const auto& layer : dls) {
363 os << layer << " ";
364 }
365 os << "]";
366 return os;
367 }
368
isInplaceOp(const FunctionSchema & schema)369 bool isInplaceOp(const FunctionSchema& schema) {
370 if (!schema.is_mutable() || schema.returns().size() != 1) {
371 return false;
372 }
373 // Check that the first argument is being written to
374 const auto& first_arg_alias_info = schema.arguments().begin()->alias_info();
375 if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) {
376 return false;
377 }
378 // Check that none of the other args are being aliased
379 for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
380 const auto& alias_info = it->alias_info();
381 if (alias_info) {
382 return false;
383 }
384 }
385 // Check that the first tensor is being returned (i.e., output has a (a!))
386 const auto& return_alias_info = schema.returns()[0].alias_info();
387 return return_alias_info && return_alias_info->isWrite();
388 }
389
findAliasedOutput(const FunctionSchema & schema,const int64_t immutable_input_idx)390 std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input_idx) {
391 for (size_t res_idx = 0; res_idx != schema.returns().size(); ++res_idx) {
392 if (schema.may_contain_alias(SchemaArgument(SchemaArgType::input, immutable_input_idx), SchemaArgument(SchemaArgType::output, res_idx))) {
393 return res_idx; // for everything currently in native_functions, each input aliases at most one output (tensor list counts as one output)
394 }
395 }
396 return std::nullopt;
397 }
398
399 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
dump_local_tls()400 static void dump_local_tls() {
401 auto tls = c10::impl::tls_local_dispatch_key_set();
402 std::cout << "[Local Include] " << tls.included_ << std::endl;
403 std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
404 }
405 #endif
406
407 struct WithoutTop {
408 WithoutTop();
409 ~WithoutTop();
410 DynamicLayer layer_;
411 };
412
WithoutTop()413 WithoutTop::WithoutTop(): layer_(popDynamicLayer()) {}
~WithoutTop()414 WithoutTop::~WithoutTop() {
415 pushDynamicLayer(std::move(layer_));
416 }
417
418 // NOTE: [functorch front and back key fallbacks]
419 //
420 // Please read NOTE: [functorch interpreter stack] first for some context.
421 // The following doc also provides some visuals:
422 // https://docs.google.com/document/d/14qyaa3xIjmVxYiMLlIlQErunYgR_uR1WupsKMZlnGY4/edit
423 //
424 // functorch's "stack of transforms" is implemented as the following:
425 // - each transform is associated with one or more dispatch keys in the PyTorch
426 // dispatcher. For example, vmap -> {FuncTorchBatched, FuncTorchVmapMode},
427 // Autograd -> {Autograd{Backend}, ADInplaceOrView}
428 // - Whenever a functorch transform is active, the FuncTorchDynamicLayer{Front, Back}Mode
429 // keys are added to the dispatcher's local dispatch key set.
430 //
431 // DynamicLayerFrontMode is responsible for:
432 // 1. selecting the transform that is at the top of the stack and grabbing its
433 // interpreter
434 // 2. Calling interpreter.process(), which does the following:
435 // 2a. enables/disables a bunch of dispatch keys, so that the only dispatch
436 // keys that are enabled are the ones that belong to the transform.
437 // 2b. redispatching
438 //
439 // Eventually, DynamicLayerBackMode captures the redispatch from the transforms.
440 // DynamicLayerBackMode is responsible for:
441 // - redirecting back to DynamicLayerFrontMode
442
dynamicLayerFrontFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)443 static void dynamicLayerFrontFallback(
444 const c10::OperatorHandle& op,
445 torch::jit::Stack* stack) {
446 auto& dynamicLayerStack = dynamicLayerStackAccessor();
447 TORCH_INTERNAL_ASSERT(!dynamicLayerStack.empty());
448 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
449 if (c10::show_dispatch_trace_enabled()) {
450 std::cout << dynamicLayerStack << std::endl;
451 dump_local_tls();
452 }
453 #endif
454 // Save the current LocalDispatchKeySet (to the current DynamicLayer).
455 // Upon exiting the current scope, that LocalDispatchKeySet gets restored.
456 // When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
457 // it will also temporarily restore the saved LocalDispatchKeySet.
458 SaveLocalDispatchKeySet guard;
459
460 // Unwrap escaped GradWrappers
461 auto num_args = op.schema().arguments().size();
462 foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), unwrapIfDead);
463
464 auto& layer = dynamicLayerStack.back();
465 layer.interpreter().process(op, stack);
466 }
467
468 static c10::impl::ForceDispatchKeyGuard
restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet & key_set)469 restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
470 return c10::impl::ForceDispatchKeyGuard(key_set);
471 }
472
473 // right now grad_special_case as a bool is sufficient because this is the only special case for grad. If we need to add
474 // more special cases, it's more scalable to add an enum to know which op we're looking at without looking at the schema
dynamicLayerBack(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool grad_special_case)475 static void dynamicLayerBack(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case) {
476 auto restore_guard = restoreLocalDispatchKeySetRAII(
477 dynamicLayerStackAccessor().back().interpreter().getSavedLocalDispatchKeySet());
478 WithoutTop guard;
479
480 // WithoutTop stores the popped DynamicLayer object.
481 guard.layer_.interpreter().sendToNextInterpreter(op, stack, grad_special_case);
482 }
483
484 // used for functions that have aliasing operations but should be treated like they're out of place (i.e. lift_fresh)
dynamicLayerBackGradSpecialCase(const c10::OperatorHandle & op,torch::jit::Stack * stack)485 static void dynamicLayerBackGradSpecialCase(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
486 return dynamicLayerBack(op, stack, true);
487 }
488
dynamicLayerBackFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)489 static void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
490 return dynamicLayerBack(op, stack, false);
491 }
492
TORCH_LIBRARY_IMPL(_,FuncTorchDynamicLayerFrontMode,m)493 TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerFrontMode, m) {
494 m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
495 }
496
TORCH_LIBRARY_IMPL(_,FuncTorchDynamicLayerBackMode,m)497 TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerBackMode, m) {
498 m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
499 }
500
501
502 #define SPECIAL_GRAD_CASE(op) \
503 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackGradSpecialCase>());
504
TORCH_LIBRARY_IMPL(aten,FuncTorchDynamicLayerBackMode,m)505 TORCH_LIBRARY_IMPL(aten, FuncTorchDynamicLayerBackMode, m) {
506 // lift_fresh: it's must be freshly allocated and should be wrapped. User shouldn't have access to input version
507 // alias: this is needed for the CompositeImplicit instance norm (running_mean/var get set to be a wrapped value)
508 // It's not a user facing function, but is more prone to possible errors
509 SPECIAL_GRAD_CASE(lift_fresh);
510 SPECIAL_GRAD_CASE(alias);
511 }
512
513 } // namespace at::functorch
514