1#!/usr/bin/env python3 2# Owner(s): ["oncall: mobile"] 3# mypy: allow-untyped-defs 4 5import io 6import textwrap 7from typing import Dict, List, Optional 8 9import torch 10import torch.utils.bundled_inputs 11from torch.testing._internal.common_utils import run_tests, TestCase 12 13 14def model_size(sm): 15 buffer = io.BytesIO() 16 torch.jit.save(sm, buffer) 17 return len(buffer.getvalue()) 18 19 20def save_and_load(sm): 21 buffer = io.BytesIO() 22 torch.jit.save(sm, buffer) 23 buffer.seek(0) 24 return torch.jit.load(buffer) 25 26 27class TestBundledInputs(TestCase): 28 def test_single_tensors(self): 29 class SingleTensorModel(torch.nn.Module): 30 def forward(self, arg): 31 return arg 32 33 sm = torch.jit.script(SingleTensorModel()) 34 original_size = model_size(sm) 35 get_expr: List[str] = [] 36 samples = [ 37 # Tensor with small numel and small storage. 38 (torch.tensor([1]),), 39 # Tensor with large numel and small storage. 40 (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],), 41 # Tensor with small numel and large storage. 42 (torch.tensor(range(1 << 16))[-8:],), 43 # Large zero tensor. 44 (torch.zeros(1 << 16),), 45 # Large channels-last ones tensor. 46 (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),), 47 # Special encoding of random tensor. 48 (torch.utils.bundled_inputs.bundle_randn(1 << 16),), 49 # Quantized uniform tensor. 50 (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),), 51 ] 52 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 53 sm, samples, get_expr 54 ) 55 # print(get_expr[0]) 56 # print(sm._generate_bundled_inputs.code) 57 58 # Make sure the model only grew a little bit, 59 # despite having nominally large bundled inputs. 60 augmented_size = model_size(sm) 61 self.assertLess(augmented_size, original_size + (1 << 12)) 62 63 loaded = save_and_load(sm) 64 inflated = loaded.get_all_bundled_inputs() 65 self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) 66 self.assertEqual(len(inflated), len(samples)) 67 self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) 68 69 for idx, inp in enumerate(inflated): 70 self.assertIsInstance(inp, tuple) 71 self.assertEqual(len(inp), 1) 72 self.assertIsInstance(inp[0], torch.Tensor) 73 if idx != 5: 74 # Strides might be important for benchmarking. 75 self.assertEqual(inp[0].stride(), samples[idx][0].stride()) 76 self.assertEqual(inp[0], samples[idx][0], exact_dtype=True) 77 78 # This tensor is random, but with 100,000 trials, 79 # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105). 80 self.assertEqual(inflated[5][0].shape, (1 << 16,)) 81 self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0) 82 self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0) 83 84 def test_large_tensor_with_inflation(self): 85 class SingleTensorModel(torch.nn.Module): 86 def forward(self, arg): 87 return arg 88 89 sm = torch.jit.script(SingleTensorModel()) 90 sample_tensor = torch.randn(1 << 16) 91 # We can store tensors with custom inflation functions regardless 92 # of size, even if inflation is just the identity. 93 sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor) 94 torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)]) 95 96 loaded = save_and_load(sm) 97 inflated = loaded.get_all_bundled_inputs() 98 self.assertEqual(len(inflated), 1) 99 100 self.assertEqual(inflated[0][0], sample_tensor) 101 102 def test_rejected_tensors(self): 103 def check_tensor(sample): 104 # Need to define the class in this scope to get a fresh type for each run. 105 class SingleTensorModel(torch.nn.Module): 106 def forward(self, arg): 107 return arg 108 109 sm = torch.jit.script(SingleTensorModel()) 110 with self.assertRaisesRegex(Exception, "Bundled input argument"): 111 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 112 sm, [(sample,)] 113 ) 114 115 # Plain old big tensor. 116 check_tensor(torch.randn(1 << 16)) 117 # This tensor has two elements, but they're far apart in memory. 118 # We currently cannot represent this compactly while preserving 119 # the strides. 120 small_sparse = torch.randn(2, 1 << 16)[:, 0:1] 121 self.assertEqual(small_sparse.numel(), 2) 122 check_tensor(small_sparse) 123 124 def test_non_tensors(self): 125 class StringAndIntModel(torch.nn.Module): 126 def forward(self, fmt: str, num: int): 127 return fmt.format(num) 128 129 sm = torch.jit.script(StringAndIntModel()) 130 samples = [ 131 ("first {}", 1), 132 ("second {}", 2), 133 ] 134 torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples) 135 136 loaded = save_and_load(sm) 137 inflated = loaded.get_all_bundled_inputs() 138 self.assertEqual(inflated, samples) 139 self.assertTrue(loaded(*inflated[0]) == "first 1") 140 141 def test_multiple_methods_with_inputs(self): 142 class MultipleMethodModel(torch.nn.Module): 143 def forward(self, arg): 144 return arg 145 146 @torch.jit.export 147 def foo(self, arg): 148 return arg 149 150 mm = torch.jit.script(MultipleMethodModel()) 151 samples = [ 152 # Tensor with small numel and small storage. 153 (torch.tensor([1]),), 154 # Tensor with large numel and small storage. 155 (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],), 156 # Tensor with small numel and large storage. 157 (torch.tensor(range(1 << 16))[-8:],), 158 # Large zero tensor. 159 (torch.zeros(1 << 16),), 160 # Large channels-last ones tensor. 161 (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),), 162 ] 163 info = [ 164 "Tensor with small numel and small storage.", 165 "Tensor with large numel and small storage.", 166 "Tensor with small numel and large storage.", 167 "Large zero tensor.", 168 "Large channels-last ones tensor.", 169 "Special encoding of random tensor.", 170 ] 171 torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( 172 mm, 173 inputs={mm.forward: samples, mm.foo: samples}, 174 info={mm.forward: info, mm.foo: info}, 175 ) 176 loaded = save_and_load(mm) 177 inflated = loaded.get_all_bundled_inputs() 178 179 # Make sure these functions are all consistent. 180 self.assertEqual(inflated, samples) 181 self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward()) 182 self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo()) 183 184 # Check running and size helpers 185 self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) 186 self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) 187 188 # Check helper that work on all functions 189 all_info = loaded.get_bundled_inputs_functions_and_info() 190 self.assertEqual(set(all_info.keys()), {"forward", "foo"}) 191 self.assertEqual( 192 all_info["forward"]["get_inputs_function_name"], 193 ["get_all_bundled_inputs_for_forward"], 194 ) 195 self.assertEqual( 196 all_info["foo"]["get_inputs_function_name"], 197 ["get_all_bundled_inputs_for_foo"], 198 ) 199 self.assertEqual(all_info["forward"]["info"], info) 200 self.assertEqual(all_info["foo"]["info"], info) 201 202 # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs 203 for func_name in all_info.keys(): 204 input_func_name = all_info[func_name]["get_inputs_function_name"][0] 205 func_to_run = getattr(loaded, input_func_name) 206 self.assertEqual(func_to_run(), samples) 207 208 def test_multiple_methods_with_inputs_both_defined_failure(self): 209 class MultipleMethodModel(torch.nn.Module): 210 def forward(self, arg): 211 return arg 212 213 @torch.jit.export 214 def foo(self, arg): 215 return arg 216 217 samples = [(torch.tensor([1]),)] 218 219 # inputs defined 2 ways so should fail 220 with self.assertRaises(Exception): 221 mm = torch.jit.script(MultipleMethodModel()) 222 definition = textwrap.dedent( 223 """ 224 def _generate_bundled_inputs_for_forward(self): 225 return [] 226 """ 227 ) 228 mm.define(definition) 229 torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( 230 mm, 231 inputs={ 232 mm.forward: samples, 233 mm.foo: samples, 234 }, 235 ) 236 237 def test_multiple_methods_with_inputs_neither_defined_failure(self): 238 class MultipleMethodModel(torch.nn.Module): 239 def forward(self, arg): 240 return arg 241 242 @torch.jit.export 243 def foo(self, arg): 244 return arg 245 246 samples = [(torch.tensor([1]),)] 247 248 # inputs not defined so should fail 249 with self.assertRaises(Exception): 250 mm = torch.jit.script(MultipleMethodModel()) 251 mm._generate_bundled_inputs_for_forward() 252 torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( 253 mm, 254 inputs={ 255 mm.forward: None, 256 mm.foo: samples, 257 }, 258 ) 259 260 def test_bad_inputs(self): 261 class SingleTensorModel(torch.nn.Module): 262 def forward(self, arg): 263 return arg 264 265 # Non list for input list 266 with self.assertRaises(TypeError): 267 m = torch.jit.script(SingleTensorModel()) 268 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 269 m, 270 inputs="foo", # type: ignore[arg-type] 271 ) 272 273 # List of non tuples. Most common error using the api. 274 with self.assertRaises(TypeError): 275 m = torch.jit.script(SingleTensorModel()) 276 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 277 m, 278 inputs=[torch.ones(1, 2)], # type: ignore[list-item] 279 ) 280 281 def test_double_augment_fail(self): 282 class SingleTensorModel(torch.nn.Module): 283 def forward(self, arg): 284 return arg 285 286 m = torch.jit.script(SingleTensorModel()) 287 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 288 m, inputs=[(torch.ones(1),)] 289 ) 290 with self.assertRaisesRegex( 291 Exception, "Models can only be augmented with bundled inputs once." 292 ): 293 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 294 m, inputs=[(torch.ones(1),)] 295 ) 296 297 def test_double_augment_non_mutator(self): 298 class SingleTensorModel(torch.nn.Module): 299 def forward(self, arg): 300 return arg 301 302 m = torch.jit.script(SingleTensorModel()) 303 bundled_model = torch.utils.bundled_inputs.bundle_inputs( 304 m, inputs=[(torch.ones(1),)] 305 ) 306 with self.assertRaises(AttributeError): 307 m.get_all_bundled_inputs() 308 self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)]) 309 self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1)) 310 311 def test_double_augment_success(self): 312 class SingleTensorModel(torch.nn.Module): 313 def forward(self, arg): 314 return arg 315 316 m = torch.jit.script(SingleTensorModel()) 317 bundled_model = torch.utils.bundled_inputs.bundle_inputs( 318 m, inputs={m.forward: [(torch.ones(1),)]} 319 ) 320 self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)]) 321 322 bundled_model2 = torch.utils.bundled_inputs.bundle_inputs( 323 bundled_model, inputs=[(torch.ones(2),)] 324 ) 325 self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)]) 326 327 def test_dict_args(self): 328 class MyModel(torch.nn.Module): 329 def forward( 330 self, 331 arg1: Optional[Dict[str, torch.Tensor]], 332 arg2: Optional[List[torch.Tensor]], 333 arg3: torch.Tensor, 334 ): 335 if arg1 is None: 336 return arg3 337 elif arg2 is None: 338 return arg1["a"] + arg1["b"] 339 else: 340 return arg1["a"] + arg1["b"] + arg2[0] 341 342 small_sample = dict( 343 a=torch.zeros([10, 20]), 344 b=torch.zeros([1, 1]), 345 c=torch.zeros([10, 20]), 346 ) 347 small_list = [torch.zeros([10, 20])] 348 349 big_sample = dict( 350 a=torch.zeros([1 << 5, 1 << 8, 1 << 10]), 351 b=torch.zeros([1 << 5, 1 << 8, 1 << 10]), 352 c=torch.zeros([1 << 5, 1 << 8, 1 << 10]), 353 ) 354 big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])] 355 356 def condensed(t): 357 ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape) 358 assert ret.storage().size() == 1 359 # ret.storage()[0] = 0 360 return ret 361 362 def bundle_optional_dict_of_randn(template): 363 return torch.utils.bundled_inputs.InflatableArg( 364 value=( 365 None 366 if template is None 367 else {k: condensed(v) for (k, v) in template.items()} 368 ), 369 fmt="{}", 370 fmt_fn=""" 371 def {}(self, value: Optional[Dict[str, Tensor]]): 372 if value is None: 373 return None 374 output = {{}} 375 for k, v in value.items(): 376 output[k] = torch.randn_like(v) 377 return output 378 """, 379 ) 380 381 def bundle_optional_list_of_randn(template): 382 return torch.utils.bundled_inputs.InflatableArg( 383 value=(None if template is None else [condensed(v) for v in template]), 384 fmt="{}", 385 fmt_fn=""" 386 def {}(self, value: Optional[List[Tensor]]): 387 if value is None: 388 return None 389 output = [] 390 for v in value: 391 output.append(torch.randn_like(v)) 392 return output 393 """, 394 ) 395 396 out: List[str] = [] 397 sm = torch.jit.script(MyModel()) 398 original_size = model_size(sm) 399 small_inputs = ( 400 bundle_optional_dict_of_randn(small_sample), 401 bundle_optional_list_of_randn(small_list), 402 torch.zeros([3, 4]), 403 ) 404 big_inputs = ( 405 bundle_optional_dict_of_randn(big_sample), 406 bundle_optional_list_of_randn(big_list), 407 torch.zeros([1 << 5, 1 << 8, 1 << 10]), 408 ) 409 410 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 411 sm, 412 [big_inputs, small_inputs], 413 _receive_inflate_expr=out, 414 ) 415 augmented_size = model_size(sm) 416 # assert the size has not increased more than 8KB 417 self.assertLess(augmented_size, original_size + (1 << 13)) 418 419 loaded = save_and_load(sm) 420 inflated = loaded.get_all_bundled_inputs() 421 self.assertEqual(len(inflated[0]), len(small_inputs)) 422 423 methods, _ = ( 424 torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods( 425 loaded 426 ) 427 ) 428 429 # One Function (forward) 430 # two bundled inputs (big_inputs and small_inputs) 431 # two args which have InflatableArg with fmt_fn 432 # 1 * 2 * 2 = 4 433 self.assertEqual( 434 sum(method.startswith("_inflate_helper") for method in methods), 4 435 ) 436 437 438if __name__ == "__main__": 439 run_tests() 440