xref: /aosp_15_r20/external/pytorch/tools/autograd/templates/Functions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker // ${generated_comment}
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/functional.h>
7*da0073e9SAndroid Build Coastguard Worker #include <ATen/TensorGeometry.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #include "torch/csrc/autograd/function.h"
10*da0073e9SAndroid Build Coastguard Worker #include "torch/csrc/autograd/variable.h"
11*da0073e9SAndroid Build Coastguard Worker #include "torch/csrc/autograd/saved_variable.h"
12*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Export.h>
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker #include <c10/core/SymIntArrayRef.h>
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker namespace torch { namespace autograd { namespace generated {
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker using at::Scalar;
19*da0073e9SAndroid Build Coastguard Worker using at::Tensor;
20*da0073e9SAndroid Build Coastguard Worker using at::IntArrayRef;
21*da0073e9SAndroid Build Coastguard Worker using at::ArrayRef;
22*da0073e9SAndroid Build Coastguard Worker using at::Type;
23*da0073e9SAndroid Build Coastguard Worker using at::TensorGeometry;
24*da0073e9SAndroid Build Coastguard Worker using at::ScalarType;
25*da0073e9SAndroid Build Coastguard Worker using std::optional;
26*da0073e9SAndroid Build Coastguard Worker using c10::fmap;
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker inline std::vector<Tensor> unpack_list(at::ArrayRef<SavedVariable> xs, std::shared_ptr<Node> saved_for = nullptr) {
29*da0073e9SAndroid Build Coastguard Worker   // NB: we must explicitly do the conversion in the lambda, otherwise template
30*da0073e9SAndroid Build Coastguard Worker   // deduction will give a Tensor of Variable which is not convertible
31*da0073e9SAndroid Build Coastguard Worker   return fmap(xs, [&saved_for](const SavedVariable& x) {
32*da0073e9SAndroid Build Coastguard Worker     // TODO(crcrpar): Use `std::move(saved_for)` to avoid incrementing refcount, which would need refactoring.
33*da0073e9SAndroid Build Coastguard Worker     return static_cast<Tensor>(x.unpack(saved_for));
34*da0073e9SAndroid Build Coastguard Worker   });
35*da0073e9SAndroid Build Coastguard Worker }
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker inline c10::List<std::optional<Tensor>> unpack_opt_list(at::ArrayRef<SavedVariable> xs, std::shared_ptr<Node> saved_for = nullptr) {
38*da0073e9SAndroid Build Coastguard Worker   torch::List<std::optional<Tensor>> result;
39*da0073e9SAndroid Build Coastguard Worker   result.reserve(xs.size());
40*da0073e9SAndroid Build Coastguard Worker   for (const SavedVariable& v : xs) {
41*da0073e9SAndroid Build Coastguard Worker     auto var = v.unpack(saved_for);
42*da0073e9SAndroid Build Coastguard Worker     result.push_back(var.defined() ? std::optional<Tensor>(var) : ::std::nullopt);
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker   return result;
45*da0073e9SAndroid Build Coastguard Worker }
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker using torch::autograd::TypeAndSize;
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker ${autograd_function_declarations}
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker }}} // namespace torch::autograd::generated
52