1# Owner(s): ["module: unknown"] 2 3import torch 4from torch._C import parse_schema 5from torch.testing._internal.common_utils import run_tests, TestCase 6 7 8class TestFunctionSchema(TestCase): 9 def test_serialize_and_deserialize(self): 10 schemas = torch._C._jit_get_all_schemas() 11 # so far we have around 1700 registered schemas 12 self.assertGreater(len(schemas), 1000) 13 for schema in schemas: 14 parsed_schema = parse_schema(str(schema)) 15 self.assertEqual(parsed_schema, schema) 16 self.assertTrue(parsed_schema.is_backward_compatible_with(schema)) 17 18 def test_out_schema(self): 19 schema_with_out = parse_schema( 20 "any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" 21 ) 22 self.assertTrue(schema_with_out.arguments[-1].is_out) 23 schema_without_out = parse_schema( 24 "any.not_out(Tensor self, Tensor b) -> Tensor" 25 ) 26 self.assertFalse(schema_without_out.arguments[-1].is_out) 27 28 def test_hash_schema(self): 29 schema1 = parse_schema("any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") 30 schema2 = parse_schema("any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") 31 self.assertEqual(hash(schema1), hash(schema2)) 32 33 schema3 = parse_schema( 34 "any.not_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" 35 ) 36 self.assertNotEqual(hash(schema2), hash(schema3)) 37 38 schema4 = parse_schema( 39 "foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)" 40 ) 41 self.assertNotEqual(hash(schema2), hash(schema4)) 42 43 # schemas with different default value, or different kw-only arg, should have different hash 44 default_val_schema0 = parse_schema("foo(Tensor self, int a = 2) -> Tensor(a!)") 45 default_val_schema1 = parse_schema("foo(Tensor self, int a = 3) -> Tensor(a!)") 46 default_val_schema2 = parse_schema( 47 "foo(Tensor self, *, int a = 2) -> Tensor(a!)" 48 ) 49 self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema1)) 50 self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema2)) 51 52 # schema with different alias annotation should have different hash 53 alias_schema = parse_schema("foo(Tensor(a!) self, int a = 2) -> Tensor(a!)") 54 self.assertNotEqual(hash(default_val_schema0), hash(alias_schema)) 55 alias_schema2 = parse_schema("foo(Tensor(b!) self, int a = 2) -> Tensor(a!)") 56 self.assertNotEqual(hash(alias_schema), hash(alias_schema2)) 57 58 # schema with different alias infos 59 alias_schema3 = parse_schema( 60 "foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)" 61 ) 62 alias_schema4 = parse_schema( 63 "foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(b!)" 64 ) 65 alias_schema5 = parse_schema( 66 "foo(Tensor self, *, int a, int b=1, Tensor(b!) out, Tensor(a!) b) -> Tensor(a!)" 67 ) 68 self.assertNotEqual(hash(alias_schema3), hash(alias_schema4)) 69 self.assertNotEqual(hash(alias_schema3), hash(alias_schema5)) 70 71 def test_backward_compatible_structure(self): 72 old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor") 73 # BC: A new schema without changes. 74 new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor") 75 self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) 76 # No-BC: A new schema with different name. 77 new_schema = parse_schema("any_.over(Tensor self, *, Tensor b) -> Tensor") 78 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 79 # No-BC: A new schema with different overload name. 80 new_schema = parse_schema("any.other(Tensor self, *, Tensor b) -> Tensor") 81 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 82 # No-BC: A new schema that adds vararg. 83 new_schema = parse_schema("any.over(Tensor self, *, Tensor b, ...) -> Tensor") 84 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 85 # No-BC: A new schema with different number of outputs. 86 new_schema = parse_schema( 87 "any.over(Tensor self, *, Tensor b) -> (Tensor, Tensor)" 88 ) 89 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 90 91 def test_backward_compatible_outputs(self): 92 old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor") 93 # No-BC: A new schema with output becoming of optional type. 94 new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor?") 95 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 96 # BC: (the opposite case) An schema where the output is not of optional type anymore. 97 self.assertTrue(old_schema.is_backward_compatible_with(new_schema)) 98 # No-BC: A new schema with a different output type. 99 new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> int") 100 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 101 # No-BC: A new schema with a different output type. 102 new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor out") 103 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 104 105 def test_backward_compatible_arguments(self): 106 old_schema = parse_schema("any(Tensor self, *, Tensor b, int c) -> Tensor") 107 # No-BC: A new schema with less arguments. 108 new_schema = parse_schema("any(Tensor self, *, Tensor b) -> Tensor") 109 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 110 # No-BC: A new schema with more arguments, appended, but no default value. 111 new_schema = parse_schema( 112 "any(Tensor self, *, Tensor b, int c, int d) -> Tensor" 113 ) 114 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 115 # BC: A new schema with more arguments, appended, that have a default value. 116 new_schema = parse_schema( 117 "any(Tensor self, *, Tensor b, int c, int d=1) -> Tensor" 118 ) 119 self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) 120 # No-BC: A new schema with more arguments, not-appended, that have a default value. 121 new_schema = parse_schema( 122 "any(Tensor self, int d=1, *, Tensor b, int c) -> Tensor" 123 ) 124 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 125 # BC: A new schema where old kwargs becomes positional. 126 new_schema = parse_schema("any(Tensor self, Tensor b, *, int c) -> Tensor") 127 self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) 128 # BC: (the opposite case) A new schema where an old positional argument becomes kwarg. 129 self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) 130 # BC: A new schema where all old kwargs become positional. 131 new_schema = parse_schema("any(Tensor self, Tensor b, int c) -> Tensor") 132 self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) 133 # BC: (the opposite case) A new schema where all old positional arguments become kwarg. 134 self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) 135 # No-BC: A new schema where old kwargs appear in different order. 136 new_schema = parse_schema("any(Tensor self, *, int c, Tensor b) -> Tensor") 137 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 138 # BC: A new schema where argument becomes of type optional. 139 new_schema = parse_schema("any(Tensor self, *, Tensor b, int? c) -> Tensor") 140 self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) 141 # BC: A new schema where argument gains a default value. 142 new_schema = parse_schema("any(Tensor self, *, Tensor b, int c=1) -> Tensor") 143 self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) 144 # No-BC: A new schema where argument is "renamed". 145 new_schema = parse_schema( 146 "any(Tensor self, *, Tensor b, int renamed) -> Tensor" 147 ) 148 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 149 # No-BC: A new schema where argument type changes to an incompatible type. 150 new_schema = parse_schema("any(Tensor self, *, Tensor b, int[] c) -> Tensor") 151 self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) 152 153 def test_backward_compatible_with_smart_serialization(self): 154 # cases where out arg is provided 155 old_schema = parse_schema( 156 "foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)" 157 ) 158 new_schema_same_out = parse_schema( 159 "foo(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)" 160 ) 161 new_schema_wrong_default = parse_schema( 162 "foo(Tensor self, *, int b=1, int a, Tensor(a!) out) -> Tensor(a!)" 163 ) 164 new_schema_more_out = parse_schema( 165 "foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)" 166 ) 167 new_schema_wrong_pos = parse_schema( 168 "foo(Tensor self, *, int a, int b=1, Tensor(b!) b, Tensor(a!) out) -> Tensor(a!)" 169 ) 170 self.assertTrue(new_schema_same_out.is_backward_compatible_with(old_schema)) 171 self.assertTrue(new_schema_more_out.is_backward_compatible_with(old_schema)) 172 self.assertFalse( 173 new_schema_wrong_default.is_backward_compatible_with(old_schema) 174 ) 175 self.assertFalse(new_schema_wrong_pos.is_backward_compatible_with(old_schema)) 176 177 # cases where out arg is not provided 178 old_schema_without_arg = parse_schema("foo(Tensor self, int a, int b=1) -> int") 179 new_schema_without_arg = parse_schema( 180 "foo(Tensor self, int a, int b=1, int c=2) -> int" 181 ) 182 new_schema_without_arg_multiple_default = parse_schema( 183 "foo(Tensor self, int a, int b=1, int c=2, int d=3) -> int" 184 ) 185 new_schema_without_arg_wrong_pos = parse_schema( 186 "foo(Tensor self, int a, int c=2, int b=1) -> int" 187 ) 188 self.assertTrue( 189 new_schema_without_arg.is_backward_compatible_with(old_schema_without_arg) 190 ) 191 self.assertTrue( 192 new_schema_without_arg_multiple_default.is_backward_compatible_with( 193 old_schema_without_arg 194 ) 195 ) 196 self.assertFalse( 197 new_schema_without_arg_wrong_pos.is_backward_compatible_with( 198 old_schema_without_arg 199 ) 200 ) 201 202 def test_string_optional_parameter_default_value(self): 203 schema_a = parse_schema('example::op(str? order="NCHW") -> (Tensor)') 204 schema_b = parse_schema(str(schema_a)) 205 self.assertEqual(schema_a, schema_b) 206 207 def test_forward_compatible_arguments_without_out(self): 208 old_schema = parse_schema("any(Tensor self, int a, int b=1) -> Tensor") 209 # deleting default arg is FC compatible 210 new_schema = parse_schema("any(Tensor self, int a) -> Tensor") 211 is_fc, _ = new_schema.check_forward_compatible_with(old_schema) 212 self.assertTrue(is_fc) 213 # adding default arg is FC compatible 214 new_schema = parse_schema("any(Tensor self, int a, int b=1, int c=1) -> Tensor") 215 is_fc, _ = new_schema.check_forward_compatible_with(old_schema) 216 self.assertTrue(is_fc) 217 # adding default arg with container type is NOT FC compatible 218 new_schema = parse_schema( 219 "any(Tensor self, int a, int b=1, int[2] c=1) -> Tensor" 220 ) 221 is_fc, reason = new_schema.check_forward_compatible_with(old_schema) 222 self.assertFalse(is_fc) 223 self.assertEqual( 224 reason, 225 "Function schema is not forward compatible since the new argument" 226 " 'c' of type int[] has a container type as its default value.", 227 ) 228 # updating the default value of a default arg is NOT FC compatible 229 new_schema = parse_schema("any(Tensor self, int a, int b=4) -> Tensor") 230 is_fc, reason = new_schema.check_forward_compatible_with(old_schema) 231 self.assertFalse(is_fc) 232 self.assertEqual( 233 reason, "'b' is not forward compatible with the older version of the schema" 234 ) 235 # updating the arg name of a default arg is NOT FC compatible 236 new_schema = parse_schema("any(Tensor self, int a, int c=1) -> Tensor") 237 is_fc, reason = new_schema.check_forward_compatible_with(old_schema) 238 self.assertFalse(is_fc) 239 self.assertEqual( 240 reason, "'c' is not forward compatible with the older version of the schema" 241 ) 242 # not adding default arg in the end is NOT FC compatible 243 new_schema = parse_schema("any(Tensor self, int a, int c=1, int b=1) -> Tensor") 244 is_fc, reason = new_schema.check_forward_compatible_with(old_schema) 245 self.assertFalse(is_fc) 246 self.assertEqual( 247 reason, "'c' is not forward compatible with the older version of the schema" 248 ) 249 # making default arg into positional arg is NOT FC compatible 250 new_schema = parse_schema("any(Tensor self, int a, int b) -> Tensor") 251 is_fc, reason = new_schema.check_forward_compatible_with(old_schema) 252 self.assertFalse(is_fc) 253 self.assertEqual( 254 reason, "'b' is not forward compatible with the older version of the schema" 255 ) 256 # making positional arg into default arg is NOT FC compatible 257 new_schema = parse_schema("any(Tensor self, int a=1, int b=1) -> Tensor") 258 is_fc, reason = new_schema.check_forward_compatible_with(old_schema) 259 self.assertFalse(is_fc) 260 self.assertEqual( 261 reason, "'a' is not forward compatible with the older version of the schema" 262 ) 263 264 def test_forward_compatible_arguments_real_use_case(self): 265 # this change introduced forward incompatibility in the past 266 old_slice_schema = parse_schema( 267 "slice(Tensor(a) self, int dim=0, int start=0, int end=0, int step=1) -> Tensor(a)" 268 ) 269 new_slice_schema = parse_schema( 270 "slice(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)" 271 ) 272 is_fc, reason = new_slice_schema.check_forward_compatible_with(old_slice_schema) 273 self.assertFalse(is_fc) 274 self.assertEqual( 275 reason, 276 "'start' is not forward compatible with the older version of the schema", 277 ) 278 279 def test_forward_compatible_arguments_with_out(self): 280 old_schema = parse_schema( 281 "any(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)" 282 ) 283 new_schema = parse_schema( 284 "any(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)" 285 ) 286 is_fc, _ = new_schema.check_forward_compatible_with(old_schema) 287 self.assertTrue(is_fc) 288 new_schema = parse_schema( 289 "any(Tensor self, *, int a, int b=1, int c=1, Tensor(a!) out) -> Tensor(a!)" 290 ) 291 is_fc, _ = new_schema.check_forward_compatible_with(old_schema) 292 self.assertTrue(is_fc) 293 new_schema = parse_schema( 294 "any(Tensor self, *, int a, Tensor(d!) d, int b=1, Tensor(a!) out) -> Tensor(a!)" 295 ) 296 is_fc, reason = new_schema.check_forward_compatible_with(old_schema) 297 self.assertFalse(is_fc) 298 self.assertEqual( 299 reason, "Function schema should have the same number of out arguments" 300 ) 301 302 def test_schema_error(self): 303 with self.assertRaisesRegex( 304 RuntimeError, r"schemas with vararg \(...\) can't have default value args" 305 ): 306 schema = parse_schema("any.foo(int arg1, int arg2=0, ...)") 307 308 def test_tensor_list_alias_annotation_properly_parsed(self): 309 schema_str = "foo(Tensor self, *, Tensor(a!)[] out) -> ()" 310 schema = parse_schema(schema_str) 311 self.assertTrue(schema.arguments[-1].alias_info.is_write) 312 self.assertEqual(str(schema), schema_str) 313 314 def test_tensor_option_arguments_properly_parsed(self): 315 schema_str = ( 316 "_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, " 317 "bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor" 318 ) 319 schema = parse_schema(schema_str) 320 # fake type of MemoryFormat? is int? 321 self.assertEqual(schema.arguments[-1].type.str(), "int?") 322 # fake type of Layout? is int? 323 self.assertEqual(schema.arguments[2].type.str(), "int?") 324 # fake type of Device? is Device? 325 self.assertEqual(schema.arguments[3].type.str(), "Device?") 326 # print real types in FunctionSchema 327 self.assertEqual(str(schema), schema_str) 328 329 def test_sym_int_argument_properly_parsed(self): 330 schema_str = "sym_size.int(Tensor self, int dim) -> SymInt" 331 schema = parse_schema(schema_str) 332 # fake type of SymInt is int 333 self.assertEqual(schema.returns[-1].type.str(), "int") 334 # real type of SymInt is SymInt 335 self.assertEqual(schema.returns[-1].real_type.str(), "SymInt") 336 # print real types in FunctionSchema 337 self.assertEqual(str(schema), schema_str) 338 339 340if __name__ == "__main__": 341 run_tests() 342