xref: /aosp_15_r20/external/pytorch/torch/nn/_reduction.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import warnings
2from typing import Optional
3
4
5# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
6
7
8def get_enum(reduction: str) -> int:
9    if reduction == "none":
10        ret = 0
11    elif reduction == "mean":
12        ret = 1
13    elif reduction == "elementwise_mean":
14        warnings.warn(
15            "reduction='elementwise_mean' is deprecated. "
16            "Please use reduction='mean' instead."
17        )
18        ret = 1
19    elif reduction == "sum":
20        ret = 2
21    else:
22        ret = -1  # TODO: remove once JIT exceptions support control flow
23        raise ValueError(f"{reduction} is not a valid value for reduction")
24    return ret
25
26
27# In order to support previous versions, accept boolean size_average and reduce
28# and convert them into the new constants for now
29
30
31# We use these functions in torch/legacy as well, in which case we'll silence the warning
32def legacy_get_string(
33    size_average: Optional[bool],
34    reduce: Optional[bool],
35    emit_warning: bool = True,
36) -> str:
37    warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
38
39    if size_average is None:
40        size_average = True
41    if reduce is None:
42        reduce = True
43
44    if size_average and reduce:
45        ret = "mean"
46    elif reduce:
47        ret = "sum"
48    else:
49        ret = "none"
50    if emit_warning:
51        warnings.warn(warning.format(ret))
52    return ret
53
54
55def legacy_get_enum(
56    size_average: Optional[bool],
57    reduce: Optional[bool],
58    emit_warning: bool = True,
59) -> int:
60    return get_enum(legacy_get_string(size_average, reduce, emit_warning))
61