1import warnings 2from typing import Optional 3 4 5# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h 6 7 8def get_enum(reduction: str) -> int: 9 if reduction == "none": 10 ret = 0 11 elif reduction == "mean": 12 ret = 1 13 elif reduction == "elementwise_mean": 14 warnings.warn( 15 "reduction='elementwise_mean' is deprecated. " 16 "Please use reduction='mean' instead." 17 ) 18 ret = 1 19 elif reduction == "sum": 20 ret = 2 21 else: 22 ret = -1 # TODO: remove once JIT exceptions support control flow 23 raise ValueError(f"{reduction} is not a valid value for reduction") 24 return ret 25 26 27# In order to support previous versions, accept boolean size_average and reduce 28# and convert them into the new constants for now 29 30 31# We use these functions in torch/legacy as well, in which case we'll silence the warning 32def legacy_get_string( 33 size_average: Optional[bool], 34 reduce: Optional[bool], 35 emit_warning: bool = True, 36) -> str: 37 warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." 38 39 if size_average is None: 40 size_average = True 41 if reduce is None: 42 reduce = True 43 44 if size_average and reduce: 45 ret = "mean" 46 elif reduce: 47 ret = "sum" 48 else: 49 ret = "none" 50 if emit_warning: 51 warnings.warn(warning.format(ret)) 52 return ret 53 54 55def legacy_get_enum( 56 size_average: Optional[bool], 57 reduce: Optional[bool], 58 emit_warning: bool = True, 59) -> int: 60 return get_enum(legacy_get_string(size_average, reduce, emit_warning)) 61