1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 #include <vector> 6 7 namespace torch::jit { 8 9 TORCH_API TypePtr getTensorType(const at::Tensor& t, bool complete); 10 11 TORCH_API TypePtr inferShapeAndTypeForInput( 12 TypePtr input_type, 13 Stack::const_iterator& s_iter, 14 const Stack::const_iterator& s_iter_end, 15 bool complete); 16 17 TORCH_API void setInputTensorTypes( 18 Graph& g, 19 const Stack& stack, 20 bool complete, 21 const std::vector<int>& param_count_list = {}); 22 23 } // namespace torch::jit 24