xref: /aosp_15_r20/external/executorch/exir/tests/test_capture.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport unittest
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
12*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.tests.models as models
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerfrom parameterized import parameterized
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerclass TestCapture(unittest.TestCase):
19*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
20*523fa7a6SAndroid Build Coastguard Worker    @parameterized.expand(models.MODELS)
21*523fa7a6SAndroid Build Coastguard Worker    def test_module_call(self, model_name: str, model: torch.nn.Module) -> None:
22*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
23*523fa7a6SAndroid Build Coastguard Worker        #  torch.nn.modules.module.Module]` is not a function.
24*523fa7a6SAndroid Build Coastguard Worker        inputs = model.get_random_inputs()
25*523fa7a6SAndroid Build Coastguard Worker        expected = model(*inputs)
26*523fa7a6SAndroid Build Coastguard Worker        # TODO(ycao): Replace it with capture_multiple
27*523fa7a6SAndroid Build Coastguard Worker        exported_program = exir.capture(model, inputs, exir.CaptureConfig())
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(expected, exported_program(*inputs)))
30