1 #include <torch/csrc/jit/ir/graph_utils.h>
2
3 namespace torch::jit {
4
getTensorType(const at::Tensor & t,bool complete)5 TypePtr getTensorType(const at::Tensor& t, bool complete) {
6 auto r = TensorType::create(t);
7 if (!complete) {
8 r = r->dimensionedOnly();
9 }
10 return r;
11 }
12
inferShapeAndTypeForInput(TypePtr input_type,Stack::const_iterator & s_iter,const Stack::const_iterator & s_iter_end,bool complete)13 TypePtr inferShapeAndTypeForInput(
14 TypePtr input_type,
15 Stack::const_iterator& s_iter,
16 const Stack::const_iterator& s_iter_end,
17 bool complete) {
18 if (auto tuple_type = input_type->cast<TupleType>()) {
19 std::vector<TypePtr> types;
20 for (const auto& sub_type : tuple_type->containedTypes()) {
21 TORCH_INTERNAL_ASSERT(s_iter != s_iter_end);
22 types.emplace_back(
23 inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete));
24 }
25 return TupleType::create(types);
26 } else if (auto list_type = input_type->cast<ListType>()) {
27 const TypePtr& sub_type = list_type->getElementType();
28 auto elem_type =
29 inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
30 return ListType::create(elem_type);
31 } else if (auto tensor_type = input_type->cast<TensorType>()) {
32 auto type = getTensorType(s_iter->toTensor(), complete);
33 s_iter++;
34 return type;
35 } else if (auto optional_type = input_type->cast<OptionalType>()) {
36 const TypePtr& sub_type = optional_type->getElementType();
37 auto elem_type =
38 inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
39 return OptionalType::create(elem_type);
40 } else {
41 // Primitive type, keep as is.
42 s_iter++;
43 return input_type;
44 }
45 }
46
setInputTensorTypes(Graph & g,const Stack & stack,bool complete,const std::vector<int> & param_count_list)47 void setInputTensorTypes(
48 Graph& g,
49 const Stack& stack,
50 bool complete,
51 const std::vector<int>& param_count_list) {
52 at::ArrayRef<Value*> input_values = g.inputs();
53 auto s_iter = stack.begin();
54 size_t list_idx = 0;
55 if (!param_count_list.empty()) {
56 TORCH_INTERNAL_ASSERT(
57 input_values.size() == param_count_list.size(),
58 " input_values:",
59 input_values.size(),
60 " vs param_count_list:",
61 param_count_list.size());
62 }
63 for (auto v : input_values) {
64 // Leave packed param types alone. This is needed for downstream passes
65 // (like alias analysis) to work properly. This will be unpacked later
66 // in unpackQuantizedWeights.
67 if (auto named_type = v->type()->cast<c10::NamedType>()) {
68 if (auto qualname = named_type->name()) {
69 if (getCustomClass(qualname->qualifiedName())) {
70 if (param_count_list.empty()) {
71 AT_ASSERT(s_iter != stack.end());
72 s_iter++;
73 } else {
74 if (param_count_list[list_idx] > 0) {
75 AT_ASSERT(s_iter != stack.end());
76 }
77 s_iter += param_count_list[list_idx];
78 }
79 list_idx++;
80 continue;
81 }
82 }
83 }
84 auto type =
85 inferShapeAndTypeForInput(v->type(), s_iter, stack.end(), complete);
86 v->setType(type);
87 list_idx++;
88 }
89 }
90
91 } // namespace torch::jit
92