xref: /aosp_15_r20/external/pytorch/torch/optim/rprop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerr"""Implementation for the Resilient backpropagation."""
4*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
10*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
11*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
12*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
13*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
14*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
15*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
16*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
17*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
18*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
19*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
20*da0073e9SAndroid Build Coastguard Worker    Optimizer,
21*da0073e9SAndroid Build Coastguard Worker    ParamsT,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker__all__ = ["Rprop", "rprop"]
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass Rprop(Optimizer):  # noqa: D101
29*da0073e9SAndroid Build Coastguard Worker    def __init__(
30*da0073e9SAndroid Build Coastguard Worker        self,
31*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
32*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1e-2,
33*da0073e9SAndroid Build Coastguard Worker        etas: Tuple[float, float] = (0.5, 1.2),
34*da0073e9SAndroid Build Coastguard Worker        step_sizes: Tuple[float, float] = (1e-6, 50),
35*da0073e9SAndroid Build Coastguard Worker        *,
36*da0073e9SAndroid Build Coastguard Worker        capturable: bool = False,
37*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
38*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
39*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
40*da0073e9SAndroid Build Coastguard Worker    ):  # noqa: D107
41*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
42*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
43*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
44*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
45*da0073e9SAndroid Build Coastguard Worker        if not 0.0 < etas[0] < 1.0 < etas[1]:
46*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
49*da0073e9SAndroid Build Coastguard Worker            lr=lr,
50*da0073e9SAndroid Build Coastguard Worker            etas=etas,
51*da0073e9SAndroid Build Coastguard Worker            step_sizes=step_sizes,
52*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
53*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
54*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
55*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
56*da0073e9SAndroid Build Coastguard Worker        )
57*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):  # noqa: D105
60*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
61*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
62*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
63*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
64*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
65*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
66*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
67*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
68*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
69*da0073e9SAndroid Build Coastguard Worker                    step_val = float(p_state["step"])
70*da0073e9SAndroid Build Coastguard Worker                    p_state["step"] = (
71*da0073e9SAndroid Build Coastguard Worker                        torch.tensor(
72*da0073e9SAndroid Build Coastguard Worker                            step_val, dtype=_get_scalar_dtype(), device=p.device
73*da0073e9SAndroid Build Coastguard Worker                        )
74*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
75*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
76*da0073e9SAndroid Build Coastguard Worker                    )
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def _init_group(self, group, params, grads, prevs, step_sizes, state_steps):
79*da0073e9SAndroid Build Coastguard Worker        has_complex = False
80*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
81*da0073e9SAndroid Build Coastguard Worker            if p.grad is None:
82*da0073e9SAndroid Build Coastguard Worker                continue
83*da0073e9SAndroid Build Coastguard Worker            has_complex |= torch.is_complex(p)
84*da0073e9SAndroid Build Coastguard Worker            params.append(p)
85*da0073e9SAndroid Build Coastguard Worker            grad = p.grad
86*da0073e9SAndroid Build Coastguard Worker            if grad.is_sparse:
87*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Rprop does not support sparse gradients")
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker            grads.append(grad)
90*da0073e9SAndroid Build Coastguard Worker            state = self.state[p]
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker            # State initialization
93*da0073e9SAndroid Build Coastguard Worker            if len(state) == 0:
94*da0073e9SAndroid Build Coastguard Worker                state["step"] = (
95*da0073e9SAndroid Build Coastguard Worker                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
96*da0073e9SAndroid Build Coastguard Worker                    if group["capturable"]
97*da0073e9SAndroid Build Coastguard Worker                    else torch.zeros((), dtype=_get_scalar_dtype())
98*da0073e9SAndroid Build Coastguard Worker                )
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker                state["prev"] = torch.zeros_like(p, memory_format=torch.preserve_format)
101*da0073e9SAndroid Build Coastguard Worker                if p.dtype.is_complex:
102*da0073e9SAndroid Build Coastguard Worker                    # Complex Number should be as if they are two independent real numbers.
103*da0073e9SAndroid Build Coastguard Worker                    # Hence the step_size shouldn't be zero for imaginary part.
104*da0073e9SAndroid Build Coastguard Worker                    state["step_size"] = torch.full_like(
105*da0073e9SAndroid Build Coastguard Worker                        grad, complex(group["lr"], group["lr"])
106*da0073e9SAndroid Build Coastguard Worker                    )
107*da0073e9SAndroid Build Coastguard Worker                else:
108*da0073e9SAndroid Build Coastguard Worker                    state["step_size"] = torch.full_like(grad, group["lr"])
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker            prevs.append(state["prev"])
111*da0073e9SAndroid Build Coastguard Worker            step_sizes.append(state["step_size"])
112*da0073e9SAndroid Build Coastguard Worker            state_steps.append(state["step"])
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        return has_complex
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
117*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
118*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        Args:
121*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
122*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
123*da0073e9SAndroid Build Coastguard Worker        """
124*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        loss = None
127*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
128*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
129*da0073e9SAndroid Build Coastguard Worker                loss = closure()
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
132*da0073e9SAndroid Build Coastguard Worker            params: List[Tensor] = []
133*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
134*da0073e9SAndroid Build Coastguard Worker            prevs: List[Tensor] = []
135*da0073e9SAndroid Build Coastguard Worker            step_sizes: List[Tensor] = []
136*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker            etaminus, etaplus = group["etas"]
139*da0073e9SAndroid Build Coastguard Worker            step_size_min, step_size_max = group["step_sizes"]
140*da0073e9SAndroid Build Coastguard Worker            foreach = group["foreach"]
141*da0073e9SAndroid Build Coastguard Worker            maximize = group["maximize"]
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
144*da0073e9SAndroid Build Coastguard Worker                group, params, grads, prevs, step_sizes, state_steps
145*da0073e9SAndroid Build Coastguard Worker            )
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker            rprop(
148*da0073e9SAndroid Build Coastguard Worker                params,
149*da0073e9SAndroid Build Coastguard Worker                grads,
150*da0073e9SAndroid Build Coastguard Worker                prevs,
151*da0073e9SAndroid Build Coastguard Worker                step_sizes,
152*da0073e9SAndroid Build Coastguard Worker                state_steps,
153*da0073e9SAndroid Build Coastguard Worker                step_size_min=step_size_min,
154*da0073e9SAndroid Build Coastguard Worker                step_size_max=step_size_max,
155*da0073e9SAndroid Build Coastguard Worker                etaminus=etaminus,
156*da0073e9SAndroid Build Coastguard Worker                etaplus=etaplus,
157*da0073e9SAndroid Build Coastguard Worker                foreach=foreach,
158*da0073e9SAndroid Build Coastguard Worker                maximize=maximize,
159*da0073e9SAndroid Build Coastguard Worker                differentiable=group["differentiable"],
160*da0073e9SAndroid Build Coastguard Worker                capturable=group["capturable"],
161*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
162*da0073e9SAndroid Build Coastguard Worker            )
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        return loss
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard WorkerRprop.__doc__ = (
168*da0073e9SAndroid Build Coastguard Worker    r"""Implements the resilient backpropagation algorithm.
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    .. math::
171*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
172*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
173*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
174*da0073e9SAndroid Build Coastguard Worker                \text{ (objective)},                                                             \\
175*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm}      \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
176*da0073e9SAndroid Build Coastguard Worker                \text{ (step sizes)}                                                             \\
177*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} :   g^0_{prev} \leftarrow 0,
178*da0073e9SAndroid Build Coastguard Worker                \: \eta_0 \leftarrow \text{lr (learning rate)}                                   \\
179*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
180*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
181*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
182*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \textbf{for} \text{  } i = 0, 1, \ldots, d-1 \: \mathbf{do}            \\
183*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}  \textbf{if} \:   g^i_{prev} g^i_t  > 0                               \\
184*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm}  \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
185*da0073e9SAndroid Build Coastguard Worker                \Gamma_{max})                                                                    \\
186*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}  \textbf{else if}  \:  g^i_{prev} g^i_t < 0                           \\
187*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm}  \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
188*da0073e9SAndroid Build Coastguard Worker                \Gamma_{min})                                                                    \\
189*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm}  g^i_t \leftarrow 0                                                   \\
190*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}  \textbf{else}  \:                                                    \\
191*da0073e9SAndroid Build Coastguard Worker            &\hspace{15mm}  \eta^i_t \leftarrow \eta^i_{t-1}                                     \\
192*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t)             \\
193*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}g_{prev} \leftarrow  g_t                                                \\
194*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
195*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
196*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
197*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to the paper
200*da0073e9SAndroid Build Coastguard Worker    `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
201*da0073e9SAndroid Build Coastguard Worker    <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
202*da0073e9SAndroid Build Coastguard Worker    """
203*da0073e9SAndroid Build Coastguard Worker    + rf"""
204*da0073e9SAndroid Build Coastguard Worker    Args:
205*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
206*da0073e9SAndroid Build Coastguard Worker            parameter groups
207*da0073e9SAndroid Build Coastguard Worker        lr (float, optional): learning rate (default: 1e-2)
208*da0073e9SAndroid Build Coastguard Worker        etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
209*da0073e9SAndroid Build Coastguard Worker            are multiplicative increase and decrease factors
210*da0073e9SAndroid Build Coastguard Worker            (default: (0.5, 1.2))
211*da0073e9SAndroid Build Coastguard Worker        step_sizes (Tuple[float, float], optional): a pair of minimal and
212*da0073e9SAndroid Build Coastguard Worker            maximal allowed step sizes (default: (1e-6, 50))
213*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
214*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
215*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
216*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    """
219*da0073e9SAndroid Build Coastguard Worker)
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_rprop(
223*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
224*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
225*da0073e9SAndroid Build Coastguard Worker    prevs: List[Tensor],
226*da0073e9SAndroid Build Coastguard Worker    step_sizes: List[Tensor],
227*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
228*da0073e9SAndroid Build Coastguard Worker    *,
229*da0073e9SAndroid Build Coastguard Worker    step_size_min: float,
230*da0073e9SAndroid Build Coastguard Worker    step_size_max: float,
231*da0073e9SAndroid Build Coastguard Worker    etaminus: float,
232*da0073e9SAndroid Build Coastguard Worker    etaplus: float,
233*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
234*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
235*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
236*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
237*da0073e9SAndroid Build Coastguard Worker):
238*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
239*da0073e9SAndroid Build Coastguard Worker        grad = grads[i]
240*da0073e9SAndroid Build Coastguard Worker        grad = grad if not maximize else -grad
241*da0073e9SAndroid Build Coastguard Worker        prev = prevs[i]
242*da0073e9SAndroid Build Coastguard Worker        step_size = step_sizes[i]
243*da0073e9SAndroid Build Coastguard Worker        step = state_steps[i]
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
246*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and capturable:
247*da0073e9SAndroid Build Coastguard Worker            capturable_supported_devices = _get_capturable_supported_devices()
248*da0073e9SAndroid Build Coastguard Worker            assert (
249*da0073e9SAndroid Build Coastguard Worker                param.device.type == step.device.type
250*da0073e9SAndroid Build Coastguard Worker                and param.device.type in capturable_supported_devices
251*da0073e9SAndroid Build Coastguard Worker            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        step += 1
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
256*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
257*da0073e9SAndroid Build Coastguard Worker            prev = torch.view_as_real(prev)
258*da0073e9SAndroid Build Coastguard Worker            param = torch.view_as_real(param)
259*da0073e9SAndroid Build Coastguard Worker            step_size = torch.view_as_real(step_size)
260*da0073e9SAndroid Build Coastguard Worker        if differentiable:
261*da0073e9SAndroid Build Coastguard Worker            sign = grad.mul(prev.clone()).sign()
262*da0073e9SAndroid Build Coastguard Worker        else:
263*da0073e9SAndroid Build Coastguard Worker            sign = grad.mul(prev).sign()
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        if capturable:
266*da0073e9SAndroid Build Coastguard Worker            sign.copy_(torch.where(sign.gt(0), etaplus, sign))
267*da0073e9SAndroid Build Coastguard Worker            sign.copy_(torch.where(sign.lt(0), etaminus, sign))
268*da0073e9SAndroid Build Coastguard Worker            sign.copy_(torch.where(sign.eq(0), 1, sign))
269*da0073e9SAndroid Build Coastguard Worker        else:
270*da0073e9SAndroid Build Coastguard Worker            sign[sign.gt(0)] = etaplus
271*da0073e9SAndroid Build Coastguard Worker            sign[sign.lt(0)] = etaminus
272*da0073e9SAndroid Build Coastguard Worker            sign[sign.eq(0)] = 1
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        # update stepsizes with step size updates
275*da0073e9SAndroid Build Coastguard Worker        step_size.mul_(sign).clamp_(step_size_min, step_size_max)
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        # for dir<0, dfdx=0
278*da0073e9SAndroid Build Coastguard Worker        # for dir>=0 dfdx=dfdx
279*da0073e9SAndroid Build Coastguard Worker        grad = grad.clone(memory_format=torch.preserve_format)
280*da0073e9SAndroid Build Coastguard Worker        if capturable:
281*da0073e9SAndroid Build Coastguard Worker            grad.copy_(torch.where(sign.eq(etaminus), 0, grad))
282*da0073e9SAndroid Build Coastguard Worker        else:
283*da0073e9SAndroid Build Coastguard Worker            grad[sign.eq(etaminus)] = 0
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        # update parameters
286*da0073e9SAndroid Build Coastguard Worker        param.addcmul_(grad.sign(), step_size, value=-1)
287*da0073e9SAndroid Build Coastguard Worker        prev.copy_(grad)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_rprop(
291*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
292*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
293*da0073e9SAndroid Build Coastguard Worker    prevs: List[Tensor],
294*da0073e9SAndroid Build Coastguard Worker    step_sizes: List[Tensor],
295*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
296*da0073e9SAndroid Build Coastguard Worker    *,
297*da0073e9SAndroid Build Coastguard Worker    step_size_min: float,
298*da0073e9SAndroid Build Coastguard Worker    step_size_max: float,
299*da0073e9SAndroid Build Coastguard Worker    etaminus: float,
300*da0073e9SAndroid Build Coastguard Worker    etaplus: float,
301*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
302*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
303*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
304*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
305*da0073e9SAndroid Build Coastguard Worker):
306*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
307*da0073e9SAndroid Build Coastguard Worker        return
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
312*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
313*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices()
314*da0073e9SAndroid Build Coastguard Worker        assert all(
315*da0073e9SAndroid Build Coastguard Worker            p.device.type == step.device.type
316*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
317*da0073e9SAndroid Build Coastguard Worker            for p, step in zip(params, state_steps)
318*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
321*da0073e9SAndroid Build Coastguard Worker        [params, grads, prevs, step_sizes, state_steps]  # type: ignore[list-item]
322*da0073e9SAndroid Build Coastguard Worker    )
323*da0073e9SAndroid Build Coastguard Worker    for (
324*da0073e9SAndroid Build Coastguard Worker        grouped_params_,
325*da0073e9SAndroid Build Coastguard Worker        grouped_grads_,
326*da0073e9SAndroid Build Coastguard Worker        grouped_prevs_,
327*da0073e9SAndroid Build Coastguard Worker        grouped_step_sizes_,
328*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps_,
329*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensors.values():
330*da0073e9SAndroid Build Coastguard Worker        grouped_params = cast(List[Tensor], grouped_params_)
331*da0073e9SAndroid Build Coastguard Worker        grouped_grads = cast(List[Tensor], grouped_grads_)
332*da0073e9SAndroid Build Coastguard Worker        grouped_prevs = cast(List[Tensor], grouped_prevs_)
333*da0073e9SAndroid Build Coastguard Worker        grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_)
334*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        # Update steps
337*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
338*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
339*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
340*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
341*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
342*da0073e9SAndroid Build Coastguard Worker                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
343*da0073e9SAndroid Build Coastguard Worker            )
344*da0073e9SAndroid Build Coastguard Worker        else:
345*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(grouped_state_steps, 1)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker        # Handle complex params
348*da0073e9SAndroid Build Coastguard Worker        if has_complex:
349*da0073e9SAndroid Build Coastguard Worker            _view_as_real(
350*da0073e9SAndroid Build Coastguard Worker                grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes
351*da0073e9SAndroid Build Coastguard Worker            )
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        signs = torch._foreach_mul(grouped_grads, grouped_prevs)
354*da0073e9SAndroid Build Coastguard Worker        if maximize:
355*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(signs)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        # At the end of the step, grouped_prevs will contain the current grads, so we reuse
358*da0073e9SAndroid Build Coastguard Worker        # grouped_prevs memory instead of creating a new buffer, but, for clarity, we reassign
359*da0073e9SAndroid Build Coastguard Worker        # to keep referring to the buffer as grouped_grads.
360*da0073e9SAndroid Build Coastguard Worker        torch._foreach_copy_(grouped_prevs, grouped_grads)
361*da0073e9SAndroid Build Coastguard Worker        if maximize:
362*da0073e9SAndroid Build Coastguard Worker            torch._foreach_neg_(grouped_prevs)
363*da0073e9SAndroid Build Coastguard Worker        grouped_grads = grouped_prevs
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        torch._foreach_sign_(signs)
366*da0073e9SAndroid Build Coastguard Worker        if capturable:
367*da0073e9SAndroid Build Coastguard Worker            for sign in signs:
368*da0073e9SAndroid Build Coastguard Worker                sign.copy_(torch.where(sign.gt(0), etaplus, sign))
369*da0073e9SAndroid Build Coastguard Worker                sign.copy_(torch.where(sign.lt(0), etaminus, sign))
370*da0073e9SAndroid Build Coastguard Worker                sign.copy_(torch.where(sign.eq(0), 1, sign))
371*da0073e9SAndroid Build Coastguard Worker        else:
372*da0073e9SAndroid Build Coastguard Worker            for sign in signs:
373*da0073e9SAndroid Build Coastguard Worker                sign[sign.gt(0)] = etaplus
374*da0073e9SAndroid Build Coastguard Worker                sign[sign.lt(0)] = etaminus
375*da0073e9SAndroid Build Coastguard Worker                sign[sign.eq(0)] = 1
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        # update stepsizes with step size updates
378*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(grouped_step_sizes, signs)
379*da0073e9SAndroid Build Coastguard Worker        for step_size in grouped_step_sizes:
380*da0073e9SAndroid Build Coastguard Worker            step_size.clamp_(step_size_min, step_size_max)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        # for dir<0, dfdx=0
383*da0073e9SAndroid Build Coastguard Worker        # for dir>=0 dfdx=dfdx
384*da0073e9SAndroid Build Coastguard Worker        grouped_grads = list(grouped_grads)
385*da0073e9SAndroid Build Coastguard Worker        for i in range(len(grouped_grads)):
386*da0073e9SAndroid Build Coastguard Worker            grouped_grads[i].copy_(
387*da0073e9SAndroid Build Coastguard Worker                torch.where(signs[i].eq(etaminus), 0, grouped_grads[i])
388*da0073e9SAndroid Build Coastguard Worker            )
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        # explicitly del signs as it's not used after here to save memory
391*da0073e9SAndroid Build Coastguard Worker        del signs
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        # update parameters
394*da0073e9SAndroid Build Coastguard Worker        grad_signs = [grad.sign() for grad in grouped_grads]
395*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(
396*da0073e9SAndroid Build Coastguard Worker            grouped_params, grad_signs, grouped_step_sizes, value=-1
397*da0073e9SAndroid Build Coastguard Worker        )
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        # Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's
400*da0073e9SAndroid Build Coastguard Worker        # basically already happened since we've been using grouped_prevs' memory to store
401*da0073e9SAndroid Build Coastguard Worker        # updated grouped_grads!
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
405*da0073e9SAndroid Build Coastguard Workerdef rprop(
406*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
407*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
408*da0073e9SAndroid Build Coastguard Worker    prevs: List[Tensor],
409*da0073e9SAndroid Build Coastguard Worker    step_sizes: List[Tensor],
410*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
411*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
412*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
413*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
414*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
415*da0073e9SAndroid Build Coastguard Worker    maximize: bool = False,
416*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
417*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
418*da0073e9SAndroid Build Coastguard Worker    *,
419*da0073e9SAndroid Build Coastguard Worker    step_size_min: float,
420*da0073e9SAndroid Build Coastguard Worker    step_size_max: float,
421*da0073e9SAndroid Build Coastguard Worker    etaminus: float,
422*da0073e9SAndroid Build Coastguard Worker    etaplus: float,
423*da0073e9SAndroid Build Coastguard Worker):
424*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs rprop algorithm computation.
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.Rprop` for details.
427*da0073e9SAndroid Build Coastguard Worker    """
428*da0073e9SAndroid Build Coastguard Worker    # this check is slow during compilation, so we skip it
429*da0073e9SAndroid Build Coastguard Worker    # if it's strictly needed we can add this check back in dynamo
430*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and not all(
431*da0073e9SAndroid Build Coastguard Worker        isinstance(t, torch.Tensor) for t in state_steps
432*da0073e9SAndroid Build Coastguard Worker    ):
433*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
434*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
435*da0073e9SAndroid Build Coastguard Worker        )
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
438*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
439*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
440*da0073e9SAndroid Build Coastguard Worker        )
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
443*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
446*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_rprop
447*da0073e9SAndroid Build Coastguard Worker    else:
448*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_rprop
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker    func(
451*da0073e9SAndroid Build Coastguard Worker        params,
452*da0073e9SAndroid Build Coastguard Worker        grads,
453*da0073e9SAndroid Build Coastguard Worker        prevs,
454*da0073e9SAndroid Build Coastguard Worker        step_sizes,
455*da0073e9SAndroid Build Coastguard Worker        state_steps,
456*da0073e9SAndroid Build Coastguard Worker        step_size_min=step_size_min,
457*da0073e9SAndroid Build Coastguard Worker        step_size_max=step_size_max,
458*da0073e9SAndroid Build Coastguard Worker        etaminus=etaminus,
459*da0073e9SAndroid Build Coastguard Worker        etaplus=etaplus,
460*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
461*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
462*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
463*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
464*da0073e9SAndroid Build Coastguard Worker    )
465