xref: /aosp_15_r20/external/executorch/exir/tests/control_flow_models.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from torch.nn import Module  # @manual
9
10
11class FTCondBasic(Module):
12    def __init__(self):
13        super().__init__()
14
15    def forward(self, inp):
16        def true_branch(x):
17            return x + x
18
19        def false_branch(x):
20            return x * x
21
22        return torch.ops.higher_order.cond(
23            inp.sum() > 4, true_branch, false_branch, [inp]
24        )
25
26    def get_random_inputs(self):
27        return (torch.rand(5),)
28
29
30class FTCondDynShape(Module):
31    def __init__(self):
32        super().__init__()
33
34    def forward(self, inp):
35        def true_branch(x):
36            return x + x + x
37
38        def false_branch(x):
39            return x * x * x
40
41        return torch.ops.higher_order.cond(
42            inp.sum() > 4, true_branch, false_branch, [inp]
43        )
44
45    def get_upper_bound_inputs(self):
46        return (torch.rand(8),)
47
48    def get_random_inputs(self):
49        return (torch.rand(5),)
50
51
52class FTCondDeadCode(Module):
53    """
54    A toy model used to test DCE on sub modules.
55
56    The graph generated for torch.inverse will contain a node:
57      torch.ops.aten._linalg_check_errors.default
58    to check for errors. There are no out variants for this op and executorch
59    runtime does not support it. For now, we simply erase this node by DCE
60    since the Fx code does not consider this node as having side effect.
61    """
62
63    def __init__(self):
64        super().__init__()
65
66    def forward(self, inp):
67        def true_branch(x):
68            x - 1
69            return x + 1
70
71        def false_branch(x):
72            return x * 2
73
74        return torch.ops.higher_order.cond(
75            inp.sum() > 4, true_branch, false_branch, [inp]
76        )
77
78    def get_random_inputs(self):
79        return (torch.eye(5) * 2,)
80
81
82class FTMapBasic(Module):
83    def __init__(self):
84        super().__init__()
85
86    def forward(self, xs, y):
87        def f(x, y):
88            return x + y
89
90        return torch.ops.higher_order.map(f, xs, y) + xs
91
92    def get_random_inputs(self):
93        return torch.rand(2, 4), torch.rand(4)
94
95
96class FTMapDynShape(Module):
97    def __init__(self):
98        super().__init__()
99
100    def forward(self, xs, y):
101        def f(x, y):
102            return x + y
103
104        return torch.ops.higher_order.map(f, xs, y) + xs
105
106    def get_upper_bound_inputs(self):
107        return torch.rand(4, 4), torch.rand(4)
108
109    def get_random_inputs(self):
110        return torch.rand(2, 4), torch.rand(4)
111