1 #include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
2
3 #include <c10/util/irange.h>
4
5 #include <ATen/core/TorchDispatchUtils.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <ATen/core/ivalue.h>
8
9 #include <c10/core/impl/TorchDispatchModeTLS.h>
10 #include <torch/csrc/autograd/VariableTypeUtils.h>
11 #include <torch/csrc/autograd/autograd.h>
12 #include <torch/csrc/autograd/function.h>
13 #include <torch/csrc/autograd/functions/basic_ops.h>
14 #include <torch/csrc/autograd/functions/utils.h>
15
16 #include <optional>
17 #include <utility>
18 #include <vector>
19
20 namespace torch::autograd {
21
22 namespace {
23
24 template <typename F>
_foreach_tensor(F fn,torch::jit::Stack * stack,size_t stack_start,size_t size)25 void _foreach_tensor(
26 F fn,
27 torch::jit::Stack* stack,
28 size_t stack_start,
29 size_t size) {
30 // Enumerate over tensors in a stack, including ones in TensorLists
31 int idx_tensor = 0;
32 for (const auto idx_arg : c10::irange(size)) {
33 auto& ivalue = (*stack)[stack_start + idx_arg];
34 if (ivalue.isTensor()) { // true for optional tensor that has value
35 const auto& tensor = ivalue.toTensor();
36 fn(idx_tensor, idx_arg, tensor);
37 idx_tensor++;
38 } else if (ivalue.isTensorList()) {
39 for (const auto& iv : ivalue.toListRef()) {
40 const auto& tensor = iv.toTensor();
41 fn(idx_tensor, idx_arg, tensor);
42 idx_tensor++;
43 }
44 }
45 }
46 }
47
48 AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
49
50 } // namespace
51
setAutogradFallbackMode(AutogradFallbackMode mode)52 void setAutogradFallbackMode(AutogradFallbackMode mode) {
53 TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'");
54 kAutogradFallbackMode = mode;
55 }
56
getAutogradFallbackMode()57 AutogradFallbackMode getAutogradFallbackMode() {
58 return kAutogradFallbackMode;
59 }
60
warnAutogradNotImplemented(const std::string & op_name)61 static void warnAutogradNotImplemented(const std::string& op_name) {
62 TORCH_WARN(
63 op_name,
64 ": an autograd kernel was not registered to the Autograd key(s) ",
65 "but we are trying to backprop through it. This may lead to silently incorrect behavior. ",
66 "This behavior is deprecated and will be removed in a future version of PyTorch. ",
67 "If your operator is differentiable, please ensure you have registered an "
68 "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
69 "DispatchKey::CompositeImplicitAutograd). If your operator is not "
70 "differentiable, or to squash this warning and use the previous behavior, "
71 "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
72 }
73
74 struct WarnNotImplemented : public Node {
WarnNotImplementedtorch::autograd::WarnNotImplemented75 WarnNotImplemented(
76 std::string op_name,
77 size_t num_outputs,
78 edge_list&& next_edges)
79 : Node(std::move(next_edges)),
80 op_name(std::move(op_name)),
81 num_outputs(num_outputs) {}
82
WarnNotImplementedtorch::autograd::WarnNotImplemented83 WarnNotImplemented(std::string op_name, size_t num_outputs)
84 : op_name(std::move(op_name)), num_outputs(num_outputs) {}
85
86 variable_list apply(variable_list&& inputs) override;
87
88 std::string op_name;
89 size_t num_outputs;
90 };
91
apply(variable_list && inputs)92 auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
93 warnAutogradNotImplemented(op_name);
94 std::vector<at::Tensor> output(num_outputs);
95 return output;
96 }
97
basicAutogradNotImplementedFallbackImpl(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)98 static void basicAutogradNotImplementedFallbackImpl(
99 const c10::OperatorHandle& op,
100 c10::DispatchKeySet dispatch_keys,
101 torch::jit::Stack* stack) {
102 const auto& schema = op.schema();
103 const auto& op_name = schema.operator_name().name;
104 const auto num_arguments = schema.arguments().size();
105 const auto num_returns = schema.returns().size();
106 const auto stack_start = stack->size() - num_arguments;
107
108 if (getAutogradFallbackMode() == AutogradFallbackMode::Nothing) {
109 op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
110 return;
111 }
112 TORCH_INTERNAL_ASSERT(
113 getAutogradFallbackMode() == AutogradFallbackMode::Warn);
114
115 bool any_input_requires_grad = false;
116 _foreach_tensor(
117 [&](size_t _, size_t idx_arg, const at::Tensor& t) {
118 if (t.requires_grad()) {
119 any_input_requires_grad = true;
120 }
121 },
122 stack,
123 stack_start,
124 num_arguments);
125 // Optimization: TLS access can be slow. So we only check if it necessary
126 // by putting it after the requires_grad checks.
127 any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled();
128
129 std::shared_ptr<WarnNotImplemented> grad_fn;
130 if (any_input_requires_grad) {
131 // NB: It is standard to collect edges from all tensors
132 // (see generated/VariableTypeEverything.cpp for examples)
133 std::vector<const at::Tensor*> all_tensors_on_stack;
134 _foreach_tensor(
135 [&](size_t _, size_t idx_arg, const at::Tensor& t) {
136 all_tensors_on_stack.push_back(&t);
137 },
138 stack,
139 stack_start,
140 num_arguments);
141 grad_fn = std::shared_ptr<WarnNotImplemented>(
142 new WarnNotImplemented(op_name, all_tensors_on_stack.size()),
143 deleteNode);
144 grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
145 }
146
147 op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
148
149 if (any_input_requires_grad) {
150 // NB: if the operator mutates any inputs in-place and does not return them
151 // as outputs, we are unable to lazily raise a warning. This is OK because
152 // we don't expect many existing operators to do this because of the amount
153 // of technical expertise necessary (you would need to manually register an
154 // autograd kernel without using autograd.Function)
155 _foreach_tensor(
156 [&](size_t _, size_t idx_ret, const at::Tensor& t) {
157 if (!isDifferentiableType(t.scalar_type())) {
158 return;
159 }
160 const bool is_mutable_output =
161 schema.is_aliasing({c10::SchemaArgType::output, idx_ret}) &&
162 schema.is_mutable({c10::SchemaArgType::output, idx_ret});
163
164 // If the post-autograd implementation returns Tensors that require
165 // grad, then we install a hook that will warn during the backwards.
166 //
167 // NB: If the operation is inplace and the inputs were views,
168 // it is possible that the history was rebased and the hook will
169 // not warn in all places where it should. That is, the following
170 // won't warn:
171 // >>> x = torch.randn(3, 3, requires_grad=True)
172 // >>> z = x.clone()
173 // >>> w = z[0]
174 // >>> k = w[0]
175 // >>> y = op(k)
176 // >>> torch.autograd.grad(z.sum(), w)
177 if (t.requires_grad()) {
178 t.register_hook([op_name](const at::Tensor& grad) {
179 warnAutogradNotImplemented(op_name);
180 });
181 // If history is rebased, then we will attempt to warn
182 // on the view's base. This will catch most cases (because
183 // users typically call .backward() and backprop through
184 // the entire program).
185 if (t.is_view() && is_mutable_output) {
186 const auto& base = t._base();
187 if (base.requires_grad()) {
188 // Can only register_hook on tensors that require grad.
189 base.register_hook([op_name](const at::TensorBase& grad) {
190 warnAutogradNotImplemented(op_name);
191 });
192 }
193 }
194 return;
195 }
196
197 // If the post-autograd implementation returns any Tensors that
198 // don't require grad, then we install the WarnNotImplemented grad_fn.
199 // This grad_fn warns in backward and returns undefined tensor
200 // gradients.
201 //
202 // NOTE [autograd fallback and in-place operations]
203 // If the schema says the output is mutable, and the output
204 // is an input, and the input is a view Tensor, then...
205 // we're not sure if set_history is OK to do, so we just skip
206 // adding the grad_fn. Builtin operators do rebase_history here,
207 // but custom operators may have multiple Tensor(a!) returns,
208 // rebase_history assumes single Tensor(a!) return, and in general
209 // custom ops don't have a good in-place story.
210 if (!is_mutable_output) {
211 set_history(t, grad_fn);
212 }
213 },
214 stack,
215 stack->size() - num_returns,
216 num_returns);
217 }
218 }
219
basicAutogradNotImplementedFallback()220 torch::CppFunction basicAutogradNotImplementedFallback() {
221 return torch::CppFunction::makeFromBoxedFunction<
222 &basicAutogradNotImplementedFallbackImpl>();
223 }
224
basic_autograd_not_implemented_fallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack) const225 void VariableHooks::basic_autograd_not_implemented_fallback(
226 const c10::OperatorHandle& op,
227 c10::DispatchKeySet dispatch_keys,
228 torch::jit::Stack* stack) const {
229 basicAutogradNotImplementedFallbackImpl(op, dispatch_keys, stack);
230 }
231
autogradNotImplementedFallbackImpl(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)232 static void autogradNotImplementedFallbackImpl(
233 const c10::OperatorHandle& op,
234 c10::DispatchKeySet dispatch_keys,
235 torch::jit::Stack* stack) {
236 // Mimics a subset of the logic of a VariableType NotImplemented kernel
237 // See gen_variable_type.py
238 const auto& schema = op.schema();
239 const auto& op_name = schema.operator_name().name;
240 const auto num_arguments = schema.arguments().size();
241 const auto num_returns = schema.returns().size();
242 const auto stack_start = stack->size() - num_arguments;
243 const bool grad_mode = GradMode::is_enabled();
244 std::vector<const at::Tensor*> tensors_requiring_grad_on_stack;
245
246 // Keep track of which outputs are output of in-place modification
247 // so we can rebase_history if necessary
248 std::vector<bool> is_inplace_output(num_returns, false);
249 bool any_is_inplace_output = false;
250 std::vector<bool> is_aliased_output(num_returns, false);
251 std::optional<size_t> aliased_output_idx;
252
253 for (const auto i : c10::irange(num_returns)) {
254 if (schema.is_aliasing({c10::SchemaArgType::output, i})) {
255 if (schema.is_mutable({c10::SchemaArgType::output, i})) {
256 is_inplace_output[i] = true;
257 any_is_inplace_output = true;
258 } else {
259 TORCH_CHECK(
260 !aliased_output_idx.has_value(),
261 "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
262 "Non-composite functions where multiple outputs are aliased with inputs aren't supported."
263 "Please rewrite your function as a composite function.");
264 aliased_output_idx = i;
265 }
266 is_aliased_output[i] = true;
267 }
268 }
269
270 int64_t aliased_input_idx = -1;
271 for (const auto i : c10::irange(num_arguments)) {
272 if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
273 !schema.is_mutable({c10::SchemaArgType::input, i})) {
274 TORCH_CHECK(
275 aliased_input_idx == -1,
276 "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
277 "Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
278 "Please rewrite your function as a composite function.");
279 aliased_input_idx = static_cast<int64_t>(i);
280 }
281 }
282
283 size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks
284 _foreach_tensor(
285 [&](size_t _, size_t idx_arg, const at::Tensor& t) {
286 if (grad_mode && t.requires_grad()) {
287 tensors_requiring_grad_on_stack.push_back(&t);
288 }
289 num_tensor_inputs++;
290 TORCH_CHECK_NOT_IMPLEMENTED(
291 !isFwGradDefined(t),
292 "Trying to use forward AD with ",
293 op_name,
294 " that does not support it.");
295 },
296 stack,
297 stack_start,
298 num_arguments);
299
300 const bool any_requires_grad = !tensors_requiring_grad_on_stack.empty();
301 const bool has_out_arg = std::any_of(
302 schema.arguments().begin(),
303 schema.arguments().end(),
304 [](const c10::Argument& arg) { return arg.is_out(); });
305
306 _foreach_tensor(
307 [&](size_t _, size_t i, const at::Tensor& t) {
308 if (schema.is_mutable({c10::SchemaArgType::input, i})) {
309 if (has_out_arg) {
310 // Normally out argument overloads would not support any arguments
311 // that require grad. However, we loosen this check to maintain
312 // backward compatibility.
313 // See https://github.com/pytorch/pytorch/issues/120988
314 if (can_mutate_inplace(t, any_requires_grad) !=
315 can_mutate_inplace_result::success) {
316 throw_error_out_requires_grad(schema.name().c_str());
317 }
318 } else {
319 check_inplace(t, any_requires_grad);
320 }
321 }
322 },
323 stack,
324 stack_start,
325 num_arguments);
326
327 std::shared_ptr<NotImplemented> grad_fn;
328 if (any_requires_grad) {
329 grad_fn = std::shared_ptr<NotImplemented>(
330 new NotImplemented(op_name), deleteNode);
331 grad_fn->set_next_edges(
332 collect_next_edges(tensors_requiring_grad_on_stack));
333 }
334
335 #ifndef NDEBUG
336 // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
337 auto stack_args_copy =
338 std::vector<c10::IValue>(stack->begin() + stack_start, stack->end());
339 std::vector<c10::intrusive_ptr<c10::TensorImpl>> impl_saved;
340 impl_saved.reserve(num_tensor_inputs);
341 std::vector<std::optional<c10::Storage>> storage_saved;
342 storage_saved.reserve(num_tensor_inputs);
343 _foreach_tensor(
344 [&](size_t idx, size_t _, const at::Tensor& t) {
345 storage_saved.push_back(
346 t.has_storage() ? std::optional<c10::Storage>(t.storage())
347 : std::nullopt);
348 impl_saved.push_back(t.getIntrusivePtr());
349 },
350 &stack_args_copy,
351 0,
352 num_arguments);
353 #endif
354 if (aliased_input_idx != -1 || any_is_inplace_output) {
355 at::AutoDispatchBelowAutograd guard;
356 op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
357 } else {
358 // If neither in-place nor view
359 at::AutoDispatchBelowADInplaceOrView guard;
360 op.redispatchBoxed(
361 dispatch_keys & c10::after_ADInplaceOrView_keyset, stack);
362 }
363 #ifndef NDEBUG
364 _foreach_tensor(
365 [&](size_t idx_tensor, size_t _, const at::Tensor& t) {
366 // Skip next two for chunk_cat, see
367 // https://github.com/pytorch/pytorch/issues/130073
368 if (storage_saved.at(idx_tensor).has_value() &&
369 op_name != "aten::_chunk_cat")
370 TORCH_INTERNAL_ASSERT(
371 storage_saved.at(idx_tensor).value().is_alias_of(t.storage()),
372 op_name);
373 if (impl_saved.at(idx_tensor) && op_name != "aten::_chunk_cat")
374 TORCH_INTERNAL_ASSERT(
375 impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name);
376 },
377 &stack_args_copy,
378 0,
379 num_arguments);
380 _foreach_tensor(
381 [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
382 if (at::impl::tensor_has_dispatch(t) ||
383 at::impl::dispatch_mode_enabled() ||
384 // NJT components are expected to be reused; skip use_count() check
385 op_name.rfind("aten::_nested_get", 0) == 0)
386 return;
387 // Skip test_parallel_materialize
388 // For details see https://github.com/pytorch/pytorch/issues/130073
389 if (op_name == "aten::_test_parallel_materialize" ||
390 op_name == "aten::_test_optional_intlist" ||
391 op_name == "aten::_test_optional_filled_intlist" ||
392 op_name == "aten::_test_optional_floatlist")
393 return;
394 if (!is_inplace_output[idx_ret])
395 TORCH_INTERNAL_ASSERT(
396 t.use_count() <= 1, op_name); // Okay to return undefined tensor
397 // note(crcrpar): `_foreach_norm` returns a list of scalar Tensors and
398 // each Tensor shares a storage of a hidden, intermediate 1D Tensor
399 // created inside the CUDA implementation. This is because the
400 // reference implementation of nvidia/apex repo returns this 1D Tensor
401 // where each element represents the norm of corresponding input Tensor,
402 // here I want to return the same number of Tensors as the input
403 // TensorList, see https://github.com/pytorch/pytorch/issues/93940
404 // Skip native_channel_shuffle as well as transformer_encoder
405 // For details see https://github.com/pytorch/pytorch/issues/130073
406 if (!is_aliased_output[idx_ret] && t.has_storage() &&
407 op_name != "aten::_foreach_norm" &&
408 op_name != "aten::_transformer_encoder_layer_fwd" &&
409 op_name != "aten::native_channel_shuffle")
410 TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1);
411 },
412 stack,
413 stack->size() - num_returns,
414 num_returns);
415 // There should be only a single base-view pair, make sure their storage is
416 // aliased.
417 if (aliased_input_idx != -1 && aliased_output_idx.has_value()) {
418 const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx];
419 const c10::IValue& aliased_output_iv =
420 (*stack)[stack->size() - num_returns + *aliased_output_idx];
421 TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name);
422 TORCH_INTERNAL_ASSERT(
423 aliased_output_iv.isTensor() || aliased_output_iv.isTensorList(),
424 op_name);
425 const at::Tensor& aliased_input = aliased_input_iv.toTensor();
426 if (aliased_input.has_storage()) {
427 if (aliased_output_iv.isTensor()) {
428 const at::Tensor& aliased_output = aliased_input_iv.toTensor();
429 // for now, skip asserts for subclasses
430 // TODO: Fix the aliasing situation involving subclasses
431 if (!at::impl::dispatch_mode_enabled() &&
432 !at::impl::tensor_has_dispatch(aliased_input) &&
433 !at::impl::tensor_has_dispatch(aliased_output)) {
434 TORCH_INTERNAL_ASSERT(
435 aliased_input.storage().is_alias_of(aliased_output.storage()),
436 op_name);
437 }
438 } else {
439 const auto aliased_output_vec = aliased_output_iv.toTensorVector();
440 for (const auto& aliased_output : aliased_output_vec) {
441 // for now, skip asserts for subclasses
442 // TODO: Fix the aliasing situation involving subclasses
443 if (!at::impl::dispatch_mode_enabled() &&
444 !at::impl::tensor_has_dispatch(aliased_input) &&
445 !at::impl::tensor_has_dispatch(aliased_output)) {
446 TORCH_INTERNAL_ASSERT(
447 aliased_input.storage().is_alias_of(aliased_output.storage()),
448 op_name);
449 }
450 }
451 }
452 }
453 }
454 #endif
455
456 if (any_requires_grad) {
457 _foreach_tensor(
458 [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
459 if (isDifferentiableType(t.scalar_type())) {
460 if (is_inplace_output[idx_ret]) {
461 rebase_history(t, grad_fn);
462 } else {
463 set_history(t, grad_fn);
464 }
465 }
466 },
467 stack,
468 stack->size() - num_returns,
469 num_returns);
470 }
471 }
472
autogradNotImplementedFallback()473 torch::CppFunction autogradNotImplementedFallback() {
474 return torch::CppFunction::makeFromBoxedFunction<
475 &autogradNotImplementedFallbackImpl>();
476 }
477
autogradNotImplementedInplaceOrViewFallbackImpl(const c10::OperatorHandle & op,c10::DispatchKeySet dispatch_keys,torch::jit::Stack * stack)478 static void autogradNotImplementedInplaceOrViewFallbackImpl(
479 const c10::OperatorHandle& op,
480 c10::DispatchKeySet dispatch_keys,
481 torch::jit::Stack* stack) {
482 // Mimics a subset of the logic from ADInplaceOrViewType kernel:
483 // - see gen_inplace_or_view_type.py
484 // - this should only be used with autogradNotImplementedFallback above
485 // - For more information see
486 // https://pytorch.org/tutorials/advanced/dispatcher
487 //
488 // NOTE [ Limitations of ADInplaceOrView boxed kernel ]
489 //
490 // This op should only be used with autogradNotImplementedFallback kernel
491 // because there is some logic we need specifically to enforce that even
492 // if we do in-place on view's created in this kernel, the proper "derivative
493 // is not implemented" error is still raised.
494 //
495 // Just like the codegened kernel, we try to enforce some things:
496 // - For views: we enforce that the view relationship is between the first
497 // input
498 // and the first output (which may be either Tensor or vec of Tensors
499 // - For inplace (TODO?): enforce that the same op cannot be both a view and
500 // inplace
501 // that is not allowed in the gen_inplace_or_view logic
502 const auto& schema = op.schema();
503 const auto& op_name = schema.operator_name().name;
504 const auto num_arguments = schema.arguments().size();
505 const auto num_returns = schema.returns().size();
506 const auto stack_start = stack->size() - num_arguments;
507
508 at::Tensor aliased_input;
509
510 int64_t aliased_output_idx = -1;
511 for (const auto i : c10::irange(num_returns)) {
512 if (schema.is_aliasing({c10::SchemaArgType::output, i}) &&
513 !schema.is_mutable({c10::SchemaArgType::output, i})) {
514 TORCH_CHECK(
515 aliased_output_idx == -1,
516 "Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a "
517 "non-write alias annotation (i.e., 'Tensor(a)'). "
518 "Non-composite functions where multiple outputs are aliased with inputs aren't supported."
519 "Please rewrite your function as a composite function.");
520 aliased_output_idx = static_cast<int64_t>(i);
521 }
522 }
523
524 std::optional<size_t> aliased_input_idx;
525 for (const auto i : c10::irange(num_arguments)) {
526 if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
527 !schema.is_mutable({c10::SchemaArgType::input, i})) {
528 TORCH_CHECK(
529 !aliased_input_idx.has_value(),
530 "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a "
531 "non-write alias annotation (i.e., 'Tensor(a)'). "
532 "Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
533 "Please rewrite your function as a composite function.");
534 aliased_input_idx = i;
535 const c10::IValue& aliased_input_iv =
536 (*stack)[stack_start + i]; // get a reference to an ivalue on the
537 // stack
538 TORCH_CHECK(aliased_input_iv.isTensor());
539 aliased_input =
540 aliased_input_iv.toTensor(); // TODO: Can we avoid saving this tensor
541 // and incurring the refcount bump?
542 }
543 }
544 // See NOTE [ Limitations of ADInplaceOrView boxed kernel ] above
545 TORCH_CHECK(
546 (!aliased_input_idx.has_value() && aliased_output_idx == -1) ||
547 (aliased_input_idx.has_value() && aliased_input_idx.value() == 0 &&
548 aliased_output_idx == 0),
549 "Fallback ADInplaceOrView kernel can only create view relationships between the first "
550 "input and the first output (the output can be a vector of tensors). Please change the "
551 "order of your operator's parameters so that this is the case.");
552 const bool is_view = aliased_input_idx.has_value();
553
554 {
555 at::AutoDispatchBelowADInplaceOrView guard;
556 op.redispatchBoxed(
557 dispatch_keys & c10::after_ADInplaceOrView_keyset, stack);
558 }
559
560 for (const auto i : c10::irange(num_returns)) {
561 if (schema.is_mutable({c10::SchemaArgType::output, i})) {
562 increment_version((*stack)[stack->size() - num_returns + i].toTensor());
563 }
564 }
565
566 if (is_view) {
567 c10::IValue& aliased_output_iv =
568 (*stack)[stack->size() - num_returns + aliased_output_idx];
569
570 // See NOTE [ View + Inplace detection ] for more details about this logic
571 // We always need this view_func because otherwise if we do in-place
572 // on this view, we would implicitly use AsStridedBackward instead
573 // of the NotImplemented node. For the cross-dtype/non-strided
574 // cases, we would create something like this anyway
575 auto error_msg =
576 ("Mutating the view " + op_name +
577 "which does not have a derivative implemented is forbidden.");
578 auto erroring_view_func = std::make_unique<ErroringViewFunc>(error_msg);
579
580 const auto erroring_rev_view_func = [op_name = op_name](const at::Tensor&) {
581 TORCH_CHECK(
582 false,
583 "Accessing the reverse view for ",
584 op_name,
585 " which does not have a derivative implemented is forbidden.");
586 return at::Tensor();
587 };
588
589 if (aliased_output_iv.isTensorList()) {
590 auto aliased_output = aliased_output_iv.toTensorVector();
591 for (auto& sub_output : aliased_output) {
592 as_view(
593 /* base=*/aliased_input,
594 /* tensor=*/sub_output,
595 /* is_bw_differentiable=*/true,
596 /* is_fw_differentiable=*/true,
597 /* view_func=*/std::move(erroring_view_func),
598 /* rev_view_func=*/erroring_rev_view_func,
599 /* creation_meta=*/
600 InferenceMode::is_enabled()
601 ? CreationMeta::INFERENCE_MODE
602 : (at::GradMode::is_enabled() ? CreationMeta::MULTI_OUTPUT_NODE
603 : CreationMeta::NO_GRAD_MODE));
604 }
605 auto result = std::move(aliased_output);
606 stack->at(stack->size() - num_returns + aliased_output_idx) = result;
607 } else {
608 TORCH_CHECK(aliased_output_iv.isTensor());
609 auto result = as_view(
610 /* base=*/aliased_input,
611 /* tensor=*/std::move(aliased_output_iv).toTensor(),
612 /* is_bw_differentiable=*/true,
613 /* is_fw_differentiable=*/true,
614 /* view_func=*/std::move(erroring_view_func),
615 /* rev_view_func=*/erroring_rev_view_func,
616 /* creation_meta=*/
617 InferenceMode::is_enabled()
618 ? CreationMeta::INFERENCE_MODE
619 : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT
620 : CreationMeta::NO_GRAD_MODE));
621 stack->at(stack->size() - num_returns + aliased_output_idx) =
622 std::move(result);
623 }
624 }
625 }
626
autogradNotImplementedInplaceOrViewFallback()627 torch::CppFunction autogradNotImplementedInplaceOrViewFallback() {
628 return torch::CppFunction::makeFromBoxedFunction<
629 &autogradNotImplementedInplaceOrViewFallbackImpl>();
630 }
631
632 } // namespace torch::autograd
633