1# mypy: allow-untyped-defs 2from collections import OrderedDict 3 4 5""" 6This file contains helper functions that implement experimental functionality 7for named tensors in python. All of these are experimental, unstable, and 8subject to change or deletion. 9""" 10 11 12def check_serializing_named_tensor(tensor): 13 if tensor.has_names(): 14 raise RuntimeError( 15 "NYI: Named tensors don't support serialization. Please drop " 16 "names via `tensor = tensor.rename(None)` before serialization." 17 ) 18 19 20def build_dim_map(tensor): 21 """Returns a map of { dim: dim_name } where dim is a name if the dim is named 22 and the dim index otherwise.""" 23 return OrderedDict( 24 [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)] 25 ) 26 27 28def unzip_namedshape(namedshape): 29 if isinstance(namedshape, OrderedDict): 30 namedshape = namedshape.items() 31 if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple): 32 raise RuntimeError( 33 f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}" 34 ) 35 if len(namedshape) == 0: 36 raise RuntimeError("Expected namedshape to non-empty.") 37 return zip(*namedshape) 38 39 40def namer_api_name(inplace): 41 if inplace: 42 return "rename_" 43 else: 44 return "rename" 45 46 47def is_ellipsis(item): 48 return item == Ellipsis or item == "..." 49 50 51def single_ellipsis_index(names, fn_name): 52 ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] 53 if len(ellipsis_indices) >= 2: 54 raise RuntimeError( 55 f"{fn_name}: More than one Ellipsis ('...') found in names (" 56 f"{names}). This function supports up to one Ellipsis." 57 ) 58 if len(ellipsis_indices) == 1: 59 return ellipsis_indices[0] 60 return None 61 62 63def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): 64 return names[numel_pre_glob : len(names) - numel_post_glob] 65 66 67def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names): 68 globbed_names = expand_single_ellipsis( 69 ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names 70 ) 71 return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :] 72 73 74def resolve_ellipsis(names, tensor_names, fn_name): 75 """ 76 Expands ... inside `names` to be equal to a list of names from `tensor_names`. 77 """ 78 ellipsis_idx = single_ellipsis_index(names, fn_name) 79 if ellipsis_idx is None: 80 return names 81 return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) 82 83 84def update_names_with_list(tensor, names, inplace): 85 # Special case for tensor.rename(None) 86 if len(names) == 1 and names[0] is None: 87 return tensor._update_names(None, inplace) 88 89 return tensor._update_names( 90 resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace 91 ) 92 93 94def update_names_with_mapping(tensor, rename_map, inplace): 95 dim_map = build_dim_map(tensor) 96 for old_dim in rename_map.keys(): 97 new_dim = rename_map[old_dim] 98 if old_dim in dim_map.keys(): 99 dim_map[old_dim] = new_dim 100 else: 101 raise RuntimeError( 102 f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim " 103 f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist" 104 ) 105 return tensor._update_names(tuple(dim_map.values()), inplace) 106 107 108def update_names(tensor, names, rename_map, inplace): 109 """There are two usages: 110 111 tensor.rename(*names) returns a view on tensor with named dims `names`. 112 `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`, 113 then it is expanded greedily to be equal to the corresponding names from 114 `tensor.names`. 115 116 For example, 117 ``` 118 >>> # xdoctest: +SKIP 119 >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) 120 >>> x.rename('...', 'height', 'width').names 121 ('N', 'C', 'height', 'width') 122 123 >>> # xdoctest: +SKIP 124 >>> x.rename('batch', '...', 'width').names 125 ('batch', 'C', 'H', 'width') 126 127 ``` 128 129 tensor.rename(**rename_map) returns a view on tensor that has rename dims 130 as specified in the mapping `rename_map`. 131 132 For example, 133 ``` 134 >>> # xdoctest: +SKIP 135 >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) 136 >>> x.rename(W='width', H='height').names 137 ('N', 'C', 'height', 'width') 138 139 ``` 140 141 Finally, tensor.rename has an in-place version called tensor.rename_. 142 """ 143 has_names = len(names) > 0 144 has_rename_pairs = bool(rename_map) 145 if has_names and has_rename_pairs: 146 raise RuntimeError( 147 f"{namer_api_name(inplace)}: This function takes either positional " 148 f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) " 149 f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename " 150 "dims." 151 ) 152 153 # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor. 154 if not has_names and not has_rename_pairs: 155 return update_names_with_list(tensor, names, inplace) 156 157 if has_names: 158 return update_names_with_list(tensor, names, inplace) 159 return update_names_with_mapping(tensor, rename_map, inplace) 160