1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker // ${generated_comment} 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h> 6*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/functional.h> 7*da0073e9SAndroid Build Coastguard Worker #include <ATen/TensorGeometry.h> 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker #include "torch/csrc/autograd/function.h" 10*da0073e9SAndroid Build Coastguard Worker #include "torch/csrc/autograd/variable.h" 11*da0073e9SAndroid Build Coastguard Worker #include "torch/csrc/autograd/saved_variable.h" 12*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Export.h> 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker #include <c10/core/SymIntArrayRef.h> 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker namespace torch { namespace autograd { namespace generated { 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker using at::Scalar; 19*da0073e9SAndroid Build Coastguard Worker using at::Tensor; 20*da0073e9SAndroid Build Coastguard Worker using at::IntArrayRef; 21*da0073e9SAndroid Build Coastguard Worker using at::ArrayRef; 22*da0073e9SAndroid Build Coastguard Worker using at::Type; 23*da0073e9SAndroid Build Coastguard Worker using at::TensorGeometry; 24*da0073e9SAndroid Build Coastguard Worker using at::ScalarType; 25*da0073e9SAndroid Build Coastguard Worker using std::optional; 26*da0073e9SAndroid Build Coastguard Worker using c10::fmap; 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker inline std::vector<Tensor> unpack_list(at::ArrayRef<SavedVariable> xs, std::shared_ptr<Node> saved_for = nullptr) { 29*da0073e9SAndroid Build Coastguard Worker // NB: we must explicitly do the conversion in the lambda, otherwise template 30*da0073e9SAndroid Build Coastguard Worker // deduction will give a Tensor of Variable which is not convertible 31*da0073e9SAndroid Build Coastguard Worker return fmap(xs, [&saved_for](const SavedVariable& x) { 32*da0073e9SAndroid Build Coastguard Worker // TODO(crcrpar): Use `std::move(saved_for)` to avoid incrementing refcount, which would need refactoring. 33*da0073e9SAndroid Build Coastguard Worker return static_cast<Tensor>(x.unpack(saved_for)); 34*da0073e9SAndroid Build Coastguard Worker }); 35*da0073e9SAndroid Build Coastguard Worker } 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker inline c10::List<std::optional<Tensor>> unpack_opt_list(at::ArrayRef<SavedVariable> xs, std::shared_ptr<Node> saved_for = nullptr) { 38*da0073e9SAndroid Build Coastguard Worker torch::List<std::optional<Tensor>> result; 39*da0073e9SAndroid Build Coastguard Worker result.reserve(xs.size()); 40*da0073e9SAndroid Build Coastguard Worker for (const SavedVariable& v : xs) { 41*da0073e9SAndroid Build Coastguard Worker auto var = v.unpack(saved_for); 42*da0073e9SAndroid Build Coastguard Worker result.push_back(var.defined() ? std::optional<Tensor>(var) : ::std::nullopt); 43*da0073e9SAndroid Build Coastguard Worker } 44*da0073e9SAndroid Build Coastguard Worker return result; 45*da0073e9SAndroid Build Coastguard Worker } 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker using torch::autograd::TypeAndSize; 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker ${autograd_function_declarations} 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker }}} // namespace torch::autograd::generated 52