xref: /aosp_15_r20/external/tensorflow/tensorflow/python/util/tf_inspect.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TFDecorator-aware replacements for the inspect module."""
16import collections
17import functools
18import inspect as _inspect
19
20import six
21
22from tensorflow.python.util import tf_decorator
23
24
25# inspect.signature() is preferred over inspect.getfullargspec() in PY3.
26# Note that while it can handle TFDecorators, it will ignore a TFDecorator's
27# provided ArgSpec/FullArgSpec and instead return the signature of the
28# inner-most function.
29def signature(obj, *, follow_wrapped=True):
30  """TFDecorator-aware replacement for inspect.signature."""
31  return _inspect.signature(
32      tf_decorator.unwrap(obj)[1], follow_wrapped=follow_wrapped)
33
34
35Parameter = _inspect.Parameter
36Signature = _inspect.Signature
37
38ArgSpec = _inspect.ArgSpec
39
40
41if hasattr(_inspect, 'FullArgSpec'):
42  FullArgSpec = _inspect.FullArgSpec  # pylint: disable=invalid-name
43else:
44  FullArgSpec = collections.namedtuple('FullArgSpec', [
45      'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults',
46      'annotations'
47  ])
48
49
50def _convert_maybe_argspec_to_fullargspec(argspec):
51  if isinstance(argspec, FullArgSpec):
52    return argspec
53  return FullArgSpec(
54      args=argspec.args,
55      varargs=argspec.varargs,
56      varkw=argspec.keywords,
57      defaults=argspec.defaults,
58      kwonlyargs=[],
59      kwonlydefaults=None,
60      annotations={})
61
62if hasattr(_inspect, 'getfullargspec'):
63  _getfullargspec = _inspect.getfullargspec  # pylint: disable=invalid-name
64
65  def _getargspec(target):
66    """A python3 version of getargspec.
67
68    Calls `getfullargspec` and assigns args, varargs,
69    varkw, and defaults to a python 2/3 compatible `ArgSpec`.
70
71    The parameter name 'varkw' is changed to 'keywords' to fit the
72    `ArgSpec` struct.
73
74    Args:
75      target: the target object to inspect.
76
77    Returns:
78      An ArgSpec with args, varargs, keywords, and defaults parameters
79      from FullArgSpec.
80    """
81    fullargspecs = getfullargspec(target)
82    argspecs = ArgSpec(
83        args=fullargspecs.args,
84        varargs=fullargspecs.varargs,
85        keywords=fullargspecs.varkw,
86        defaults=fullargspecs.defaults)
87    return argspecs
88else:
89  _getargspec = _inspect.getargspec
90
91  def _getfullargspec(target):
92    """A python2 version of getfullargspec.
93
94    Args:
95      target: the target object to inspect.
96
97    Returns:
98      A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
99    """
100    return _convert_maybe_argspec_to_fullargspec(getargspec(target))
101
102
103def currentframe():
104  """TFDecorator-aware replacement for inspect.currentframe."""
105  return _inspect.stack()[1][0]
106
107
108def getargspec(obj):
109  """TFDecorator-aware replacement for `inspect.getargspec`.
110
111  Note: `getfullargspec` is recommended as the python 2/3 compatible
112  replacement for this function.
113
114  Args:
115    obj: A function, partial function, or callable object, possibly decorated.
116
117  Returns:
118    The `ArgSpec` that describes the signature of the outermost decorator that
119    changes the callable's signature, or the `ArgSpec` that describes
120    the object if not decorated.
121
122  Raises:
123    ValueError: When callable's signature can not be expressed with
124      ArgSpec.
125    TypeError: For objects of unsupported types.
126  """
127  if isinstance(obj, functools.partial):
128    return _get_argspec_for_partial(obj)
129
130  decorators, target = tf_decorator.unwrap(obj)
131
132  spec = next((d.decorator_argspec
133               for d in decorators
134               if d.decorator_argspec is not None), None)
135  if spec:
136    return spec
137
138  try:
139    # Python3 will handle most callables here (not partial).
140    return _getargspec(target)
141  except TypeError:
142    pass
143
144  if isinstance(target, type):
145    try:
146      return _getargspec(target.__init__)
147    except TypeError:
148      pass
149
150    try:
151      return _getargspec(target.__new__)
152    except TypeError:
153      pass
154
155  # The `type(target)` ensures that if a class is received we don't return
156  # the signature of its __call__ method.
157  return _getargspec(type(target).__call__)
158
159
160def _get_argspec_for_partial(obj):
161  """Implements `getargspec` for `functools.partial` objects.
162
163  Args:
164    obj: The `functools.partial` object
165  Returns:
166    An `inspect.ArgSpec`
167  Raises:
168    ValueError: When callable's signature can not be expressed with
169      ArgSpec.
170  """
171  # When callable is a functools.partial object, we construct its ArgSpec with
172  # following strategy:
173  # - If callable partial contains default value for positional arguments (ie.
174  # object.args), then final ArgSpec doesn't contain those positional arguments.
175  # - If callable partial contains default value for keyword arguments (ie.
176  # object.keywords), then we merge them with wrapped target. Default values
177  # from callable partial takes precedence over those from wrapped target.
178  #
179  # However, there is a case where it is impossible to construct a valid
180  # ArgSpec. Python requires arguments that have no default values must be
181  # defined before those with default values. ArgSpec structure is only valid
182  # when this presumption holds true because default values are expressed as a
183  # tuple of values without keywords and they are always assumed to belong to
184  # last K arguments where K is number of default values present.
185  #
186  # Since functools.partial can give default value to any argument, this
187  # presumption may no longer hold in some cases. For example:
188  #
189  # def func(m, n):
190  #   return 2 * m + n
191  # partialed = functools.partial(func, m=1)
192  #
193  # This example will result in m having a default value but n doesn't. This is
194  # usually not allowed in Python and can not be expressed in ArgSpec correctly.
195  #
196  # Thus, we must detect cases like this by finding first argument with default
197  # value and ensures all following arguments also have default values. When
198  # this is not true, a ValueError is raised.
199
200  n_prune_args = len(obj.args)
201  partial_keywords = obj.keywords or {}
202
203  args, varargs, keywords, defaults = getargspec(obj.func)
204
205  # Pruning first n_prune_args arguments.
206  args = args[n_prune_args:]
207
208  # Partial function may give default value to any argument, therefore length
209  # of default value list must be len(args) to allow each argument to
210  # potentially be given a default value.
211  no_default = object()
212  all_defaults = [no_default] * len(args)
213
214  if defaults:
215    all_defaults[-len(defaults):] = defaults
216
217  # Fill in default values provided by partial function in all_defaults.
218  for kw, default in six.iteritems(partial_keywords):
219    if kw in args:
220      idx = args.index(kw)
221      all_defaults[idx] = default
222    elif not keywords:
223      raise ValueError(f'{obj} does not have a **kwargs parameter, but '
224                       f'contains an unknown partial keyword {kw}.')
225
226  # Find first argument with default value set.
227  first_default = next(
228      (idx for idx, x in enumerate(all_defaults) if x is not no_default), None)
229
230  # If no default values are found, return ArgSpec with defaults=None.
231  if first_default is None:
232    return ArgSpec(args, varargs, keywords, None)
233
234  # Checks if all arguments have default value set after first one.
235  invalid_default_values = [
236      args[i] for i, j in enumerate(all_defaults)
237      if j is no_default and i > first_default
238  ]
239
240  if invalid_default_values:
241    raise ValueError(f'{obj} has some keyword-only arguments, which are not'
242                     f' supported: {invalid_default_values}.')
243
244  return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
245
246
247def getfullargspec(obj):
248  """TFDecorator-aware replacement for `inspect.getfullargspec`.
249
250  This wrapper emulates `inspect.getfullargspec` in[^)]* Python2.
251
252  Args:
253    obj: A callable, possibly decorated.
254
255  Returns:
256    The `FullArgSpec` that describes the signature of
257    the outermost decorator that changes the callable's signature. If the
258    callable is not decorated, `inspect.getfullargspec()` will be called
259    directly on the callable.
260  """
261  decorators, target = tf_decorator.unwrap(obj)
262
263  for d in decorators:
264    if d.decorator_argspec is not None:
265      return _convert_maybe_argspec_to_fullargspec(d.decorator_argspec)
266  return _getfullargspec(target)
267
268
269def getcallargs(*func_and_positional, **named):
270  """TFDecorator-aware replacement for inspect.getcallargs.
271
272  Args:
273    *func_and_positional: A callable, possibly decorated, followed by any
274      positional arguments that would be passed to `func`.
275    **named: The named argument dictionary that would be passed to `func`.
276
277  Returns:
278    A dictionary mapping `func`'s named arguments to the values they would
279    receive if `func(*positional, **named)` were called.
280
281  `getcallargs` will use the argspec from the outermost decorator that provides
282  it. If no attached decorators modify argspec, the final unwrapped target's
283  argspec will be used.
284  """
285  func = func_and_positional[0]
286  positional = func_and_positional[1:]
287  argspec = getfullargspec(func)
288  call_args = named.copy()
289  this = getattr(func, 'im_self', None) or getattr(func, '__self__', None)
290  if ismethod(func) and this:
291    positional = (this,) + positional
292  remaining_positionals = [arg for arg in argspec.args if arg not in call_args]
293  call_args.update(dict(zip(remaining_positionals, positional)))
294  default_count = 0 if not argspec.defaults else len(argspec.defaults)
295  if default_count:
296    for arg, value in zip(argspec.args[-default_count:], argspec.defaults):
297      if arg not in call_args:
298        call_args[arg] = value
299  if argspec.kwonlydefaults is not None:
300    for k, v in argspec.kwonlydefaults.items():
301      if k not in call_args:
302        call_args[k] = v
303  return call_args
304
305
306def getframeinfo(*args, **kwargs):
307  return _inspect.getframeinfo(*args, **kwargs)
308
309
310def getdoc(object):  # pylint: disable=redefined-builtin
311  """TFDecorator-aware replacement for inspect.getdoc.
312
313  Args:
314    object: An object, possibly decorated.
315
316  Returns:
317    The docstring associated with the object.
318
319  The outermost-decorated object is intended to have the most complete
320  documentation, so the decorated parameter is not unwrapped.
321  """
322  return _inspect.getdoc(object)
323
324
325def getfile(object):  # pylint: disable=redefined-builtin
326  """TFDecorator-aware replacement for inspect.getfile."""
327  unwrapped_object = tf_decorator.unwrap(object)[1]
328
329  # Work around for the case when object is a stack frame
330  # and only .pyc files are used. In this case, getfile
331  # might return incorrect path. So, we get the path from f_globals
332  # instead.
333  if (hasattr(unwrapped_object, 'f_globals') and
334      '__file__' in unwrapped_object.f_globals):
335    return unwrapped_object.f_globals['__file__']
336  return _inspect.getfile(unwrapped_object)
337
338
339def getmembers(object, predicate=None):  # pylint: disable=redefined-builtin
340  """TFDecorator-aware replacement for inspect.getmembers."""
341  return _inspect.getmembers(object, predicate)
342
343
344def getmodule(object):  # pylint: disable=redefined-builtin
345  """TFDecorator-aware replacement for inspect.getmodule."""
346  return _inspect.getmodule(object)
347
348
349def getmro(cls):
350  """TFDecorator-aware replacement for inspect.getmro."""
351  return _inspect.getmro(cls)
352
353
354def getsource(object):  # pylint: disable=redefined-builtin
355  """TFDecorator-aware replacement for inspect.getsource."""
356  return _inspect.getsource(tf_decorator.unwrap(object)[1])
357
358
359def getsourcefile(object):  # pylint: disable=redefined-builtin
360  """TFDecorator-aware replacement for inspect.getsourcefile."""
361  return _inspect.getsourcefile(tf_decorator.unwrap(object)[1])
362
363
364def getsourcelines(object):  # pylint: disable=redefined-builtin
365  """TFDecorator-aware replacement for inspect.getsourcelines."""
366  return _inspect.getsourcelines(tf_decorator.unwrap(object)[1])
367
368
369def isbuiltin(object):  # pylint: disable=redefined-builtin
370  """TFDecorator-aware replacement for inspect.isbuiltin."""
371  return _inspect.isbuiltin(tf_decorator.unwrap(object)[1])
372
373
374def isclass(object):  # pylint: disable=redefined-builtin
375  """TFDecorator-aware replacement for inspect.isclass."""
376  return _inspect.isclass(tf_decorator.unwrap(object)[1])
377
378
379def isfunction(object):  # pylint: disable=redefined-builtin
380  """TFDecorator-aware replacement for inspect.isfunction."""
381  return _inspect.isfunction(tf_decorator.unwrap(object)[1])
382
383
384def isframe(object):  # pylint: disable=redefined-builtin
385  """TFDecorator-aware replacement for inspect.ismodule."""
386  return _inspect.isframe(tf_decorator.unwrap(object)[1])
387
388
389def isgenerator(object):  # pylint: disable=redefined-builtin
390  """TFDecorator-aware replacement for inspect.isgenerator."""
391  return _inspect.isgenerator(tf_decorator.unwrap(object)[1])
392
393
394def isgeneratorfunction(object):  # pylint: disable=redefined-builtin
395  """TFDecorator-aware replacement for inspect.isgeneratorfunction."""
396  return _inspect.isgeneratorfunction(tf_decorator.unwrap(object)[1])
397
398
399def ismethod(object):  # pylint: disable=redefined-builtin
400  """TFDecorator-aware replacement for inspect.ismethod."""
401  return _inspect.ismethod(tf_decorator.unwrap(object)[1])
402
403
404def isanytargetmethod(object):  # pylint: disable=redefined-builtin
405  # pylint: disable=g-doc-args,g-doc-return-or-yield
406  """Checks if `object` or a TF Decorator wrapped target contains self or cls.
407
408  This function could be used along with `tf_inspect.getfullargspec` to
409  determine if the first argument of `object` argspec is self or cls. If the
410  first argument is self or cls, it needs to be excluded from argspec when we
411  compare the argspec to the input arguments and, if provided, the tf.function
412  input_signature.
413
414  Like `tf_inspect.getfullargspec` and python `inspect.getfullargspec`, it
415  does not unwrap python decorators.
416
417  Args:
418    obj: An method, function, or functool.partial, possibly decorated by
419    TFDecorator.
420
421  Returns:
422    A bool indicates if `object` or any target along the chain of TF decorators
423    is a method.
424  """
425  decorators, target = tf_decorator.unwrap(object)
426  for decorator in decorators:
427    if _inspect.ismethod(decorator.decorated_target):
428      return True
429
430  # TODO(b/194845243): Implement the long term solution with inspect.signature.
431  # A functools.partial object is not a function or method. But if the wrapped
432  # func is a method, the argspec will contain self/cls.
433  while isinstance(target, functools.partial):
434    target = target.func
435
436  # `target` is a method or an instance with __call__
437  return callable(target) and not _inspect.isfunction(target)
438
439
440def ismodule(object):  # pylint: disable=redefined-builtin
441  """TFDecorator-aware replacement for inspect.ismodule."""
442  return _inspect.ismodule(tf_decorator.unwrap(object)[1])
443
444
445def isroutine(object):  # pylint: disable=redefined-builtin
446  """TFDecorator-aware replacement for inspect.isroutine."""
447  return _inspect.isroutine(tf_decorator.unwrap(object)[1])
448
449
450def stack(context=1):
451  """TFDecorator-aware replacement for inspect.stack."""
452  return _inspect.stack(context)[1:]
453