xref: /aosp_15_r20/external/pytorch/torch/distributed/pipelining/_debug.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import torch
4
5
6def friendly_debug_info(v):
7    """
8    Helper function to print out debug info in a friendly way.
9    """
10    if isinstance(v, torch.Tensor):
11        return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})"
12    else:
13        return str(v)
14
15
16def map_debug_info(a):
17    """
18    Helper function to apply `friendly_debug_info` to items in `a`.
19    `a` may be a list, tuple, or dict.
20    """
21    return torch.fx.node.map_aggregate(a, friendly_debug_info)
22