1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport unittest 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 12*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.tests.models as models 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerfrom parameterized import parameterized 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerclass TestCapture(unittest.TestCase): 19*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 20*523fa7a6SAndroid Build Coastguard Worker @parameterized.expand(models.MODELS) 21*523fa7a6SAndroid Build Coastguard Worker def test_module_call(self, model_name: str, model: torch.nn.Module) -> None: 22*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[29]: `Union[torch._tensor.Tensor, 23*523fa7a6SAndroid Build Coastguard Worker # torch.nn.modules.module.Module]` is not a function. 24*523fa7a6SAndroid Build Coastguard Worker inputs = model.get_random_inputs() 25*523fa7a6SAndroid Build Coastguard Worker expected = model(*inputs) 26*523fa7a6SAndroid Build Coastguard Worker # TODO(ycao): Replace it with capture_multiple 27*523fa7a6SAndroid Build Coastguard Worker exported_program = exir.capture(model, inputs, exir.CaptureConfig()) 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, exported_program(*inputs))) 30