xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/variable_tensor_list.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace torch::jit {
5 
6 // a wrapper to mark places where we expect all the at::Tensors to be
7 // variables
8 struct variable_tensor_list : public std::vector<at::Tensor> {
9   variable_tensor_list() = default;
10   template <class InputIt>
variable_tensor_listvariable_tensor_list11   variable_tensor_list(InputIt first, InputIt last)
12       : std::vector<at::Tensor>(first, last) {}
variable_tensor_listvariable_tensor_list13   explicit variable_tensor_list(std::vector<at::Tensor>&& tensor)
14       : std::vector<at::Tensor>(std::move(tensor)) {}
15 };
16 
17 } // namespace torch::jit
18