1""" 2Python polyfills for common builtins. 3""" 4 5# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports. 6# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py. 7# Add it in the TYPE_CHECKING block below as well. 8 9# mypy: allow-untyped-defs 10 11from typing import Any, Callable, Sequence, TYPE_CHECKING 12 13import torch 14 15 16if TYPE_CHECKING: 17 # Load by torch._dynamo.polyfills.loader 18 # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py 19 # Put the submodules here to avoid circular imports 20 from . import ( 21 builtins as builtins, 22 functools as functools, 23 itertools as itertools, 24 os as os, 25 sys as sys, 26 ) 27 28 29def index(iterator, item, start=0, end=None): 30 from itertools import islice 31 32 for i, elem in islice(enumerate(iterator), start, end): 33 if item == elem: 34 return i 35 # This will not run in dynamo 36 raise ValueError(f"{item} is not in {type(iterator)}") 37 38 39def repeat(item, count): 40 for i in range(count): 41 yield item 42 43 44def radians(x): 45 import math 46 47 return math.pi / 180.0 * x 48 49 50def accumulate_grad(x, new_grad): 51 new_grad = torch.clone(new_grad) 52 if x.grad is None: 53 x.grad = new_grad 54 else: 55 x.grad.add_(new_grad) 56 57 58def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]): 59 """emulate `(1,2,3) > (1,2)` etc""" 60 for a, b in zip(left, right): 61 if a != b: 62 return op(a, b) 63 return op(len(left), len(right)) 64 65 66def set_isdisjoint(set1, set2): 67 for x in set1: 68 if x in set2: 69 return False 70 return True 71 72 73def set_intersection(set1, set2): 74 intersection_set = set() 75 for x in set1: 76 if x in set2: 77 intersection_set.add(x) 78 return intersection_set 79 80 81def set_union(set1, set2): 82 union_set = set1.copy() 83 for x in set2: 84 if x not in union_set: 85 union_set.add(x) 86 return union_set 87 88 89def set_difference(set1, set2): 90 difference_set = set() 91 for x in set1: 92 if x not in set2: 93 difference_set.add(x) 94 return difference_set 95 96 97def dropwhile(predicate, iterable): 98 # dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1 99 iterable = iter(iterable) 100 for x in iterable: 101 if not predicate(x): 102 yield x 103 break 104 yield from iterable 105 106 107def zip_longest(*iterables, fillvalue=None): 108 # Create a list of iterators from the input iterables 109 iterators = [iter(it) for it in iterables] 110 result = [] 111 while True: 112 row = [] 113 active = False 114 for it in iterators: 115 try: 116 # Try to get the next item from the iterator 117 value = next(it) 118 row.append(value) 119 active = True 120 except StopIteration: 121 # If the iterator is exhausted, use the fillvalue 122 row.append(fillvalue) 123 if not active: 124 break 125 result.append(tuple(row)) 126 return result 127 128 129def getattr_and_trace(*args, **kwargs): 130 wrapper_obj = args[0] 131 attr_name = args[1] 132 fn = getattr(wrapper_obj, attr_name) 133 return fn(*args[2:], **kwargs) 134 135 136def mapping_get(obj, key, value=None): 137 try: 138 return obj.__getitem__(key) 139 except KeyError: 140 return value 141 142 143def instantiate_user_defined_class_object(cls, /, *args, **kwargs): 144 obj = cls.__new__(cls, *args, **kwargs) 145 146 # Only call __init__ if the object is an instance of the class 147 # Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673 148 if isinstance(obj, cls): 149 obj.__init__(*args, **kwargs) 150 return obj 151 152 153def foreach_lerp_inplace(self, end, weight): 154 # decompose foreach lerp into constituent ops, prevents a graph break due to 155 # converting a value to a scalar when arg[2] is a single tensor 156 result = torch._foreach_sub(end, self) 157 result = torch._foreach_mul(result, weight) 158 return torch._foreach_add_(self, result) 159 160 161def foreach_pow_scalar(scalar, exps): 162 return torch._foreach_pow([scalar for _ in exps], exps) 163 164 165def addcmul_inplace(self, tensor1, tensor2, value): 166 return self.add_(tensor1 * tensor2 * value) 167