xref: /aosp_15_r20/external/pytorch/torch/nn/utils/_expanded_weights/expanded_weights_impl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from contextlib import contextmanager
4from typing import Callable, Dict
5
6import torch
7from torch._decomp import decomposition_table
8from torch.utils._pytree import tree_map_only
9
10
11HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
12
13aten = torch._ops.ops.aten
14# __torch_function__ runs before the pydispatcher so we need to manually use the same
15# decompositions indexed by their torch equivalent
16expanded_weights_rnn_decomps = {
17    # func: (input_decomp, data_decomp)
18    torch.rnn_relu: (
19        decomposition_table[aten.rnn_relu.input],
20        decomposition_table[aten.rnn_relu.data],
21    ),
22    torch.rnn_tanh: (
23        decomposition_table[aten.rnn_tanh.input],
24        decomposition_table[aten.rnn_tanh.data],
25    ),
26    torch.lstm: (
27        decomposition_table[aten.lstm.input],
28        decomposition_table[aten.lstm.data],
29    ),
30    torch.gru: (
31        decomposition_table[aten.gru.input],
32        decomposition_table[aten.gru.data],
33    ),
34}
35
36
37# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set
38@contextmanager
39def batch_second(args, kwargs):
40    def set_batch_second(ew):
41        ew.set_batch_first(False)
42
43    def reset_batch_first(ew):
44        ew.set_batch_first(True)
45
46    tree_map_only(ExpandedWeight, set_batch_second, args)
47    tree_map_only(ExpandedWeight, set_batch_second, kwargs)
48    try:
49        yield
50    finally:
51        tree_map_only(ExpandedWeight, reset_batch_first, args)
52        tree_map_only(ExpandedWeight, reset_batch_first, kwargs)
53
54
55# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch
56@contextmanager
57def allow_smaller_batches(args, kwargs):
58    def allow(ew):
59        ew.set_allow_smaller_batches(True)
60
61    def reset(ew):
62        ew.set_allow_smaller_batches(False)
63
64    tree_map_only(ExpandedWeight, allow, args)
65    tree_map_only(ExpandedWeight, allow, kwargs)
66    try:
67        yield
68    finally:
69        tree_map_only(ExpandedWeight, reset, args)
70        tree_map_only(ExpandedWeight, reset, kwargs)
71
72
73@contextmanager
74def setup_rnn(use_input_variant, args, kwargs):
75    with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(
76        args, kwargs
77    ):
78        yield
79
80
81def implements_per_sample_grads(torch_function):
82    @functools.wraps(torch_function)
83    def decorator(autograd_func):
84        HANDLED_FUNCTIONS[torch_function] = autograd_func
85        return autograd_func
86
87    return decorator
88
89
90# ExpandedWeight represents a weight (parameter) Tensor that has an expanded
91# batch dimension. Operations on the ExpandedWeight Tensor act exactly like
92# those without an expanded batch dimension but a call to .backward() populates
93# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field
94#
95# ExpandedWeight has a fallback that always fails since we cannot know what the batch
96# dimension of the input tensor is and therefore cannot know if this is a valid call
97#
98# This is a __torch_function__ object but it could have also been a Tensor Extension
99# with a dispatch key.
100#
101# Needs to be a tensor subclass to allow reparamaterization
102class ExpandedWeight(torch.Tensor):
103    def __init__(self, orig_weight, batch_size, loss_reduction):
104        self.batch_size = batch_size
105        self.batch_first = True
106        self.allow_smaller_batches = False
107        self.orig_weight = orig_weight
108        self.loss_reduction = loss_reduction
109
110    handled_functions = HANDLED_FUNCTIONS
111
112    def __new__(cls, orig_weight, batch_size, loss_reduction):
113        if not isinstance(orig_weight, torch.Tensor):
114            raise RuntimeError(
115                f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}"
116            )
117        if not orig_weight.requires_grad:
118            raise RuntimeError(
119                "Can only build ExpandedWeights objects of tensors that require_grad"
120            )
121        ret = torch.Tensor._make_subclass(cls, orig_weight, True)
122        return ret
123
124    @classmethod
125    def __torch_function__(cls, func, _, args=(), kwargs=None):
126        if kwargs is None:
127            kwargs = {}
128        if func in expanded_weights_rnn_decomps:
129            # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
130            decomp_opts = expanded_weights_rnn_decomps[func]
131            use_input_variant = isinstance(
132                args[2], list
133            )  # data variant uses a list here
134            decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
135
136            if decomp is not None:
137                with setup_rnn(use_input_variant, args, kwargs):
138                    return decomp(*args, **kwargs)
139        if func == torch._cudnn_rnn_flatten_weight:
140            # since we aren't using the fused cuda kernels for RNNs, don't do this
141            return
142        if func in cls.handled_functions:
143            return cls.handled_functions[func].apply(
144                tuple(kwargs.keys()), func, *(args + tuple(kwargs.values()))
145            )
146        # We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
147        # i.e. torch.add(torch.Tensor, ExpandedWeight)
148        raise RuntimeError(
149            f"Expanded Weights encountered but cannot handle function {func.__name__}"
150        )
151
152    @property
153    def dtype(self):
154        return self.orig_weight.dtype
155
156    @property
157    def data(self):
158        return self.orig_weight.data
159
160    @property
161    def shape(self):
162        return self.orig_weight.shape
163
164    @property
165    def device(self):
166        return self.orig_weight.device
167
168    @property
169    def is_cuda(self):
170        return self.orig_weight.is_cuda
171
172    def data_ptr(self):
173        return self.orig_weight.data_ptr()
174
175    def get_device(self):
176        return self.orig_weight.get_device()
177
178    def set_allow_smaller_batches(self, is_allow_smaller_batches):
179        self.allow_smaller_batches = is_allow_smaller_batches
180
181    def set_batch_first(self, is_batch_first=True):
182        self.batch_first = is_batch_first
183