1 #include <torch/csrc/autograd/functions/basic_ops.h>
2
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/autograd/functions/utils.h>
5 #include <torch/csrc/autograd/variable.h>
6 #include <torch/csrc/dynamo/compiled_autograd.h>
7
8 #include <ATen/ATen.h>
9
10 #include <memory>
11 #include <utility>
12
13 namespace torch::autograd {
14
apply(variable_list && inputs)15 auto Error::apply(variable_list&& inputs) -> variable_list {
16 throw std::runtime_error(msg);
17 }
18
compiled_args(CompiledNodeArgs & args)19 void Error::compiled_args(CompiledNodeArgs& args) {
20 // throw the error durring collect, the graph won't get compiled
21 apply(variable_list());
22 }
23
apply_with_saved(const variable_list & inputs,SwapSavedVariables & saved)24 variable_list Error::apply_with_saved(
25 const variable_list& inputs,
26 SwapSavedVariables& saved) {
27 TORCH_INTERNAL_ASSERT(false, "unreachable");
28 }
29
apply(variable_list && inputs)30 auto DelayedError::apply(variable_list&& inputs) -> variable_list {
31 tensor_list outputs;
32 outputs.reserve(inputs.size());
33 for (auto& var : inputs) {
34 // FIXME: share version counters
35 outputs.emplace_back(var.defined() ? var.tensor_data() : at::Tensor());
36 }
37 return wrap_outputs(inputs, std::move(outputs), [&](edge_list&& next_edges) {
38 return std::make_shared<Error>(msg, std::move(next_edges));
39 });
40 }
41
apply(variable_list && inputs)42 auto UndefinedGrad::apply(variable_list&& inputs) -> variable_list {
43 tensor_list outputs;
44 outputs.reserve(inputs.size());
45 for (auto& var : inputs) {
46 outputs.emplace_back(
47 var.defined() ? var.clone().tensor_data() : at::Tensor());
48 }
49 return wrap_outputs(inputs, std::move(outputs), [&](edge_list&& next_edges) {
50 return std::make_shared<UndefinedGradBackward>(std::move(next_edges));
51 });
52 }
53
apply(variable_list && output_grads)54 auto UndefinedGradBackward::apply(variable_list&& output_grads)
55 -> variable_list {
56 tensor_list input_grads;
57 output_grads.reserve(input_grads.size());
58 for (auto& grad : output_grads) {
59 (void)grad; // Suppress unused variable warning
60 input_grads.emplace_back();
61 }
62 return input_grads;
63 }
64
apply(variable_list && grads)65 auto Identity::apply(variable_list&& grads) -> variable_list {
66 return std::move(grads);
67 }
68
compiled_args(CompiledNodeArgs & args)69 void GraphRoot::compiled_args(CompiledNodeArgs& args) {
70 args.collect(outputs);
71 }
apply_with_saved(const variable_list & inputs,SwapSavedVariables & saved)72 variable_list GraphRoot::apply_with_saved(
73 const variable_list& inputs,
74 SwapSavedVariables& saved) {
75 saved.before(outputs);
76 variable_list result(outputs);
77 saved.after(outputs);
78 return result;
79 }
80
81 } // namespace torch::autograd
82