xref: /aosp_15_r20/external/pytorch/test/test_bundled_inputs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"]
3*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport io
6*da0073e9SAndroid Build Coastguard Workerimport textwrap
7*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch.utils.bundled_inputs
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef model_size(sm):
15*da0073e9SAndroid Build Coastguard Worker    buffer = io.BytesIO()
16*da0073e9SAndroid Build Coastguard Worker    torch.jit.save(sm, buffer)
17*da0073e9SAndroid Build Coastguard Worker    return len(buffer.getvalue())
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerdef save_and_load(sm):
21*da0073e9SAndroid Build Coastguard Worker    buffer = io.BytesIO()
22*da0073e9SAndroid Build Coastguard Worker    torch.jit.save(sm, buffer)
23*da0073e9SAndroid Build Coastguard Worker    buffer.seek(0)
24*da0073e9SAndroid Build Coastguard Worker    return torch.jit.load(buffer)
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerclass TestBundledInputs(TestCase):
28*da0073e9SAndroid Build Coastguard Worker    def test_single_tensors(self):
29*da0073e9SAndroid Build Coastguard Worker        class SingleTensorModel(torch.nn.Module):
30*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
31*da0073e9SAndroid Build Coastguard Worker                return arg
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(SingleTensorModel())
34*da0073e9SAndroid Build Coastguard Worker        original_size = model_size(sm)
35*da0073e9SAndroid Build Coastguard Worker        get_expr: List[str] = []
36*da0073e9SAndroid Build Coastguard Worker        samples = [
37*da0073e9SAndroid Build Coastguard Worker            # Tensor with small numel and small storage.
38*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([1]),),
39*da0073e9SAndroid Build Coastguard Worker            # Tensor with large numel and small storage.
40*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
41*da0073e9SAndroid Build Coastguard Worker            # Tensor with small numel and large storage.
42*da0073e9SAndroid Build Coastguard Worker            (torch.tensor(range(1 << 16))[-8:],),
43*da0073e9SAndroid Build Coastguard Worker            # Large zero tensor.
44*da0073e9SAndroid Build Coastguard Worker            (torch.zeros(1 << 16),),
45*da0073e9SAndroid Build Coastguard Worker            # Large channels-last ones tensor.
46*da0073e9SAndroid Build Coastguard Worker            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
47*da0073e9SAndroid Build Coastguard Worker            # Special encoding of random tensor.
48*da0073e9SAndroid Build Coastguard Worker            (torch.utils.bundled_inputs.bundle_randn(1 << 16),),
49*da0073e9SAndroid Build Coastguard Worker            # Quantized uniform tensor.
50*da0073e9SAndroid Build Coastguard Worker            (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),),
51*da0073e9SAndroid Build Coastguard Worker        ]
52*da0073e9SAndroid Build Coastguard Worker        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
53*da0073e9SAndroid Build Coastguard Worker            sm, samples, get_expr
54*da0073e9SAndroid Build Coastguard Worker        )
55*da0073e9SAndroid Build Coastguard Worker        # print(get_expr[0])
56*da0073e9SAndroid Build Coastguard Worker        # print(sm._generate_bundled_inputs.code)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        # Make sure the model only grew a little bit,
59*da0073e9SAndroid Build Coastguard Worker        # despite having nominally large bundled inputs.
60*da0073e9SAndroid Build Coastguard Worker        augmented_size = model_size(sm)
61*da0073e9SAndroid Build Coastguard Worker        self.assertLess(augmented_size, original_size + (1 << 12))
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        loaded = save_and_load(sm)
64*da0073e9SAndroid Build Coastguard Worker        inflated = loaded.get_all_bundled_inputs()
65*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
66*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(inflated), len(samples))
67*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        for idx, inp in enumerate(inflated):
70*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(inp, tuple)
71*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(inp), 1)
72*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(inp[0], torch.Tensor)
73*da0073e9SAndroid Build Coastguard Worker            if idx != 5:
74*da0073e9SAndroid Build Coastguard Worker                # Strides might be important for benchmarking.
75*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inp[0].stride(), samples[idx][0].stride())
76*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        # This tensor is random, but with 100,000 trials,
79*da0073e9SAndroid Build Coastguard Worker        # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
80*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated[5][0].shape, (1 << 16,))
81*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
82*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    def test_large_tensor_with_inflation(self):
85*da0073e9SAndroid Build Coastguard Worker        class SingleTensorModel(torch.nn.Module):
86*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
87*da0073e9SAndroid Build Coastguard Worker                return arg
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(SingleTensorModel())
90*da0073e9SAndroid Build Coastguard Worker        sample_tensor = torch.randn(1 << 16)
91*da0073e9SAndroid Build Coastguard Worker        # We can store tensors with custom inflation functions regardless
92*da0073e9SAndroid Build Coastguard Worker        # of size, even if inflation is just the identity.
93*da0073e9SAndroid Build Coastguard Worker        sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
94*da0073e9SAndroid Build Coastguard Worker        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)])
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        loaded = save_and_load(sm)
97*da0073e9SAndroid Build Coastguard Worker        inflated = loaded.get_all_bundled_inputs()
98*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(inflated), 1)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated[0][0], sample_tensor)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def test_rejected_tensors(self):
103*da0073e9SAndroid Build Coastguard Worker        def check_tensor(sample):
104*da0073e9SAndroid Build Coastguard Worker            # Need to define the class in this scope to get a fresh type for each run.
105*da0073e9SAndroid Build Coastguard Worker            class SingleTensorModel(torch.nn.Module):
106*da0073e9SAndroid Build Coastguard Worker                def forward(self, arg):
107*da0073e9SAndroid Build Coastguard Worker                    return arg
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker            sm = torch.jit.script(SingleTensorModel())
110*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(Exception, "Bundled input argument"):
111*da0073e9SAndroid Build Coastguard Worker                torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
112*da0073e9SAndroid Build Coastguard Worker                    sm, [(sample,)]
113*da0073e9SAndroid Build Coastguard Worker                )
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        # Plain old big tensor.
116*da0073e9SAndroid Build Coastguard Worker        check_tensor(torch.randn(1 << 16))
117*da0073e9SAndroid Build Coastguard Worker        # This tensor has two elements, but they're far apart in memory.
118*da0073e9SAndroid Build Coastguard Worker        # We currently cannot represent this compactly while preserving
119*da0073e9SAndroid Build Coastguard Worker        # the strides.
120*da0073e9SAndroid Build Coastguard Worker        small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(small_sparse.numel(), 2)
122*da0073e9SAndroid Build Coastguard Worker        check_tensor(small_sparse)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_non_tensors(self):
125*da0073e9SAndroid Build Coastguard Worker        class StringAndIntModel(torch.nn.Module):
126*da0073e9SAndroid Build Coastguard Worker            def forward(self, fmt: str, num: int):
127*da0073e9SAndroid Build Coastguard Worker                return fmt.format(num)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(StringAndIntModel())
130*da0073e9SAndroid Build Coastguard Worker        samples = [
131*da0073e9SAndroid Build Coastguard Worker            ("first {}", 1),
132*da0073e9SAndroid Build Coastguard Worker            ("second {}", 2),
133*da0073e9SAndroid Build Coastguard Worker        ]
134*da0073e9SAndroid Build Coastguard Worker        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        loaded = save_and_load(sm)
137*da0073e9SAndroid Build Coastguard Worker        inflated = loaded.get_all_bundled_inputs()
138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated, samples)
139*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(loaded(*inflated[0]) == "first 1")
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    def test_multiple_methods_with_inputs(self):
142*da0073e9SAndroid Build Coastguard Worker        class MultipleMethodModel(torch.nn.Module):
143*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
144*da0073e9SAndroid Build Coastguard Worker                return arg
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
147*da0073e9SAndroid Build Coastguard Worker            def foo(self, arg):
148*da0073e9SAndroid Build Coastguard Worker                return arg
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        mm = torch.jit.script(MultipleMethodModel())
151*da0073e9SAndroid Build Coastguard Worker        samples = [
152*da0073e9SAndroid Build Coastguard Worker            # Tensor with small numel and small storage.
153*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([1]),),
154*da0073e9SAndroid Build Coastguard Worker            # Tensor with large numel and small storage.
155*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
156*da0073e9SAndroid Build Coastguard Worker            # Tensor with small numel and large storage.
157*da0073e9SAndroid Build Coastguard Worker            (torch.tensor(range(1 << 16))[-8:],),
158*da0073e9SAndroid Build Coastguard Worker            # Large zero tensor.
159*da0073e9SAndroid Build Coastguard Worker            (torch.zeros(1 << 16),),
160*da0073e9SAndroid Build Coastguard Worker            # Large channels-last ones tensor.
161*da0073e9SAndroid Build Coastguard Worker            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
162*da0073e9SAndroid Build Coastguard Worker        ]
163*da0073e9SAndroid Build Coastguard Worker        info = [
164*da0073e9SAndroid Build Coastguard Worker            "Tensor with small numel and small storage.",
165*da0073e9SAndroid Build Coastguard Worker            "Tensor with large numel and small storage.",
166*da0073e9SAndroid Build Coastguard Worker            "Tensor with small numel and large storage.",
167*da0073e9SAndroid Build Coastguard Worker            "Large zero tensor.",
168*da0073e9SAndroid Build Coastguard Worker            "Large channels-last ones tensor.",
169*da0073e9SAndroid Build Coastguard Worker            "Special encoding of random tensor.",
170*da0073e9SAndroid Build Coastguard Worker        ]
171*da0073e9SAndroid Build Coastguard Worker        torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
172*da0073e9SAndroid Build Coastguard Worker            mm,
173*da0073e9SAndroid Build Coastguard Worker            inputs={mm.forward: samples, mm.foo: samples},
174*da0073e9SAndroid Build Coastguard Worker            info={mm.forward: info, mm.foo: info},
175*da0073e9SAndroid Build Coastguard Worker        )
176*da0073e9SAndroid Build Coastguard Worker        loaded = save_and_load(mm)
177*da0073e9SAndroid Build Coastguard Worker        inflated = loaded.get_all_bundled_inputs()
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        # Make sure these functions are all consistent.
180*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated, samples)
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward())
182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        # Check running and size helpers
185*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        # Check helper that work on all functions
189*da0073e9SAndroid Build Coastguard Worker        all_info = loaded.get_bundled_inputs_functions_and_info()
190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(set(all_info.keys()), {"forward", "foo"})
191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
192*da0073e9SAndroid Build Coastguard Worker            all_info["forward"]["get_inputs_function_name"],
193*da0073e9SAndroid Build Coastguard Worker            ["get_all_bundled_inputs_for_forward"],
194*da0073e9SAndroid Build Coastguard Worker        )
195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
196*da0073e9SAndroid Build Coastguard Worker            all_info["foo"]["get_inputs_function_name"],
197*da0073e9SAndroid Build Coastguard Worker            ["get_all_bundled_inputs_for_foo"],
198*da0073e9SAndroid Build Coastguard Worker        )
199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all_info["forward"]["info"], info)
200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all_info["foo"]["info"], info)
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
203*da0073e9SAndroid Build Coastguard Worker        for func_name in all_info.keys():
204*da0073e9SAndroid Build Coastguard Worker            input_func_name = all_info[func_name]["get_inputs_function_name"][0]
205*da0073e9SAndroid Build Coastguard Worker            func_to_run = getattr(loaded, input_func_name)
206*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func_to_run(), samples)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    def test_multiple_methods_with_inputs_both_defined_failure(self):
209*da0073e9SAndroid Build Coastguard Worker        class MultipleMethodModel(torch.nn.Module):
210*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
211*da0073e9SAndroid Build Coastguard Worker                return arg
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
214*da0073e9SAndroid Build Coastguard Worker            def foo(self, arg):
215*da0073e9SAndroid Build Coastguard Worker                return arg
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        samples = [(torch.tensor([1]),)]
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        # inputs defined 2 ways so should fail
220*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
221*da0073e9SAndroid Build Coastguard Worker            mm = torch.jit.script(MultipleMethodModel())
222*da0073e9SAndroid Build Coastguard Worker            definition = textwrap.dedent(
223*da0073e9SAndroid Build Coastguard Worker                """
224*da0073e9SAndroid Build Coastguard Worker                def _generate_bundled_inputs_for_forward(self):
225*da0073e9SAndroid Build Coastguard Worker                    return []
226*da0073e9SAndroid Build Coastguard Worker                """
227*da0073e9SAndroid Build Coastguard Worker            )
228*da0073e9SAndroid Build Coastguard Worker            mm.define(definition)
229*da0073e9SAndroid Build Coastguard Worker            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
230*da0073e9SAndroid Build Coastguard Worker                mm,
231*da0073e9SAndroid Build Coastguard Worker                inputs={
232*da0073e9SAndroid Build Coastguard Worker                    mm.forward: samples,
233*da0073e9SAndroid Build Coastguard Worker                    mm.foo: samples,
234*da0073e9SAndroid Build Coastguard Worker                },
235*da0073e9SAndroid Build Coastguard Worker            )
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    def test_multiple_methods_with_inputs_neither_defined_failure(self):
238*da0073e9SAndroid Build Coastguard Worker        class MultipleMethodModel(torch.nn.Module):
239*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
240*da0073e9SAndroid Build Coastguard Worker                return arg
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
243*da0073e9SAndroid Build Coastguard Worker            def foo(self, arg):
244*da0073e9SAndroid Build Coastguard Worker                return arg
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        samples = [(torch.tensor([1]),)]
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        # inputs not defined so should fail
249*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
250*da0073e9SAndroid Build Coastguard Worker            mm = torch.jit.script(MultipleMethodModel())
251*da0073e9SAndroid Build Coastguard Worker            mm._generate_bundled_inputs_for_forward()
252*da0073e9SAndroid Build Coastguard Worker            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
253*da0073e9SAndroid Build Coastguard Worker                mm,
254*da0073e9SAndroid Build Coastguard Worker                inputs={
255*da0073e9SAndroid Build Coastguard Worker                    mm.forward: None,
256*da0073e9SAndroid Build Coastguard Worker                    mm.foo: samples,
257*da0073e9SAndroid Build Coastguard Worker                },
258*da0073e9SAndroid Build Coastguard Worker            )
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    def test_bad_inputs(self):
261*da0073e9SAndroid Build Coastguard Worker        class SingleTensorModel(torch.nn.Module):
262*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
263*da0073e9SAndroid Build Coastguard Worker                return arg
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        # Non list for input list
266*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
267*da0073e9SAndroid Build Coastguard Worker            m = torch.jit.script(SingleTensorModel())
268*da0073e9SAndroid Build Coastguard Worker            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
269*da0073e9SAndroid Build Coastguard Worker                m,
270*da0073e9SAndroid Build Coastguard Worker                inputs="foo",  # type: ignore[arg-type]
271*da0073e9SAndroid Build Coastguard Worker            )
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        # List of non tuples. Most common error using the api.
274*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
275*da0073e9SAndroid Build Coastguard Worker            m = torch.jit.script(SingleTensorModel())
276*da0073e9SAndroid Build Coastguard Worker            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
277*da0073e9SAndroid Build Coastguard Worker                m,
278*da0073e9SAndroid Build Coastguard Worker                inputs=[torch.ones(1, 2)],  # type: ignore[list-item]
279*da0073e9SAndroid Build Coastguard Worker            )
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    def test_double_augment_fail(self):
282*da0073e9SAndroid Build Coastguard Worker        class SingleTensorModel(torch.nn.Module):
283*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
284*da0073e9SAndroid Build Coastguard Worker                return arg
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(SingleTensorModel())
287*da0073e9SAndroid Build Coastguard Worker        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
288*da0073e9SAndroid Build Coastguard Worker            m, inputs=[(torch.ones(1),)]
289*da0073e9SAndroid Build Coastguard Worker        )
290*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
291*da0073e9SAndroid Build Coastguard Worker            Exception, "Models can only be augmented with bundled inputs once."
292*da0073e9SAndroid Build Coastguard Worker        ):
293*da0073e9SAndroid Build Coastguard Worker            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
294*da0073e9SAndroid Build Coastguard Worker                m, inputs=[(torch.ones(1),)]
295*da0073e9SAndroid Build Coastguard Worker            )
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def test_double_augment_non_mutator(self):
298*da0073e9SAndroid Build Coastguard Worker        class SingleTensorModel(torch.nn.Module):
299*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
300*da0073e9SAndroid Build Coastguard Worker                return arg
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(SingleTensorModel())
303*da0073e9SAndroid Build Coastguard Worker        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
304*da0073e9SAndroid Build Coastguard Worker            m, inputs=[(torch.ones(1),)]
305*da0073e9SAndroid Build Coastguard Worker        )
306*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AttributeError):
307*da0073e9SAndroid Build Coastguard Worker            m.get_all_bundled_inputs()
308*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1))
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    def test_double_augment_success(self):
312*da0073e9SAndroid Build Coastguard Worker        class SingleTensorModel(torch.nn.Module):
313*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
314*da0073e9SAndroid Build Coastguard Worker                return arg
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(SingleTensorModel())
317*da0073e9SAndroid Build Coastguard Worker        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
318*da0073e9SAndroid Build Coastguard Worker            m, inputs={m.forward: [(torch.ones(1),)]}
319*da0073e9SAndroid Build Coastguard Worker        )
320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        bundled_model2 = torch.utils.bundled_inputs.bundle_inputs(
323*da0073e9SAndroid Build Coastguard Worker            bundled_model, inputs=[(torch.ones(2),)]
324*da0073e9SAndroid Build Coastguard Worker        )
325*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)])
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    def test_dict_args(self):
328*da0073e9SAndroid Build Coastguard Worker        class MyModel(torch.nn.Module):
329*da0073e9SAndroid Build Coastguard Worker            def forward(
330*da0073e9SAndroid Build Coastguard Worker                self,
331*da0073e9SAndroid Build Coastguard Worker                arg1: Optional[Dict[str, torch.Tensor]],
332*da0073e9SAndroid Build Coastguard Worker                arg2: Optional[List[torch.Tensor]],
333*da0073e9SAndroid Build Coastguard Worker                arg3: torch.Tensor,
334*da0073e9SAndroid Build Coastguard Worker            ):
335*da0073e9SAndroid Build Coastguard Worker                if arg1 is None:
336*da0073e9SAndroid Build Coastguard Worker                    return arg3
337*da0073e9SAndroid Build Coastguard Worker                elif arg2 is None:
338*da0073e9SAndroid Build Coastguard Worker                    return arg1["a"] + arg1["b"]
339*da0073e9SAndroid Build Coastguard Worker                else:
340*da0073e9SAndroid Build Coastguard Worker                    return arg1["a"] + arg1["b"] + arg2[0]
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        small_sample = dict(
343*da0073e9SAndroid Build Coastguard Worker            a=torch.zeros([10, 20]),
344*da0073e9SAndroid Build Coastguard Worker            b=torch.zeros([1, 1]),
345*da0073e9SAndroid Build Coastguard Worker            c=torch.zeros([10, 20]),
346*da0073e9SAndroid Build Coastguard Worker        )
347*da0073e9SAndroid Build Coastguard Worker        small_list = [torch.zeros([10, 20])]
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        big_sample = dict(
350*da0073e9SAndroid Build Coastguard Worker            a=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
351*da0073e9SAndroid Build Coastguard Worker            b=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
352*da0073e9SAndroid Build Coastguard Worker            c=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
353*da0073e9SAndroid Build Coastguard Worker        )
354*da0073e9SAndroid Build Coastguard Worker        big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])]
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        def condensed(t):
357*da0073e9SAndroid Build Coastguard Worker            ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape)
358*da0073e9SAndroid Build Coastguard Worker            assert ret.storage().size() == 1
359*da0073e9SAndroid Build Coastguard Worker            # ret.storage()[0] = 0
360*da0073e9SAndroid Build Coastguard Worker            return ret
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        def bundle_optional_dict_of_randn(template):
363*da0073e9SAndroid Build Coastguard Worker            return torch.utils.bundled_inputs.InflatableArg(
364*da0073e9SAndroid Build Coastguard Worker                value=(
365*da0073e9SAndroid Build Coastguard Worker                    None
366*da0073e9SAndroid Build Coastguard Worker                    if template is None
367*da0073e9SAndroid Build Coastguard Worker                    else {k: condensed(v) for (k, v) in template.items()}
368*da0073e9SAndroid Build Coastguard Worker                ),
369*da0073e9SAndroid Build Coastguard Worker                fmt="{}",
370*da0073e9SAndroid Build Coastguard Worker                fmt_fn="""
371*da0073e9SAndroid Build Coastguard Worker                def {}(self, value: Optional[Dict[str, Tensor]]):
372*da0073e9SAndroid Build Coastguard Worker                    if value is None:
373*da0073e9SAndroid Build Coastguard Worker                        return None
374*da0073e9SAndroid Build Coastguard Worker                    output = {{}}
375*da0073e9SAndroid Build Coastguard Worker                    for k, v in value.items():
376*da0073e9SAndroid Build Coastguard Worker                        output[k] = torch.randn_like(v)
377*da0073e9SAndroid Build Coastguard Worker                    return output
378*da0073e9SAndroid Build Coastguard Worker                """,
379*da0073e9SAndroid Build Coastguard Worker            )
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker        def bundle_optional_list_of_randn(template):
382*da0073e9SAndroid Build Coastguard Worker            return torch.utils.bundled_inputs.InflatableArg(
383*da0073e9SAndroid Build Coastguard Worker                value=(None if template is None else [condensed(v) for v in template]),
384*da0073e9SAndroid Build Coastguard Worker                fmt="{}",
385*da0073e9SAndroid Build Coastguard Worker                fmt_fn="""
386*da0073e9SAndroid Build Coastguard Worker                def {}(self, value: Optional[List[Tensor]]):
387*da0073e9SAndroid Build Coastguard Worker                    if value is None:
388*da0073e9SAndroid Build Coastguard Worker                        return None
389*da0073e9SAndroid Build Coastguard Worker                    output = []
390*da0073e9SAndroid Build Coastguard Worker                    for v in value:
391*da0073e9SAndroid Build Coastguard Worker                        output.append(torch.randn_like(v))
392*da0073e9SAndroid Build Coastguard Worker                    return output
393*da0073e9SAndroid Build Coastguard Worker                """,
394*da0073e9SAndroid Build Coastguard Worker            )
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        out: List[str] = []
397*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(MyModel())
398*da0073e9SAndroid Build Coastguard Worker        original_size = model_size(sm)
399*da0073e9SAndroid Build Coastguard Worker        small_inputs = (
400*da0073e9SAndroid Build Coastguard Worker            bundle_optional_dict_of_randn(small_sample),
401*da0073e9SAndroid Build Coastguard Worker            bundle_optional_list_of_randn(small_list),
402*da0073e9SAndroid Build Coastguard Worker            torch.zeros([3, 4]),
403*da0073e9SAndroid Build Coastguard Worker        )
404*da0073e9SAndroid Build Coastguard Worker        big_inputs = (
405*da0073e9SAndroid Build Coastguard Worker            bundle_optional_dict_of_randn(big_sample),
406*da0073e9SAndroid Build Coastguard Worker            bundle_optional_list_of_randn(big_list),
407*da0073e9SAndroid Build Coastguard Worker            torch.zeros([1 << 5, 1 << 8, 1 << 10]),
408*da0073e9SAndroid Build Coastguard Worker        )
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
411*da0073e9SAndroid Build Coastguard Worker            sm,
412*da0073e9SAndroid Build Coastguard Worker            [big_inputs, small_inputs],
413*da0073e9SAndroid Build Coastguard Worker            _receive_inflate_expr=out,
414*da0073e9SAndroid Build Coastguard Worker        )
415*da0073e9SAndroid Build Coastguard Worker        augmented_size = model_size(sm)
416*da0073e9SAndroid Build Coastguard Worker        # assert the size has not increased more than 8KB
417*da0073e9SAndroid Build Coastguard Worker        self.assertLess(augmented_size, original_size + (1 << 13))
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        loaded = save_and_load(sm)
420*da0073e9SAndroid Build Coastguard Worker        inflated = loaded.get_all_bundled_inputs()
421*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(inflated[0]), len(small_inputs))
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker        methods, _ = (
424*da0073e9SAndroid Build Coastguard Worker            torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
425*da0073e9SAndroid Build Coastguard Worker                loaded
426*da0073e9SAndroid Build Coastguard Worker            )
427*da0073e9SAndroid Build Coastguard Worker        )
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        # One Function (forward)
430*da0073e9SAndroid Build Coastguard Worker        # two bundled inputs (big_inputs and small_inputs)
431*da0073e9SAndroid Build Coastguard Worker        # two args which have InflatableArg with fmt_fn
432*da0073e9SAndroid Build Coastguard Worker        # 1 * 2 * 2 = 4
433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
434*da0073e9SAndroid Build Coastguard Worker            sum(method.startswith("_inflate_helper") for method in methods), 4
435*da0073e9SAndroid Build Coastguard Worker        )
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
439*da0073e9SAndroid Build Coastguard Worker    run_tests()
440