xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/graph_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 #include <vector>
6 
7 namespace torch::jit {
8 
9 TORCH_API TypePtr getTensorType(const at::Tensor& t, bool complete);
10 
11 TORCH_API TypePtr inferShapeAndTypeForInput(
12     TypePtr input_type,
13     Stack::const_iterator& s_iter,
14     const Stack::const_iterator& s_iter_end,
15     bool complete);
16 
17 TORCH_API void setInputTensorTypes(
18     Graph& g,
19     const Stack& stack,
20     bool complete,
21     const std::vector<int>& param_count_list = {});
22 
23 } // namespace torch::jit
24