1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import torch 4 5 6def friendly_debug_info(v): 7 """ 8 Helper function to print out debug info in a friendly way. 9 """ 10 if isinstance(v, torch.Tensor): 11 return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})" 12 else: 13 return str(v) 14 15 16def map_debug_info(a): 17 """ 18 Helper function to apply `friendly_debug_info` to items in `a`. 19 `a` may be a list, tuple, or dict. 20 """ 21 return torch.fx.node.map_aggregate(a, friendly_debug_info) 22