xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
2 
3 #include <c10/util/irange.h>
4 
5 #include <ATen/core/TorchDispatchUtils.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <ATen/core/ivalue.h>
8 
9 #include <c10/core/impl/TorchDispatchModeTLS.h>
10 #include <torch/csrc/autograd/VariableTypeUtils.h>
11 #include <torch/csrc/autograd/autograd.h>
12 #include <torch/csrc/autograd/function.h>
13 #include <torch/csrc/autograd/functions/basic_ops.h>
14 #include <torch/csrc/autograd/functions/utils.h>
15 
16 #include <optional>
17 #include <utility>
18 #include <vector>
19 
20 namespace torch::autograd {
21 
22 namespace {
23 
24 template <typename F>
_foreach_tensor(F fn,torch::jit::Stack * stack,size_t stack_start,size_t size)25 void _foreach_tensor(
26     F fn,
27     torch::jit::Stack* stack,
28     size_t stack_start,
29     size_t size) {
30   // Enumerate over tensors in a stack, including ones in TensorLists
31   int idx_tensor = 0;
32   for (const auto idx_arg : c10::irange(size)) {
33     auto& ivalue = (*stack)[stack_start + idx_arg];
34     if (ivalue.isTensor()) { // true for optional tensor that has value
35       const auto& tensor = ivalue.toTensor();
36       fn(idx_tensor, idx_arg, tensor);
37       idx_tensor++;
38     } else if (ivalue.isTensorList()) {
39       for (const auto& iv : ivalue.toListRef()) {
40         const auto& tensor = iv.toTensor();
41         fn(idx_tensor, idx_arg, tensor);
42         idx_tensor++;
43       }
44     }
45   }
46 }
47 
48 AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
49 
50 } // namespace
51 
setAutogradFallbackMode(AutogradFallbackMode mode)52 void setAutogradFallbackMode(AutogradFallbackMode mode) {
53   TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'");
54   kAutogradFallbackMode = mode;
55 }
56 
getAutogradFallbackMode()57 AutogradFallbackMode getAutogradFallbackMode() {
58   return kAutogradFallbackMode;
59 }
60 
warnAutogradNotImplemented(const std::string & op_name)61 static void warnAutogradNotImplemented(const std::string& op_name) {
62   TORCH_WARN(
63       op_name,
64       ": an autograd kernel was not registered to the Autograd key(s) ",
65       "but we are trying to backprop through it. This may lead to silently incorrect behavior. ",
66       "This behavior is deprecated and will be removed in a future version of PyTorch. ",
67       "If your operator is differentiable, please ensure you have registered an "
68       "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
69       "DispatchKey::CompositeImplicitAutograd). If your operator is not "
70       "differentiable, or to squash this warning and use the previous behavior, "
71       "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
72 }
73 
74 struct WarnNotImplemented : public Node {
WarnNotImplementedtorch::autograd::WarnNotImplemented75   WarnNotImplemented(
76       std::string op_name,
77       size_t num_outputs,
78       edge_list&& next_edges)
79       : Node(std::move(next_edges)),
80         op_name(std::move(op_name)),
81         num_outputs(num_outputs) {}
82 
WarnNotImplementedtorch::autograd::WarnNotImplemented83   WarnNotImplemented(std::string op_name, size_t num_outputs)
84       : op_name(std::move(op_name)), num_outputs(num_outputs) {}
85 
86   variable_list apply(variable_list&& inputs) override;
87 
88   std::string op_name;
89   size_t num_outputs;
90 };
91 
apply(variable_list && inputs)92 auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
93   warnAutogradNotImplemented(op_name);
94   std::vector<at::Tensor> output(num_outputs);
95   return output;
96 }
97 
basicAutogradNotImplementedFallbackImpl(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)98 static void basicAutogradNotImplementedFallbackImpl(
99     const c10::OperatorHandle& op,
100     c10::DispatchKeySet dispatch_keys,
101     torch::jit::Stack* stack) {
102   const auto& schema = op.schema();
103   const auto& op_name = schema.operator_name().name;
104   const auto num_arguments = schema.arguments().size();
105   const auto num_returns = schema.returns().size();
106   const auto stack_start = stack->size() - num_arguments;
107 
108   if (getAutogradFallbackMode() == AutogradFallbackMode::Nothing) {
109     op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
110     return;
111   }
112   TORCH_INTERNAL_ASSERT(
113       getAutogradFallbackMode() == AutogradFallbackMode::Warn);
114 
115   bool any_input_requires_grad = false;
116   _foreach_tensor(
117       [&](size_t _, size_t idx_arg, const at::Tensor& t) {
118         if (t.requires_grad()) {
119           any_input_requires_grad = true;
120         }
121       },
122       stack,
123       stack_start,
124       num_arguments);
125   // Optimization: TLS access can be slow. So we only check if it necessary
126   // by putting it after the requires_grad checks.
127   any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled();
128 
129   std::shared_ptr<WarnNotImplemented> grad_fn;
130   if (any_input_requires_grad) {
131     // NB: It is standard to collect edges from all tensors
132     // (see generated/VariableTypeEverything.cpp for examples)
133     std::vector<const at::Tensor*> all_tensors_on_stack;
134     _foreach_tensor(
135         [&](size_t _, size_t idx_arg, const at::Tensor& t) {
136           all_tensors_on_stack.push_back(&t);
137         },
138         stack,
139         stack_start,
140         num_arguments);
141     grad_fn = std::shared_ptr<WarnNotImplemented>(
142         new WarnNotImplemented(op_name, all_tensors_on_stack.size()),
143         deleteNode);
144     grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
145   }
146 
147   op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
148 
149   if (any_input_requires_grad) {
150     // NB: if the operator mutates any inputs in-place and does not return them
151     // as outputs, we are unable to lazily raise a warning. This is OK because
152     // we don't expect many existing operators to do this because of the amount
153     // of technical expertise necessary (you would need to manually register an
154     // autograd kernel without using autograd.Function)
155     _foreach_tensor(
156         [&](size_t _, size_t idx_ret, const at::Tensor& t) {
157           if (!isDifferentiableType(t.scalar_type())) {
158             return;
159           }
160           const bool is_mutable_output =
161               schema.is_aliasing({c10::SchemaArgType::output, idx_ret}) &&
162               schema.is_mutable({c10::SchemaArgType::output, idx_ret});
163 
164           // If the post-autograd implementation returns Tensors that require
165           // grad, then we install a hook that will warn during the backwards.
166           //
167           // NB: If the operation is inplace and the inputs were views,
168           // it is possible that the history was rebased and the hook will
169           // not warn in all places where it should. That is, the following
170           // won't warn:
171           // >>> x = torch.randn(3, 3, requires_grad=True)
172           // >>> z = x.clone()
173           // >>> w = z[0]
174           // >>> k = w[0]
175           // >>> y = op(k)
176           // >>> torch.autograd.grad(z.sum(), w)
177           if (t.requires_grad()) {
178             t.register_hook([op_name](const at::Tensor& grad) {
179               warnAutogradNotImplemented(op_name);
180             });
181             // If history is rebased, then we will attempt to warn
182             // on the view's base. This will catch most cases (because
183             // users typically call .backward() and backprop through
184             // the entire program).
185             if (t.is_view() && is_mutable_output) {
186               const auto& base = t._base();
187               if (base.requires_grad()) {
188                 // Can only register_hook on tensors that require grad.
189                 base.register_hook([op_name](const at::TensorBase& grad) {
190                   warnAutogradNotImplemented(op_name);
191                 });
192               }
193             }
194             return;
195           }
196 
197           // If the post-autograd implementation returns any Tensors that
198           // don't require grad, then we install the WarnNotImplemented grad_fn.
199           // This grad_fn warns in backward and returns undefined tensor
200           // gradients.
201           //
202           // NOTE [autograd fallback and in-place operations]
203           // If the schema says the output is mutable, and the output
204           // is an input, and the input is a view Tensor, then...
205           // we're not sure if set_history is OK to do, so we just skip
206           // adding the grad_fn. Builtin operators do rebase_history here,
207           // but custom operators may have multiple Tensor(a!) returns,
208           // rebase_history assumes single Tensor(a!) return, and in general
209           // custom ops don't have a good in-place story.
210           if (!is_mutable_output) {
211             set_history(t, grad_fn);
212           }
213         },
214         stack,
215         stack->size() - num_returns,
216         num_returns);
217   }
218 }
219 
basicAutogradNotImplementedFallback()220 torch::CppFunction basicAutogradNotImplementedFallback() {
221   return torch::CppFunction::makeFromBoxedFunction<
222       &basicAutogradNotImplementedFallbackImpl>();
223 }
224 
basic_autograd_not_implemented_fallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack) const225 void VariableHooks::basic_autograd_not_implemented_fallback(
226     const c10::OperatorHandle& op,
227     c10::DispatchKeySet dispatch_keys,
228     torch::jit::Stack* stack) const {
229   basicAutogradNotImplementedFallbackImpl(op, dispatch_keys, stack);
230 }
231 
autogradNotImplementedFallbackImpl(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)232 static void autogradNotImplementedFallbackImpl(
233     const c10::OperatorHandle& op,
234     c10::DispatchKeySet dispatch_keys,
235     torch::jit::Stack* stack) {
236   // Mimics a subset of the logic of a VariableType NotImplemented kernel
237   // See gen_variable_type.py
238   const auto& schema = op.schema();
239   const auto& op_name = schema.operator_name().name;
240   const auto num_arguments = schema.arguments().size();
241   const auto num_returns = schema.returns().size();
242   const auto stack_start = stack->size() - num_arguments;
243   const bool grad_mode = GradMode::is_enabled();
244   std::vector<const at::Tensor*> tensors_requiring_grad_on_stack;
245 
246   // Keep track of which outputs are output of in-place modification
247   // so we can rebase_history if necessary
248   std::vector<bool> is_inplace_output(num_returns, false);
249   bool any_is_inplace_output = false;
250   std::vector<bool> is_aliased_output(num_returns, false);
251   std::optional<size_t> aliased_output_idx;
252 
253   for (const auto i : c10::irange(num_returns)) {
254     if (schema.is_aliasing({c10::SchemaArgType::output, i})) {
255       if (schema.is_mutable({c10::SchemaArgType::output, i})) {
256         is_inplace_output[i] = true;
257         any_is_inplace_output = true;
258       } else {
259         TORCH_CHECK(
260             !aliased_output_idx.has_value(),
261             "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
262             "Non-composite functions where multiple outputs are aliased with inputs aren't supported."
263             "Please rewrite your function as a composite function.");
264         aliased_output_idx = i;
265       }
266       is_aliased_output[i] = true;
267     }
268   }
269 
270   int64_t aliased_input_idx = -1;
271   for (const auto i : c10::irange(num_arguments)) {
272     if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
273         !schema.is_mutable({c10::SchemaArgType::input, i})) {
274       TORCH_CHECK(
275           aliased_input_idx == -1,
276           "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
277           "Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
278           "Please rewrite your function as a composite function.");
279       aliased_input_idx = static_cast<int64_t>(i);
280     }
281   }
282 
283   size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks
284   _foreach_tensor(
285       [&](size_t _, size_t idx_arg, const at::Tensor& t) {
286         if (grad_mode && t.requires_grad()) {
287           tensors_requiring_grad_on_stack.push_back(&t);
288         }
289         num_tensor_inputs++;
290         TORCH_CHECK_NOT_IMPLEMENTED(
291             !isFwGradDefined(t),
292             "Trying to use forward AD with ",
293             op_name,
294             " that does not support it.");
295       },
296       stack,
297       stack_start,
298       num_arguments);
299 
300   const bool any_requires_grad = !tensors_requiring_grad_on_stack.empty();
301   const bool has_out_arg = std::any_of(
302       schema.arguments().begin(),
303       schema.arguments().end(),
304       [](const c10::Argument& arg) { return arg.is_out(); });
305 
306   _foreach_tensor(
307       [&](size_t _, size_t i, const at::Tensor& t) {
308         if (schema.is_mutable({c10::SchemaArgType::input, i})) {
309           if (has_out_arg) {
310             // Normally out argument overloads would not support any arguments
311             // that require grad. However, we loosen this check to maintain
312             // backward compatibility.
313             // See https://github.com/pytorch/pytorch/issues/120988
314             if (can_mutate_inplace(t, any_requires_grad) !=
315                 can_mutate_inplace_result::success) {
316               throw_error_out_requires_grad(schema.name().c_str());
317             }
318           } else {
319             check_inplace(t, any_requires_grad);
320           }
321         }
322       },
323       stack,
324       stack_start,
325       num_arguments);
326 
327   std::shared_ptr<NotImplemented> grad_fn;
328   if (any_requires_grad) {
329     grad_fn = std::shared_ptr<NotImplemented>(
330         new NotImplemented(op_name), deleteNode);
331     grad_fn->set_next_edges(
332         collect_next_edges(tensors_requiring_grad_on_stack));
333   }
334 
335 #ifndef NDEBUG
336   // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
337   auto stack_args_copy =
338       std::vector<c10::IValue>(stack->begin() + stack_start, stack->end());
339   std::vector<c10::intrusive_ptr<c10::TensorImpl>> impl_saved;
340   impl_saved.reserve(num_tensor_inputs);
341   std::vector<std::optional<c10::Storage>> storage_saved;
342   storage_saved.reserve(num_tensor_inputs);
343   _foreach_tensor(
344       [&](size_t idx, size_t _, const at::Tensor& t) {
345         storage_saved.push_back(
346             t.has_storage() ? std::optional<c10::Storage>(t.storage())
347                             : std::nullopt);
348         impl_saved.push_back(t.getIntrusivePtr());
349       },
350       &stack_args_copy,
351       0,
352       num_arguments);
353 #endif
354   if (aliased_input_idx != -1 || any_is_inplace_output) {
355     at::AutoDispatchBelowAutograd guard;
356     op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
357   } else {
358     // If neither in-place nor view
359     at::AutoDispatchBelowADInplaceOrView guard;
360     op.redispatchBoxed(
361         dispatch_keys & c10::after_ADInplaceOrView_keyset, stack);
362   }
363 #ifndef NDEBUG
364   _foreach_tensor(
365       [&](size_t idx_tensor, size_t _, const at::Tensor& t) {
366         // Skip next two for chunk_cat, see
367         // https://github.com/pytorch/pytorch/issues/130073
368         if (storage_saved.at(idx_tensor).has_value() &&
369             op_name != "aten::_chunk_cat")
370           TORCH_INTERNAL_ASSERT(
371               storage_saved.at(idx_tensor).value().is_alias_of(t.storage()),
372               op_name);
373         if (impl_saved.at(idx_tensor) && op_name != "aten::_chunk_cat")
374           TORCH_INTERNAL_ASSERT(
375               impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name);
376       },
377       &stack_args_copy,
378       0,
379       num_arguments);
380   _foreach_tensor(
381       [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
382         if (at::impl::tensor_has_dispatch(t) ||
383             at::impl::dispatch_mode_enabled() ||
384             // NJT components are expected to be reused; skip use_count() check
385             op_name.rfind("aten::_nested_get", 0) == 0)
386           return;
387         // Skip test_parallel_materialize
388         // For details see https://github.com/pytorch/pytorch/issues/130073
389         if (op_name == "aten::_test_parallel_materialize" ||
390             op_name == "aten::_test_optional_intlist" ||
391             op_name == "aten::_test_optional_filled_intlist" ||
392             op_name == "aten::_test_optional_floatlist")
393           return;
394         if (!is_inplace_output[idx_ret])
395           TORCH_INTERNAL_ASSERT(
396               t.use_count() <= 1, op_name); // Okay to return undefined tensor
397         // note(crcrpar): `_foreach_norm` returns a list of scalar Tensors and
398         // each Tensor shares a storage of a hidden, intermediate 1D Tensor
399         // created inside the CUDA implementation. This is because the
400         // reference implementation of nvidia/apex repo returns this 1D Tensor
401         // where each element represents the norm of corresponding input Tensor,
402         // here I want to return the same number of Tensors as the input
403         // TensorList, see https://github.com/pytorch/pytorch/issues/93940
404         // Skip native_channel_shuffle as well as transformer_encoder
405         // For details see https://github.com/pytorch/pytorch/issues/130073
406         if (!is_aliased_output[idx_ret] && t.has_storage() &&
407             op_name != "aten::_foreach_norm" &&
408             op_name != "aten::_transformer_encoder_layer_fwd" &&
409             op_name != "aten::native_channel_shuffle")
410           TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1);
411       },
412       stack,
413       stack->size() - num_returns,
414       num_returns);
415   // There should be only a single base-view pair, make sure their storage is
416   // aliased.
417   if (aliased_input_idx != -1 && aliased_output_idx.has_value()) {
418     const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx];
419     const c10::IValue& aliased_output_iv =
420         (*stack)[stack->size() - num_returns + *aliased_output_idx];
421     TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name);
422     TORCH_INTERNAL_ASSERT(
423         aliased_output_iv.isTensor() || aliased_output_iv.isTensorList(),
424         op_name);
425     const at::Tensor& aliased_input = aliased_input_iv.toTensor();
426     if (aliased_input.has_storage()) {
427       if (aliased_output_iv.isTensor()) {
428         const at::Tensor& aliased_output = aliased_input_iv.toTensor();
429         // for now, skip asserts for subclasses
430         // TODO: Fix the aliasing situation involving subclasses
431         if (!at::impl::dispatch_mode_enabled() &&
432             !at::impl::tensor_has_dispatch(aliased_input) &&
433             !at::impl::tensor_has_dispatch(aliased_output)) {
434           TORCH_INTERNAL_ASSERT(
435               aliased_input.storage().is_alias_of(aliased_output.storage()),
436               op_name);
437         }
438       } else {
439         const auto aliased_output_vec = aliased_output_iv.toTensorVector();
440         for (const auto& aliased_output : aliased_output_vec) {
441           // for now, skip asserts for subclasses
442           // TODO: Fix the aliasing situation involving subclasses
443           if (!at::impl::dispatch_mode_enabled() &&
444               !at::impl::tensor_has_dispatch(aliased_input) &&
445               !at::impl::tensor_has_dispatch(aliased_output)) {
446             TORCH_INTERNAL_ASSERT(
447                 aliased_input.storage().is_alias_of(aliased_output.storage()),
448                 op_name);
449           }
450         }
451       }
452     }
453   }
454 #endif
455 
456   if (any_requires_grad) {
457     _foreach_tensor(
458         [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
459           if (isDifferentiableType(t.scalar_type())) {
460             if (is_inplace_output[idx_ret]) {
461               rebase_history(t, grad_fn);
462             } else {
463               set_history(t, grad_fn);
464             }
465           }
466         },
467         stack,
468         stack->size() - num_returns,
469         num_returns);
470   }
471 }
472 
autogradNotImplementedFallback()473 torch::CppFunction autogradNotImplementedFallback() {
474   return torch::CppFunction::makeFromBoxedFunction<
475       &autogradNotImplementedFallbackImpl>();
476 }
477 
autogradNotImplementedInplaceOrViewFallbackImpl(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)478 static void autogradNotImplementedInplaceOrViewFallbackImpl(
479     const c10::OperatorHandle& op,
480     c10::DispatchKeySet dispatch_keys,
481     torch::jit::Stack* stack) {
482   // Mimics a subset of the logic from ADInplaceOrViewType kernel:
483   // - see gen_inplace_or_view_type.py
484   // - this should only be used with autogradNotImplementedFallback above
485   // - For more information see
486   // https://pytorch.org/tutorials/advanced/dispatcher
487   //
488   // NOTE [ Limitations of ADInplaceOrView boxed kernel ]
489   //
490   // This op should only be used with autogradNotImplementedFallback kernel
491   // because there is some logic we need specifically to enforce that even
492   // if we do in-place on view's created in this kernel, the proper "derivative
493   // is not implemented" error is still raised.
494   //
495   // Just like the codegened kernel, we try to enforce some things:
496   // - For views: we enforce that the view relationship is between the first
497   // input
498   //   and the first output (which may be either Tensor or vec of Tensors
499   // - For inplace (TODO?): enforce that the same op cannot be both a view and
500   // inplace
501   //   that is not allowed in the gen_inplace_or_view logic
502   const auto& schema = op.schema();
503   const auto& op_name = schema.operator_name().name;
504   const auto num_arguments = schema.arguments().size();
505   const auto num_returns = schema.returns().size();
506   const auto stack_start = stack->size() - num_arguments;
507 
508   at::Tensor aliased_input;
509 
510   int64_t aliased_output_idx = -1;
511   for (const auto i : c10::irange(num_returns)) {
512     if (schema.is_aliasing({c10::SchemaArgType::output, i}) &&
513         !schema.is_mutable({c10::SchemaArgType::output, i})) {
514       TORCH_CHECK(
515           aliased_output_idx == -1,
516           "Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a "
517           "non-write alias annotation (i.e., 'Tensor(a)'). "
518           "Non-composite functions where multiple outputs are aliased with inputs aren't supported."
519           "Please rewrite your function as a composite function.");
520       aliased_output_idx = static_cast<int64_t>(i);
521     }
522   }
523 
524   std::optional<size_t> aliased_input_idx;
525   for (const auto i : c10::irange(num_arguments)) {
526     if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
527         !schema.is_mutable({c10::SchemaArgType::input, i})) {
528       TORCH_CHECK(
529           !aliased_input_idx.has_value(),
530           "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a "
531           "non-write alias annotation (i.e., 'Tensor(a)'). "
532           "Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
533           "Please rewrite your function as a composite function.");
534       aliased_input_idx = i;
535       const c10::IValue& aliased_input_iv =
536           (*stack)[stack_start + i]; // get a reference to an ivalue on the
537                                      // stack
538       TORCH_CHECK(aliased_input_iv.isTensor());
539       aliased_input =
540           aliased_input_iv.toTensor(); // TODO: Can we avoid saving this tensor
541                                        // and incurring the refcount bump?
542     }
543   }
544   // See NOTE [ Limitations of ADInplaceOrView boxed kernel ] above
545   TORCH_CHECK(
546       (!aliased_input_idx.has_value() && aliased_output_idx == -1) ||
547           (aliased_input_idx.has_value() && aliased_input_idx.value() == 0 &&
548            aliased_output_idx == 0),
549       "Fallback ADInplaceOrView kernel can only create view relationships between the first "
550       "input and the first output (the output can be a vector of tensors). Please change the "
551       "order of your operator's parameters so that this is the case.");
552   const bool is_view = aliased_input_idx.has_value();
553 
554   {
555     at::AutoDispatchBelowADInplaceOrView guard;
556     op.redispatchBoxed(
557         dispatch_keys & c10::after_ADInplaceOrView_keyset, stack);
558   }
559 
560   for (const auto i : c10::irange(num_returns)) {
561     if (schema.is_mutable({c10::SchemaArgType::output, i})) {
562       increment_version((*stack)[stack->size() - num_returns + i].toTensor());
563     }
564   }
565 
566   if (is_view) {
567     c10::IValue& aliased_output_iv =
568         (*stack)[stack->size() - num_returns + aliased_output_idx];
569 
570     // See NOTE [ View + Inplace detection ] for more details about this logic
571     // We always need this view_func because otherwise if we do in-place
572     // on this view, we would implicitly use AsStridedBackward instead
573     // of the NotImplemented node. For the cross-dtype/non-strided
574     // cases, we would create something like this anyway
575     auto error_msg =
576         ("Mutating the view " + op_name +
577          "which does not have a derivative implemented is forbidden.");
578     auto erroring_view_func = std::make_unique<ErroringViewFunc>(error_msg);
579 
580     const auto erroring_rev_view_func = [op_name = op_name](const at::Tensor&) {
581       TORCH_CHECK(
582           false,
583           "Accessing the reverse view for ",
584           op_name,
585           " which does not have a derivative implemented is forbidden.");
586       return at::Tensor();
587     };
588 
589     if (aliased_output_iv.isTensorList()) {
590       auto aliased_output = aliased_output_iv.toTensorVector();
591       for (auto& sub_output : aliased_output) {
592         as_view(
593             /* base=*/aliased_input,
594             /* tensor=*/sub_output,
595             /* is_bw_differentiable=*/true,
596             /* is_fw_differentiable=*/true,
597             /* view_func=*/std::move(erroring_view_func),
598             /* rev_view_func=*/erroring_rev_view_func,
599             /* creation_meta=*/
600             InferenceMode::is_enabled()
601                 ? CreationMeta::INFERENCE_MODE
602                 : (at::GradMode::is_enabled() ? CreationMeta::MULTI_OUTPUT_NODE
603                                               : CreationMeta::NO_GRAD_MODE));
604       }
605       auto result = std::move(aliased_output);
606       stack->at(stack->size() - num_returns + aliased_output_idx) = result;
607     } else {
608       TORCH_CHECK(aliased_output_iv.isTensor());
609       auto result = as_view(
610           /* base=*/aliased_input,
611           /* tensor=*/std::move(aliased_output_iv).toTensor(),
612           /* is_bw_differentiable=*/true,
613           /* is_fw_differentiable=*/true,
614           /* view_func=*/std::move(erroring_view_func),
615           /* rev_view_func=*/erroring_rev_view_func,
616           /* creation_meta=*/
617           InferenceMode::is_enabled()
618               ? CreationMeta::INFERENCE_MODE
619               : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT
620                                             : CreationMeta::NO_GRAD_MODE));
621       stack->at(stack->size() - num_returns + aliased_output_idx) =
622           std::move(result);
623     }
624   }
625 }
626 
autogradNotImplementedInplaceOrViewFallback()627 torch::CppFunction autogradNotImplementedInplaceOrViewFallback() {
628   return torch::CppFunction::makeFromBoxedFunction<
629       &autogradNotImplementedInplaceOrViewFallbackImpl>();
630 }
631 
632 } // namespace torch::autograd
633