1""" 2Custom pytest shard plugin 3https://github.com/AdamGleave/pytest-shard/blob/64610a08dac6b0511b6d51cf895d0e1040d162ad/pytest_shard/pytest_shard.py#L1 4Modifications: 5* shards are now 1 indexed instead of 0 indexed 6* option for printing items in shard 7""" 8 9import hashlib 10 11from _pytest.config.argparsing import Parser 12 13 14def pytest_addoptions(parser: Parser): 15 """Add options to control sharding.""" 16 group = parser.getgroup("shard") 17 group.addoption( 18 "--shard-id", dest="shard_id", type=int, default=1, help="Number of this shard." 19 ) 20 group.addoption( 21 "--num-shards", 22 dest="num_shards", 23 type=int, 24 default=1, 25 help="Total number of shards.", 26 ) 27 group.addoption( 28 "--print-items", 29 dest="print_items", 30 action="store_true", 31 default=False, 32 help="Print out the items being tested in this shard.", 33 ) 34 35 36class PytestShardPlugin: 37 def __init__(self, config): 38 self.config = config 39 40 def pytest_report_collectionfinish(self, config, items) -> str: 41 """Log how many and which items are tested in this shard.""" 42 msg = f"Running {len(items)} items in this shard" 43 if config.getoption("print_items"): 44 msg += ": " + ", ".join([item.nodeid for item in items]) 45 return msg 46 47 def sha256hash(self, x: str) -> int: 48 return int.from_bytes(hashlib.sha256(x.encode()).digest(), "little") 49 50 def filter_items_by_shard(self, items, shard_id: int, num_shards: int): 51 """Computes `items` that should be tested in `shard_id` out of `num_shards` total shards.""" 52 new_items = [ 53 item 54 for item in items 55 if self.sha256hash(item.nodeid) % num_shards == shard_id - 1 56 ] 57 return new_items 58 59 def pytest_collection_modifyitems(self, config, items): 60 """Mutate the collection to consist of just items to be tested in this shard.""" 61 shard_id = config.getoption("shard_id") 62 shard_total = config.getoption("num_shards") 63 if shard_id < 1 or shard_id > shard_total: 64 raise ValueError( 65 f"{shard_id} is not a valid shard ID out of {shard_total} total shards" 66 ) 67 68 items[:] = self.filter_items_by_shard(items, shard_id, shard_total) 69