xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/trainer/preprocess_data.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1def preprocess_dummy_data(rank, data):
2    r"""
3    A function that moves the data from CPU to GPU
4    for DummyData class.
5    Args:
6        rank (int): worker rank
7        data (list): training examples
8    """
9    for i in range(len(data)):
10        data[i][0] = data[i][0].cuda(rank)
11        data[i][1] = data[i][1].cuda(rank)
12    return data
13