1# mypy: allow-untyped-defs 2"""Globals used internally by the ONNX exporter. 3 4Do not use this module outside of `torch.onnx` and its tests. 5 6Be very judicious when adding any new global variables. Do not create new global 7variables unless they are absolutely necessary. 8""" 9 10import torch._C._onnx as _C_onnx 11 12# This module should only depend on _constants and nothing else in torch.onnx to keep 13# dependency direction clean. 14from torch.onnx import _constants 15 16 17class _InternalGlobals: 18 """Globals used internally by ONNX exporter. 19 20 NOTE: Be very judicious when adding any new variables. Do not create new 21 global variables unless they are absolutely necessary. 22 """ 23 24 def __init__(self) -> None: 25 self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET 26 self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL 27 self._in_onnx_export: bool = False 28 # Whether the user's model is training during export 29 self.export_training: bool = False 30 self.operator_export_type: _C_onnx.OperatorExportTypes = ( 31 _C_onnx.OperatorExportTypes.ONNX 32 ) 33 self.onnx_shape_inference: bool = True 34 self._autograd_inlining: bool = True 35 36 @property 37 def training_mode(self): 38 """The training mode for the exporter.""" 39 return self._training_mode 40 41 @training_mode.setter 42 def training_mode(self, training_mode: _C_onnx.TrainingMode): 43 if not isinstance(training_mode, _C_onnx.TrainingMode): 44 raise TypeError( 45 "training_mode must be of type 'torch.onnx.TrainingMode'. This is " 46 "likely a bug in torch.onnx." 47 ) 48 self._training_mode = training_mode 49 50 @property 51 def export_onnx_opset_version(self) -> int: 52 """Opset version used during export.""" 53 return self._export_onnx_opset_version 54 55 @export_onnx_opset_version.setter 56 def export_onnx_opset_version(self, value: int): 57 supported_versions = range( 58 _constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1 59 ) 60 if value not in supported_versions: 61 raise ValueError(f"Unsupported ONNX opset version: {value}") 62 self._export_onnx_opset_version = value 63 64 @property 65 def in_onnx_export(self) -> bool: 66 """Whether it is in the middle of ONNX export.""" 67 return self._in_onnx_export 68 69 @in_onnx_export.setter 70 def in_onnx_export(self, value: bool): 71 if type(value) is not bool: 72 raise TypeError("in_onnx_export must be a boolean") 73 self._in_onnx_export = value 74 75 @property 76 def autograd_inlining(self) -> bool: 77 """Whether Autograd must be inlined.""" 78 return self._autograd_inlining 79 80 @autograd_inlining.setter 81 def autograd_inlining(self, value: bool): 82 if type(value) is not bool: 83 raise TypeError("autograd_inlining must be a boolean") 84 self._autograd_inlining = value 85 86 87GLOBALS = _InternalGlobals() 88