1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 3*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport io 6*da0073e9SAndroid Build Coastguard Workerimport textwrap 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.utils.bundled_inputs 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdef model_size(sm): 15*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 16*da0073e9SAndroid Build Coastguard Worker torch.jit.save(sm, buffer) 17*da0073e9SAndroid Build Coastguard Worker return len(buffer.getvalue()) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerdef save_and_load(sm): 21*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 22*da0073e9SAndroid Build Coastguard Worker torch.jit.save(sm, buffer) 23*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 24*da0073e9SAndroid Build Coastguard Worker return torch.jit.load(buffer) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerclass TestBundledInputs(TestCase): 28*da0073e9SAndroid Build Coastguard Worker def test_single_tensors(self): 29*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 30*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 31*da0073e9SAndroid Build Coastguard Worker return arg 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(SingleTensorModel()) 34*da0073e9SAndroid Build Coastguard Worker original_size = model_size(sm) 35*da0073e9SAndroid Build Coastguard Worker get_expr: List[str] = [] 36*da0073e9SAndroid Build Coastguard Worker samples = [ 37*da0073e9SAndroid Build Coastguard Worker # Tensor with small numel and small storage. 38*da0073e9SAndroid Build Coastguard Worker (torch.tensor([1]),), 39*da0073e9SAndroid Build Coastguard Worker # Tensor with large numel and small storage. 40*da0073e9SAndroid Build Coastguard Worker (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],), 41*da0073e9SAndroid Build Coastguard Worker # Tensor with small numel and large storage. 42*da0073e9SAndroid Build Coastguard Worker (torch.tensor(range(1 << 16))[-8:],), 43*da0073e9SAndroid Build Coastguard Worker # Large zero tensor. 44*da0073e9SAndroid Build Coastguard Worker (torch.zeros(1 << 16),), 45*da0073e9SAndroid Build Coastguard Worker # Large channels-last ones tensor. 46*da0073e9SAndroid Build Coastguard Worker (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),), 47*da0073e9SAndroid Build Coastguard Worker # Special encoding of random tensor. 48*da0073e9SAndroid Build Coastguard Worker (torch.utils.bundled_inputs.bundle_randn(1 << 16),), 49*da0073e9SAndroid Build Coastguard Worker # Quantized uniform tensor. 50*da0073e9SAndroid Build Coastguard Worker (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),), 51*da0073e9SAndroid Build Coastguard Worker ] 52*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 53*da0073e9SAndroid Build Coastguard Worker sm, samples, get_expr 54*da0073e9SAndroid Build Coastguard Worker ) 55*da0073e9SAndroid Build Coastguard Worker # print(get_expr[0]) 56*da0073e9SAndroid Build Coastguard Worker # print(sm._generate_bundled_inputs.code) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker # Make sure the model only grew a little bit, 59*da0073e9SAndroid Build Coastguard Worker # despite having nominally large bundled inputs. 60*da0073e9SAndroid Build Coastguard Worker augmented_size = model_size(sm) 61*da0073e9SAndroid Build Coastguard Worker self.assertLess(augmented_size, original_size + (1 << 12)) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker loaded = save_and_load(sm) 64*da0073e9SAndroid Build Coastguard Worker inflated = loaded.get_all_bundled_inputs() 65*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) 66*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(inflated), len(samples)) 67*da0073e9SAndroid Build Coastguard Worker self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker for idx, inp in enumerate(inflated): 70*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(inp, tuple) 71*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(inp), 1) 72*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(inp[0], torch.Tensor) 73*da0073e9SAndroid Build Coastguard Worker if idx != 5: 74*da0073e9SAndroid Build Coastguard Worker # Strides might be important for benchmarking. 75*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp[0].stride(), samples[idx][0].stride()) 76*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp[0], samples[idx][0], exact_dtype=True) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker # This tensor is random, but with 100,000 trials, 79*da0073e9SAndroid Build Coastguard Worker # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105). 80*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated[5][0].shape, (1 << 16,)) 81*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def test_large_tensor_with_inflation(self): 85*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 86*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 87*da0073e9SAndroid Build Coastguard Worker return arg 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(SingleTensorModel()) 90*da0073e9SAndroid Build Coastguard Worker sample_tensor = torch.randn(1 << 16) 91*da0073e9SAndroid Build Coastguard Worker # We can store tensors with custom inflation functions regardless 92*da0073e9SAndroid Build Coastguard Worker # of size, even if inflation is just the identity. 93*da0073e9SAndroid Build Coastguard Worker sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor) 94*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)]) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker loaded = save_and_load(sm) 97*da0073e9SAndroid Build Coastguard Worker inflated = loaded.get_all_bundled_inputs() 98*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(inflated), 1) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated[0][0], sample_tensor) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def test_rejected_tensors(self): 103*da0073e9SAndroid Build Coastguard Worker def check_tensor(sample): 104*da0073e9SAndroid Build Coastguard Worker # Need to define the class in this scope to get a fresh type for each run. 105*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 106*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 107*da0073e9SAndroid Build Coastguard Worker return arg 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(SingleTensorModel()) 110*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Bundled input argument"): 111*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 112*da0073e9SAndroid Build Coastguard Worker sm, [(sample,)] 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker # Plain old big tensor. 116*da0073e9SAndroid Build Coastguard Worker check_tensor(torch.randn(1 << 16)) 117*da0073e9SAndroid Build Coastguard Worker # This tensor has two elements, but they're far apart in memory. 118*da0073e9SAndroid Build Coastguard Worker # We currently cannot represent this compactly while preserving 119*da0073e9SAndroid Build Coastguard Worker # the strides. 120*da0073e9SAndroid Build Coastguard Worker small_sparse = torch.randn(2, 1 << 16)[:, 0:1] 121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(small_sparse.numel(), 2) 122*da0073e9SAndroid Build Coastguard Worker check_tensor(small_sparse) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_non_tensors(self): 125*da0073e9SAndroid Build Coastguard Worker class StringAndIntModel(torch.nn.Module): 126*da0073e9SAndroid Build Coastguard Worker def forward(self, fmt: str, num: int): 127*da0073e9SAndroid Build Coastguard Worker return fmt.format(num) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(StringAndIntModel()) 130*da0073e9SAndroid Build Coastguard Worker samples = [ 131*da0073e9SAndroid Build Coastguard Worker ("first {}", 1), 132*da0073e9SAndroid Build Coastguard Worker ("second {}", 2), 133*da0073e9SAndroid Build Coastguard Worker ] 134*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker loaded = save_and_load(sm) 137*da0073e9SAndroid Build Coastguard Worker inflated = loaded.get_all_bundled_inputs() 138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated, samples) 139*da0073e9SAndroid Build Coastguard Worker self.assertTrue(loaded(*inflated[0]) == "first 1") 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker def test_multiple_methods_with_inputs(self): 142*da0073e9SAndroid Build Coastguard Worker class MultipleMethodModel(torch.nn.Module): 143*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 144*da0073e9SAndroid Build Coastguard Worker return arg 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 147*da0073e9SAndroid Build Coastguard Worker def foo(self, arg): 148*da0073e9SAndroid Build Coastguard Worker return arg 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker mm = torch.jit.script(MultipleMethodModel()) 151*da0073e9SAndroid Build Coastguard Worker samples = [ 152*da0073e9SAndroid Build Coastguard Worker # Tensor with small numel and small storage. 153*da0073e9SAndroid Build Coastguard Worker (torch.tensor([1]),), 154*da0073e9SAndroid Build Coastguard Worker # Tensor with large numel and small storage. 155*da0073e9SAndroid Build Coastguard Worker (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],), 156*da0073e9SAndroid Build Coastguard Worker # Tensor with small numel and large storage. 157*da0073e9SAndroid Build Coastguard Worker (torch.tensor(range(1 << 16))[-8:],), 158*da0073e9SAndroid Build Coastguard Worker # Large zero tensor. 159*da0073e9SAndroid Build Coastguard Worker (torch.zeros(1 << 16),), 160*da0073e9SAndroid Build Coastguard Worker # Large channels-last ones tensor. 161*da0073e9SAndroid Build Coastguard Worker (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),), 162*da0073e9SAndroid Build Coastguard Worker ] 163*da0073e9SAndroid Build Coastguard Worker info = [ 164*da0073e9SAndroid Build Coastguard Worker "Tensor with small numel and small storage.", 165*da0073e9SAndroid Build Coastguard Worker "Tensor with large numel and small storage.", 166*da0073e9SAndroid Build Coastguard Worker "Tensor with small numel and large storage.", 167*da0073e9SAndroid Build Coastguard Worker "Large zero tensor.", 168*da0073e9SAndroid Build Coastguard Worker "Large channels-last ones tensor.", 169*da0073e9SAndroid Build Coastguard Worker "Special encoding of random tensor.", 170*da0073e9SAndroid Build Coastguard Worker ] 171*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( 172*da0073e9SAndroid Build Coastguard Worker mm, 173*da0073e9SAndroid Build Coastguard Worker inputs={mm.forward: samples, mm.foo: samples}, 174*da0073e9SAndroid Build Coastguard Worker info={mm.forward: info, mm.foo: info}, 175*da0073e9SAndroid Build Coastguard Worker ) 176*da0073e9SAndroid Build Coastguard Worker loaded = save_and_load(mm) 177*da0073e9SAndroid Build Coastguard Worker inflated = loaded.get_all_bundled_inputs() 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker # Make sure these functions are all consistent. 180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated, samples) 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward()) 182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo()) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker # Check running and size helpers 185*da0073e9SAndroid Build Coastguard Worker self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) 186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker # Check helper that work on all functions 189*da0073e9SAndroid Build Coastguard Worker all_info = loaded.get_bundled_inputs_functions_and_info() 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(all_info.keys()), {"forward", "foo"}) 191*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 192*da0073e9SAndroid Build Coastguard Worker all_info["forward"]["get_inputs_function_name"], 193*da0073e9SAndroid Build Coastguard Worker ["get_all_bundled_inputs_for_forward"], 194*da0073e9SAndroid Build Coastguard Worker ) 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 196*da0073e9SAndroid Build Coastguard Worker all_info["foo"]["get_inputs_function_name"], 197*da0073e9SAndroid Build Coastguard Worker ["get_all_bundled_inputs_for_foo"], 198*da0073e9SAndroid Build Coastguard Worker ) 199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_info["forward"]["info"], info) 200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_info["foo"]["info"], info) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs 203*da0073e9SAndroid Build Coastguard Worker for func_name in all_info.keys(): 204*da0073e9SAndroid Build Coastguard Worker input_func_name = all_info[func_name]["get_inputs_function_name"][0] 205*da0073e9SAndroid Build Coastguard Worker func_to_run = getattr(loaded, input_func_name) 206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func_to_run(), samples) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker def test_multiple_methods_with_inputs_both_defined_failure(self): 209*da0073e9SAndroid Build Coastguard Worker class MultipleMethodModel(torch.nn.Module): 210*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 211*da0073e9SAndroid Build Coastguard Worker return arg 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 214*da0073e9SAndroid Build Coastguard Worker def foo(self, arg): 215*da0073e9SAndroid Build Coastguard Worker return arg 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker samples = [(torch.tensor([1]),)] 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker # inputs defined 2 ways so should fail 220*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 221*da0073e9SAndroid Build Coastguard Worker mm = torch.jit.script(MultipleMethodModel()) 222*da0073e9SAndroid Build Coastguard Worker definition = textwrap.dedent( 223*da0073e9SAndroid Build Coastguard Worker """ 224*da0073e9SAndroid Build Coastguard Worker def _generate_bundled_inputs_for_forward(self): 225*da0073e9SAndroid Build Coastguard Worker return [] 226*da0073e9SAndroid Build Coastguard Worker """ 227*da0073e9SAndroid Build Coastguard Worker ) 228*da0073e9SAndroid Build Coastguard Worker mm.define(definition) 229*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( 230*da0073e9SAndroid Build Coastguard Worker mm, 231*da0073e9SAndroid Build Coastguard Worker inputs={ 232*da0073e9SAndroid Build Coastguard Worker mm.forward: samples, 233*da0073e9SAndroid Build Coastguard Worker mm.foo: samples, 234*da0073e9SAndroid Build Coastguard Worker }, 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def test_multiple_methods_with_inputs_neither_defined_failure(self): 238*da0073e9SAndroid Build Coastguard Worker class MultipleMethodModel(torch.nn.Module): 239*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 240*da0073e9SAndroid Build Coastguard Worker return arg 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 243*da0073e9SAndroid Build Coastguard Worker def foo(self, arg): 244*da0073e9SAndroid Build Coastguard Worker return arg 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker samples = [(torch.tensor([1]),)] 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker # inputs not defined so should fail 249*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 250*da0073e9SAndroid Build Coastguard Worker mm = torch.jit.script(MultipleMethodModel()) 251*da0073e9SAndroid Build Coastguard Worker mm._generate_bundled_inputs_for_forward() 252*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( 253*da0073e9SAndroid Build Coastguard Worker mm, 254*da0073e9SAndroid Build Coastguard Worker inputs={ 255*da0073e9SAndroid Build Coastguard Worker mm.forward: None, 256*da0073e9SAndroid Build Coastguard Worker mm.foo: samples, 257*da0073e9SAndroid Build Coastguard Worker }, 258*da0073e9SAndroid Build Coastguard Worker ) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def test_bad_inputs(self): 261*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 262*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 263*da0073e9SAndroid Build Coastguard Worker return arg 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker # Non list for input list 266*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 267*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(SingleTensorModel()) 268*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 269*da0073e9SAndroid Build Coastguard Worker m, 270*da0073e9SAndroid Build Coastguard Worker inputs="foo", # type: ignore[arg-type] 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker # List of non tuples. Most common error using the api. 274*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 275*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(SingleTensorModel()) 276*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 277*da0073e9SAndroid Build Coastguard Worker m, 278*da0073e9SAndroid Build Coastguard Worker inputs=[torch.ones(1, 2)], # type: ignore[list-item] 279*da0073e9SAndroid Build Coastguard Worker ) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker def test_double_augment_fail(self): 282*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 283*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 284*da0073e9SAndroid Build Coastguard Worker return arg 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(SingleTensorModel()) 287*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 288*da0073e9SAndroid Build Coastguard Worker m, inputs=[(torch.ones(1),)] 289*da0073e9SAndroid Build Coastguard Worker ) 290*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 291*da0073e9SAndroid Build Coastguard Worker Exception, "Models can only be augmented with bundled inputs once." 292*da0073e9SAndroid Build Coastguard Worker ): 293*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 294*da0073e9SAndroid Build Coastguard Worker m, inputs=[(torch.ones(1),)] 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def test_double_augment_non_mutator(self): 298*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 299*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 300*da0073e9SAndroid Build Coastguard Worker return arg 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(SingleTensorModel()) 303*da0073e9SAndroid Build Coastguard Worker bundled_model = torch.utils.bundled_inputs.bundle_inputs( 304*da0073e9SAndroid Build Coastguard Worker m, inputs=[(torch.ones(1),)] 305*da0073e9SAndroid Build Coastguard Worker ) 306*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AttributeError): 307*da0073e9SAndroid Build Coastguard Worker m.get_all_bundled_inputs() 308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)]) 309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1)) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker def test_double_augment_success(self): 312*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 313*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 314*da0073e9SAndroid Build Coastguard Worker return arg 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(SingleTensorModel()) 317*da0073e9SAndroid Build Coastguard Worker bundled_model = torch.utils.bundled_inputs.bundle_inputs( 318*da0073e9SAndroid Build Coastguard Worker m, inputs={m.forward: [(torch.ones(1),)]} 319*da0073e9SAndroid Build Coastguard Worker ) 320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)]) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker bundled_model2 = torch.utils.bundled_inputs.bundle_inputs( 323*da0073e9SAndroid Build Coastguard Worker bundled_model, inputs=[(torch.ones(2),)] 324*da0073e9SAndroid Build Coastguard Worker ) 325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)]) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker def test_dict_args(self): 328*da0073e9SAndroid Build Coastguard Worker class MyModel(torch.nn.Module): 329*da0073e9SAndroid Build Coastguard Worker def forward( 330*da0073e9SAndroid Build Coastguard Worker self, 331*da0073e9SAndroid Build Coastguard Worker arg1: Optional[Dict[str, torch.Tensor]], 332*da0073e9SAndroid Build Coastguard Worker arg2: Optional[List[torch.Tensor]], 333*da0073e9SAndroid Build Coastguard Worker arg3: torch.Tensor, 334*da0073e9SAndroid Build Coastguard Worker ): 335*da0073e9SAndroid Build Coastguard Worker if arg1 is None: 336*da0073e9SAndroid Build Coastguard Worker return arg3 337*da0073e9SAndroid Build Coastguard Worker elif arg2 is None: 338*da0073e9SAndroid Build Coastguard Worker return arg1["a"] + arg1["b"] 339*da0073e9SAndroid Build Coastguard Worker else: 340*da0073e9SAndroid Build Coastguard Worker return arg1["a"] + arg1["b"] + arg2[0] 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker small_sample = dict( 343*da0073e9SAndroid Build Coastguard Worker a=torch.zeros([10, 20]), 344*da0073e9SAndroid Build Coastguard Worker b=torch.zeros([1, 1]), 345*da0073e9SAndroid Build Coastguard Worker c=torch.zeros([10, 20]), 346*da0073e9SAndroid Build Coastguard Worker ) 347*da0073e9SAndroid Build Coastguard Worker small_list = [torch.zeros([10, 20])] 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker big_sample = dict( 350*da0073e9SAndroid Build Coastguard Worker a=torch.zeros([1 << 5, 1 << 8, 1 << 10]), 351*da0073e9SAndroid Build Coastguard Worker b=torch.zeros([1 << 5, 1 << 8, 1 << 10]), 352*da0073e9SAndroid Build Coastguard Worker c=torch.zeros([1 << 5, 1 << 8, 1 << 10]), 353*da0073e9SAndroid Build Coastguard Worker ) 354*da0073e9SAndroid Build Coastguard Worker big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])] 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker def condensed(t): 357*da0073e9SAndroid Build Coastguard Worker ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape) 358*da0073e9SAndroid Build Coastguard Worker assert ret.storage().size() == 1 359*da0073e9SAndroid Build Coastguard Worker # ret.storage()[0] = 0 360*da0073e9SAndroid Build Coastguard Worker return ret 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker def bundle_optional_dict_of_randn(template): 363*da0073e9SAndroid Build Coastguard Worker return torch.utils.bundled_inputs.InflatableArg( 364*da0073e9SAndroid Build Coastguard Worker value=( 365*da0073e9SAndroid Build Coastguard Worker None 366*da0073e9SAndroid Build Coastguard Worker if template is None 367*da0073e9SAndroid Build Coastguard Worker else {k: condensed(v) for (k, v) in template.items()} 368*da0073e9SAndroid Build Coastguard Worker ), 369*da0073e9SAndroid Build Coastguard Worker fmt="{}", 370*da0073e9SAndroid Build Coastguard Worker fmt_fn=""" 371*da0073e9SAndroid Build Coastguard Worker def {}(self, value: Optional[Dict[str, Tensor]]): 372*da0073e9SAndroid Build Coastguard Worker if value is None: 373*da0073e9SAndroid Build Coastguard Worker return None 374*da0073e9SAndroid Build Coastguard Worker output = {{}} 375*da0073e9SAndroid Build Coastguard Worker for k, v in value.items(): 376*da0073e9SAndroid Build Coastguard Worker output[k] = torch.randn_like(v) 377*da0073e9SAndroid Build Coastguard Worker return output 378*da0073e9SAndroid Build Coastguard Worker """, 379*da0073e9SAndroid Build Coastguard Worker ) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker def bundle_optional_list_of_randn(template): 382*da0073e9SAndroid Build Coastguard Worker return torch.utils.bundled_inputs.InflatableArg( 383*da0073e9SAndroid Build Coastguard Worker value=(None if template is None else [condensed(v) for v in template]), 384*da0073e9SAndroid Build Coastguard Worker fmt="{}", 385*da0073e9SAndroid Build Coastguard Worker fmt_fn=""" 386*da0073e9SAndroid Build Coastguard Worker def {}(self, value: Optional[List[Tensor]]): 387*da0073e9SAndroid Build Coastguard Worker if value is None: 388*da0073e9SAndroid Build Coastguard Worker return None 389*da0073e9SAndroid Build Coastguard Worker output = [] 390*da0073e9SAndroid Build Coastguard Worker for v in value: 391*da0073e9SAndroid Build Coastguard Worker output.append(torch.randn_like(v)) 392*da0073e9SAndroid Build Coastguard Worker return output 393*da0073e9SAndroid Build Coastguard Worker """, 394*da0073e9SAndroid Build Coastguard Worker ) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker out: List[str] = [] 397*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(MyModel()) 398*da0073e9SAndroid Build Coastguard Worker original_size = model_size(sm) 399*da0073e9SAndroid Build Coastguard Worker small_inputs = ( 400*da0073e9SAndroid Build Coastguard Worker bundle_optional_dict_of_randn(small_sample), 401*da0073e9SAndroid Build Coastguard Worker bundle_optional_list_of_randn(small_list), 402*da0073e9SAndroid Build Coastguard Worker torch.zeros([3, 4]), 403*da0073e9SAndroid Build Coastguard Worker ) 404*da0073e9SAndroid Build Coastguard Worker big_inputs = ( 405*da0073e9SAndroid Build Coastguard Worker bundle_optional_dict_of_randn(big_sample), 406*da0073e9SAndroid Build Coastguard Worker bundle_optional_list_of_randn(big_list), 407*da0073e9SAndroid Build Coastguard Worker torch.zeros([1 << 5, 1 << 8, 1 << 10]), 408*da0073e9SAndroid Build Coastguard Worker ) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 411*da0073e9SAndroid Build Coastguard Worker sm, 412*da0073e9SAndroid Build Coastguard Worker [big_inputs, small_inputs], 413*da0073e9SAndroid Build Coastguard Worker _receive_inflate_expr=out, 414*da0073e9SAndroid Build Coastguard Worker ) 415*da0073e9SAndroid Build Coastguard Worker augmented_size = model_size(sm) 416*da0073e9SAndroid Build Coastguard Worker # assert the size has not increased more than 8KB 417*da0073e9SAndroid Build Coastguard Worker self.assertLess(augmented_size, original_size + (1 << 13)) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker loaded = save_and_load(sm) 420*da0073e9SAndroid Build Coastguard Worker inflated = loaded.get_all_bundled_inputs() 421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(inflated[0]), len(small_inputs)) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker methods, _ = ( 424*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods( 425*da0073e9SAndroid Build Coastguard Worker loaded 426*da0073e9SAndroid Build Coastguard Worker ) 427*da0073e9SAndroid Build Coastguard Worker ) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker # One Function (forward) 430*da0073e9SAndroid Build Coastguard Worker # two bundled inputs (big_inputs and small_inputs) 431*da0073e9SAndroid Build Coastguard Worker # two args which have InflatableArg with fmt_fn 432*da0073e9SAndroid Build Coastguard Worker # 1 * 2 * 2 = 4 433*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 434*da0073e9SAndroid Build Coastguard Worker sum(method.startswith("_inflate_helper") for method in methods), 4 435*da0073e9SAndroid Build Coastguard Worker ) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 439*da0073e9SAndroid Build Coastguard Worker run_tests() 440