1 #include <c10/util/irange.h>
2 #include <torch/csrc/autograd/functions/utils.h>
3
4 #include <torch/csrc/autograd/edge.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/autograd/variable.h>
7
8 #include <sstream>
9
10 namespace torch::autograd {
11
wrap_outputs(const variable_list & inputs,tensor_list && outputs,const function_constructor & ctr)12 variable_list wrap_outputs(
13 const variable_list& inputs,
14 tensor_list&& outputs,
15 const function_constructor& ctr) {
16 variable_list result;
17 result.reserve(outputs.size());
18 if (!any_variable_requires_grad(inputs)) {
19 for (auto& output : outputs) {
20 if (output.defined()) {
21 result.push_back(make_variable(output, /*requires_grad=*/false));
22 } else {
23 result.emplace_back();
24 }
25 }
26 } else {
27 auto grad_fn =
28 ctr(GradMode::is_enabled() ? collect_next_edges(inputs) : edge_list());
29 for (auto& output : outputs) {
30 if (output.defined()) {
31 auto variable =
32 autograd::make_variable(output, /*requires_grad=*/false);
33 autograd::create_gradient_edge(variable, grad_fn);
34 result.push_back(std::move(variable));
35 } else {
36 grad_fn->add_input_metadata(Node::undefined_input());
37 result.emplace_back();
38 }
39 }
40 }
41 return result;
42 }
43
check_input_variables(const char * name,const variable_list & inputs,int args,int required_args,bool allow_undefined)44 void check_input_variables(
45 const char* name,
46 const variable_list& inputs,
47 int args,
48 int required_args,
49 bool allow_undefined) {
50 if (required_args == -1) {
51 required_args = args;
52 }
53 if (inputs.size() != (size_t)args) {
54 std::stringstream ss;
55 ss << name << ": expected " << args << " arguments (got " << inputs.size();
56 ss << ")";
57 throw std::runtime_error(ss.str());
58 }
59 for (const auto i : c10::irange(required_args)) {
60 if (!inputs[i].defined() && !allow_undefined) {
61 std::stringstream ss;
62 ss << name << ": expected Tensor at argument " << i << " (got None)";
63 throw std::runtime_error(ss.str());
64 }
65 }
66 }
67 } // namespace torch::autograd
68