xref: /aosp_15_r20/external/pytorch/tools/testing/test_run.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from copy import copy
4from functools import total_ordering
5from typing import Any, Iterable
6
7
8class TestRun:
9    """
10    TestRun defines the set of tests that should be run together in a single pytest invocation.
11    It'll either be a whole test file or a subset of a test file.
12
13    This class assumes that we won't always know the full set of TestClasses in a test file.
14    So it's designed to include or exclude explicitly requested TestClasses, while having accepting
15    that there will be an ambiguous set of "unknown" test classes that are not expliclty called out.
16    Those manifest as tests that haven't been explicitly excluded.
17    """
18
19    test_file: str
20    _excluded: frozenset[str]  # Tests that should be excluded from this test run
21    _included: frozenset[
22        str
23    ]  # If non-empy, only these tests should be run in this test run
24
25    def __init__(
26        self,
27        name: str,
28        excluded: Iterable[str] | None = None,
29        included: Iterable[str] | None = None,
30    ) -> None:
31        if excluded and included:
32            raise ValueError("Can't specify both included and excluded")
33
34        ins = set(included or [])
35        exs = set(excluded or [])
36
37        if "::" in name:
38            assert (
39                not included and not excluded
40            ), "Can't specify included or excluded tests when specifying a test class in the file name"
41            self.test_file, test_class = name.split("::")
42            ins.add(test_class)
43        else:
44            self.test_file = name
45
46        self._excluded = frozenset(exs)
47        self._included = frozenset(ins)
48
49    @staticmethod
50    def empty() -> TestRun:
51        return TestRun("")
52
53    def is_empty(self) -> bool:
54        # Lack of a test_file means that this is an empty run,
55        # which means there is nothing to run. It's the zero.
56        return not self.test_file
57
58    def is_full_file(self) -> bool:
59        return not self._included and not self._excluded
60
61    def included(self) -> frozenset[str]:
62        return self._included
63
64    def excluded(self) -> frozenset[str]:
65        return self._excluded
66
67    def get_pytest_filter(self) -> str:
68        if self._included:
69            return " or ".join(sorted(self._included))
70        elif self._excluded:
71            return f"not ({' or '.join(sorted(self._excluded))})"
72        else:
73            return ""
74
75    def contains(self, test: TestRun) -> bool:
76        if self.test_file != test.test_file:
77            return False
78
79        if self.is_full_file():
80            return True  # self contains all tests
81
82        if test.is_full_file():
83            return False  # test contains all tests, but self doesn't
84
85        # Does self exclude a subset of what test excludes?
86        if test._excluded:
87            return test._excluded.issubset(self._excluded)
88
89        # Does self include everything test includes?
90        if self._included:
91            return test._included.issubset(self._included)
92
93        # Getting to here means that test includes and self excludes
94        # Does self exclude anything test includes? If not, we're good
95        return not self._excluded.intersection(test._included)
96
97    def __copy__(self) -> TestRun:
98        return TestRun(self.test_file, excluded=self._excluded, included=self._included)
99
100    def __bool__(self) -> bool:
101        return not self.is_empty()
102
103    def __repr__(self) -> str:
104        r: str = f"RunTest({self.test_file}"
105        r += f", included: {self._included}" if self._included else ""
106        r += f", excluded: {self._excluded}" if self._excluded else ""
107        r += ")"
108        return r
109
110    def __str__(self) -> str:
111        if self.is_empty():
112            return "Empty"
113
114        pytest_filter = self.get_pytest_filter()
115        if pytest_filter:
116            return self.test_file + ", " + pytest_filter
117        return self.test_file
118
119    def __eq__(self, other: object) -> bool:
120        if not isinstance(other, TestRun):
121            return False
122
123        ret = self.test_file == other.test_file
124        ret = ret and self._included == other._included
125        ret = ret and self._excluded == other._excluded
126        return ret
127
128    def __hash__(self) -> int:
129        return hash((self.test_file, self._included, self._excluded))
130
131    def __or__(self, other: TestRun) -> TestRun:
132        """
133        To OR/Union test runs means to run all the tests that either of the two runs specify.
134        """
135
136        # Is any file empty?
137        if self.is_empty():
138            return other
139        if other.is_empty():
140            return copy(self)
141
142        # If not, ensure we have the same file
143        assert (
144            self.test_file == other.test_file
145        ), f"Can't exclude {other} from {self} because they're not the same test file"
146
147        # 4 possible cases:
148
149        # 1. Either file is the full file, so union is everything
150        if self.is_full_file() or other.is_full_file():
151            # The union is the whole file
152            return TestRun(self.test_file)
153
154        # 2. Both files only run what's in _included, so union is the union of the two sets
155        if self._included and other._included:
156            return TestRun(
157                self.test_file, included=self._included.union(other._included)
158            )
159
160        # 3. Both files only exclude what's in _excluded, so union is the intersection of the two sets
161        if self._excluded and other._excluded:
162            return TestRun(
163                self.test_file, excluded=self._excluded.intersection(other._excluded)
164            )
165
166        # 4. One file includes and the other excludes, so we then continue excluding the _excluded set minus
167        #    whatever is in the _included set
168        included = self._included | other._included
169        excluded = self._excluded | other._excluded
170        return TestRun(self.test_file, excluded=excluded - included)
171
172    def __sub__(self, other: TestRun) -> TestRun:
173        """
174        To subtract test runs means to run all the tests in the first run except for what the second run specifies.
175        """
176
177        # Is any file empty?
178        if self.is_empty():
179            return TestRun.empty()
180        if other.is_empty():
181            return copy(self)
182
183        # Are you trying to subtract tests that don't even exist in this test run?
184        if self.test_file != other.test_file:
185            return copy(self)
186
187        # You're subtracting everything?
188        if other.is_full_file():
189            return TestRun.empty()
190
191        def return_inclusions_or_empty(inclusions: frozenset[str]) -> TestRun:
192            if inclusions:
193                return TestRun(self.test_file, included=inclusions)
194            return TestRun.empty()
195
196        if other._included:
197            if self._included:
198                return return_inclusions_or_empty(self._included - other._included)
199            else:
200                return TestRun(
201                    self.test_file, excluded=self._excluded | other._included
202                )
203        else:
204            if self._included:
205                return return_inclusions_or_empty(self._included & other._excluded)
206            else:
207                return return_inclusions_or_empty(other._excluded - self._excluded)
208
209    def __and__(self, other: TestRun) -> TestRun:
210        if self.test_file != other.test_file:
211            return TestRun.empty()
212
213        return (self | other) - (self - other) - (other - self)
214
215    def to_json(self) -> dict[str, Any]:
216        r: dict[str, Any] = {
217            "test_file": self.test_file,
218        }
219        if self._included:
220            r["included"] = list(self._included)
221        if self._excluded:
222            r["excluded"] = list(self._excluded)
223        return r
224
225    @staticmethod
226    def from_json(json: dict[str, Any]) -> TestRun:
227        return TestRun(
228            json["test_file"],
229            included=json.get("included", []),
230            excluded=json.get("excluded", []),
231        )
232
233
234@total_ordering
235class ShardedTest:
236    test: TestRun
237    shard: int
238    num_shards: int
239    time: float | None  # In seconds
240
241    def __init__(
242        self,
243        test: TestRun | str,
244        shard: int,
245        num_shards: int,
246        time: float | None = None,
247    ) -> None:
248        if isinstance(test, str):
249            test = TestRun(test)
250        self.test = test
251        self.shard = shard
252        self.num_shards = num_shards
253        self.time = time
254
255    @property
256    def name(self) -> str:
257        return self.test.test_file
258
259    def __eq__(self, other: object) -> bool:
260        if not isinstance(other, ShardedTest):
261            return False
262        return (
263            self.test == other.test
264            and self.shard == other.shard
265            and self.num_shards == other.num_shards
266            and self.time == other.time
267        )
268
269    def __repr__(self) -> str:
270        ret = f"{self.test} {self.shard}/{self.num_shards}"
271        if self.time:
272            ret += f" ({self.time}s)"
273
274        return ret
275
276    def __lt__(self, other: object) -> bool:
277        if not isinstance(other, ShardedTest):
278            raise NotImplementedError
279
280        # This is how the list was implicity sorted when it was a NamedTuple
281        if self.name != other.name:
282            return self.name < other.name
283        if self.shard != other.shard:
284            return self.shard < other.shard
285        if self.num_shards != other.num_shards:
286            return self.num_shards < other.num_shards
287
288        # None is the smallest value
289        if self.time is None:
290            return True
291        if other.time is None:
292            return False
293        return self.time < other.time
294
295    def __str__(self) -> str:
296        return f"{self.test} {self.shard}/{self.num_shards}"
297
298    def get_time(self, default: float = 0) -> float:
299        return self.time if self.time is not None else default
300
301    def get_pytest_args(self) -> list[str]:
302        filter = self.test.get_pytest_filter()
303        if filter:
304            return ["-k", self.test.get_pytest_filter()]
305        return []
306