xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/functions/utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/autograd/functions/utils.h>
3 
4 #include <torch/csrc/autograd/edge.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/autograd/variable.h>
7 
8 #include <sstream>
9 
10 namespace torch::autograd {
11 
wrap_outputs(const variable_list & inputs,tensor_list && outputs,const function_constructor & ctr)12 variable_list wrap_outputs(
13     const variable_list& inputs,
14     tensor_list&& outputs,
15     const function_constructor& ctr) {
16   variable_list result;
17   result.reserve(outputs.size());
18   if (!any_variable_requires_grad(inputs)) {
19     for (auto& output : outputs) {
20       if (output.defined()) {
21         result.push_back(make_variable(output, /*requires_grad=*/false));
22       } else {
23         result.emplace_back();
24       }
25     }
26   } else {
27     auto grad_fn =
28         ctr(GradMode::is_enabled() ? collect_next_edges(inputs) : edge_list());
29     for (auto& output : outputs) {
30       if (output.defined()) {
31         auto variable =
32             autograd::make_variable(output, /*requires_grad=*/false);
33         autograd::create_gradient_edge(variable, grad_fn);
34         result.push_back(std::move(variable));
35       } else {
36         grad_fn->add_input_metadata(Node::undefined_input());
37         result.emplace_back();
38       }
39     }
40   }
41   return result;
42 }
43 
check_input_variables(const char * name,const variable_list & inputs,int args,int required_args,bool allow_undefined)44 void check_input_variables(
45     const char* name,
46     const variable_list& inputs,
47     int args,
48     int required_args,
49     bool allow_undefined) {
50   if (required_args == -1) {
51     required_args = args;
52   }
53   if (inputs.size() != (size_t)args) {
54     std::stringstream ss;
55     ss << name << ": expected " << args << " arguments (got " << inputs.size();
56     ss << ")";
57     throw std::runtime_error(ss.str());
58   }
59   for (const auto i : c10::irange(required_args)) {
60     if (!inputs[i].defined() && !allow_undefined) {
61       std::stringstream ss;
62       ss << name << ": expected Tensor at argument " << i << " (got None)";
63       throw std::runtime_error(ss.str());
64     }
65   }
66 }
67 } // namespace torch::autograd
68