1*da0073e9SAndroid Build Coastguard Workerfrom collections import OrderedDict 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _disabled_torch_function_impl 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker# Metaclass to combine _TensorMeta and the instance check override for Parameter. 8*da0073e9SAndroid Build Coastguard Workerclass _ParameterMeta(torch._C._TensorMeta): 9*da0073e9SAndroid Build Coastguard Worker # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. 10*da0073e9SAndroid Build Coastguard Worker def __instancecheck__(self, instance): 11*da0073e9SAndroid Build Coastguard Worker if self is Parameter: 12*da0073e9SAndroid Build Coastguard Worker if isinstance(instance, torch.Tensor) and getattr( 13*da0073e9SAndroid Build Coastguard Worker instance, "_is_param", False 14*da0073e9SAndroid Build Coastguard Worker ): 15*da0073e9SAndroid Build Coastguard Worker return True 16*da0073e9SAndroid Build Coastguard Worker return super().__instancecheck__(instance) 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass Parameter(torch.Tensor, metaclass=_ParameterMeta): 20*da0073e9SAndroid Build Coastguard Worker r"""A kind of Tensor that is to be considered a module parameter. 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker Parameters are :class:`~torch.Tensor` subclasses, that have a 23*da0073e9SAndroid Build Coastguard Worker very special property when used with :class:`Module` s - when they're 24*da0073e9SAndroid Build Coastguard Worker assigned as Module attributes they are automatically added to the list of 25*da0073e9SAndroid Build Coastguard Worker its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. 26*da0073e9SAndroid Build Coastguard Worker Assigning a Tensor doesn't have such effect. This is because one might 27*da0073e9SAndroid Build Coastguard Worker want to cache some temporary state, like last hidden state of the RNN, in 28*da0073e9SAndroid Build Coastguard Worker the model. If there was no such class as :class:`Parameter`, these 29*da0073e9SAndroid Build Coastguard Worker temporaries would get registered too. 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker Args: 32*da0073e9SAndroid Build Coastguard Worker data (Tensor): parameter tensor. 33*da0073e9SAndroid Build Coastguard Worker requires_grad (bool, optional): if the parameter requires gradient. Note that 34*da0073e9SAndroid Build Coastguard Worker the torch.no_grad() context does NOT affect the default behavior of 35*da0073e9SAndroid Build Coastguard Worker Parameter creation--the Parameter will still have `requires_grad=True` in 36*da0073e9SAndroid Build Coastguard Worker :class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more 37*da0073e9SAndroid Build Coastguard Worker details. Default: `True` 38*da0073e9SAndroid Build Coastguard Worker """ 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data=None, requires_grad=True): 41*da0073e9SAndroid Build Coastguard Worker if data is None: 42*da0073e9SAndroid Build Coastguard Worker data = torch.empty(0) 43*da0073e9SAndroid Build Coastguard Worker if type(data) is torch.Tensor or type(data) is Parameter: 44*da0073e9SAndroid Build Coastguard Worker # For ease of BC maintenance, keep this path for standard Tensor. 45*da0073e9SAndroid Build Coastguard Worker # Eventually (tm), we should change the behavior for standard Tensor to match. 46*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, data, requires_grad) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker # Path for custom tensors: set a flag on the instance to indicate parameter-ness. 49*da0073e9SAndroid Build Coastguard Worker t = data.detach().requires_grad_(requires_grad) 50*da0073e9SAndroid Build Coastguard Worker if type(t) is not type(data): 51*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 52*da0073e9SAndroid Build Coastguard Worker f"Creating a Parameter from an instance of type {type(data).__name__} " 53*da0073e9SAndroid Build Coastguard Worker "requires that detach() returns an instance of the same type, but return " 54*da0073e9SAndroid Build Coastguard Worker f"type {type(t).__name__} was found instead. To use the type as a " 55*da0073e9SAndroid Build Coastguard Worker "Parameter, please correct the detach() semantics defined by " 56*da0073e9SAndroid Build Coastguard Worker "its __torch_dispatch__() implementation." 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker t._is_param = True 59*da0073e9SAndroid Build Coastguard Worker return t 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types 62*da0073e9SAndroid Build Coastguard Worker # are still considered that custom tensor type and these methods will not be called for them. 63*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo): 64*da0073e9SAndroid Build Coastguard Worker if id(self) in memo: 65*da0073e9SAndroid Build Coastguard Worker return memo[id(self)] 66*da0073e9SAndroid Build Coastguard Worker else: 67*da0073e9SAndroid Build Coastguard Worker result = type(self)( 68*da0073e9SAndroid Build Coastguard Worker self.data.clone(memory_format=torch.preserve_format), self.requires_grad 69*da0073e9SAndroid Build Coastguard Worker ) 70*da0073e9SAndroid Build Coastguard Worker memo[id(self)] = result 71*da0073e9SAndroid Build Coastguard Worker return result 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 74*da0073e9SAndroid Build Coastguard Worker return "Parameter containing:\n" + super().__repr__() 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def __reduce_ex__(self, proto): 77*da0073e9SAndroid Build Coastguard Worker state = torch._utils._get_obj_state(self) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # See Note [Don't serialize hooks] 80*da0073e9SAndroid Build Coastguard Worker hooks = OrderedDict() 81*da0073e9SAndroid Build Coastguard Worker if not state: 82*da0073e9SAndroid Build Coastguard Worker return ( 83*da0073e9SAndroid Build Coastguard Worker torch._utils._rebuild_parameter, 84*da0073e9SAndroid Build Coastguard Worker (self.data, self.requires_grad, hooks), 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker return ( 88*da0073e9SAndroid Build Coastguard Worker torch._utils._rebuild_parameter_with_state, 89*da0073e9SAndroid Build Coastguard Worker (self.data, self.requires_grad, hooks, state), 90*da0073e9SAndroid Build Coastguard Worker ) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker __torch_function__ = _disabled_torch_function_impl 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Workerclass UninitializedTensorMixin: 96*da0073e9SAndroid Build Coastguard Worker _allowed_methods = [ 97*da0073e9SAndroid Build Coastguard Worker torch.Tensor.__hash__, 98*da0073e9SAndroid Build Coastguard Worker torch.Tensor.size, 99*da0073e9SAndroid Build Coastguard Worker torch.Tensor.copy_, 100*da0073e9SAndroid Build Coastguard Worker torch.Tensor.is_complex, 101*da0073e9SAndroid Build Coastguard Worker torch.Tensor.is_floating_point, 102*da0073e9SAndroid Build Coastguard Worker torch.Tensor.half, 103*da0073e9SAndroid Build Coastguard Worker torch.Tensor.float, 104*da0073e9SAndroid Build Coastguard Worker torch.Tensor.double, 105*da0073e9SAndroid Build Coastguard Worker torch.Tensor.char, 106*da0073e9SAndroid Build Coastguard Worker torch.Tensor.short, 107*da0073e9SAndroid Build Coastguard Worker torch.Tensor.int, 108*da0073e9SAndroid Build Coastguard Worker torch.Tensor.long, 109*da0073e9SAndroid Build Coastguard Worker torch.Tensor.cuda, 110*da0073e9SAndroid Build Coastguard Worker torch.Tensor.cpu, 111*da0073e9SAndroid Build Coastguard Worker torch.Tensor.to, 112*da0073e9SAndroid Build Coastguard Worker torch.Tensor.get_device, 113*da0073e9SAndroid Build Coastguard Worker torch._has_compatible_shallow_copy_type, 114*da0073e9SAndroid Build Coastguard Worker ] 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def materialize(self, shape, device=None, dtype=None): 117*da0073e9SAndroid Build Coastguard Worker r"""Create a Parameter or Tensor with the same properties of the uninitialized one. 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker Given a shape, it materializes a parameter in the same device 120*da0073e9SAndroid Build Coastguard Worker and with the same `dtype` as the current one or the specified ones in the 121*da0073e9SAndroid Build Coastguard Worker arguments. 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker Args: 124*da0073e9SAndroid Build Coastguard Worker shape : (tuple): the shape for the materialized tensor. 125*da0073e9SAndroid Build Coastguard Worker device (:class:`torch.device`): the desired device of the parameters 126*da0073e9SAndroid Build Coastguard Worker and buffers in this module. Optional. 127*da0073e9SAndroid Build Coastguard Worker dtype (:class:`torch.dtype`): the desired floating point type of 128*da0073e9SAndroid Build Coastguard Worker the floating point parameters and buffers in this module. Optional. 129*da0073e9SAndroid Build Coastguard Worker """ 130*da0073e9SAndroid Build Coastguard Worker if device is None: 131*da0073e9SAndroid Build Coastguard Worker device = self.data.device 132*da0073e9SAndroid Build Coastguard Worker if dtype is None: 133*da0073e9SAndroid Build Coastguard Worker dtype = self.data.dtype 134*da0073e9SAndroid Build Coastguard Worker self.data = torch.empty(shape, device=device, dtype=dtype) 135*da0073e9SAndroid Build Coastguard Worker self.__class__ = self.cls_to_become 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker @property 138*da0073e9SAndroid Build Coastguard Worker def shape(self): 139*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 140*da0073e9SAndroid Build Coastguard Worker "Can't access the shape of an uninitialized parameter or buffer. " 141*da0073e9SAndroid Build Coastguard Worker "This error usually happens in `load_state_dict` when trying to load " 142*da0073e9SAndroid Build Coastguard Worker "an uninitialized parameter into an initialized one. " 143*da0073e9SAndroid Build Coastguard Worker "Call `forward` to initialize the parameters before accessing their attributes." 144*da0073e9SAndroid Build Coastguard Worker ) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def share_memory_(self): 147*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 148*da0073e9SAndroid Build Coastguard Worker "Can't share memory on an uninitialized parameter or buffer. " 149*da0073e9SAndroid Build Coastguard Worker "Call `forward` to initialize the parameters before calling " 150*da0073e9SAndroid Build Coastguard Worker "`module.share_memory()`." 151*da0073e9SAndroid Build Coastguard Worker ) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 154*da0073e9SAndroid Build Coastguard Worker return f"<{self.__class__.__name__}>" 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def __reduce_ex__(self, proto): 157*da0073e9SAndroid Build Coastguard Worker # See Note [Don't serialize hooks] 158*da0073e9SAndroid Build Coastguard Worker return (self.__class__, (self.requires_grad,)) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker @classmethod 161*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 162*da0073e9SAndroid Build Coastguard Worker # method-wrapper is to detect access to Tensor properties that are 163*da0073e9SAndroid Build Coastguard Worker # wrapped in descriptors 164*da0073e9SAndroid Build Coastguard Worker if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": 165*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 166*da0073e9SAndroid Build Coastguard Worker kwargs = {} 167*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 168*da0073e9SAndroid Build Coastguard Worker raise ValueError( 169*da0073e9SAndroid Build Coastguard Worker f"Attempted to use an uninitialized parameter in {func}. " 170*da0073e9SAndroid Build Coastguard Worker "This error happens when you are using a `LazyModule` or " 171*da0073e9SAndroid Build Coastguard Worker f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " 172*da0073e9SAndroid Build Coastguard Worker "objects. When using LazyModules Call `forward` with a dummy batch " 173*da0073e9SAndroid Build Coastguard Worker "to initialize the parameters before calling torch functions" 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Workerdef is_lazy(param): 178*da0073e9SAndroid Build Coastguard Worker return isinstance(param, UninitializedTensorMixin) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Workerclass UninitializedParameter(UninitializedTensorMixin, Parameter): 182*da0073e9SAndroid Build Coastguard Worker r"""A parameter that is not initialized. 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker Uninitialized Parameters are a a special case of :class:`torch.nn.Parameter` 185*da0073e9SAndroid Build Coastguard Worker where the shape of the data is still unknown. 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker Unlike a :class:`torch.nn.Parameter`, uninitialized parameters 188*da0073e9SAndroid Build Coastguard Worker hold no data and attempting to access some properties, like their shape, 189*da0073e9SAndroid Build Coastguard Worker will throw a runtime error. The only operations that can be performed on a uninitialized 190*da0073e9SAndroid Build Coastguard Worker parameter are changing its datatype, moving it to a different device and 191*da0073e9SAndroid Build Coastguard Worker converting it to a regular :class:`torch.nn.Parameter`. 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker The default device or dtype to use when the parameter is materialized can be set 194*da0073e9SAndroid Build Coastguard Worker during construction using e.g. ``device='cuda'``. 195*da0073e9SAndroid Build Coastguard Worker """ 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker cls_to_become = Parameter 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: 200*da0073e9SAndroid Build Coastguard Worker factory_kwargs = {"device": device, "dtype": dtype} 201*da0073e9SAndroid Build Coastguard Worker data = torch.empty(0, **factory_kwargs) 202*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, data, requires_grad) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo): 205*da0073e9SAndroid Build Coastguard Worker if id(self) in memo: 206*da0073e9SAndroid Build Coastguard Worker return memo[id(self)] 207*da0073e9SAndroid Build Coastguard Worker else: 208*da0073e9SAndroid Build Coastguard Worker result = type(self)(self.requires_grad, self.data.device, self.data.dtype) 209*da0073e9SAndroid Build Coastguard Worker memo[id(self)] = result 210*da0073e9SAndroid Build Coastguard Worker return result 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker# Metaclass to combine _TensorMeta and the instance check override for Buffer. 214*da0073e9SAndroid Build Coastguard Workerclass _BufferMeta(torch._C._TensorMeta): 215*da0073e9SAndroid Build Coastguard Worker # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. 216*da0073e9SAndroid Build Coastguard Worker def __instancecheck__(self, instance): 217*da0073e9SAndroid Build Coastguard Worker if self is Buffer: 218*da0073e9SAndroid Build Coastguard Worker if isinstance(instance, torch.Tensor) and getattr( 219*da0073e9SAndroid Build Coastguard Worker instance, "_is_buffer", False 220*da0073e9SAndroid Build Coastguard Worker ): 221*da0073e9SAndroid Build Coastguard Worker return True 222*da0073e9SAndroid Build Coastguard Worker return super().__instancecheck__(instance) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Workerclass Buffer(torch.Tensor, metaclass=_BufferMeta): 226*da0073e9SAndroid Build Coastguard Worker r"""A kind of Tensor that should not be considered a model 227*da0073e9SAndroid Build Coastguard Worker parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker Buffers are :class:`~torch.Tensor` subclasses, that have a 230*da0073e9SAndroid Build Coastguard Worker very special property when used with :class:`Module` s -- when they're 231*da0073e9SAndroid Build Coastguard Worker assigned as Module attributes they are automatically added to the list of 232*da0073e9SAndroid Build Coastguard Worker its buffers, and will appear e.g. in :meth:`~torch.nn.Module.buffers` iterator. 233*da0073e9SAndroid Build Coastguard Worker Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using 234*da0073e9SAndroid Build Coastguard Worker the :meth:`~torch.nn.Module.register_buffer` function. 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker Args: 237*da0073e9SAndroid Build Coastguard Worker data (Tensor): buffer tensor. 238*da0073e9SAndroid Build Coastguard Worker persistent (bool, optional): whether the buffer is part of the module's 239*da0073e9SAndroid Build Coastguard Worker :attr:`state_dict`. Default: ``True`` 240*da0073e9SAndroid Build Coastguard Worker """ 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data=None, *, persistent=True): 243*da0073e9SAndroid Build Coastguard Worker if data is None: 244*da0073e9SAndroid Build Coastguard Worker data = torch.empty(0) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker t = data.detach().requires_grad_(data.requires_grad) 247*da0073e9SAndroid Build Coastguard Worker t.persistent = persistent 248*da0073e9SAndroid Build Coastguard Worker t._is_buffer = True 249*da0073e9SAndroid Build Coastguard Worker return t 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker __torch_function__ = _disabled_torch_function_impl 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Workerclass UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): 255*da0073e9SAndroid Build Coastguard Worker r"""A buffer that is not initialized. 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker Uninitialized Buffer is a a special case of :class:`torch.Tensor` 258*da0073e9SAndroid Build Coastguard Worker where the shape of the data is still unknown. 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker Unlike a :class:`torch.Tensor`, uninitialized parameters 261*da0073e9SAndroid Build Coastguard Worker hold no data and attempting to access some properties, like their shape, 262*da0073e9SAndroid Build Coastguard Worker will throw a runtime error. The only operations that can be performed on a uninitialized 263*da0073e9SAndroid Build Coastguard Worker parameter are changing its datatype, moving it to a different device and 264*da0073e9SAndroid Build Coastguard Worker converting it to a regular :class:`torch.Tensor`. 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker The default device or dtype to use when the buffer is materialized can be set 267*da0073e9SAndroid Build Coastguard Worker during construction using e.g. ``device='cuda'``. 268*da0073e9SAndroid Build Coastguard Worker """ 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker cls_to_become = torch.Tensor 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker def __new__( 273*da0073e9SAndroid Build Coastguard Worker cls, requires_grad=False, device=None, dtype=None, persistent=True 274*da0073e9SAndroid Build Coastguard Worker ) -> None: 275*da0073e9SAndroid Build Coastguard Worker factory_kwargs = {"device": device, "dtype": dtype} 276*da0073e9SAndroid Build Coastguard Worker data = torch.empty(0, **factory_kwargs) 277*da0073e9SAndroid Build Coastguard Worker ret = torch.Tensor._make_subclass(cls, data, requires_grad) 278*da0073e9SAndroid Build Coastguard Worker ret.persistent = persistent 279*da0073e9SAndroid Build Coastguard Worker ret._is_buffer = True 280*da0073e9SAndroid Build Coastguard Worker return ret 281