xref: /aosp_15_r20/external/pytorch/torch/utils/_sympy/numbers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import mpmath.libmp as mlib  # type: ignore[import-untyped]
3import sympy
4from sympy import Expr
5from sympy.core.decorators import _sympifyit
6from sympy.core.expr import AtomicExpr
7from sympy.core.numbers import Number
8from sympy.core.parameters import global_parameters
9from sympy.core.singleton import S, Singleton
10
11
12class IntInfinity(Number, metaclass=Singleton):
13    r"""Positive integer infinite quantity.
14
15    Integer infinity is a value in an extended integers which
16    is greater than all other integers.  We distinguish it from
17    sympy's existing notion of infinity in that it reports that
18    it is_integer.
19
20    Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
21    or can be imported as ``int_oo``.
22    """
23
24    # NB: We can't actually mark this as infinite, as integer and infinite are
25    # inconsistent assumptions in sympy.  We also report that we are complex,
26    # different from sympy.oo
27
28    is_integer = True
29    is_commutative = True
30    is_number = True
31    is_extended_real = True
32    is_comparable = True
33    is_extended_positive = True
34    is_prime = False
35
36    # Ensure we get dispatched to before plain numbers
37    _op_priority = 100.0
38
39    __slots__ = ()
40
41    def __new__(cls):
42        return AtomicExpr.__new__(cls)
43
44    def _sympystr(self, printer):
45        return "int_oo"
46
47    def _eval_subs(self, old, new):
48        if self == old:
49            return new
50
51    # We could do these, not sure about it
52    """
53    def _eval_evalf(self, prec=None):
54        return Float('inf')
55
56    def evalf(self, prec=None, **options):
57        return self._eval_evalf(prec)
58    """
59
60    @_sympifyit("other", NotImplemented)
61    def __add__(self, other):
62        if isinstance(other, Number) and global_parameters.evaluate:
63            if other in (S.Infinity, S.NegativeInfinity):
64                return other
65            if other in (S.NegativeIntInfinity, S.NaN):
66                return S.NaN
67            return self
68        return Number.__add__(self, other)
69
70    __radd__ = __add__
71
72    @_sympifyit("other", NotImplemented)
73    def __sub__(self, other):
74        if isinstance(other, Number) and global_parameters.evaluate:
75            if other is S.Infinity:
76                return S.NegativeInfinity
77            if other is S.NegativeInfinity:
78                return S.Infinity
79            if other in (S.IntInfinity, S.NaN):
80                return S.NaN
81            return self
82        return Number.__sub__(self, other)
83
84    @_sympifyit("other", NotImplemented)
85    def __rsub__(self, other):
86        return (-self).__add__(other)
87
88    @_sympifyit("other", NotImplemented)
89    def __mul__(self, other):
90        if isinstance(other, Number) and global_parameters.evaluate:
91            if other.is_zero or other is S.NaN:
92                return S.NaN
93            if other.is_extended_positive:
94                return self
95            return S.NegativeIntInfinity
96        return Number.__mul__(self, other)
97
98    __rmul__ = __mul__
99
100    @_sympifyit("other", NotImplemented)
101    def __truediv__(self, other):
102        if isinstance(other, Number) and global_parameters.evaluate:
103            if other in (
104                S.Infinity,
105                S.IntInfinity,
106                S.NegativeInfinity,
107                S.NegativeIntInfinity,
108                S.NaN,
109            ):
110                return S.NaN
111            if other.is_extended_nonnegative:
112                return S.Infinity  # truediv produces float
113            return S.NegativeInfinity  # truediv produces float
114        return Number.__truediv__(self, other)
115
116    def __abs__(self):
117        return S.IntInfinity
118
119    def __neg__(self):
120        return S.NegativeIntInfinity
121
122    def _eval_power(self, expt):
123        if expt.is_extended_positive:
124            return S.IntInfinity
125        if expt.is_extended_negative:
126            return S.Zero
127        if expt is S.NaN:
128            return S.NaN
129        if expt is S.ComplexInfinity:
130            return S.NaN
131        if expt.is_extended_real is False and expt.is_number:
132            from sympy.functions.elementary.complexes import re
133
134            expt_real = re(expt)
135            if expt_real.is_positive:
136                return S.ComplexInfinity
137            if expt_real.is_negative:
138                return S.Zero
139            if expt_real.is_zero:
140                return S.NaN
141
142            return self ** expt.evalf()
143
144    def _as_mpf_val(self, prec):
145        return mlib.finf
146
147    def __hash__(self):
148        return super().__hash__()
149
150    def __eq__(self, other):
151        return other is S.IntInfinity
152
153    def __ne__(self, other):
154        return other is not S.IntInfinity
155
156    def __gt__(self, other):
157        if other is S.Infinity:
158            return sympy.false  # sympy.oo > int_oo
159        elif other is S.IntInfinity:
160            return sympy.false  # consistency with sympy.oo
161        else:
162            return sympy.true
163
164    def __ge__(self, other):
165        if other is S.Infinity:
166            return sympy.false  # sympy.oo > int_oo
167        elif other is S.IntInfinity:
168            return sympy.true  # consistency with sympy.oo
169        else:
170            return sympy.true
171
172    def __lt__(self, other):
173        if other is S.Infinity:
174            return sympy.true  # sympy.oo > int_oo
175        elif other is S.IntInfinity:
176            return sympy.false  # consistency with sympy.oo
177        else:
178            return sympy.false
179
180    def __le__(self, other):
181        if other is S.Infinity:
182            return sympy.true  # sympy.oo > int_oo
183        elif other is S.IntInfinity:
184            return sympy.true  # consistency with sympy.oo
185        else:
186            return sympy.false
187
188    @_sympifyit("other", NotImplemented)
189    def __mod__(self, other):
190        if not isinstance(other, Expr):
191            return NotImplemented
192        return S.NaN
193
194    __rmod__ = __mod__
195
196    def floor(self):
197        return self
198
199    def ceiling(self):
200        return self
201
202
203int_oo = S.IntInfinity
204
205
206class NegativeIntInfinity(Number, metaclass=Singleton):
207    """Negative integer infinite quantity.
208
209    NegativeInfinity is a singleton, and can be accessed
210    by ``S.NegativeInfinity``.
211
212    See Also
213    ========
214
215    IntInfinity
216    """
217
218    # Ensure we get dispatched to before plain numbers
219    _op_priority = 100.0
220
221    is_integer = True
222    is_extended_real = True
223    is_commutative = True
224    is_comparable = True
225    is_extended_negative = True
226    is_number = True
227    is_prime = False
228
229    __slots__ = ()
230
231    def __new__(cls):
232        return AtomicExpr.__new__(cls)
233
234    def _eval_subs(self, old, new):
235        if self == old:
236            return new
237
238    def _sympystr(self, printer):
239        return "-int_oo"
240
241    """
242    def _eval_evalf(self, prec=None):
243        return Float('-inf')
244
245    def evalf(self, prec=None, **options):
246        return self._eval_evalf(prec)
247    """
248
249    @_sympifyit("other", NotImplemented)
250    def __add__(self, other):
251        if isinstance(other, Number) and global_parameters.evaluate:
252            if other is S.Infinity:
253                return S.Infinity
254            if other in (S.IntInfinity, S.NaN):
255                return S.NaN
256            return self
257        return Number.__add__(self, other)
258
259    __radd__ = __add__
260
261    @_sympifyit("other", NotImplemented)
262    def __sub__(self, other):
263        if isinstance(other, Number) and global_parameters.evaluate:
264            if other is S.NegativeInfinity:
265                return S.Infinity
266            if other in (S.NegativeIntInfinity, S.NaN):
267                return S.NaN
268            return self
269        return Number.__sub__(self, other)
270
271    @_sympifyit("other", NotImplemented)
272    def __rsub__(self, other):
273        return (-self).__add__(other)
274
275    @_sympifyit("other", NotImplemented)
276    def __mul__(self, other):
277        if isinstance(other, Number) and global_parameters.evaluate:
278            if other.is_zero or other is S.NaN:
279                return S.NaN
280            if other.is_extended_positive:
281                return self
282            return S.IntInfinity
283        return Number.__mul__(self, other)
284
285    __rmul__ = __mul__
286
287    @_sympifyit("other", NotImplemented)
288    def __truediv__(self, other):
289        if isinstance(other, Number) and global_parameters.evaluate:
290            if other in (
291                S.Infinity,
292                S.IntInfinity,
293                S.NegativeInfinity,
294                S.NegativeIntInfinity,
295                S.NaN,
296            ):
297                return S.NaN
298            if other.is_extended_nonnegative:
299                return self
300            return S.Infinity  # truediv returns float
301        return Number.__truediv__(self, other)
302
303    def __abs__(self):
304        return S.IntInfinity
305
306    def __neg__(self):
307        return S.IntInfinity
308
309    def _eval_power(self, expt):
310        if expt.is_number:
311            if expt in (
312                S.NaN,
313                S.Infinity,
314                S.NegativeInfinity,
315                S.IntInfinity,
316                S.NegativeIntInfinity,
317            ):
318                return S.NaN
319
320            if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
321                if expt.is_odd:
322                    return S.NegativeIntInfinity
323                else:
324                    return S.IntInfinity
325
326            inf_part = S.IntInfinity**expt
327            s_part = S.NegativeOne**expt
328            if inf_part == 0 and s_part.is_finite:
329                return inf_part
330            if (
331                inf_part is S.ComplexInfinity
332                and s_part.is_finite
333                and not s_part.is_zero
334            ):
335                return S.ComplexInfinity
336            return s_part * inf_part
337
338    def _as_mpf_val(self, prec):
339        return mlib.fninf
340
341    def __hash__(self):
342        return super().__hash__()
343
344    def __eq__(self, other):
345        return other is S.NegativeIntInfinity
346
347    def __ne__(self, other):
348        return other is not S.NegativeIntInfinity
349
350    def __gt__(self, other):
351        if other is S.NegativeInfinity:
352            return sympy.true  # -sympy.oo < -int_oo
353        elif other is S.NegativeIntInfinity:
354            return sympy.false  # consistency with sympy.oo
355        else:
356            return sympy.false
357
358    def __ge__(self, other):
359        if other is S.NegativeInfinity:
360            return sympy.true  # -sympy.oo < -int_oo
361        elif other is S.NegativeIntInfinity:
362            return sympy.true  # consistency with sympy.oo
363        else:
364            return sympy.false
365
366    def __lt__(self, other):
367        if other is S.NegativeInfinity:
368            return sympy.false  # -sympy.oo < -int_oo
369        elif other is S.NegativeIntInfinity:
370            return sympy.false  # consistency with sympy.oo
371        else:
372            return sympy.true
373
374    def __le__(self, other):
375        if other is S.NegativeInfinity:
376            return sympy.false  # -sympy.oo < -int_oo
377        elif other is S.NegativeIntInfinity:
378            return sympy.true  # consistency with sympy.oo
379        else:
380            return sympy.true
381
382    @_sympifyit("other", NotImplemented)
383    def __mod__(self, other):
384        if not isinstance(other, Expr):
385            return NotImplemented
386        return S.NaN
387
388    __rmod__ = __mod__
389
390    def floor(self):
391        return self
392
393    def ceiling(self):
394        return self
395
396    def as_powers_dict(self):
397        return {S.NegativeOne: 1, S.IntInfinity: 1}
398