1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10 11import executorch.exir as exir 12import executorch.exir.tests.models as models 13import torch 14 15from parameterized import parameterized 16 17 18class TestCapture(unittest.TestCase): 19 # pyre-ignore 20 @parameterized.expand(models.MODELS) 21 def test_module_call(self, model_name: str, model: torch.nn.Module) -> None: 22 # pyre-fixme[29]: `Union[torch._tensor.Tensor, 23 # torch.nn.modules.module.Module]` is not a function. 24 inputs = model.get_random_inputs() 25 expected = model(*inputs) 26 # TODO(ycao): Replace it with capture_multiple 27 exported_program = exir.capture(model, inputs, exir.CaptureConfig()) 28 29 self.assertTrue(torch.allclose(expected, exported_program(*inputs))) 30