xref: /aosp_15_r20/external/pytorch/test/test_bundled_inputs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: mobile"]
3# mypy: allow-untyped-defs
4
5import io
6import textwrap
7from typing import Dict, List, Optional
8
9import torch
10import torch.utils.bundled_inputs
11from torch.testing._internal.common_utils import run_tests, TestCase
12
13
14def model_size(sm):
15    buffer = io.BytesIO()
16    torch.jit.save(sm, buffer)
17    return len(buffer.getvalue())
18
19
20def save_and_load(sm):
21    buffer = io.BytesIO()
22    torch.jit.save(sm, buffer)
23    buffer.seek(0)
24    return torch.jit.load(buffer)
25
26
27class TestBundledInputs(TestCase):
28    def test_single_tensors(self):
29        class SingleTensorModel(torch.nn.Module):
30            def forward(self, arg):
31                return arg
32
33        sm = torch.jit.script(SingleTensorModel())
34        original_size = model_size(sm)
35        get_expr: List[str] = []
36        samples = [
37            # Tensor with small numel and small storage.
38            (torch.tensor([1]),),
39            # Tensor with large numel and small storage.
40            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
41            # Tensor with small numel and large storage.
42            (torch.tensor(range(1 << 16))[-8:],),
43            # Large zero tensor.
44            (torch.zeros(1 << 16),),
45            # Large channels-last ones tensor.
46            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
47            # Special encoding of random tensor.
48            (torch.utils.bundled_inputs.bundle_randn(1 << 16),),
49            # Quantized uniform tensor.
50            (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),),
51        ]
52        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
53            sm, samples, get_expr
54        )
55        # print(get_expr[0])
56        # print(sm._generate_bundled_inputs.code)
57
58        # Make sure the model only grew a little bit,
59        # despite having nominally large bundled inputs.
60        augmented_size = model_size(sm)
61        self.assertLess(augmented_size, original_size + (1 << 12))
62
63        loaded = save_and_load(sm)
64        inflated = loaded.get_all_bundled_inputs()
65        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
66        self.assertEqual(len(inflated), len(samples))
67        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
68
69        for idx, inp in enumerate(inflated):
70            self.assertIsInstance(inp, tuple)
71            self.assertEqual(len(inp), 1)
72            self.assertIsInstance(inp[0], torch.Tensor)
73            if idx != 5:
74                # Strides might be important for benchmarking.
75                self.assertEqual(inp[0].stride(), samples[idx][0].stride())
76                self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
77
78        # This tensor is random, but with 100,000 trials,
79        # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
80        self.assertEqual(inflated[5][0].shape, (1 << 16,))
81        self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
82        self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
83
84    def test_large_tensor_with_inflation(self):
85        class SingleTensorModel(torch.nn.Module):
86            def forward(self, arg):
87                return arg
88
89        sm = torch.jit.script(SingleTensorModel())
90        sample_tensor = torch.randn(1 << 16)
91        # We can store tensors with custom inflation functions regardless
92        # of size, even if inflation is just the identity.
93        sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
94        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)])
95
96        loaded = save_and_load(sm)
97        inflated = loaded.get_all_bundled_inputs()
98        self.assertEqual(len(inflated), 1)
99
100        self.assertEqual(inflated[0][0], sample_tensor)
101
102    def test_rejected_tensors(self):
103        def check_tensor(sample):
104            # Need to define the class in this scope to get a fresh type for each run.
105            class SingleTensorModel(torch.nn.Module):
106                def forward(self, arg):
107                    return arg
108
109            sm = torch.jit.script(SingleTensorModel())
110            with self.assertRaisesRegex(Exception, "Bundled input argument"):
111                torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
112                    sm, [(sample,)]
113                )
114
115        # Plain old big tensor.
116        check_tensor(torch.randn(1 << 16))
117        # This tensor has two elements, but they're far apart in memory.
118        # We currently cannot represent this compactly while preserving
119        # the strides.
120        small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
121        self.assertEqual(small_sparse.numel(), 2)
122        check_tensor(small_sparse)
123
124    def test_non_tensors(self):
125        class StringAndIntModel(torch.nn.Module):
126            def forward(self, fmt: str, num: int):
127                return fmt.format(num)
128
129        sm = torch.jit.script(StringAndIntModel())
130        samples = [
131            ("first {}", 1),
132            ("second {}", 2),
133        ]
134        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples)
135
136        loaded = save_and_load(sm)
137        inflated = loaded.get_all_bundled_inputs()
138        self.assertEqual(inflated, samples)
139        self.assertTrue(loaded(*inflated[0]) == "first 1")
140
141    def test_multiple_methods_with_inputs(self):
142        class MultipleMethodModel(torch.nn.Module):
143            def forward(self, arg):
144                return arg
145
146            @torch.jit.export
147            def foo(self, arg):
148                return arg
149
150        mm = torch.jit.script(MultipleMethodModel())
151        samples = [
152            # Tensor with small numel and small storage.
153            (torch.tensor([1]),),
154            # Tensor with large numel and small storage.
155            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
156            # Tensor with small numel and large storage.
157            (torch.tensor(range(1 << 16))[-8:],),
158            # Large zero tensor.
159            (torch.zeros(1 << 16),),
160            # Large channels-last ones tensor.
161            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
162        ]
163        info = [
164            "Tensor with small numel and small storage.",
165            "Tensor with large numel and small storage.",
166            "Tensor with small numel and large storage.",
167            "Large zero tensor.",
168            "Large channels-last ones tensor.",
169            "Special encoding of random tensor.",
170        ]
171        torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
172            mm,
173            inputs={mm.forward: samples, mm.foo: samples},
174            info={mm.forward: info, mm.foo: info},
175        )
176        loaded = save_and_load(mm)
177        inflated = loaded.get_all_bundled_inputs()
178
179        # Make sure these functions are all consistent.
180        self.assertEqual(inflated, samples)
181        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward())
182        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
183
184        # Check running and size helpers
185        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
186        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
187
188        # Check helper that work on all functions
189        all_info = loaded.get_bundled_inputs_functions_and_info()
190        self.assertEqual(set(all_info.keys()), {"forward", "foo"})
191        self.assertEqual(
192            all_info["forward"]["get_inputs_function_name"],
193            ["get_all_bundled_inputs_for_forward"],
194        )
195        self.assertEqual(
196            all_info["foo"]["get_inputs_function_name"],
197            ["get_all_bundled_inputs_for_foo"],
198        )
199        self.assertEqual(all_info["forward"]["info"], info)
200        self.assertEqual(all_info["foo"]["info"], info)
201
202        # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
203        for func_name in all_info.keys():
204            input_func_name = all_info[func_name]["get_inputs_function_name"][0]
205            func_to_run = getattr(loaded, input_func_name)
206            self.assertEqual(func_to_run(), samples)
207
208    def test_multiple_methods_with_inputs_both_defined_failure(self):
209        class MultipleMethodModel(torch.nn.Module):
210            def forward(self, arg):
211                return arg
212
213            @torch.jit.export
214            def foo(self, arg):
215                return arg
216
217        samples = [(torch.tensor([1]),)]
218
219        # inputs defined 2 ways so should fail
220        with self.assertRaises(Exception):
221            mm = torch.jit.script(MultipleMethodModel())
222            definition = textwrap.dedent(
223                """
224                def _generate_bundled_inputs_for_forward(self):
225                    return []
226                """
227            )
228            mm.define(definition)
229            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
230                mm,
231                inputs={
232                    mm.forward: samples,
233                    mm.foo: samples,
234                },
235            )
236
237    def test_multiple_methods_with_inputs_neither_defined_failure(self):
238        class MultipleMethodModel(torch.nn.Module):
239            def forward(self, arg):
240                return arg
241
242            @torch.jit.export
243            def foo(self, arg):
244                return arg
245
246        samples = [(torch.tensor([1]),)]
247
248        # inputs not defined so should fail
249        with self.assertRaises(Exception):
250            mm = torch.jit.script(MultipleMethodModel())
251            mm._generate_bundled_inputs_for_forward()
252            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
253                mm,
254                inputs={
255                    mm.forward: None,
256                    mm.foo: samples,
257                },
258            )
259
260    def test_bad_inputs(self):
261        class SingleTensorModel(torch.nn.Module):
262            def forward(self, arg):
263                return arg
264
265        # Non list for input list
266        with self.assertRaises(TypeError):
267            m = torch.jit.script(SingleTensorModel())
268            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
269                m,
270                inputs="foo",  # type: ignore[arg-type]
271            )
272
273        # List of non tuples. Most common error using the api.
274        with self.assertRaises(TypeError):
275            m = torch.jit.script(SingleTensorModel())
276            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
277                m,
278                inputs=[torch.ones(1, 2)],  # type: ignore[list-item]
279            )
280
281    def test_double_augment_fail(self):
282        class SingleTensorModel(torch.nn.Module):
283            def forward(self, arg):
284                return arg
285
286        m = torch.jit.script(SingleTensorModel())
287        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
288            m, inputs=[(torch.ones(1),)]
289        )
290        with self.assertRaisesRegex(
291            Exception, "Models can only be augmented with bundled inputs once."
292        ):
293            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
294                m, inputs=[(torch.ones(1),)]
295            )
296
297    def test_double_augment_non_mutator(self):
298        class SingleTensorModel(torch.nn.Module):
299            def forward(self, arg):
300                return arg
301
302        m = torch.jit.script(SingleTensorModel())
303        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
304            m, inputs=[(torch.ones(1),)]
305        )
306        with self.assertRaises(AttributeError):
307            m.get_all_bundled_inputs()
308        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
309        self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1))
310
311    def test_double_augment_success(self):
312        class SingleTensorModel(torch.nn.Module):
313            def forward(self, arg):
314                return arg
315
316        m = torch.jit.script(SingleTensorModel())
317        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
318            m, inputs={m.forward: [(torch.ones(1),)]}
319        )
320        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
321
322        bundled_model2 = torch.utils.bundled_inputs.bundle_inputs(
323            bundled_model, inputs=[(torch.ones(2),)]
324        )
325        self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)])
326
327    def test_dict_args(self):
328        class MyModel(torch.nn.Module):
329            def forward(
330                self,
331                arg1: Optional[Dict[str, torch.Tensor]],
332                arg2: Optional[List[torch.Tensor]],
333                arg3: torch.Tensor,
334            ):
335                if arg1 is None:
336                    return arg3
337                elif arg2 is None:
338                    return arg1["a"] + arg1["b"]
339                else:
340                    return arg1["a"] + arg1["b"] + arg2[0]
341
342        small_sample = dict(
343            a=torch.zeros([10, 20]),
344            b=torch.zeros([1, 1]),
345            c=torch.zeros([10, 20]),
346        )
347        small_list = [torch.zeros([10, 20])]
348
349        big_sample = dict(
350            a=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
351            b=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
352            c=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
353        )
354        big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])]
355
356        def condensed(t):
357            ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape)
358            assert ret.storage().size() == 1
359            # ret.storage()[0] = 0
360            return ret
361
362        def bundle_optional_dict_of_randn(template):
363            return torch.utils.bundled_inputs.InflatableArg(
364                value=(
365                    None
366                    if template is None
367                    else {k: condensed(v) for (k, v) in template.items()}
368                ),
369                fmt="{}",
370                fmt_fn="""
371                def {}(self, value: Optional[Dict[str, Tensor]]):
372                    if value is None:
373                        return None
374                    output = {{}}
375                    for k, v in value.items():
376                        output[k] = torch.randn_like(v)
377                    return output
378                """,
379            )
380
381        def bundle_optional_list_of_randn(template):
382            return torch.utils.bundled_inputs.InflatableArg(
383                value=(None if template is None else [condensed(v) for v in template]),
384                fmt="{}",
385                fmt_fn="""
386                def {}(self, value: Optional[List[Tensor]]):
387                    if value is None:
388                        return None
389                    output = []
390                    for v in value:
391                        output.append(torch.randn_like(v))
392                    return output
393                """,
394            )
395
396        out: List[str] = []
397        sm = torch.jit.script(MyModel())
398        original_size = model_size(sm)
399        small_inputs = (
400            bundle_optional_dict_of_randn(small_sample),
401            bundle_optional_list_of_randn(small_list),
402            torch.zeros([3, 4]),
403        )
404        big_inputs = (
405            bundle_optional_dict_of_randn(big_sample),
406            bundle_optional_list_of_randn(big_list),
407            torch.zeros([1 << 5, 1 << 8, 1 << 10]),
408        )
409
410        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
411            sm,
412            [big_inputs, small_inputs],
413            _receive_inflate_expr=out,
414        )
415        augmented_size = model_size(sm)
416        # assert the size has not increased more than 8KB
417        self.assertLess(augmented_size, original_size + (1 << 13))
418
419        loaded = save_and_load(sm)
420        inflated = loaded.get_all_bundled_inputs()
421        self.assertEqual(len(inflated[0]), len(small_inputs))
422
423        methods, _ = (
424            torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
425                loaded
426            )
427        )
428
429        # One Function (forward)
430        # two bundled inputs (big_inputs and small_inputs)
431        # two args which have InflatableArg with fmt_fn
432        # 1 * 2 * 2 = 4
433        self.assertEqual(
434            sum(method.startswith("_inflate_helper") for method in methods), 4
435        )
436
437
438if __name__ == "__main__":
439    run_tests()
440