xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_mobile_type_parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <test/cpp/jit/test_utils.h>
3 
4 #include <ATen/core/jit_type.h>
5 #include <torch/csrc/jit/mobile/type_parser.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 // Parse Success cases
TEST(MobileTypeParserTest,Int)11 TEST(MobileTypeParserTest, Int) {
12   std::string int_ps("int");
13   auto int_tp = c10::parseType(int_ps);
14   EXPECT_EQ(*int_tp, *IntType::get());
15 }
16 
TEST(MobileTypeParserTest,NestedContainersAnnotationStr)17 TEST(MobileTypeParserTest, NestedContainersAnnotationStr) {
18   std::string tuple_ps(
19       "Tuple[str, Optional[float], Dict[str, List[Tensor]], int]");
20   auto tuple_tp = c10::parseType(tuple_ps);
21   std::vector<TypePtr> args = {
22       c10::StringType::get(),
23       c10::OptionalType::create(c10::FloatType::get()),
24       c10::DictType::create(
25           StringType::get(), ListType::create(TensorType::get())),
26       IntType::get()};
27   auto tp = TupleType::create(std::move(args));
28   ASSERT_EQ(*tuple_tp, *tp);
29 }
30 
TEST(MobileTypeParserTest,TorchBindClass)31 TEST(MobileTypeParserTest, TorchBindClass) {
32   std::string tuple_ps("__torch__.torch.classes.rnn.CellParamsBase");
33   auto tuple_tp = c10::parseType(tuple_ps);
34   std::string tuple_tps = tuple_tp->annotation_str();
35   ASSERT_EQ(tuple_ps, tuple_tps);
36 }
37 
TEST(MobileTypeParserTest,ListOfTorchBindClass)38 TEST(MobileTypeParserTest, ListOfTorchBindClass) {
39   std::string tuple_ps("List[__torch__.torch.classes.rnn.CellParamsBase]");
40   auto tuple_tp = c10::parseType(tuple_ps);
41   EXPECT_TRUE(tuple_tp->isSubtypeOf(AnyListType::get()));
42   EXPECT_EQ(
43       "__torch__.torch.classes.rnn.CellParamsBase",
44       tuple_tp->containedType(0)->annotation_str());
45 }
46 
TEST(MobileTypeParserTest,NestedContainersAnnotationStrWithSpaces)47 TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) {
48   std::string tuple_space_ps(
49       "Tuple[  str, Optional[float], Dict[str, List[Tensor ]]  , int]");
50   auto tuple_space_tp = c10::parseType(tuple_space_ps);
51   // tuple_space_tps should not have weird white spaces
52   std::string tuple_space_tps = tuple_space_tp->annotation_str();
53   ASSERT_TRUE(tuple_space_tps.find("[ ") == std::string::npos);
54   ASSERT_TRUE(tuple_space_tps.find(" ]") == std::string::npos);
55   ASSERT_TRUE(tuple_space_tps.find(" ,") == std::string::npos);
56 }
57 
TEST(MobileTypeParserTest,NamedTuple)58 TEST(MobileTypeParserTest, NamedTuple) {
59   std::string named_tuple_ps(
60       "__torch__.base_models.preproc_types.PreprocOutputType["
61       "    NamedTuple, ["
62       "        [float_features, Tensor],"
63       "        [id_list_features, List[Tensor]],"
64       "        [label,  Tensor],"
65       "        [weight, Tensor],"
66       "        [prod_prediction, Tuple[Tensor, Tensor]],"
67       "        [id_score_list_features, List[Tensor]],"
68       "        [embedding_features, List[Tensor]],"
69       "        [teacher_label, Tensor]"
70       "        ]"
71       "    ]");
72 
73   c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps);
74   std::string named_tuple_annotation_str = named_tuple_tp->annotation_str();
75   ASSERT_EQ(
76       named_tuple_annotation_str,
77       "__torch__.base_models.preproc_types.PreprocOutputType");
78 }
79 
TEST(MobileTypeParserTest,DictNestedNamedTupleTypeList)80 TEST(MobileTypeParserTest, DictNestedNamedTupleTypeList) {
81   std::string type_str_1(
82       "__torch__.base_models.preproc_types.PreprocOutputType["
83       "  NamedTuple, ["
84       "      [float_features, Tensor],"
85       "      [id_list_features, List[Tensor]],"
86       "      [label,  Tensor],"
87       "      [weight, Tensor],"
88       "      [prod_prediction, Tuple[Tensor, Tensor]],"
89       "      [id_score_list_features, List[Tensor]],"
90       "      [embedding_features, List[Tensor]],"
91       "      [teacher_label, Tensor]"
92       "      ]");
93   std::string type_str_2(
94       "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]");
95   std::vector<std::string> type_strs = {type_str_1, type_str_2};
96   std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs);
97   EXPECT_EQ(*named_tuple_tps[1]->containedType(0), *c10::StringType::get());
98   EXPECT_EQ(*named_tuple_tps[0], *named_tuple_tps[1]->containedType(1));
99 }
100 
TEST(MobileTypeParserTest,NamedTupleNestedNamedTupleTypeList)101 TEST(MobileTypeParserTest, NamedTupleNestedNamedTupleTypeList) {
102   std::string type_str_1(
103       " __torch__.ccc.xxx ["
104       "    NamedTuple, ["
105       "      [field_name_c_1, Tensor],"
106       "      [field_name_c_2, Tuple[Tensor, Tensor]]"
107       "    ]"
108       "]");
109   std::string type_str_2(
110       "__torch__.bbb.xxx ["
111       "    NamedTuple,["
112       "        [field_name_b, __torch__.ccc.xxx]]"
113       "    ]"
114       "]");
115 
116   std::string type_str_3(
117       "__torch__.aaa.xxx["
118       "    NamedTuple, ["
119       "        [field_name_a, __torch__.bbb.xxx]"
120       "    ]"
121       "]");
122 
123   std::vector<std::string> type_strs = {type_str_1, type_str_2, type_str_3};
124   std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs);
125   std::string named_tuple_annotation_str = named_tuple_tps[2]->annotation_str();
126   ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx");
127 }
128 
TEST(MobileTypeParserTest,NamedTupleNestedNamedTuple)129 TEST(MobileTypeParserTest, NamedTupleNestedNamedTuple) {
130   std::string named_tuple_ps(
131       "__torch__.aaa.xxx["
132       "    NamedTuple, ["
133       "        [field_name_a, __torch__.bbb.xxx ["
134       "            NamedTuple, ["
135       "                [field_name_b, __torch__.ccc.xxx ["
136       "                    NamedTuple, ["
137       "                      [field_name_c_1, Tensor],"
138       "                      [field_name_c_2, Tuple[Tensor, Tensor]]"
139       "                    ]"
140       "                ]"
141       "                ]"
142       "            ]"
143       "        ]"
144       "        ]"
145       "    ]   "
146       "]");
147 
148   c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps);
149   std::string named_tuple_annotation_str = named_tuple_tp->str();
150   ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx");
151 }
152 
153 // Parse throw cases
TEST(MobileTypeParserTest,Empty)154 TEST(MobileTypeParserTest, Empty) {
155   std::string empty_ps("");
156   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
157   ASSERT_ANY_THROW(c10::parseType(empty_ps));
158 }
159 
TEST(MobileTypeParserTest,TypoRaises)160 TEST(MobileTypeParserTest, TypoRaises) {
161   std::string typo_token("List[tensor]");
162   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
163   ASSERT_ANY_THROW(c10::parseType(typo_token));
164 }
165 
TEST(MobileTypeParserTest,MismatchBracketRaises)166 TEST(MobileTypeParserTest, MismatchBracketRaises) {
167   std::string mismatch1("List[Tensor");
168   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
169   ASSERT_ANY_THROW(c10::parseType(mismatch1));
170 }
171 
TEST(MobileTypeParserTest,MismatchBracketRaises2)172 TEST(MobileTypeParserTest, MismatchBracketRaises2) {
173   std::string mismatch2("List[[Tensor]");
174   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
175   ASSERT_ANY_THROW(c10::parseType(mismatch2));
176 }
177 
TEST(MobileTypeParserTest,DictWithoutValueRaises)178 TEST(MobileTypeParserTest, DictWithoutValueRaises) {
179   std::string mismatch3("Dict[Tensor]");
180   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
181   ASSERT_ANY_THROW(c10::parseType(mismatch3));
182 }
183 
TEST(MobileTypeParserTest,ListArgCountMismatchRaises)184 TEST(MobileTypeParserTest, ListArgCountMismatchRaises) {
185   // arg count mismatch
186   std::string mismatch4("List[int, str]");
187   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
188   ASSERT_ANY_THROW(c10::parseType(mismatch4));
189 }
190 
TEST(MobileTypeParserTest,DictArgCountMismatchRaises)191 TEST(MobileTypeParserTest, DictArgCountMismatchRaises) {
192   std::string trailing_commm("Dict[str,]");
193   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
194   ASSERT_ANY_THROW(c10::parseType(trailing_commm));
195 }
196 
TEST(MobileTypeParserTest,ValidTypeWithExtraStuffRaises)197 TEST(MobileTypeParserTest, ValidTypeWithExtraStuffRaises) {
198   std::string extra_stuff("int int");
199   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
200   ASSERT_ANY_THROW(c10::parseType(extra_stuff));
201 }
202 
TEST(MobileTypeParserTest,NonIdentifierRaises)203 TEST(MobileTypeParserTest, NonIdentifierRaises) {
204   std::string non_id("(int)");
205   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
206   ASSERT_ANY_THROW(c10::parseType(non_id));
207 }
208 
TEST(MobileTypeParserTest,DictNestedNamedTupleTypeListRaises)209 TEST(MobileTypeParserTest, DictNestedNamedTupleTypeListRaises) {
210   std::string type_str_1(
211       "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]");
212   std::string type_str_2(
213       "__torch__.base_models.preproc_types.PreprocOutputType["
214       "  NamedTuple, ["
215       "      [float_features, Tensor],"
216       "      [id_list_features, List[Tensor]],"
217       "      [label,  Tensor],"
218       "      [weight, Tensor],"
219       "      [prod_prediction, Tuple[Tensor, Tensor]],"
220       "      [id_score_list_features, List[Tensor]],"
221       "      [embedding_features, List[Tensor]],"
222       "      [teacher_label, Tensor]"
223       "      ]");
224   std::vector<std::string> type_strs = {type_str_1, type_str_2};
225   std::string error_message =
226       R"(Can't find definition for the type: __torch__.base_models.preproc_types.PreprocOutputType)";
227   ASSERT_THROWS_WITH_MESSAGE(c10::parseType(type_strs), error_message);
228 }
229 
230 } // namespace jit
231 } // namespace torch
232