xref: /aosp_15_r20/external/pytorch/test/pytest_shard_custom.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""
2*da0073e9SAndroid Build Coastguard WorkerCustom pytest shard plugin
3*da0073e9SAndroid Build Coastguard Workerhttps://github.com/AdamGleave/pytest-shard/blob/64610a08dac6b0511b6d51cf895d0e1040d162ad/pytest_shard/pytest_shard.py#L1
4*da0073e9SAndroid Build Coastguard WorkerModifications:
5*da0073e9SAndroid Build Coastguard Worker* shards are now 1 indexed instead of 0 indexed
6*da0073e9SAndroid Build Coastguard Worker* option for printing items in shard
7*da0073e9SAndroid Build Coastguard Worker"""
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport hashlib
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerfrom _pytest.config.argparsing import Parser
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef pytest_addoptions(parser: Parser):
15*da0073e9SAndroid Build Coastguard Worker    """Add options to control sharding."""
16*da0073e9SAndroid Build Coastguard Worker    group = parser.getgroup("shard")
17*da0073e9SAndroid Build Coastguard Worker    group.addoption(
18*da0073e9SAndroid Build Coastguard Worker        "--shard-id", dest="shard_id", type=int, default=1, help="Number of this shard."
19*da0073e9SAndroid Build Coastguard Worker    )
20*da0073e9SAndroid Build Coastguard Worker    group.addoption(
21*da0073e9SAndroid Build Coastguard Worker        "--num-shards",
22*da0073e9SAndroid Build Coastguard Worker        dest="num_shards",
23*da0073e9SAndroid Build Coastguard Worker        type=int,
24*da0073e9SAndroid Build Coastguard Worker        default=1,
25*da0073e9SAndroid Build Coastguard Worker        help="Total number of shards.",
26*da0073e9SAndroid Build Coastguard Worker    )
27*da0073e9SAndroid Build Coastguard Worker    group.addoption(
28*da0073e9SAndroid Build Coastguard Worker        "--print-items",
29*da0073e9SAndroid Build Coastguard Worker        dest="print_items",
30*da0073e9SAndroid Build Coastguard Worker        action="store_true",
31*da0073e9SAndroid Build Coastguard Worker        default=False,
32*da0073e9SAndroid Build Coastguard Worker        help="Print out the items being tested in this shard.",
33*da0073e9SAndroid Build Coastguard Worker    )
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerclass PytestShardPlugin:
37*da0073e9SAndroid Build Coastguard Worker    def __init__(self, config):
38*da0073e9SAndroid Build Coastguard Worker        self.config = config
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    def pytest_report_collectionfinish(self, config, items) -> str:
41*da0073e9SAndroid Build Coastguard Worker        """Log how many and which items are tested in this shard."""
42*da0073e9SAndroid Build Coastguard Worker        msg = f"Running {len(items)} items in this shard"
43*da0073e9SAndroid Build Coastguard Worker        if config.getoption("print_items"):
44*da0073e9SAndroid Build Coastguard Worker            msg += ": " + ", ".join([item.nodeid for item in items])
45*da0073e9SAndroid Build Coastguard Worker        return msg
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def sha256hash(self, x: str) -> int:
48*da0073e9SAndroid Build Coastguard Worker        return int.from_bytes(hashlib.sha256(x.encode()).digest(), "little")
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    def filter_items_by_shard(self, items, shard_id: int, num_shards: int):
51*da0073e9SAndroid Build Coastguard Worker        """Computes `items` that should be tested in `shard_id` out of `num_shards` total shards."""
52*da0073e9SAndroid Build Coastguard Worker        new_items = [
53*da0073e9SAndroid Build Coastguard Worker            item
54*da0073e9SAndroid Build Coastguard Worker            for item in items
55*da0073e9SAndroid Build Coastguard Worker            if self.sha256hash(item.nodeid) % num_shards == shard_id - 1
56*da0073e9SAndroid Build Coastguard Worker        ]
57*da0073e9SAndroid Build Coastguard Worker        return new_items
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def pytest_collection_modifyitems(self, config, items):
60*da0073e9SAndroid Build Coastguard Worker        """Mutate the collection to consist of just items to be tested in this shard."""
61*da0073e9SAndroid Build Coastguard Worker        shard_id = config.getoption("shard_id")
62*da0073e9SAndroid Build Coastguard Worker        shard_total = config.getoption("num_shards")
63*da0073e9SAndroid Build Coastguard Worker        if shard_id < 1 or shard_id > shard_total:
64*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
65*da0073e9SAndroid Build Coastguard Worker                f"{shard_id} is not a valid shard ID out of {shard_total} total shards"
66*da0073e9SAndroid Build Coastguard Worker            )
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        items[:] = self.filter_items_by_shard(items, shard_id, shard_total)
69