xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/core/ag_ctx.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Thread-local context managers for AutoGraph."""
16
17import enum
18import inspect
19import threading
20
21from tensorflow.python.autograph.utils import ag_logging
22from tensorflow.python.util.tf_export import tf_export
23
24
25stacks = threading.local()
26
27
28def _control_ctx():
29  if not hasattr(stacks, 'control_status'):
30    stacks.control_status = [_default_control_status_ctx()]
31  return stacks.control_status
32
33
34@tf_export('__internal__.autograph.control_status_ctx', v1=[])
35def control_status_ctx():
36  """Returns the current control context for autograph.
37
38  This method is useful when calling `tf.__internal__.autograph.tf_convert`,
39  The context will be used by tf_convert to determine whether it should convert
40  the input function. See the sample usage like below:
41
42  ```
43  def foo(func):
44    return tf.__internal__.autograph.tf_convert(
45       input_fn, ctx=tf.__internal__.autograph.control_status_ctx())()
46  ```
47
48  Returns:
49    The current control context of autograph.
50  """
51  ret = _control_ctx()[-1]
52  return ret
53
54
55class Status(enum.Enum):
56  UNSPECIFIED = 0
57  ENABLED = 1
58  DISABLED = 2
59
60
61class ControlStatusCtx(object):
62  """A context that tracks whether autograph is enabled by the user."""
63
64  def __init__(self, status, options=None):
65    self.status = status
66    self.options = options
67
68  def __enter__(self):
69    _control_ctx().append(self)
70    return self
71
72  def __repr__(self):
73    return '{}[status={}, options={}]'.format(
74        self.__class__.__name__, self.status, self.options)
75
76  def __exit__(self, unused_type, unused_value, unused_traceback):
77    assert _control_ctx()[-1] is self
78    _control_ctx().pop()
79
80
81class NullCtx(object):
82  """Helper substitute for contextlib.nullcontext."""
83
84  def __enter__(self):
85    pass
86
87  def __exit__(self, unused_type, unused_value, unused_traceback):
88    pass
89
90
91def _default_control_status_ctx():
92  return ControlStatusCtx(status=Status.UNSPECIFIED)
93
94
95INSPECT_SOURCE_SUPPORTED = True
96try:
97  inspect.getsource(ag_logging.log)
98except OSError:
99  INSPECT_SOURCE_SUPPORTED = False
100  ag_logging.warning(
101      'AutoGraph is not available in this environment: functions lack code'
102      ' information. This is typical of some environments like the interactive'
103      ' Python shell. See'
104      ' https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code'
105      ' for more information.')
106