xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/graph_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/graph_utils.h>
2 
3 namespace torch::jit {
4 
getTensorType(const at::Tensor & t,bool complete)5 TypePtr getTensorType(const at::Tensor& t, bool complete) {
6   auto r = TensorType::create(t);
7   if (!complete) {
8     r = r->dimensionedOnly();
9   }
10   return r;
11 }
12 
inferShapeAndTypeForInput(TypePtr input_type,Stack::const_iterator & s_iter,const Stack::const_iterator & s_iter_end,bool complete)13 TypePtr inferShapeAndTypeForInput(
14     TypePtr input_type,
15     Stack::const_iterator& s_iter,
16     const Stack::const_iterator& s_iter_end,
17     bool complete) {
18   if (auto tuple_type = input_type->cast<TupleType>()) {
19     std::vector<TypePtr> types;
20     for (const auto& sub_type : tuple_type->containedTypes()) {
21       TORCH_INTERNAL_ASSERT(s_iter != s_iter_end);
22       types.emplace_back(
23           inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete));
24     }
25     return TupleType::create(types);
26   } else if (auto list_type = input_type->cast<ListType>()) {
27     const TypePtr& sub_type = list_type->getElementType();
28     auto elem_type =
29         inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
30     return ListType::create(elem_type);
31   } else if (auto tensor_type = input_type->cast<TensorType>()) {
32     auto type = getTensorType(s_iter->toTensor(), complete);
33     s_iter++;
34     return type;
35   } else if (auto optional_type = input_type->cast<OptionalType>()) {
36     const TypePtr& sub_type = optional_type->getElementType();
37     auto elem_type =
38         inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
39     return OptionalType::create(elem_type);
40   } else {
41     // Primitive type, keep as is.
42     s_iter++;
43     return input_type;
44   }
45 }
46 
setInputTensorTypes(Graph & g,const Stack & stack,bool complete,const std::vector<int> & param_count_list)47 void setInputTensorTypes(
48     Graph& g,
49     const Stack& stack,
50     bool complete,
51     const std::vector<int>& param_count_list) {
52   at::ArrayRef<Value*> input_values = g.inputs();
53   auto s_iter = stack.begin();
54   size_t list_idx = 0;
55   if (!param_count_list.empty()) {
56     TORCH_INTERNAL_ASSERT(
57         input_values.size() == param_count_list.size(),
58         " input_values:",
59         input_values.size(),
60         " vs param_count_list:",
61         param_count_list.size());
62   }
63   for (auto v : input_values) {
64     // Leave packed param types alone. This is needed for downstream passes
65     // (like alias analysis) to work properly. This will be unpacked later
66     // in unpackQuantizedWeights.
67     if (auto named_type = v->type()->cast<c10::NamedType>()) {
68       if (auto qualname = named_type->name()) {
69         if (getCustomClass(qualname->qualifiedName())) {
70           if (param_count_list.empty()) {
71             AT_ASSERT(s_iter != stack.end());
72             s_iter++;
73           } else {
74             if (param_count_list[list_idx] > 0) {
75               AT_ASSERT(s_iter != stack.end());
76             }
77             s_iter += param_count_list[list_idx];
78           }
79           list_idx++;
80           continue;
81         }
82       }
83     }
84     auto type =
85         inferShapeAndTypeForInput(v->type(), s_iter, stack.end(), complete);
86     v->setType(type);
87     list_idx++;
88   }
89 }
90 
91 } // namespace torch::jit
92