xref: /aosp_15_r20/external/pytorch/test/test_function_schema.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import torch
4from torch._C import parse_schema
5from torch.testing._internal.common_utils import run_tests, TestCase
6
7
8class TestFunctionSchema(TestCase):
9    def test_serialize_and_deserialize(self):
10        schemas = torch._C._jit_get_all_schemas()
11        # so far we have around 1700 registered schemas
12        self.assertGreater(len(schemas), 1000)
13        for schema in schemas:
14            parsed_schema = parse_schema(str(schema))
15            self.assertEqual(parsed_schema, schema)
16            self.assertTrue(parsed_schema.is_backward_compatible_with(schema))
17
18    def test_out_schema(self):
19        schema_with_out = parse_schema(
20            "any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
21        )
22        self.assertTrue(schema_with_out.arguments[-1].is_out)
23        schema_without_out = parse_schema(
24            "any.not_out(Tensor self, Tensor b) -> Tensor"
25        )
26        self.assertFalse(schema_without_out.arguments[-1].is_out)
27
28    def test_hash_schema(self):
29        schema1 = parse_schema("any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
30        schema2 = parse_schema("any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
31        self.assertEqual(hash(schema1), hash(schema2))
32
33        schema3 = parse_schema(
34            "any.not_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
35        )
36        self.assertNotEqual(hash(schema2), hash(schema3))
37
38        schema4 = parse_schema(
39            "foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)"
40        )
41        self.assertNotEqual(hash(schema2), hash(schema4))
42
43        # schemas with different default value, or different kw-only arg, should have different hash
44        default_val_schema0 = parse_schema("foo(Tensor self, int a = 2) -> Tensor(a!)")
45        default_val_schema1 = parse_schema("foo(Tensor self, int a = 3) -> Tensor(a!)")
46        default_val_schema2 = parse_schema(
47            "foo(Tensor self, *, int a = 2) -> Tensor(a!)"
48        )
49        self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema1))
50        self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema2))
51
52        # schema with different alias annotation should have different hash
53        alias_schema = parse_schema("foo(Tensor(a!) self, int a = 2) -> Tensor(a!)")
54        self.assertNotEqual(hash(default_val_schema0), hash(alias_schema))
55        alias_schema2 = parse_schema("foo(Tensor(b!) self, int a = 2) -> Tensor(a!)")
56        self.assertNotEqual(hash(alias_schema), hash(alias_schema2))
57
58        # schema with different alias infos
59        alias_schema3 = parse_schema(
60            "foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)"
61        )
62        alias_schema4 = parse_schema(
63            "foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(b!)"
64        )
65        alias_schema5 = parse_schema(
66            "foo(Tensor self, *, int a, int b=1, Tensor(b!) out, Tensor(a!) b) -> Tensor(a!)"
67        )
68        self.assertNotEqual(hash(alias_schema3), hash(alias_schema4))
69        self.assertNotEqual(hash(alias_schema3), hash(alias_schema5))
70
71    def test_backward_compatible_structure(self):
72        old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
73        # BC: A new schema without changes.
74        new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
75        self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
76        # No-BC: A new schema with different name.
77        new_schema = parse_schema("any_.over(Tensor self, *, Tensor b) -> Tensor")
78        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
79        # No-BC: A new schema with different overload name.
80        new_schema = parse_schema("any.other(Tensor self, *, Tensor b) -> Tensor")
81        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
82        # No-BC: A new schema that adds vararg.
83        new_schema = parse_schema("any.over(Tensor self, *, Tensor b, ...) -> Tensor")
84        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
85        # No-BC: A new schema with different number of outputs.
86        new_schema = parse_schema(
87            "any.over(Tensor self, *, Tensor b) -> (Tensor, Tensor)"
88        )
89        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
90
91    def test_backward_compatible_outputs(self):
92        old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
93        # No-BC: A new schema with output becoming of optional type.
94        new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor?")
95        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
96        # BC: (the opposite case) An schema where the output is not of optional type anymore.
97        self.assertTrue(old_schema.is_backward_compatible_with(new_schema))
98        # No-BC: A new schema with a different output type.
99        new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> int")
100        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
101        # No-BC: A new schema with a different output type.
102        new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor out")
103        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
104
105    def test_backward_compatible_arguments(self):
106        old_schema = parse_schema("any(Tensor self, *, Tensor b, int c) -> Tensor")
107        # No-BC: A new schema with less arguments.
108        new_schema = parse_schema("any(Tensor self, *, Tensor b) -> Tensor")
109        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
110        # No-BC: A new schema with more arguments, appended, but no default value.
111        new_schema = parse_schema(
112            "any(Tensor self, *, Tensor b, int c, int d) -> Tensor"
113        )
114        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
115        # BC: A new schema with more arguments, appended, that have a default value.
116        new_schema = parse_schema(
117            "any(Tensor self, *, Tensor b, int c, int d=1) -> Tensor"
118        )
119        self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
120        # No-BC: A new schema with more arguments, not-appended, that have a default value.
121        new_schema = parse_schema(
122            "any(Tensor self, int d=1, *, Tensor b, int c) -> Tensor"
123        )
124        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
125        # BC: A new schema where old kwargs becomes positional.
126        new_schema = parse_schema("any(Tensor self, Tensor b, *, int c) -> Tensor")
127        self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
128        # BC: (the opposite case) A new schema where an old positional argument becomes kwarg.
129        self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
130        # BC: A new schema where all old kwargs become positional.
131        new_schema = parse_schema("any(Tensor self, Tensor b, int c) -> Tensor")
132        self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
133        # BC: (the opposite case) A new schema where all old positional arguments become kwarg.
134        self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
135        # No-BC: A new schema where old kwargs appear in different order.
136        new_schema = parse_schema("any(Tensor self, *, int c, Tensor b) -> Tensor")
137        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
138        # BC: A new schema where argument becomes of type optional.
139        new_schema = parse_schema("any(Tensor self, *, Tensor b, int? c) -> Tensor")
140        self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
141        # BC: A new schema where argument gains a default value.
142        new_schema = parse_schema("any(Tensor self, *, Tensor b, int c=1) -> Tensor")
143        self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
144        # No-BC: A new schema where argument is "renamed".
145        new_schema = parse_schema(
146            "any(Tensor self, *, Tensor b, int renamed) -> Tensor"
147        )
148        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
149        # No-BC: A new schema where argument type changes to an incompatible type.
150        new_schema = parse_schema("any(Tensor self, *, Tensor b, int[] c) -> Tensor")
151        self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
152
153    def test_backward_compatible_with_smart_serialization(self):
154        # cases where out arg is provided
155        old_schema = parse_schema(
156            "foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)"
157        )
158        new_schema_same_out = parse_schema(
159            "foo(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)"
160        )
161        new_schema_wrong_default = parse_schema(
162            "foo(Tensor self, *, int b=1, int a, Tensor(a!) out) -> Tensor(a!)"
163        )
164        new_schema_more_out = parse_schema(
165            "foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)"
166        )
167        new_schema_wrong_pos = parse_schema(
168            "foo(Tensor self, *, int a, int b=1, Tensor(b!) b, Tensor(a!) out) -> Tensor(a!)"
169        )
170        self.assertTrue(new_schema_same_out.is_backward_compatible_with(old_schema))
171        self.assertTrue(new_schema_more_out.is_backward_compatible_with(old_schema))
172        self.assertFalse(
173            new_schema_wrong_default.is_backward_compatible_with(old_schema)
174        )
175        self.assertFalse(new_schema_wrong_pos.is_backward_compatible_with(old_schema))
176
177        # cases where out arg is not provided
178        old_schema_without_arg = parse_schema("foo(Tensor self, int a, int b=1) -> int")
179        new_schema_without_arg = parse_schema(
180            "foo(Tensor self, int a, int b=1, int c=2) -> int"
181        )
182        new_schema_without_arg_multiple_default = parse_schema(
183            "foo(Tensor self, int a, int b=1, int c=2, int d=3) -> int"
184        )
185        new_schema_without_arg_wrong_pos = parse_schema(
186            "foo(Tensor self, int a, int c=2, int b=1) -> int"
187        )
188        self.assertTrue(
189            new_schema_without_arg.is_backward_compatible_with(old_schema_without_arg)
190        )
191        self.assertTrue(
192            new_schema_without_arg_multiple_default.is_backward_compatible_with(
193                old_schema_without_arg
194            )
195        )
196        self.assertFalse(
197            new_schema_without_arg_wrong_pos.is_backward_compatible_with(
198                old_schema_without_arg
199            )
200        )
201
202    def test_string_optional_parameter_default_value(self):
203        schema_a = parse_schema('example::op(str? order="NCHW") -> (Tensor)')
204        schema_b = parse_schema(str(schema_a))
205        self.assertEqual(schema_a, schema_b)
206
207    def test_forward_compatible_arguments_without_out(self):
208        old_schema = parse_schema("any(Tensor self, int a, int b=1) -> Tensor")
209        # deleting default arg is FC compatible
210        new_schema = parse_schema("any(Tensor self, int a) -> Tensor")
211        is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
212        self.assertTrue(is_fc)
213        # adding default arg is FC compatible
214        new_schema = parse_schema("any(Tensor self, int a, int b=1, int c=1) -> Tensor")
215        is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
216        self.assertTrue(is_fc)
217        # adding default arg with container type is NOT FC compatible
218        new_schema = parse_schema(
219            "any(Tensor self, int a, int b=1, int[2] c=1) -> Tensor"
220        )
221        is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
222        self.assertFalse(is_fc)
223        self.assertEqual(
224            reason,
225            "Function schema is not forward compatible since the new argument"
226            " 'c' of type int[] has a container type as its default value.",
227        )
228        # updating the default value of a default arg is NOT FC compatible
229        new_schema = parse_schema("any(Tensor self, int a, int b=4) -> Tensor")
230        is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
231        self.assertFalse(is_fc)
232        self.assertEqual(
233            reason, "'b' is not forward compatible with the older version of the schema"
234        )
235        # updating the arg name of a default arg is NOT FC compatible
236        new_schema = parse_schema("any(Tensor self, int a, int c=1) -> Tensor")
237        is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
238        self.assertFalse(is_fc)
239        self.assertEqual(
240            reason, "'c' is not forward compatible with the older version of the schema"
241        )
242        # not adding default arg in the end is NOT FC compatible
243        new_schema = parse_schema("any(Tensor self, int a, int c=1, int b=1) -> Tensor")
244        is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
245        self.assertFalse(is_fc)
246        self.assertEqual(
247            reason, "'c' is not forward compatible with the older version of the schema"
248        )
249        # making default arg into positional arg is NOT FC compatible
250        new_schema = parse_schema("any(Tensor self, int a, int b) -> Tensor")
251        is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
252        self.assertFalse(is_fc)
253        self.assertEqual(
254            reason, "'b' is not forward compatible with the older version of the schema"
255        )
256        # making positional arg into default arg is NOT FC compatible
257        new_schema = parse_schema("any(Tensor self, int a=1, int b=1) -> Tensor")
258        is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
259        self.assertFalse(is_fc)
260        self.assertEqual(
261            reason, "'a' is not forward compatible with the older version of the schema"
262        )
263
264    def test_forward_compatible_arguments_real_use_case(self):
265        # this change introduced forward incompatibility in the past
266        old_slice_schema = parse_schema(
267            "slice(Tensor(a) self, int dim=0, int start=0, int end=0, int step=1) -> Tensor(a)"
268        )
269        new_slice_schema = parse_schema(
270            "slice(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)"
271        )
272        is_fc, reason = new_slice_schema.check_forward_compatible_with(old_slice_schema)
273        self.assertFalse(is_fc)
274        self.assertEqual(
275            reason,
276            "'start' is not forward compatible with the older version of the schema",
277        )
278
279    def test_forward_compatible_arguments_with_out(self):
280        old_schema = parse_schema(
281            "any(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)"
282        )
283        new_schema = parse_schema(
284            "any(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)"
285        )
286        is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
287        self.assertTrue(is_fc)
288        new_schema = parse_schema(
289            "any(Tensor self, *, int a, int b=1, int c=1, Tensor(a!) out) -> Tensor(a!)"
290        )
291        is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
292        self.assertTrue(is_fc)
293        new_schema = parse_schema(
294            "any(Tensor self, *, int a, Tensor(d!) d, int b=1, Tensor(a!) out) -> Tensor(a!)"
295        )
296        is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
297        self.assertFalse(is_fc)
298        self.assertEqual(
299            reason, "Function schema should have the same number of out arguments"
300        )
301
302    def test_schema_error(self):
303        with self.assertRaisesRegex(
304            RuntimeError, r"schemas with vararg \(...\) can't have default value args"
305        ):
306            schema = parse_schema("any.foo(int arg1, int arg2=0, ...)")
307
308    def test_tensor_list_alias_annotation_properly_parsed(self):
309        schema_str = "foo(Tensor self, *, Tensor(a!)[] out) -> ()"
310        schema = parse_schema(schema_str)
311        self.assertTrue(schema.arguments[-1].alias_info.is_write)
312        self.assertEqual(str(schema), schema_str)
313
314    def test_tensor_option_arguments_properly_parsed(self):
315        schema_str = (
316            "_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
317            "bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"
318        )
319        schema = parse_schema(schema_str)
320        # fake type of MemoryFormat? is int?
321        self.assertEqual(schema.arguments[-1].type.str(), "int?")
322        # fake type of Layout? is int?
323        self.assertEqual(schema.arguments[2].type.str(), "int?")
324        # fake type of Device? is Device?
325        self.assertEqual(schema.arguments[3].type.str(), "Device?")
326        # print real types in FunctionSchema
327        self.assertEqual(str(schema), schema_str)
328
329    def test_sym_int_argument_properly_parsed(self):
330        schema_str = "sym_size.int(Tensor self, int dim) -> SymInt"
331        schema = parse_schema(schema_str)
332        # fake type of SymInt is int
333        self.assertEqual(schema.returns[-1].type.str(), "int")
334        # real type of SymInt is SymInt
335        self.assertEqual(schema.returns[-1].real_type.str(), "SymInt")
336        # print real types in FunctionSchema
337        self.assertEqual(str(schema), schema_str)
338
339
340if __name__ == "__main__":
341    run_tests()
342