xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/backend.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=redefined-outer-name
17# pylint: disable=redefined-builtin
18# pylint: disable=g-classes-have-attributes
19"""Keras backend API."""
20
21import collections
22import itertools
23import json
24import os
25import sys
26import threading
27import warnings
28import weakref
29
30import numpy as np
31
32from tensorflow.core.protobuf import config_pb2
33from tensorflow.python import tf2
34from tensorflow.python.checkpoint import checkpoint as tracking_util
35from tensorflow.python.client import session as session_module
36from tensorflow.python.distribute import distribution_strategy_context
37from tensorflow.python.eager import context
38from tensorflow.python.eager.context import get_config
39from tensorflow.python.framework import composite_tensor
40from tensorflow.python.framework import config
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import device_spec
43from tensorflow.python.framework import dtypes as dtypes_module
44from tensorflow.python.framework import func_graph
45from tensorflow.python.framework import ops
46from tensorflow.python.framework import sparse_tensor
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import tensor_util
50from tensorflow.python.keras import backend_config
51from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc
52from tensorflow.python.keras.engine import keras_tensor
53from tensorflow.python.keras.utils import control_flow_util
54from tensorflow.python.keras.utils import object_identity
55from tensorflow.python.keras.utils import tf_contextlib
56from tensorflow.python.keras.utils import tf_inspect
57from tensorflow.python.ops import array_ops
58from tensorflow.python.ops import clip_ops
59from tensorflow.python.ops import control_flow_ops
60from tensorflow.python.ops import ctc_ops as ctc
61from tensorflow.python.ops import functional_ops
62from tensorflow.python.ops import gradients as gradients_module
63from tensorflow.python.ops import image_ops
64from tensorflow.python.ops import init_ops
65from tensorflow.python.ops import linalg_ops
66from tensorflow.python.ops import logging_ops
67from tensorflow.python.ops import map_fn as map_fn_lib
68from tensorflow.python.ops import math_ops
69from tensorflow.python.ops import nn
70from tensorflow.python.ops import random_ops
71from tensorflow.python.ops import sparse_ops
72from tensorflow.python.ops import state_ops
73from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
74from tensorflow.python.ops import tensor_array_ops
75from tensorflow.python.ops import variables as variables_module
76from tensorflow.python.ops.ragged import ragged_tensor
77from tensorflow.python.platform import tf_logging as logging
78from tensorflow.python.training import moving_averages
79from tensorflow.python.util import dispatch
80from tensorflow.python.util import keras_deps
81from tensorflow.python.util import nest
82from tensorflow.python.util.tf_export import keras_export
83from tensorflow.tools.docs import doc_controls
84
85py_all = all
86py_sum = sum
87py_any = any
88
89# INTERNAL UTILS
90
91# The internal graph maintained by Keras and used by the symbolic Keras APIs
92# while executing eagerly (such as the functional API for model-building).
93# This is thread-local to allow building separate models in different threads
94# concurrently, but comes at the cost of not being able to build one model
95# across threads.
96_GRAPH = threading.local()
97
98# A graph which is used for constructing functions in eager mode.
99_CURRENT_SCRATCH_GRAPH = threading.local()
100
101# This is a thread local object that will hold the default internal TF session
102# used by Keras. It can be set manually via `set_session(sess)`.
103_SESSION = threading.local()
104
105
106# A global dictionary mapping graph objects to an index of counters used
107# for various layer/optimizer names in each graph.
108# Allows to give unique autogenerated names to layers, in a graph-specific way.
109PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
110
111
112# A global set tracking what object names have been seen so far.
113# Optionally used as an avoid-list when generating names
114OBSERVED_NAMES = set()
115
116
117# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
118# We keep a separate reference to it to make sure it does not get removed from
119# _GRAPH_LEARNING_PHASES.
120# _DummyEagerGraph inherits from threading.local to make its `key` attribute
121# thread local. This is needed to make set_learning_phase affect only the
122# current thread during eager execution (see b/123096885 for more details).
123class _DummyEagerGraph(threading.local):
124  """_DummyEagerGraph provides a thread local `key` attribute.
125
126  We can't use threading.local directly, i.e. without subclassing, because
127  gevent monkey patches threading.local and its version does not support
128  weak references.
129  """
130
131  class _WeakReferencableClass:
132    """This dummy class is needed for two reasons.
133
134    - We need something that supports weak references. Basic types like string
135    and ints don't.
136    - We need something whose hash and equality are based on object identity
137    to make sure they are treated as different keys to _GRAPH_LEARNING_PHASES.
138
139    An empty Python class satisfies both of these requirements.
140    """
141    pass
142
143  def __init__(self):
144    # Constructors for classes subclassing threading.local run once
145    # per thread accessing something in the class. Thus, each thread will
146    # get a different key.
147    super(_DummyEagerGraph, self).__init__()
148    self.key = _DummyEagerGraph._WeakReferencableClass()
149    self.learning_phase_is_set = False
150
151
152_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
153
154# This boolean flag can be set to True to leave variable initialization
155# up to the user.
156# Change its value via `manual_variable_initialization(value)`.
157_MANUAL_VAR_INIT = False
158
159# This list holds the available devices.
160# It is populated when `_get_available_gpus()` is called for the first time.
161# We assume our devices don't change henceforth.
162_LOCAL_DEVICES = None
163
164# The below functions are kept accessible from backend for compatibility.
165epsilon = backend_config.epsilon
166floatx = backend_config.floatx
167image_data_format = backend_config.image_data_format
168set_epsilon = backend_config.set_epsilon
169set_floatx = backend_config.set_floatx
170set_image_data_format = backend_config.set_image_data_format
171
172
173@keras_export('keras.backend.backend')
174@doc_controls.do_not_generate_docs
175def backend():
176  """Publicly accessible method for determining the current backend.
177
178  Only exists for API compatibility with multi-backend Keras.
179
180  Returns:
181      The string "tensorflow".
182  """
183  return 'tensorflow'
184
185
186@keras_export('keras.backend.cast_to_floatx')
187@dispatch.add_dispatch_support
188@doc_controls.do_not_generate_docs
189def cast_to_floatx(x):
190  """Cast a Numpy array to the default Keras float type.
191
192  Args:
193      x: Numpy array or TensorFlow tensor.
194
195  Returns:
196      The same array (Numpy array if `x` was a Numpy array, or TensorFlow tensor
197      if `x` was a tensor), cast to its new type.
198
199  Example:
200
201  >>> tf.keras.backend.floatx()
202  'float32'
203  >>> arr = np.array([1.0, 2.0], dtype='float64')
204  >>> arr.dtype
205  dtype('float64')
206  >>> new_arr = cast_to_floatx(arr)
207  >>> new_arr
208  array([1.,  2.], dtype=float32)
209  >>> new_arr.dtype
210  dtype('float32')
211
212  """
213  if isinstance(x, (ops.Tensor,
214                    variables_module.Variable,
215                    sparse_tensor.SparseTensor)):
216    return math_ops.cast(x, dtype=floatx())
217  return np.asarray(x, dtype=floatx())
218
219
220@keras_export('keras.backend.get_uid')
221def get_uid(prefix=''):
222  """Associates a string prefix with an integer counter in a TensorFlow graph.
223
224  Args:
225    prefix: String prefix to index.
226
227  Returns:
228    Unique integer ID.
229
230  Example:
231
232  >>> get_uid('dense')
233  1
234  >>> get_uid('dense')
235  2
236
237  """
238  graph = get_graph()
239  if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
240    PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
241  layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
242  layer_name_uids[prefix] += 1
243  return layer_name_uids[prefix]
244
245
246@keras_export('keras.backend.reset_uids')
247def reset_uids():
248  """Resets graph identifiers.
249  """
250
251  PER_GRAPH_OBJECT_NAME_UIDS.clear()
252  OBSERVED_NAMES.clear()
253
254
255@keras_export('keras.backend.clear_session')
256def clear_session():
257  """Resets all state generated by Keras.
258
259  Keras manages a global state, which it uses to implement the Functional
260  model-building API and to uniquify autogenerated layer names.
261
262  If you are creating many models in a loop, this global state will consume
263  an increasing amount of memory over time, and you may want to clear it.
264  Calling `clear_session()` releases the global state: this helps avoid clutter
265  from old models and layers, especially when memory is limited.
266
267  Example 1: calling `clear_session()` when creating models in a loop
268
269  ```python
270  for _ in range(100):
271    # Without `clear_session()`, each iteration of this loop will
272    # slightly increase the size of the global state managed by Keras
273    model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
274
275  for _ in range(100):
276    # With `clear_session()` called at the beginning,
277    # Keras starts with a blank state at each iteration
278    # and memory consumption is constant over time.
279    tf.keras.backend.clear_session()
280    model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
281  ```
282
283  Example 2: resetting the layer name generation counter
284
285  >>> import tensorflow as tf
286  >>> layers = [tf.keras.layers.Dense(10) for _ in range(10)]
287  >>> new_layer = tf.keras.layers.Dense(10)
288  >>> print(new_layer.name)
289  dense_10
290  >>> tf.keras.backend.set_learning_phase(1)
291  >>> print(tf.keras.backend.learning_phase())
292  1
293  >>> tf.keras.backend.clear_session()
294  >>> new_layer = tf.keras.layers.Dense(10)
295  >>> print(new_layer.name)
296  dense
297  """
298  global _SESSION
299  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
300  global _GRAPH_VARIABLES  # pylint: disable=global-variable-not-assigned
301  global _GRAPH_TF_OPTIMIZERS  # pylint: disable=global-variable-not-assigned
302  global _GRAPH
303  _GRAPH.graph = None
304  ops.reset_default_graph()
305  reset_uids()
306  _SESSION.session = None
307  graph = get_graph()
308  with graph.as_default():
309    _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
310    _GRAPH_LEARNING_PHASES.clear()
311    # Create the learning phase placeholder in graph using the default factory.
312    _GRAPH_LEARNING_PHASES.setdefault(graph)
313    _GRAPH_VARIABLES.pop(graph, None)
314    _GRAPH_TF_OPTIMIZERS.pop(graph, None)
315  if context.executing_eagerly():
316    # Clear pending nodes in eager executors, kernel caches and step_containers.
317    context.context().clear_kernel_cache()
318
319# Inject the clear_session function to keras_deps to remove the dependency
320# from TFLite to Keras.
321keras_deps.register_clear_session_function(clear_session)
322
323
324@keras_export('keras.backend.manual_variable_initialization')
325@doc_controls.do_not_generate_docs
326def manual_variable_initialization(value):
327  """Sets the manual variable initialization flag.
328
329  This boolean flag determines whether
330  variables should be initialized
331  as they are instantiated (default), or if
332  the user should handle the initialization
333  (e.g. via `tf.compat.v1.initialize_all_variables()`).
334
335  Args:
336      value: Python boolean.
337  """
338  global _MANUAL_VAR_INIT
339  _MANUAL_VAR_INIT = value
340
341
342@keras_export('keras.backend.learning_phase')
343@doc_controls.do_not_generate_docs
344def learning_phase():
345  """Returns the learning phase flag.
346
347  The learning phase flag is a bool tensor (0 = test, 1 = train)
348  to be passed as input to any Keras function
349  that uses a different behavior at train time and test time.
350
351  Returns:
352      Learning phase (scalar integer tensor or Python integer).
353  """
354  graph = ops.get_default_graph()
355  if graph is getattr(_GRAPH, 'graph', None):
356    # Don't enter an init_scope for the learning phase if eager execution
357    # is enabled but we're inside the Keras workspace graph.
358    learning_phase = symbolic_learning_phase()
359  else:
360    with ops.init_scope():
361      # We always check & set the learning phase inside the init_scope,
362      # otherwise the wrong default_graph will be used to look up the learning
363      # phase inside of functions & defuns.
364      #
365      # This is because functions & defuns (both in graph & in eager mode)
366      # will always execute non-eagerly using a function-specific default
367      # subgraph.
368      learning_phase = _GRAPH_LEARNING_PHASES[None]
369  _mark_func_graph_as_unsaveable(graph, learning_phase)
370  return learning_phase
371
372
373def global_learning_phase_is_set():
374  return _DUMMY_EAGER_GRAPH.learning_phase_is_set
375
376
377def _mark_func_graph_as_unsaveable(graph, learning_phase):
378  """Mark func graph as unsaveable due to use of symbolic keras learning phase.
379
380  Functions that capture the symbolic learning phase cannot be exported to
381  SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
382  if it is exported.
383
384  Args:
385    graph: Graph or FuncGraph object.
386    learning_phase: Learning phase placeholder or int defined in the graph.
387  """
388  if graph.building_function and is_placeholder(learning_phase):
389    graph.mark_as_unsaveable(
390        'The keras learning phase placeholder was used inside a function. '
391        'Exporting placeholders is not supported when saving out a SavedModel. '
392        'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
393        'to set the learning phase to a constant value.')
394
395
396def symbolic_learning_phase():
397  graph = get_graph()
398  with graph.as_default():
399    return _GRAPH_LEARNING_PHASES[graph]
400
401
402def _default_learning_phase():
403  if context.executing_eagerly():
404    return 0
405  else:
406    with name_scope(''):
407      return array_ops.placeholder_with_default(
408          False, shape=(), name='keras_learning_phase')
409
410
411@keras_export('keras.backend.set_learning_phase')
412@doc_controls.do_not_generate_docs
413def set_learning_phase(value):
414  """Sets the learning phase to a fixed value.
415
416  The backend learning phase affects any code that calls
417  `backend.learning_phase()`
418  In particular, all Keras built-in layers use the learning phase as the default
419  for the `training` arg to `Layer.__call__`.
420
421  User-written layers and models can achieve the same behavior with code that
422  looks like:
423
424  ```python
425    def call(self, inputs, training=None):
426      if training is None:
427        training = backend.learning_phase()
428  ```
429
430  Args:
431      value: Learning phase value, either 0 or 1 (integers).
432             0 = test, 1 = train
433
434  Raises:
435      ValueError: if `value` is neither `0` nor `1`.
436  """
437  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
438                'will be removed after 2020-10-11. To update it, simply '
439                'pass a True/False value to the `training` argument of the '
440                '`__call__` method of your layer or model.')
441  deprecated_internal_set_learning_phase(value)
442
443
444def deprecated_internal_set_learning_phase(value):
445  """A deprecated internal implementation of set_learning_phase.
446
447  This method is an internal-only version of `set_learning_phase` that
448  does not raise a deprecation error. It is required because
449  saved_model needs to keep working with user code that uses the deprecated
450  learning phase methods until those APIs are fully removed from the public API.
451
452  Specifically SavedModel saving needs to make sure the learning phase is 0
453  during tracing even if users overwrote it to a different value.
454
455  But, we don't want to raise deprecation warnings for users when savedmodel
456  sets learning phase just for compatibility with code that relied on
457  explicitly setting the learning phase for other values.
458
459  Args:
460      value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
461
462  Raises:
463      ValueError: if `value` is neither `0` nor `1`.
464  """
465  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
466  if value not in {0, 1}:
467    raise ValueError('Expected learning phase to be 0 or 1.')
468  with ops.init_scope():
469    if context.executing_eagerly():
470      # In an eager context, the learning phase values applies to both the eager
471      # context and the internal Keras graph.
472      _DUMMY_EAGER_GRAPH.learning_phase_is_set = True
473      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
474    _GRAPH_LEARNING_PHASES[get_graph()] = value
475
476
477@keras_export('keras.backend.learning_phase_scope')
478@tf_contextlib.contextmanager
479@doc_controls.do_not_generate_docs
480def learning_phase_scope(value):
481  """Provides a scope within which the learning phase is equal to `value`.
482
483  The learning phase gets restored to its original value upon exiting the scope.
484
485  Args:
486     value: Learning phase value, either 0 or 1 (integers).
487            0 = test, 1 = train
488
489  Yields:
490    None.
491
492  Raises:
493     ValueError: if `value` is neither `0` nor `1`.
494  """
495  warnings.warn('`tf.keras.backend.learning_phase_scope` is deprecated and '
496                'will be removed after 2020-10-11. To update it, simply '
497                'pass a True/False value to the `training` argument of the '
498                '`__call__` method of your layer or model.')
499  with deprecated_internal_learning_phase_scope(value):
500    try:
501      yield
502    finally:
503      pass
504
505
506@tf_contextlib.contextmanager
507def deprecated_internal_learning_phase_scope(value):
508  """An internal-only version of `learning_phase_scope`.
509
510  Unlike the public method, this method does not raise a deprecation warning.
511  This is needed because saved model saving needs to set learning phase
512  to maintain compatibility
513  with code that sets/gets the learning phase, but saved model
514  saving itself shouldn't raise a deprecation warning.
515
516  We can get rid of this method and its usages when the public API is
517  removed.
518
519  Args:
520     value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
521
522  Yields:
523    None.
524
525  Raises:
526     ValueError: if `value` is neither `0` nor `1`.
527  """
528  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
529  if value not in {0, 1}:
530    raise ValueError('Expected learning phase to be 0 or 1.')
531
532  with ops.init_scope():
533    if context.executing_eagerly():
534      previous_eager_value = _GRAPH_LEARNING_PHASES.get(
535          _DUMMY_EAGER_GRAPH.key, None)
536    previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
537
538  learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
539  try:
540    deprecated_internal_set_learning_phase(value)
541    yield
542  finally:
543    # Restore learning phase to initial value.
544    if not learning_phase_previously_set:
545      _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
546    with ops.init_scope():
547      if context.executing_eagerly():
548        if previous_eager_value is not None:
549          _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_eager_value
550        elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES:
551          del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
552
553      graph = get_graph()
554      if previous_graph_value is not None:
555        _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
556      elif graph in _GRAPH_LEARNING_PHASES:
557        del _GRAPH_LEARNING_PHASES[graph]
558
559
560@tf_contextlib.contextmanager
561def eager_learning_phase_scope(value):
562  """Internal scope that sets the learning phase in eager / tf.function only.
563
564  Args:
565      value: Learning phase value, either 0 or 1 (integers).
566             0 = test, 1 = train
567
568  Yields:
569    None.
570
571  Raises:
572     ValueError: if `value` is neither `0` nor `1`.
573  """
574  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
575  assert value in {0, 1}
576  assert ops.executing_eagerly_outside_functions()
577  global_learning_phase_was_set = global_learning_phase_is_set()
578  if global_learning_phase_was_set:
579    previous_value = learning_phase()
580  try:
581    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
582    yield
583  finally:
584    # Restore learning phase to initial value or unset.
585    if global_learning_phase_was_set:
586      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_value
587    else:
588      del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
589
590
591def _as_graph_element(obj):
592  """Convert `obj` to a graph element if possible, otherwise return `None`.
593
594  Args:
595    obj: Object to convert.
596
597  Returns:
598    The result of `obj._as_graph_element()` if that method is available;
599        otherwise `None`.
600  """
601  conv_fn = getattr(obj, '_as_graph_element', None)
602  if conv_fn and callable(conv_fn):
603    return conv_fn()
604  return None
605
606
607def _assert_same_graph(original_item, item):
608  """Fail if the 2 items are from different graphs.
609
610  Args:
611    original_item: Original item to check against.
612    item: Item to check.
613
614  Raises:
615    ValueError: if graphs do not match.
616  """
617  original_graph = getattr(original_item, 'graph', None)
618  graph = getattr(item, 'graph', None)
619  if original_graph and graph and original_graph is not graph:
620    raise ValueError(
621        '%s must be from the same graph as %s (graphs are %s and %s).' %
622        (item, original_item, graph, original_graph))
623
624
625def _current_graph(op_input_list, graph=None):
626  """Returns the appropriate graph to use for the given inputs.
627
628  This library method provides a consistent algorithm for choosing the graph
629  in which an Operation should be constructed:
630
631  1. If the default graph is being used to construct a function, we
632     use the default graph.
633  2. If the "graph" is specified explicitly, we validate that all of the inputs
634     in "op_input_list" are compatible with that graph.
635  3. Otherwise, we attempt to select a graph from the first Operation-
636     or Tensor-valued input in "op_input_list", and validate that all other
637     such inputs are in the same graph.
638  4. If the graph was not specified and it could not be inferred from
639     "op_input_list", we attempt to use the default graph.
640
641  Args:
642    op_input_list: A list of inputs to an operation, which may include `Tensor`,
643      `Operation`, and other objects that may be converted to a graph element.
644    graph: (Optional) The explicit graph to use.
645
646  Raises:
647    TypeError: If op_input_list is not a list or tuple, or if graph is not a
648      Graph.
649    ValueError: If a graph is explicitly passed and not all inputs are from it,
650      or if the inputs are from multiple graphs, or we could not find a graph
651      and there was no default graph.
652
653  Returns:
654    The appropriate graph to use for the given inputs.
655
656  """
657  current_default_graph = ops.get_default_graph()
658  if current_default_graph.building_function:
659    return current_default_graph
660
661  op_input_list = tuple(op_input_list)  # Handle generators correctly
662  if graph and not isinstance(graph, ops.Graph):
663    raise TypeError('Input graph needs to be a Graph: %s' % (graph,))
664
665  # 1. We validate that all of the inputs are from the same graph. This is
666  #    either the supplied graph parameter, or the first one selected from one
667  #    the graph-element-valued inputs. In the latter case, we hold onto
668  #    that input in original_graph_element so we can provide a more
669  #    informative error if a mismatch is found.
670  original_graph_element = None
671  for op_input in op_input_list:
672    # Determine if this is a valid graph_element.
673    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
674    # up.
675    if (isinstance(op_input, (
676        ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and
677        ((not isinstance(op_input, ops.Tensor))
678         or type(op_input) == ops.Tensor)):  # pylint: disable=unidiomatic-typecheck
679      graph_element = op_input
680    else:
681      graph_element = _as_graph_element(op_input)
682
683    if graph_element is not None:
684      if not graph:
685        original_graph_element = graph_element
686        graph = getattr(graph_element, 'graph', None)
687      elif original_graph_element is not None:
688        _assert_same_graph(original_graph_element, graph_element)
689      elif graph_element.graph is not graph:
690        raise ValueError('%s is not from the passed-in graph.' % graph_element)
691
692  # 2. If all else fails, we use the default graph, which is always there.
693  return graph or current_default_graph
694
695
696def _get_session(op_input_list=()):
697  """Returns the session object for the current thread."""
698  global _SESSION
699  default_session = ops.get_default_session()
700  if default_session is not None:
701    session = default_session
702  else:
703    if ops.inside_function():
704      raise RuntimeError('Cannot get session inside Tensorflow graph function.')
705    # If we don't have a session, or that session does not match the current
706    # graph, create and cache a new session.
707    if (getattr(_SESSION, 'session', None) is None or
708        _SESSION.session.graph is not _current_graph(op_input_list)):
709      # If we are creating the Session inside a tf.distribute.Strategy scope,
710      # we ask the strategy for the right session options to use.
711      if distribution_strategy_context.has_strategy():
712        configure_and_create_distributed_session(
713            distribution_strategy_context.get_strategy())
714      else:
715        _SESSION.session = session_module.Session(
716            config=get_default_session_config())
717    session = _SESSION.session
718  return session
719
720
721@keras_export(v1=['keras.backend.get_session'])
722def get_session(op_input_list=()):
723  """Returns the TF session to be used by the backend.
724
725  If a default TensorFlow session is available, we will return it.
726
727  Else, we will return the global Keras session assuming it matches
728  the current graph.
729
730  If no global Keras session exists at this point:
731  we will create a new global session.
732
733  Note that you can manually set the global session
734  via `K.set_session(sess)`.
735
736  Args:
737      op_input_list: An option sequence of tensors or ops, which will be used
738        to determine the current graph. Otherwise the default graph will be
739        used.
740
741  Returns:
742      A TensorFlow session.
743  """
744  session = _get_session(op_input_list)
745  if not _MANUAL_VAR_INIT:
746    with session.graph.as_default():
747      _initialize_variables(session)
748  return session
749
750# Inject the get_session function to keras_deps to remove the dependency
751# from TFLite to Keras.
752keras_deps.register_get_session_function(get_session)
753
754# Inject the get_session function to tracking_util to avoid the backward
755# dependency from TF to Keras.
756tracking_util.register_session_provider(get_session)
757
758
759def get_graph():
760  if context.executing_eagerly():
761    global _GRAPH
762    if not getattr(_GRAPH, 'graph', None):
763      _GRAPH.graph = func_graph.FuncGraph('keras_graph')
764    return _GRAPH.graph
765  else:
766    return ops.get_default_graph()
767
768
769@tf_contextlib.contextmanager
770def _scratch_graph(graph=None):
771  """Retrieve a shared and temporary func graph.
772
773  The eager execution path lifts a subgraph from the keras global graph into
774  a scratch graph in order to create a function. DistributionStrategies, in
775  turn, constructs multiple functions as well as a final combined function. In
776  order for that logic to work correctly, all of the functions need to be
777  created on the same scratch FuncGraph.
778
779  Args:
780    graph: A graph to be used as the current scratch graph. If not set then
781      a scratch graph will either be retrieved or created:
782
783  Yields:
784    The current scratch graph.
785  """
786  global _CURRENT_SCRATCH_GRAPH
787  scratch_graph = getattr(_CURRENT_SCRATCH_GRAPH, 'graph', None)
788  # If scratch graph and `graph` are both configured, they must match.
789  if (scratch_graph is not None and graph is not None and
790      scratch_graph is not graph):
791    raise ValueError('Multiple scratch graphs specified.')
792
793  if scratch_graph:
794    yield scratch_graph
795    return
796
797  graph = graph or func_graph.FuncGraph('keras_scratch_graph')
798  try:
799    _CURRENT_SCRATCH_GRAPH.graph = graph
800    yield graph
801  finally:
802    _CURRENT_SCRATCH_GRAPH.graph = None
803
804
805@keras_export(v1=['keras.backend.set_session'])
806def set_session(session):
807  """Sets the global TensorFlow session.
808
809  Args:
810      session: A TF Session.
811  """
812  global _SESSION
813  _SESSION.session = session
814
815
816def get_default_session_config():
817  if os.environ.get('OMP_NUM_THREADS'):
818    logging.warning(
819        'OMP_NUM_THREADS is no longer used by the default Keras config. '
820        'To configure the number of threads, use tf.config.threading APIs.')
821
822  config = get_config()
823  config.allow_soft_placement = True
824
825  return config
826
827
828def get_default_graph_uid_map():
829  graph = ops.get_default_graph()
830  name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
831  if name_uid_map is None:
832    name_uid_map = collections.defaultdict(int)
833    PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
834  return name_uid_map
835
836
837# DEVICE MANIPULATION
838
839
840class _TfDeviceCaptureOp:
841  """Class for capturing the TF device scope."""
842
843  def __init__(self):
844    self.device = None
845
846  def _set_device(self, device):
847    """This method captures TF's explicit device scope setting."""
848    if isinstance(device, device_spec.DeviceSpecV2):
849      device = device.to_string()
850    self.device = device
851
852  def _set_device_from_string(self, device_str):
853    self.device = device_str
854
855
856def _get_current_tf_device():
857  """Return explicit device of current context, otherwise returns `None`.
858
859  Returns:
860      If the current device scope is explicitly set, it returns a string with
861      the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
862      return `None`.
863  """
864  graph = get_graph()
865  op = _TfDeviceCaptureOp()
866  graph._apply_device_functions(op)
867  if tf2.enabled():
868    return device_spec.DeviceSpecV2.from_string(op.device)
869  else:
870    return device_spec.DeviceSpecV1.from_string(op.device)
871
872
873def _is_current_explicit_device(device_type):
874  """Check if the current device is explicitly set on the device type specified.
875
876  Args:
877      device_type: A string containing `GPU` or `CPU` (case-insensitive).
878
879  Returns:
880      A boolean indicating if the current device scope is explicitly set on the
881      device type.
882
883  Raises:
884      ValueError: If the `device_type` string indicates an unsupported device.
885  """
886  device_type = device_type.upper()
887  if device_type not in ['CPU', 'GPU']:
888    raise ValueError('`device_type` should be either "CPU" or "GPU".')
889  device = _get_current_tf_device()
890  return device is not None and device.device_type == device_type.upper()
891
892
893def _get_available_gpus():
894  """Get a list of available GPU devices (formatted as strings).
895
896  Returns:
897      A list of available GPU devices.
898  """
899  if ops.executing_eagerly_outside_functions():
900    # Returns names of devices directly.
901    return [d.name for d in config.list_logical_devices('GPU')]
902
903  global _LOCAL_DEVICES
904  if _LOCAL_DEVICES is None:
905    _LOCAL_DEVICES = get_session().list_devices()
906  return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
907
908
909def _has_nchw_support():
910  """Check whether the current scope supports NCHW ops.
911
912  TensorFlow does not support NCHW on CPU. Therefore we check if we are not
913  explicitly put on
914  CPU, and have GPUs available. In this case there will be soft-placing on the
915  GPU device.
916
917  Returns:
918      bool: if the current scope device placement would support nchw
919  """
920  explicitly_on_cpu = _is_current_explicit_device('CPU')
921  gpus_available = bool(_get_available_gpus())
922  return not explicitly_on_cpu and gpus_available
923
924
925# VARIABLE MANIPULATION
926
927
928def _constant_to_tensor(x, dtype):
929  """Convert the input `x` to a tensor of type `dtype`.
930
931  This is slightly faster than the _to_tensor function, at the cost of
932  handling fewer cases.
933
934  Args:
935      x: An object to be converted (numpy arrays, floats, ints and lists of
936        them).
937      dtype: The destination type.
938
939  Returns:
940      A tensor.
941  """
942  return constant_op.constant(x, dtype=dtype)
943
944
945def _to_tensor(x, dtype):
946  """Convert the input `x` to a tensor of type `dtype`.
947
948  Args:
949      x: An object to be converted (numpy array, list, tensors).
950      dtype: The destination type.
951
952  Returns:
953      A tensor.
954  """
955  return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
956
957
958@keras_export('keras.backend.is_sparse')
959@doc_controls.do_not_generate_docs
960def is_sparse(tensor):
961  """Returns whether a tensor is a sparse tensor.
962
963  Args:
964      tensor: A tensor instance.
965
966  Returns:
967      A boolean.
968
969  Example:
970
971
972  >>> a = tf.keras.backend.placeholder((2, 2), sparse=False)
973  >>> print(tf.keras.backend.is_sparse(a))
974  False
975  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
976  >>> print(tf.keras.backend.is_sparse(b))
977  True
978
979  """
980  spec = getattr(tensor, '_type_spec', None)
981  if spec is not None:
982    return isinstance(spec, sparse_tensor.SparseTensorSpec)
983  return isinstance(tensor, sparse_tensor.SparseTensor)
984
985
986@keras_export('keras.backend.to_dense')
987@dispatch.add_dispatch_support
988@doc_controls.do_not_generate_docs
989def to_dense(tensor):
990  """Converts a sparse tensor into a dense tensor and returns it.
991
992  Args:
993      tensor: A tensor instance (potentially sparse).
994
995  Returns:
996      A dense tensor.
997
998  Examples:
999
1000
1001  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
1002  >>> print(tf.keras.backend.is_sparse(b))
1003  True
1004  >>> c = tf.keras.backend.to_dense(b)
1005  >>> print(tf.keras.backend.is_sparse(c))
1006  False
1007
1008  """
1009  if is_sparse(tensor):
1010    return sparse_ops.sparse_tensor_to_dense(tensor)
1011  else:
1012    return tensor
1013
1014
1015@keras_export('keras.backend.name_scope', v1=[])
1016@doc_controls.do_not_generate_docs
1017def name_scope(name):
1018  """A context manager for use when defining a Python op.
1019
1020  This context manager pushes a name scope, which will make the name of all
1021  operations added within it have a prefix.
1022
1023  For example, to define a new Python op called `my_op`:
1024
1025
1026  def my_op(a):
1027    with tf.name_scope("MyOp") as scope:
1028      a = tf.convert_to_tensor(a, name="a")
1029      # Define some computation that uses `a`.
1030      return foo_op(..., name=scope)
1031
1032
1033  When executed, the Tensor `a` will have the name `MyOp/a`.
1034
1035  Args:
1036    name: The prefix to use on all names created within the name scope.
1037
1038  Returns:
1039    Name scope context manager.
1040  """
1041  return ops.name_scope_v2(name)
1042
1043# Export V1 version.
1044_v1_name_scope = ops.name_scope_v1
1045keras_export(v1=['keras.backend.name_scope'])(_v1_name_scope)
1046
1047
1048@keras_export('keras.backend.variable')
1049@doc_controls.do_not_generate_docs
1050def variable(value, dtype=None, name=None, constraint=None):
1051  """Instantiates a variable and returns it.
1052
1053  Args:
1054      value: Numpy array, initial value of the tensor.
1055      dtype: Tensor type.
1056      name: Optional name string for the tensor.
1057      constraint: Optional projection function to be
1058          applied to the variable after an optimizer update.
1059
1060  Returns:
1061      A variable instance (with Keras metadata included).
1062
1063  Examples:
1064
1065  >>> val = np.array([[1, 2], [3, 4]])
1066  >>> kvar = tf.keras.backend.variable(value=val, dtype='float64',
1067  ...                                  name='example_var')
1068  >>> tf.keras.backend.dtype(kvar)
1069  'float64'
1070  >>> print(kvar)
1071  <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy=
1072    array([[1., 2.],
1073           [3., 4.]])>
1074
1075  """
1076  if dtype is None:
1077    dtype = floatx()
1078  if hasattr(value, 'tocoo'):
1079    sparse_coo = value.tocoo()
1080    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
1081        sparse_coo.col, 1)), 1)
1082    v = sparse_tensor.SparseTensor(
1083        indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
1084    v._keras_shape = sparse_coo.shape
1085    return v
1086  v = variables_module.Variable(
1087      value,
1088      dtype=dtypes_module.as_dtype(dtype),
1089      name=name,
1090      constraint=constraint)
1091  if isinstance(value, np.ndarray):
1092    v._keras_shape = value.shape
1093  elif hasattr(value, 'shape'):
1094    v._keras_shape = int_shape(value)
1095  track_variable(v)
1096  return v
1097
1098
1099def track_tf_optimizer(tf_optimizer):
1100  """Tracks the given TF optimizer for initialization of its variables."""
1101  if context.executing_eagerly():
1102    return
1103  optimizers = _GRAPH_TF_OPTIMIZERS[None]
1104  optimizers.add(tf_optimizer)
1105
1106
1107@keras_export('keras.__internal__.backend.track_variable', v1=[])
1108def track_variable(v):
1109  """Tracks the given variable for initialization."""
1110  if context.executing_eagerly():
1111    return
1112  graph = v.graph if hasattr(v, 'graph') else get_graph()
1113  _GRAPH_VARIABLES[graph].add(v)
1114
1115
1116def observe_object_name(name):
1117  """Observe a name and make sure it won't be used by `unique_object_name`."""
1118  OBSERVED_NAMES.add(name)
1119
1120
1121def unique_object_name(name,
1122                       name_uid_map=None,
1123                       avoid_names=None,
1124                       namespace='',
1125                       zero_based=False,
1126                       avoid_observed_names=False):
1127  """Makes a object name (or arbitrary string) unique within a TensorFlow graph.
1128
1129  Args:
1130    name: String name to make unique.
1131    name_uid_map: An optional defaultdict(int) to use when creating unique
1132      names. If None (default), uses a per-Graph dictionary.
1133    avoid_names: An optional set or dict with names which should not be used. If
1134      None (default), don't avoid any names unless `avoid_observed_names` is
1135      True.
1136    namespace: Gets a name which is unique within the (graph, namespace). Layers
1137      which are not Networks use a blank namespace and so get graph-global
1138      names.
1139    zero_based: If True, name sequences start with no suffix (e.g. "dense",
1140      "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
1141    avoid_observed_names: If True, avoid any names that have been observed by
1142      `backend.observe_object_name`.
1143
1144  Returns:
1145    Unique string name.
1146
1147  Example:
1148
1149
1150  unique_object_name('dense')  # dense_1
1151  unique_object_name('dense')  # dense_2
1152
1153  """
1154  if name_uid_map is None:
1155    name_uid_map = get_default_graph_uid_map()
1156  if avoid_names is None:
1157    if avoid_observed_names:
1158      avoid_names = OBSERVED_NAMES
1159    else:
1160      avoid_names = set()
1161  proposed_name = None
1162  while proposed_name is None or proposed_name in avoid_names:
1163    name_key = (namespace, name)
1164    if zero_based:
1165      number = name_uid_map[name_key]
1166      if number:
1167        proposed_name = name + '_' + str(number)
1168      else:
1169        proposed_name = name
1170      name_uid_map[name_key] += 1
1171    else:
1172      name_uid_map[name_key] += 1
1173      proposed_name = name + '_' + str(name_uid_map[name_key])
1174  return proposed_name
1175
1176
1177def _get_variables(graph=None):
1178  """Returns variables corresponding to the given graph for initialization."""
1179  assert not context.executing_eagerly()
1180  variables = _GRAPH_VARIABLES[graph]
1181  for opt in _GRAPH_TF_OPTIMIZERS[graph]:
1182    variables.update(opt.optimizer.variables())
1183  return variables
1184
1185
1186@keras_export('keras.__internal__.backend.initialize_variables', v1=[])
1187def _initialize_variables(session):
1188  """Utility to initialize uninitialized variables on the fly."""
1189  variables = _get_variables(get_graph())
1190  candidate_vars = []
1191  for v in variables:
1192    if not getattr(v, '_keras_initialized', False):
1193      candidate_vars.append(v)
1194  if candidate_vars:
1195    # This step is expensive, so we only run it on variables not already
1196    # marked as initialized.
1197    is_initialized = session.run(
1198        [variables_module.is_variable_initialized(v) for v in candidate_vars])
1199    # TODO(kathywu): Some metric variables loaded from SavedModel are never
1200    # actually used, and do not have an initializer.
1201    should_be_initialized = [
1202        (not is_initialized[n]) and v.initializer is not None
1203        for n, v in enumerate(candidate_vars)]
1204    uninitialized_vars = []
1205    for flag, v in zip(should_be_initialized, candidate_vars):
1206      if flag:
1207        uninitialized_vars.append(v)
1208      v._keras_initialized = True
1209    if uninitialized_vars:
1210      session.run(variables_module.variables_initializer(uninitialized_vars))
1211
1212
1213@keras_export('keras.backend.constant')
1214@dispatch.add_dispatch_support
1215@doc_controls.do_not_generate_docs
1216def constant(value, dtype=None, shape=None, name=None):
1217  """Creates a constant tensor.
1218
1219  Args:
1220      value: A constant value (or list)
1221      dtype: The type of the elements of the resulting tensor.
1222      shape: Optional dimensions of resulting tensor.
1223      name: Optional name for the tensor.
1224
1225  Returns:
1226      A Constant Tensor.
1227  """
1228  if dtype is None:
1229    dtype = floatx()
1230
1231  return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
1232
1233
1234@keras_export('keras.backend.is_keras_tensor')
1235def is_keras_tensor(x):
1236  """Returns whether `x` is a Keras tensor.
1237
1238  A "Keras tensor" is a tensor that was returned by a Keras layer,
1239  (`Layer` class) or by `Input`.
1240
1241  Args:
1242      x: A candidate tensor.
1243
1244  Returns:
1245      A boolean: Whether the argument is a Keras tensor.
1246
1247  Raises:
1248      ValueError: In case `x` is not a symbolic tensor.
1249
1250  Examples:
1251
1252  >>> np_var = np.array([1, 2])
1253  >>> # A numpy array is not a symbolic tensor.
1254  >>> tf.keras.backend.is_keras_tensor(np_var)
1255  Traceback (most recent call last):
1256  ...
1257  ValueError: Unexpectedly found an instance of type `<class 'numpy.ndarray'>`.
1258  Expected a symbolic tensor instance.
1259  >>> keras_var = tf.keras.backend.variable(np_var)
1260  >>> # A variable created with the keras backend is not a Keras tensor.
1261  >>> tf.keras.backend.is_keras_tensor(keras_var)
1262  False
1263  >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
1264  >>> # A placeholder is a Keras tensor.
1265  >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
1266  True
1267  >>> keras_input = tf.keras.layers.Input([10])
1268  >>> # An Input is a Keras tensor.
1269  >>> tf.keras.backend.is_keras_tensor(keras_input)
1270  True
1271  >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input)
1272  >>> # Any Keras layer output is a Keras tensor.
1273  >>> tf.keras.backend.is_keras_tensor(keras_layer_output)
1274  True
1275
1276  """
1277  if not isinstance(x,
1278                    (ops.Tensor, variables_module.Variable,
1279                     sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor,
1280                     keras_tensor.KerasTensor)):
1281    raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
1282                     '`. Expected a symbolic tensor instance.')
1283  if ops.executing_eagerly_outside_functions():
1284    return isinstance(x, keras_tensor.KerasTensor)
1285  return hasattr(x, '_keras_history')
1286
1287
1288@keras_export('keras.backend.placeholder')
1289@doc_controls.do_not_generate_docs
1290def placeholder(shape=None,
1291                ndim=None,
1292                dtype=None,
1293                sparse=False,
1294                name=None,
1295                ragged=False):
1296  """Instantiates a placeholder tensor and returns it.
1297
1298  Args:
1299      shape: Shape of the placeholder
1300          (integer tuple, may include `None` entries).
1301      ndim: Number of axes of the tensor.
1302          At least one of {`shape`, `ndim`} must be specified.
1303          If both are specified, `shape` is used.
1304      dtype: Placeholder type.
1305      sparse: Boolean, whether the placeholder should have a sparse type.
1306      name: Optional name string for the placeholder.
1307      ragged: Boolean, whether the placeholder should have a ragged type.
1308          In this case, values of 'None' in the 'shape' argument represent
1309          ragged dimensions. For more information about RaggedTensors, see this
1310          [guide](https://www.tensorflow.org/guide/ragged_tensors).
1311
1312  Raises:
1313      ValueError: If called with sparse = True and ragged = True.
1314
1315  Returns:
1316      Tensor instance (with Keras metadata included).
1317
1318  Examples:
1319
1320
1321  >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
1322  >>> input_ph
1323  <KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)>
1324
1325  """
1326  if sparse and ragged:
1327    raise ValueError(
1328        'Cannot set both sparse and ragged to True when creating a placeholder.'
1329    )
1330  if dtype is None:
1331    dtype = floatx()
1332  if not shape:
1333    if ndim:
1334      shape = (None,) * ndim
1335  if ops.executing_eagerly_outside_functions():
1336    if sparse:
1337      spec = sparse_tensor.SparseTensorSpec(
1338          shape=shape, dtype=dtype)
1339    elif ragged:
1340      ragged_rank = 0
1341      for i in range(1, len(shape)):
1342        # Hacky because could be tensorshape or tuple maybe?
1343        # Or just tensorshape?
1344        if shape[i] is None or (
1345            hasattr(shape[i], 'value') and
1346            shape[i].value is None):
1347          ragged_rank = i
1348      spec = ragged_tensor.RaggedTensorSpec(
1349          shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1350    else:
1351      spec = tensor_spec.TensorSpec(
1352          shape=shape, dtype=dtype, name=name)
1353    x = keras_tensor.keras_tensor_from_type_spec(spec, name=name)
1354  else:
1355    with get_graph().as_default():
1356      if sparse:
1357        x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
1358      elif ragged:
1359        ragged_rank = 0
1360        for i in range(1, len(shape)):
1361          if shape[i] is None:
1362            ragged_rank = i
1363        type_spec = ragged_tensor.RaggedTensorSpec(
1364            shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1365        def tensor_spec_to_placeholder(tensorspec):
1366          return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
1367        x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
1368                               expand_composites=True)
1369      else:
1370        x = array_ops.placeholder(dtype, shape=shape, name=name)
1371
1372  if context.executing_eagerly():
1373    # Add keras_history connectivity information to the placeholder
1374    # when the placeholder is built in a top-level eager context
1375    # (intended to be used with keras.backend.function)
1376    from tensorflow.python.keras.engine import input_layer  # pylint: disable=g-import-not-at-top
1377    x = input_layer.Input(tensor=x)
1378    x._is_backend_placeholder = True
1379
1380  return x
1381
1382
1383def is_placeholder(x):
1384  """Returns whether `x` is a placeholder.
1385
1386  Args:
1387      x: A candidate placeholder.
1388
1389  Returns:
1390      Boolean.
1391  """
1392  try:
1393    if ops.executing_eagerly_outside_functions():
1394      return hasattr(x, '_is_backend_placeholder')
1395    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
1396    if tf_utils.is_extension_type(x):
1397      flat_components = nest.flatten(x, expand_composites=True)
1398      return py_any(is_placeholder(c) for c in flat_components)
1399    else:
1400      return x.op.type == 'Placeholder'
1401  except AttributeError:
1402    return False
1403
1404
1405@keras_export('keras.backend.shape')
1406@dispatch.add_dispatch_support
1407@doc_controls.do_not_generate_docs
1408def shape(x):
1409  """Returns the symbolic shape of a tensor or variable.
1410
1411  Args:
1412      x: A tensor or variable.
1413
1414  Returns:
1415      A symbolic shape (which is itself a tensor).
1416
1417  Examples:
1418
1419  >>> val = np.array([[1, 2], [3, 4]])
1420  >>> kvar = tf.keras.backend.variable(value=val)
1421  >>> tf.keras.backend.shape(kvar)
1422  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
1423  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1424  >>> tf.keras.backend.shape(input)
1425  <KerasTensor: shape=(3,) dtype=int32 inferred_value=[2, 4, 5] ...>
1426
1427  """
1428  return array_ops.shape(x)
1429
1430
1431@keras_export('keras.backend.int_shape')
1432@doc_controls.do_not_generate_docs
1433def int_shape(x):
1434  """Returns the shape of tensor or variable as a tuple of int or None entries.
1435
1436  Args:
1437      x: Tensor or variable.
1438
1439  Returns:
1440      A tuple of integers (or None entries).
1441
1442  Examples:
1443
1444  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1445  >>> tf.keras.backend.int_shape(input)
1446  (2, 4, 5)
1447  >>> val = np.array([[1, 2], [3, 4]])
1448  >>> kvar = tf.keras.backend.variable(value=val)
1449  >>> tf.keras.backend.int_shape(kvar)
1450  (2, 2)
1451
1452  """
1453  try:
1454    shape = x.shape
1455    if not isinstance(shape, tuple):
1456      shape = tuple(shape.as_list())
1457    return shape
1458  except ValueError:
1459    return None
1460
1461
1462@keras_export('keras.backend.ndim')
1463@doc_controls.do_not_generate_docs
1464def ndim(x):
1465  """Returns the number of axes in a tensor, as an integer.
1466
1467  Args:
1468      x: Tensor or variable.
1469
1470  Returns:
1471      Integer (scalar), number of axes.
1472
1473  Examples:
1474
1475
1476  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1477  >>> val = np.array([[1, 2], [3, 4]])
1478  >>> kvar = tf.keras.backend.variable(value=val)
1479  >>> tf.keras.backend.ndim(input)
1480  3
1481  >>> tf.keras.backend.ndim(kvar)
1482  2
1483
1484  """
1485  return x.shape.rank
1486
1487
1488@keras_export('keras.backend.dtype')
1489@dispatch.add_dispatch_support
1490@doc_controls.do_not_generate_docs
1491def dtype(x):
1492  """Returns the dtype of a Keras tensor or variable, as a string.
1493
1494  Args:
1495      x: Tensor or variable.
1496
1497  Returns:
1498      String, dtype of `x`.
1499
1500  Examples:
1501
1502  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5)))
1503  'float32'
1504  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1505  ...                                                     dtype='float32'))
1506  'float32'
1507  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1508  ...                                                     dtype='float64'))
1509  'float64'
1510  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]))
1511  >>> tf.keras.backend.dtype(kvar)
1512  'float32'
1513  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1514  ...                                  dtype='float32')
1515  >>> tf.keras.backend.dtype(kvar)
1516  'float32'
1517
1518  """
1519  return x.dtype.base_dtype.name
1520
1521
1522@doc_controls.do_not_generate_docs
1523def dtype_numpy(x):
1524  """Returns the numpy dtype of a Keras tensor or variable.
1525
1526  Args:
1527      x: Tensor or variable.
1528
1529  Returns:
1530      numpy.dtype, dtype of `x`.
1531  """
1532  return dtypes_module.as_dtype(x.dtype).as_numpy_dtype
1533
1534
1535@keras_export('keras.backend.eval')
1536@doc_controls.do_not_generate_docs
1537def eval(x):
1538  """Evaluates the value of a variable.
1539
1540  Args:
1541      x: A variable.
1542
1543  Returns:
1544      A Numpy array.
1545
1546  Examples:
1547
1548  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1549  ...                                  dtype='float32')
1550  >>> tf.keras.backend.eval(kvar)
1551  array([[1.,  2.],
1552         [3.,  4.]], dtype=float32)
1553
1554  """
1555  return get_value(to_dense(x))
1556
1557
1558@keras_export('keras.backend.zeros')
1559@doc_controls.do_not_generate_docs
1560def zeros(shape, dtype=None, name=None):
1561  """Instantiates an all-zeros variable and returns it.
1562
1563  Args:
1564      shape: Tuple or list of integers, shape of returned Keras variable
1565      dtype: data type of returned Keras variable
1566      name: name of returned Keras variable
1567
1568  Returns:
1569      A variable (including Keras metadata), filled with `0.0`.
1570      Note that if `shape` was symbolic, we cannot return a variable,
1571      and will return a dynamically-shaped tensor instead.
1572
1573  Example:
1574
1575  >>> kvar = tf.keras.backend.zeros((3,4))
1576  >>> tf.keras.backend.eval(kvar)
1577  array([[0.,  0.,  0.,  0.],
1578         [0.,  0.,  0.,  0.],
1579         [0.,  0.,  0.,  0.]], dtype=float32)
1580  >>> A = tf.constant([1,2,3])
1581  >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.]
1582  >>> tf.keras.backend.eval(kvar2)
1583  array([0., 0., 0.], dtype=float32)
1584  >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32)
1585  >>> tf.keras.backend.eval(kvar3)
1586  array([0, 0, 0], dtype=int32)
1587  >>> kvar4 = tf.keras.backend.zeros([2,3])
1588  >>> tf.keras.backend.eval(kvar4)
1589  array([[0., 0., 0.],
1590         [0., 0., 0.]], dtype=float32)
1591
1592  """
1593  with ops.init_scope():
1594    if dtype is None:
1595      dtype = floatx()
1596    tf_dtype = dtypes_module.as_dtype(dtype)
1597    v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
1598    if py_all(v.shape.as_list()):
1599      return variable(v, dtype=dtype, name=name)
1600    return v
1601
1602
1603@keras_export('keras.backend.ones')
1604@dispatch.add_dispatch_support
1605@doc_controls.do_not_generate_docs
1606def ones(shape, dtype=None, name=None):
1607  """Instantiates an all-ones variable and returns it.
1608
1609  Args:
1610      shape: Tuple of integers, shape of returned Keras variable.
1611      dtype: String, data type of returned Keras variable.
1612      name: String, name of returned Keras variable.
1613
1614  Returns:
1615      A Keras variable, filled with `1.0`.
1616      Note that if `shape` was symbolic, we cannot return a variable,
1617      and will return a dynamically-shaped tensor instead.
1618
1619  Example:
1620
1621
1622  >>> kvar = tf.keras.backend.ones((3,4))
1623  >>> tf.keras.backend.eval(kvar)
1624  array([[1.,  1.,  1.,  1.],
1625         [1.,  1.,  1.,  1.],
1626         [1.,  1.,  1.,  1.]], dtype=float32)
1627
1628  """
1629  with ops.init_scope():
1630    if dtype is None:
1631      dtype = floatx()
1632    tf_dtype = dtypes_module.as_dtype(dtype)
1633    v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
1634    if py_all(v.shape.as_list()):
1635      return variable(v, dtype=dtype, name=name)
1636    return v
1637
1638
1639@keras_export('keras.backend.eye')
1640@dispatch.add_dispatch_support
1641@doc_controls.do_not_generate_docs
1642def eye(size, dtype=None, name=None):
1643  """Instantiate an identity matrix and returns it.
1644
1645  Args:
1646      size: Integer, number of rows/columns.
1647      dtype: String, data type of returned Keras variable.
1648      name: String, name of returned Keras variable.
1649
1650  Returns:
1651      A Keras variable, an identity matrix.
1652
1653  Example:
1654
1655
1656  >>> kvar = tf.keras.backend.eye(3)
1657  >>> tf.keras.backend.eval(kvar)
1658  array([[1.,  0.,  0.],
1659         [0.,  1.,  0.],
1660         [0.,  0.,  1.]], dtype=float32)
1661
1662
1663  """
1664  if dtype is None:
1665    dtype = floatx()
1666  tf_dtype = dtypes_module.as_dtype(dtype)
1667  return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
1668
1669
1670@keras_export('keras.backend.zeros_like')
1671@doc_controls.do_not_generate_docs
1672def zeros_like(x, dtype=None, name=None):
1673  """Instantiates an all-zeros variable of the same shape as another tensor.
1674
1675  Args:
1676      x: Keras variable or Keras tensor.
1677      dtype: dtype of returned Keras variable.
1678             `None` uses the dtype of `x`.
1679      name: name for the variable to create.
1680
1681  Returns:
1682      A Keras variable with the shape of `x` filled with zeros.
1683
1684  Example:
1685
1686  ```python
1687  from tensorflow.keras import backend as K
1688  kvar = K.variable(np.random.random((2,3)))
1689  kvar_zeros = K.zeros_like(kvar)
1690  K.eval(kvar_zeros)
1691  # array([[ 0.,  0.,  0.], [ 0.,  0.,  0.]], dtype=float32)
1692  ```
1693  """
1694  return array_ops.zeros_like(x, dtype=dtype, name=name)
1695
1696
1697@keras_export('keras.backend.ones_like')
1698@dispatch.add_dispatch_support
1699@doc_controls.do_not_generate_docs
1700def ones_like(x, dtype=None, name=None):
1701  """Instantiates an all-ones variable of the same shape as another tensor.
1702
1703  Args:
1704      x: Keras variable or tensor.
1705      dtype: String, dtype of returned Keras variable.
1706           None uses the dtype of x.
1707      name: String, name for the variable to create.
1708
1709  Returns:
1710      A Keras variable with the shape of x filled with ones.
1711
1712  Example:
1713
1714  >>> kvar = tf.keras.backend.variable(np.random.random((2,3)))
1715  >>> kvar_ones = tf.keras.backend.ones_like(kvar)
1716  >>> tf.keras.backend.eval(kvar_ones)
1717  array([[1.,  1.,  1.],
1718         [1.,  1.,  1.]], dtype=float32)
1719
1720  """
1721  return array_ops.ones_like(x, dtype=dtype, name=name)
1722
1723
1724def identity(x, name=None):
1725  """Returns a tensor with the same content as the input tensor.
1726
1727  Args:
1728      x: The input tensor.
1729      name: String, name for the variable to create.
1730
1731  Returns:
1732      A tensor of the same shape, type and content.
1733  """
1734  return array_ops.identity(x, name=name)
1735
1736
1737@keras_export('keras.backend.random_uniform_variable')
1738@doc_controls.do_not_generate_docs
1739def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1740  """Instantiates a variable with values drawn from a uniform distribution.
1741
1742  Args:
1743      shape: Tuple of integers, shape of returned Keras variable.
1744      low: Float, lower boundary of the output interval.
1745      high: Float, upper boundary of the output interval.
1746      dtype: String, dtype of returned Keras variable.
1747      name: String, name of returned Keras variable.
1748      seed: Integer, random seed.
1749
1750  Returns:
1751      A Keras variable, filled with drawn samples.
1752
1753  Example:
1754
1755  >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3),
1756  ... low=0.0, high=1.0)
1757  >>> kvar
1758  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1759  dtype=float32)>
1760  """
1761  if dtype is None:
1762    dtype = floatx()
1763  tf_dtype = dtypes_module.as_dtype(dtype)
1764  if seed is None:
1765    # ensure that randomness is conditioned by the Numpy RNG
1766    seed = np.random.randint(10e8)
1767  value = init_ops.random_uniform_initializer(
1768      low, high, dtype=tf_dtype, seed=seed)(shape)
1769  return variable(value, dtype=dtype, name=name)
1770
1771
1772@keras_export('keras.backend.random_normal_variable')
1773@doc_controls.do_not_generate_docs
1774def random_normal_variable(shape, mean, scale, dtype=None, name=None,
1775                           seed=None):
1776  """Instantiates a variable with values drawn from a normal distribution.
1777
1778  Args:
1779      shape: Tuple of integers, shape of returned Keras variable.
1780      mean: Float, mean of the normal distribution.
1781      scale: Float, standard deviation of the normal distribution.
1782      dtype: String, dtype of returned Keras variable.
1783      name: String, name of returned Keras variable.
1784      seed: Integer, random seed.
1785
1786  Returns:
1787      A Keras variable, filled with drawn samples.
1788
1789  Example:
1790
1791  >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3),
1792  ... mean=0.0, scale=1.0)
1793  >>> kvar
1794  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1795  dtype=float32)>
1796  """
1797  if dtype is None:
1798    dtype = floatx()
1799  tf_dtype = dtypes_module.as_dtype(dtype)
1800  if seed is None:
1801    # ensure that randomness is conditioned by the Numpy RNG
1802    seed = np.random.randint(10e8)
1803  value = init_ops.random_normal_initializer(
1804      mean, scale, dtype=tf_dtype, seed=seed)(shape)
1805  return variable(value, dtype=dtype, name=name)
1806
1807
1808@keras_export('keras.backend.count_params')
1809@doc_controls.do_not_generate_docs
1810def count_params(x):
1811  """Returns the static number of elements in a variable or tensor.
1812
1813  Args:
1814      x: Variable or tensor.
1815
1816  Returns:
1817      Integer, the number of scalars in `x`.
1818
1819  Example:
1820
1821  >>> kvar = tf.keras.backend.zeros((2,3))
1822  >>> tf.keras.backend.count_params(kvar)
1823  6
1824  >>> tf.keras.backend.eval(kvar)
1825  array([[0.,  0.,  0.],
1826         [0.,  0.,  0.]], dtype=float32)
1827
1828  """
1829  return np.prod(x.shape.as_list())
1830
1831
1832@keras_export('keras.backend.cast')
1833@dispatch.add_dispatch_support
1834@doc_controls.do_not_generate_docs
1835def cast(x, dtype):
1836  """Casts a tensor to a different dtype and returns it.
1837
1838  You can cast a Keras variable but it still returns a Keras tensor.
1839
1840  Args:
1841      x: Keras tensor (or variable).
1842      dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
1843
1844  Returns:
1845      Keras tensor with dtype `dtype`.
1846
1847  Examples:
1848      Cast a float32 variable to a float64 tensor
1849
1850  >>> input = tf.keras.backend.ones(shape=(1,3))
1851  >>> print(input)
1852  <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32,
1853  numpy=array([[1., 1., 1.]], dtype=float32)>
1854  >>> cast_input = tf.keras.backend.cast(input, dtype='float64')
1855  >>> print(cast_input)
1856  tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64)
1857
1858  """
1859  return math_ops.cast(x, dtype)
1860
1861
1862# UPDATES OPS
1863
1864
1865@keras_export('keras.backend.update')
1866@doc_controls.do_not_generate_docs
1867def update(x, new_x):
1868  return state_ops.assign(x, new_x)
1869
1870
1871@keras_export('keras.backend.update_add')
1872@doc_controls.do_not_generate_docs
1873def update_add(x, increment):
1874  """Update the value of `x` by adding `increment`.
1875
1876  Args:
1877      x: A Variable.
1878      increment: A tensor of same shape as `x`.
1879
1880  Returns:
1881      The variable `x` updated.
1882  """
1883  return state_ops.assign_add(x, increment)
1884
1885
1886@keras_export('keras.backend.update_sub')
1887@doc_controls.do_not_generate_docs
1888def update_sub(x, decrement):
1889  """Update the value of `x` by subtracting `decrement`.
1890
1891  Args:
1892      x: A Variable.
1893      decrement: A tensor of same shape as `x`.
1894
1895  Returns:
1896      The variable `x` updated.
1897  """
1898  return state_ops.assign_sub(x, decrement)
1899
1900
1901@keras_export('keras.backend.moving_average_update')
1902@doc_controls.do_not_generate_docs
1903def moving_average_update(x, value, momentum):
1904  """Compute the exponential moving average of a value.
1905
1906  The moving average 'x' is updated with 'value' following:
1907
1908  ```
1909  x = x * momentum + value * (1 - momentum)
1910  ```
1911
1912  For example:
1913
1914  >>> x = tf.Variable(0.0)
1915  >>> momentum=0.9
1916  >>> moving_average_update(x, value = 2.0, momentum=momentum).numpy()
1917  >>> x.numpy()
1918  0.2
1919
1920  The result will be biased towards the initial value of the variable.
1921
1922  If the variable was initialized to zero, you can divide by
1923  `1 - momentum ** num_updates` to debias it (Section 3 of
1924  [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)):
1925
1926  >>> num_updates = 1.0
1927  >>> x_zdb = x/(1 - momentum**num_updates)
1928  >>> x_zdb.numpy()
1929  2.0
1930
1931  Args:
1932      x: A Variable, the moving average.
1933      value: A tensor with the same shape as `x`, the new value to be
1934        averaged in.
1935      momentum: The moving average momentum.
1936
1937  Returns:
1938      The updated variable.
1939  """
1940  if tf2.enabled():
1941    momentum = math_ops.cast(momentum, x.dtype)
1942    value = math_ops.cast(value, x.dtype)
1943    return x.assign(x * momentum + value * (1 - momentum))
1944  else:
1945    return moving_averages.assign_moving_average(
1946        x, value, momentum, zero_debias=True)
1947
1948
1949# LINEAR ALGEBRA
1950
1951
1952@keras_export('keras.backend.dot')
1953@dispatch.add_dispatch_support
1954@doc_controls.do_not_generate_docs
1955def dot(x, y):
1956  """Multiplies 2 tensors (and/or variables) and returns a tensor.
1957
1958  This operation corresponds to `numpy.dot(a, b, out=None)`.
1959
1960  Args:
1961      x: Tensor or variable.
1962      y: Tensor or variable.
1963
1964  Returns:
1965      A tensor, dot product of `x` and `y`.
1966
1967  Examples:
1968
1969  If inputs `x` and `y` are 2-D arrays, then it is equivalent to `tf.matmul`.
1970  >>> x = tf.keras.backend.placeholder(shape=(2, 3))
1971  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1972  >>> xy = tf.keras.backend.dot(x, y)
1973  >>> xy
1974  <KerasTensor: shape=(2, 4) dtype=float32 ...>
1975
1976  >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
1977  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1978  >>> xy = tf.keras.backend.dot(x, y)
1979  >>> xy
1980  <KerasTensor: shape=(32, 28, 4) dtype=float32 ...>
1981
1982  If `x` is an N-D array and `y` is an M-D array (where M>=2), it is a sum
1983  product over the last axis of `x` and the second-to-last axis of `y`.
1984  >>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1)
1985  >>> y = tf.keras.backend.ones((4, 3, 5))
1986  >>> xy = tf.keras.backend.dot(x, y)
1987  >>> tf.keras.backend.int_shape(xy)
1988  (2, 4, 5)
1989  """
1990  if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
1991    x_shape = []
1992    for i, s in zip(int_shape(x), array_ops.unstack(array_ops.shape(x))):
1993      if i is not None:
1994        x_shape.append(i)
1995      else:
1996        x_shape.append(s)
1997    x_shape = tuple(x_shape)
1998    y_shape = []
1999    for i, s in zip(int_shape(y), array_ops.unstack(array_ops.shape(y))):
2000      if i is not None:
2001        y_shape.append(i)
2002      else:
2003        y_shape.append(s)
2004    y_shape = tuple(y_shape)
2005    y_permute_dim = list(range(ndim(y)))
2006    y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
2007    xt = array_ops.reshape(x, [-1, x_shape[-1]])
2008    yt = array_ops.reshape(
2009        array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
2010    return array_ops.reshape(
2011        math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
2012  if is_sparse(x):
2013    out = sparse_ops.sparse_tensor_dense_matmul(x, y)
2014  else:
2015    out = math_ops.matmul(x, y)
2016  return out
2017
2018
2019@keras_export('keras.backend.batch_dot')
2020@dispatch.add_dispatch_support
2021@doc_controls.do_not_generate_docs
2022def batch_dot(x, y, axes=None):
2023  """Batchwise dot product.
2024
2025  `batch_dot` is used to compute dot product of `x` and `y` when
2026  `x` and `y` are data in batch, i.e. in a shape of
2027  `(batch_size, :)`.
2028  `batch_dot` results in a tensor or variable with less dimensions
2029  than the input. If the number of dimensions is reduced to 1,
2030  we use `expand_dims` to make sure that ndim is at least 2.
2031
2032  Args:
2033    x: Keras tensor or variable with `ndim >= 2`.
2034    y: Keras tensor or variable with `ndim >= 2`.
2035    axes: Tuple or list of integers with target dimensions, or single integer.
2036      The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal.
2037
2038  Returns:
2039    A tensor with shape equal to the concatenation of `x`'s shape
2040    (less the dimension that was summed over) and `y`'s shape
2041    (less the batch dimension and the dimension that was summed over).
2042    If the final rank is 1, we reshape it to `(batch_size, 1)`.
2043
2044  Examples:
2045
2046  >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1))
2047  >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20))
2048  >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2))
2049  >>> tf.keras.backend.int_shape(xy_batch_dot)
2050  (32, 1, 30)
2051
2052  Shape inference:
2053    Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
2054    If `axes` is (1, 2), to find the output shape of resultant tensor,
2055        loop through each dimension in `x`'s shape and `y`'s shape:
2056    * `x.shape[0]` : 100 : append to output shape
2057    * `x.shape[1]` : 20 : do not append to output shape,
2058        dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
2059    * `y.shape[0]` : 100 : do not append to output shape,
2060        always ignore first dimension of `y`
2061    * `y.shape[1]` : 30 : append to output shape
2062    * `y.shape[2]` : 20 : do not append to output shape,
2063        dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
2064    `output_shape` = `(100, 30)`
2065  """
2066  x_shape = int_shape(x)
2067  y_shape = int_shape(y)
2068
2069  x_ndim = len(x_shape)
2070  y_ndim = len(y_shape)
2071
2072  if x_ndim < 2 or y_ndim < 2:
2073    raise ValueError('Cannot do batch_dot on inputs '
2074                     'with rank < 2. '
2075                     'Received inputs with shapes ' +
2076                     str(x_shape) + ' and ' +
2077                     str(y_shape) + '.')
2078
2079  x_batch_size = x_shape[0]
2080  y_batch_size = y_shape[0]
2081
2082  if x_batch_size is not None and y_batch_size is not None:
2083    if x_batch_size != y_batch_size:
2084      raise ValueError('Cannot do batch_dot on inputs '
2085                       'with different batch sizes. '
2086                       'Received inputs with shapes ' +
2087                       str(x_shape) + ' and ' +
2088                       str(y_shape) + '.')
2089  if isinstance(axes, int):
2090    axes = [axes, axes]
2091
2092  if axes is None:
2093    if y_ndim == 2:
2094      axes = [x_ndim - 1, y_ndim - 1]
2095    else:
2096      axes = [x_ndim - 1, y_ndim - 2]
2097
2098  if py_any(isinstance(a, (list, tuple)) for a in axes):
2099    raise ValueError('Multiple target dimensions are not supported. ' +
2100                     'Expected: None, int, (int, int), ' +
2101                     'Provided: ' + str(axes))
2102
2103  # if tuple, convert to list.
2104  axes = list(axes)
2105
2106  # convert negative indices.
2107  if axes[0] < 0:
2108    axes[0] += x_ndim
2109  if axes[1] < 0:
2110    axes[1] += y_ndim
2111
2112  # sanity checks
2113  if 0 in axes:
2114    raise ValueError('Cannot perform batch_dot over axis 0. '
2115                     'If your inputs are not batched, '
2116                     'add a dummy batch dimension to your '
2117                     'inputs using K.expand_dims(x, 0)')
2118  a0, a1 = axes
2119  d1 = x_shape[a0]
2120  d2 = y_shape[a1]
2121
2122  if d1 is not None and d2 is not None and d1 != d2:
2123    raise ValueError('Cannot do batch_dot on inputs with shapes ' +
2124                     str(x_shape) + ' and ' + str(y_shape) +
2125                     ' with axes=' + str(axes) + '. x.shape[%d] != '
2126                     'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))
2127
2128  # backup ndims. Need them later.
2129  orig_x_ndim = x_ndim
2130  orig_y_ndim = y_ndim
2131
2132  # if rank is 2, expand to 3.
2133  if x_ndim == 2:
2134    x = array_ops.expand_dims(x, 1)
2135    a0 += 1
2136    x_ndim += 1
2137  if y_ndim == 2:
2138    y = array_ops.expand_dims(y, 2)
2139    y_ndim += 1
2140
2141  # bring x's dimension to be reduced to last axis.
2142  if a0 != x_ndim - 1:
2143    pattern = list(range(x_ndim))
2144    for i in range(a0, x_ndim - 1):
2145      pattern[i] = pattern[i + 1]
2146    pattern[-1] = a0
2147    x = array_ops.transpose(x, pattern)
2148
2149  # bring y's dimension to be reduced to axis 1.
2150  if a1 != 1:
2151    pattern = list(range(y_ndim))
2152    for i in range(a1, 1, -1):
2153      pattern[i] = pattern[i - 1]
2154    pattern[1] = a1
2155    y = array_ops.transpose(y, pattern)
2156
2157  # normalize both inputs to rank 3.
2158  if x_ndim > 3:
2159    # squash middle dimensions of x.
2160    x_shape = shape(x)
2161    x_mid_dims = x_shape[1:-1]
2162    x_squashed_shape = array_ops.stack(
2163        [x_shape[0], -1, x_shape[-1]])
2164    x = array_ops.reshape(x, x_squashed_shape)
2165    x_squashed = True
2166  else:
2167    x_squashed = False
2168
2169  if y_ndim > 3:
2170    # squash trailing dimensions of y.
2171    y_shape = shape(y)
2172    y_trail_dims = y_shape[2:]
2173    y_squashed_shape = array_ops.stack(
2174        [y_shape[0], y_shape[1], -1])
2175    y = array_ops.reshape(y, y_squashed_shape)
2176    y_squashed = True
2177  else:
2178    y_squashed = False
2179
2180  result = math_ops.matmul(x, y)
2181
2182  # if inputs were squashed, we have to reshape the matmul output.
2183  output_shape = array_ops.shape(result)
2184  do_reshape = False
2185
2186  if x_squashed:
2187    output_shape = array_ops.concat(
2188        [output_shape[:1],
2189         x_mid_dims,
2190         output_shape[-1:]], 0)
2191    do_reshape = True
2192
2193  if y_squashed:
2194    output_shape = array_ops.concat([output_shape[:-1], y_trail_dims], 0)
2195    do_reshape = True
2196
2197  if do_reshape:
2198    result = array_ops.reshape(result, output_shape)
2199
2200  # if the inputs were originally rank 2, we remove the added 1 dim.
2201  if orig_x_ndim == 2:
2202    result = array_ops.squeeze(result, 1)
2203  elif orig_y_ndim == 2:
2204    result = array_ops.squeeze(result, -1)
2205
2206  return result
2207
2208
2209@keras_export('keras.backend.transpose')
2210@dispatch.add_dispatch_support
2211@doc_controls.do_not_generate_docs
2212def transpose(x):
2213  """Transposes a tensor and returns it.
2214
2215  Args:
2216      x: Tensor or variable.
2217
2218  Returns:
2219      A tensor.
2220
2221  Examples:
2222
2223  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2224  >>> tf.keras.backend.eval(var)
2225  array([[1.,  2.,  3.],
2226         [4.,  5.,  6.]], dtype=float32)
2227  >>> var_transposed = tf.keras.backend.transpose(var)
2228  >>> tf.keras.backend.eval(var_transposed)
2229  array([[1.,  4.],
2230         [2.,  5.],
2231         [3.,  6.]], dtype=float32)
2232  >>> input = tf.keras.backend.placeholder((2, 3))
2233  >>> input
2234  <KerasTensor: shape=(2, 3) dtype=float32 ...>
2235  >>> input_transposed = tf.keras.backend.transpose(input)
2236  >>> input_transposed
2237  <KerasTensor: shape=(3, 2) dtype=float32 ...>
2238  """
2239  return array_ops.transpose(x)
2240
2241
2242@keras_export('keras.backend.gather')
2243@dispatch.add_dispatch_support
2244@doc_controls.do_not_generate_docs
2245def gather(reference, indices):
2246  """Retrieves the elements of indices `indices` in the tensor `reference`.
2247
2248  Args:
2249      reference: A tensor.
2250      indices: An integer tensor of indices.
2251
2252  Returns:
2253      A tensor of same type as `reference`.
2254
2255  Examples:
2256
2257  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2258  >>> tf.keras.backend.eval(var)
2259  array([[1., 2., 3.],
2260         [4., 5., 6.]], dtype=float32)
2261  >>> var_gathered = tf.keras.backend.gather(var, [0])
2262  >>> tf.keras.backend.eval(var_gathered)
2263  array([[1., 2., 3.]], dtype=float32)
2264  >>> var_gathered = tf.keras.backend.gather(var, [1])
2265  >>> tf.keras.backend.eval(var_gathered)
2266  array([[4., 5., 6.]], dtype=float32)
2267  >>> var_gathered = tf.keras.backend.gather(var, [0,1,0])
2268  >>> tf.keras.backend.eval(var_gathered)
2269  array([[1., 2., 3.],
2270         [4., 5., 6.],
2271         [1., 2., 3.]], dtype=float32)
2272  """
2273  return array_ops.gather(reference, indices)
2274
2275
2276# ELEMENT-WISE OPERATIONS
2277
2278
2279@keras_export('keras.backend.max')
2280@dispatch.add_dispatch_support
2281@doc_controls.do_not_generate_docs
2282def max(x, axis=None, keepdims=False):
2283  """Maximum value in a tensor.
2284
2285  Args:
2286      x: A tensor or variable.
2287      axis: An integer, the axis to find maximum values.
2288      keepdims: A boolean, whether to keep the dimensions or not.
2289          If `keepdims` is `False`, the rank of the tensor is reduced
2290          by 1. If `keepdims` is `True`,
2291          the reduced dimension is retained with length 1.
2292
2293  Returns:
2294      A tensor with maximum values of `x`.
2295  """
2296  return math_ops.reduce_max(x, axis, keepdims)
2297
2298
2299@keras_export('keras.backend.min')
2300@dispatch.add_dispatch_support
2301@doc_controls.do_not_generate_docs
2302def min(x, axis=None, keepdims=False):
2303  """Minimum value in a tensor.
2304
2305  Args:
2306      x: A tensor or variable.
2307      axis: An integer, the axis to find minimum values.
2308      keepdims: A boolean, whether to keep the dimensions or not.
2309          If `keepdims` is `False`, the rank of the tensor is reduced
2310          by 1. If `keepdims` is `True`,
2311          the reduced dimension is retained with length 1.
2312
2313  Returns:
2314      A tensor with minimum values of `x`.
2315  """
2316  return math_ops.reduce_min(x, axis, keepdims)
2317
2318
2319@keras_export('keras.backend.sum')
2320@dispatch.add_dispatch_support
2321@doc_controls.do_not_generate_docs
2322def sum(x, axis=None, keepdims=False):
2323  """Sum of the values in a tensor, alongside the specified axis.
2324
2325  Args:
2326      x: A tensor or variable.
2327      axis: An integer, the axis to sum over.
2328      keepdims: A boolean, whether to keep the dimensions or not.
2329          If `keepdims` is `False`, the rank of the tensor is reduced
2330          by 1. If `keepdims` is `True`,
2331          the reduced dimension is retained with length 1.
2332
2333  Returns:
2334      A tensor with sum of `x`.
2335  """
2336  return math_ops.reduce_sum(x, axis, keepdims)
2337
2338
2339@keras_export('keras.backend.prod')
2340@dispatch.add_dispatch_support
2341@doc_controls.do_not_generate_docs
2342def prod(x, axis=None, keepdims=False):
2343  """Multiplies the values in a tensor, alongside the specified axis.
2344
2345  Args:
2346      x: A tensor or variable.
2347      axis: An integer, the axis to compute the product.
2348      keepdims: A boolean, whether to keep the dimensions or not.
2349          If `keepdims` is `False`, the rank of the tensor is reduced
2350          by 1. If `keepdims` is `True`,
2351          the reduced dimension is retained with length 1.
2352
2353  Returns:
2354      A tensor with the product of elements of `x`.
2355  """
2356  return math_ops.reduce_prod(x, axis, keepdims)
2357
2358
2359@keras_export('keras.backend.cumsum')
2360@dispatch.add_dispatch_support
2361@doc_controls.do_not_generate_docs
2362def cumsum(x, axis=0):
2363  """Cumulative sum of the values in a tensor, alongside the specified axis.
2364
2365  Args:
2366      x: A tensor or variable.
2367      axis: An integer, the axis to compute the sum.
2368
2369  Returns:
2370      A tensor of the cumulative sum of values of `x` along `axis`.
2371  """
2372  return math_ops.cumsum(x, axis=axis)
2373
2374
2375@keras_export('keras.backend.cumprod')
2376@dispatch.add_dispatch_support
2377@doc_controls.do_not_generate_docs
2378def cumprod(x, axis=0):
2379  """Cumulative product of the values in a tensor, alongside the specified axis.
2380
2381  Args:
2382      x: A tensor or variable.
2383      axis: An integer, the axis to compute the product.
2384
2385  Returns:
2386      A tensor of the cumulative product of values of `x` along `axis`.
2387  """
2388  return math_ops.cumprod(x, axis=axis)
2389
2390
2391@keras_export('keras.backend.var')
2392@doc_controls.do_not_generate_docs
2393def var(x, axis=None, keepdims=False):
2394  """Variance of a tensor, alongside the specified axis.
2395
2396  Args:
2397      x: A tensor or variable.
2398      axis: An integer, the axis to compute the variance.
2399      keepdims: A boolean, whether to keep the dimensions or not.
2400          If `keepdims` is `False`, the rank of the tensor is reduced
2401          by 1. If `keepdims` is `True`,
2402          the reduced dimension is retained with length 1.
2403
2404  Returns:
2405      A tensor with the variance of elements of `x`.
2406  """
2407  if x.dtype.base_dtype == dtypes_module.bool:
2408    x = math_ops.cast(x, floatx())
2409  return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
2410
2411
2412@keras_export('keras.backend.std')
2413@dispatch.add_dispatch_support
2414@doc_controls.do_not_generate_docs
2415def std(x, axis=None, keepdims=False):
2416  """Standard deviation of a tensor, alongside the specified axis.
2417
2418  It is an alias to `tf.math.reduce_std`.
2419
2420  Args:
2421      x: A tensor or variable. It should have numerical dtypes. Boolean type
2422        inputs will be converted to float.
2423      axis: An integer, the axis to compute the standard deviation. If `None`
2424        (the default), reduces all dimensions. Must be in the range
2425        `[-rank(x), rank(x))`.
2426      keepdims: A boolean, whether to keep the dimensions or not.
2427          If `keepdims` is `False`, the rank of the tensor is reduced
2428          by 1. If `keepdims` is `True`, the reduced dimension is retained with
2429          length 1.
2430
2431  Returns:
2432      A tensor with the standard deviation of elements of `x` with same dtype.
2433      Boolean type input will be converted to float.
2434  """
2435  if x.dtype.base_dtype == dtypes_module.bool:
2436    x = math_ops.cast(x, floatx())
2437  return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
2438
2439
2440@keras_export('keras.backend.mean')
2441@dispatch.add_dispatch_support
2442@doc_controls.do_not_generate_docs
2443def mean(x, axis=None, keepdims=False):
2444  """Mean of a tensor, alongside the specified axis.
2445
2446  Args:
2447      x: A tensor or variable.
2448      axis: A list of integer. Axes to compute the mean.
2449      keepdims: A boolean, whether to keep the dimensions or not.
2450          If `keepdims` is `False`, the rank of the tensor is reduced
2451          by 1 for each entry in `axis`. If `keepdims` is `True`,
2452          the reduced dimensions are retained with length 1.
2453
2454  Returns:
2455      A tensor with the mean of elements of `x`.
2456  """
2457  if x.dtype.base_dtype == dtypes_module.bool:
2458    x = math_ops.cast(x, floatx())
2459  return math_ops.reduce_mean(x, axis, keepdims)
2460
2461
2462@keras_export('keras.backend.any')
2463@dispatch.add_dispatch_support
2464@doc_controls.do_not_generate_docs
2465def any(x, axis=None, keepdims=False):
2466  """Bitwise reduction (logical OR).
2467
2468  Args:
2469      x: Tensor or variable.
2470      axis: axis along which to perform the reduction.
2471      keepdims: whether the drop or broadcast the reduction axes.
2472
2473  Returns:
2474      A uint8 tensor (0s and 1s).
2475  """
2476  x = math_ops.cast(x, dtypes_module.bool)
2477  return math_ops.reduce_any(x, axis, keepdims)
2478
2479
2480@keras_export('keras.backend.all')
2481@dispatch.add_dispatch_support
2482@doc_controls.do_not_generate_docs
2483def all(x, axis=None, keepdims=False):
2484  """Bitwise reduction (logical AND).
2485
2486  Args:
2487      x: Tensor or variable.
2488      axis: axis along which to perform the reduction.
2489      keepdims: whether the drop or broadcast the reduction axes.
2490
2491  Returns:
2492      A uint8 tensor (0s and 1s).
2493  """
2494  x = math_ops.cast(x, dtypes_module.bool)
2495  return math_ops.reduce_all(x, axis, keepdims)
2496
2497
2498@keras_export('keras.backend.argmax')
2499@dispatch.add_dispatch_support
2500@doc_controls.do_not_generate_docs
2501def argmax(x, axis=-1):
2502  """Returns the index of the maximum value along an axis.
2503
2504  Args:
2505      x: Tensor or variable.
2506      axis: axis along which to perform the reduction.
2507
2508  Returns:
2509      A tensor.
2510  """
2511  return math_ops.argmax(x, axis)
2512
2513
2514@keras_export('keras.backend.argmin')
2515@dispatch.add_dispatch_support
2516@doc_controls.do_not_generate_docs
2517def argmin(x, axis=-1):
2518  """Returns the index of the minimum value along an axis.
2519
2520  Args:
2521      x: Tensor or variable.
2522      axis: axis along which to perform the reduction.
2523
2524  Returns:
2525      A tensor.
2526  """
2527  return math_ops.argmin(x, axis)
2528
2529
2530@keras_export('keras.backend.square')
2531@dispatch.add_dispatch_support
2532@doc_controls.do_not_generate_docs
2533def square(x):
2534  """Element-wise square.
2535
2536  Args:
2537      x: Tensor or variable.
2538
2539  Returns:
2540      A tensor.
2541  """
2542  return math_ops.square(x)
2543
2544
2545@keras_export('keras.backend.abs')
2546@dispatch.add_dispatch_support
2547@doc_controls.do_not_generate_docs
2548def abs(x):
2549  """Element-wise absolute value.
2550
2551  Args:
2552      x: Tensor or variable.
2553
2554  Returns:
2555      A tensor.
2556  """
2557  return math_ops.abs(x)
2558
2559
2560@keras_export('keras.backend.sqrt')
2561@dispatch.add_dispatch_support
2562@doc_controls.do_not_generate_docs
2563def sqrt(x):
2564  """Element-wise square root.
2565
2566     This function clips negative tensor values to 0 before computing the
2567     square root.
2568
2569  Args:
2570      x: Tensor or variable.
2571
2572  Returns:
2573      A tensor.
2574  """
2575  zero = _constant_to_tensor(0., x.dtype.base_dtype)
2576  x = math_ops.maximum(x, zero)
2577  return math_ops.sqrt(x)
2578
2579
2580@keras_export('keras.backend.exp')
2581@dispatch.add_dispatch_support
2582@doc_controls.do_not_generate_docs
2583def exp(x):
2584  """Element-wise exponential.
2585
2586  Args:
2587      x: Tensor or variable.
2588
2589  Returns:
2590      A tensor.
2591  """
2592  return math_ops.exp(x)
2593
2594
2595@keras_export('keras.backend.log')
2596@dispatch.add_dispatch_support
2597@doc_controls.do_not_generate_docs
2598def log(x):
2599  """Element-wise log.
2600
2601  Args:
2602      x: Tensor or variable.
2603
2604  Returns:
2605      A tensor.
2606  """
2607  return math_ops.log(x)
2608
2609
2610def logsumexp(x, axis=None, keepdims=False):
2611  """Computes log(sum(exp(elements across dimensions of a tensor))).
2612
2613  This function is more numerically stable than log(sum(exp(x))).
2614  It avoids overflows caused by taking the exp of large inputs and
2615  underflows caused by taking the log of small inputs.
2616
2617  Args:
2618      x: A tensor or variable.
2619      axis: An integer, the axis to reduce over.
2620      keepdims: A boolean, whether to keep the dimensions or not.
2621          If `keepdims` is `False`, the rank of the tensor is reduced
2622          by 1. If `keepdims` is `True`, the reduced dimension is
2623          retained with length 1.
2624
2625  Returns:
2626      The reduced tensor.
2627  """
2628  return math_ops.reduce_logsumexp(x, axis, keepdims)
2629
2630
2631@keras_export('keras.backend.round')
2632@dispatch.add_dispatch_support
2633@doc_controls.do_not_generate_docs
2634def round(x):
2635  """Element-wise rounding to the closest integer.
2636
2637  In case of tie, the rounding mode used is "half to even".
2638
2639  Args:
2640      x: Tensor or variable.
2641
2642  Returns:
2643      A tensor.
2644  """
2645  return math_ops.round(x)
2646
2647
2648@keras_export('keras.backend.sign')
2649@dispatch.add_dispatch_support
2650@doc_controls.do_not_generate_docs
2651def sign(x):
2652  """Element-wise sign.
2653
2654  Args:
2655      x: Tensor or variable.
2656
2657  Returns:
2658      A tensor.
2659  """
2660  return math_ops.sign(x)
2661
2662
2663@keras_export('keras.backend.pow')
2664@dispatch.add_dispatch_support
2665@doc_controls.do_not_generate_docs
2666def pow(x, a):
2667  """Element-wise exponentiation.
2668
2669  Args:
2670      x: Tensor or variable.
2671      a: Python integer.
2672
2673  Returns:
2674      A tensor.
2675  """
2676  return math_ops.pow(x, a)
2677
2678
2679@keras_export('keras.backend.clip')
2680@dispatch.add_dispatch_support
2681@doc_controls.do_not_generate_docs
2682def clip(x, min_value, max_value):
2683  """Element-wise value clipping.
2684
2685  Args:
2686      x: Tensor or variable.
2687      min_value: Python float, integer, or tensor.
2688      max_value: Python float, integer, or tensor.
2689
2690  Returns:
2691      A tensor.
2692  """
2693  if (isinstance(min_value, (int, float)) and
2694      isinstance(max_value, (int, float))):
2695    if max_value < min_value:
2696      max_value = min_value
2697  if min_value is None:
2698    min_value = -np.inf
2699  if max_value is None:
2700    max_value = np.inf
2701  return clip_ops.clip_by_value(x, min_value, max_value)
2702
2703
2704@keras_export('keras.backend.equal')
2705@dispatch.add_dispatch_support
2706@doc_controls.do_not_generate_docs
2707def equal(x, y):
2708  """Element-wise equality between two tensors.
2709
2710  Args:
2711      x: Tensor or variable.
2712      y: Tensor or variable.
2713
2714  Returns:
2715      A bool tensor.
2716  """
2717  return math_ops.equal(x, y)
2718
2719
2720@keras_export('keras.backend.not_equal')
2721@dispatch.add_dispatch_support
2722@doc_controls.do_not_generate_docs
2723def not_equal(x, y):
2724  """Element-wise inequality between two tensors.
2725
2726  Args:
2727      x: Tensor or variable.
2728      y: Tensor or variable.
2729
2730  Returns:
2731      A bool tensor.
2732  """
2733  return math_ops.not_equal(x, y)
2734
2735
2736@keras_export('keras.backend.greater')
2737@dispatch.add_dispatch_support
2738@doc_controls.do_not_generate_docs
2739def greater(x, y):
2740  """Element-wise truth value of (x > y).
2741
2742  Args:
2743      x: Tensor or variable.
2744      y: Tensor or variable.
2745
2746  Returns:
2747      A bool tensor.
2748  """
2749  return math_ops.greater(x, y)
2750
2751
2752@keras_export('keras.backend.greater_equal')
2753@dispatch.add_dispatch_support
2754@doc_controls.do_not_generate_docs
2755def greater_equal(x, y):
2756  """Element-wise truth value of (x >= y).
2757
2758  Args:
2759      x: Tensor or variable.
2760      y: Tensor or variable.
2761
2762  Returns:
2763      A bool tensor.
2764  """
2765  return math_ops.greater_equal(x, y)
2766
2767
2768@keras_export('keras.backend.less')
2769@dispatch.add_dispatch_support
2770@doc_controls.do_not_generate_docs
2771def less(x, y):
2772  """Element-wise truth value of (x < y).
2773
2774  Args:
2775      x: Tensor or variable.
2776      y: Tensor or variable.
2777
2778  Returns:
2779      A bool tensor.
2780  """
2781  return math_ops.less(x, y)
2782
2783
2784@keras_export('keras.backend.less_equal')
2785@dispatch.add_dispatch_support
2786@doc_controls.do_not_generate_docs
2787def less_equal(x, y):
2788  """Element-wise truth value of (x <= y).
2789
2790  Args:
2791      x: Tensor or variable.
2792      y: Tensor or variable.
2793
2794  Returns:
2795      A bool tensor.
2796  """
2797  return math_ops.less_equal(x, y)
2798
2799
2800@keras_export('keras.backend.maximum')
2801@dispatch.add_dispatch_support
2802@doc_controls.do_not_generate_docs
2803def maximum(x, y):
2804  """Element-wise maximum of two tensors.
2805
2806  Args:
2807      x: Tensor or variable.
2808      y: Tensor or variable.
2809
2810  Returns:
2811      A tensor with the element wise maximum value(s) of `x` and `y`.
2812
2813  Examples:
2814
2815  >>> x = tf.Variable([[1, 2], [3, 4]])
2816  >>> y = tf.Variable([[2, 1], [0, -1]])
2817  >>> m = tf.keras.backend.maximum(x, y)
2818  >>> m
2819  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2820  array([[2, 2],
2821         [3, 4]], dtype=int32)>
2822  """
2823  return math_ops.maximum(x, y)
2824
2825
2826@keras_export('keras.backend.minimum')
2827@dispatch.add_dispatch_support
2828@doc_controls.do_not_generate_docs
2829def minimum(x, y):
2830  """Element-wise minimum of two tensors.
2831
2832  Args:
2833      x: Tensor or variable.
2834      y: Tensor or variable.
2835
2836  Returns:
2837      A tensor.
2838  """
2839  return math_ops.minimum(x, y)
2840
2841
2842@keras_export('keras.backend.sin')
2843@dispatch.add_dispatch_support
2844@doc_controls.do_not_generate_docs
2845def sin(x):
2846  """Computes sin of x element-wise.
2847
2848  Args:
2849      x: Tensor or variable.
2850
2851  Returns:
2852      A tensor.
2853  """
2854  return math_ops.sin(x)
2855
2856
2857@keras_export('keras.backend.cos')
2858@dispatch.add_dispatch_support
2859@doc_controls.do_not_generate_docs
2860def cos(x):
2861  """Computes cos of x element-wise.
2862
2863  Args:
2864      x: Tensor or variable.
2865
2866  Returns:
2867      A tensor.
2868  """
2869  return math_ops.cos(x)
2870
2871
2872def _regular_normalize_batch_in_training(x,
2873                                         gamma,
2874                                         beta,
2875                                         reduction_axes,
2876                                         epsilon=1e-3):
2877  """Non-fused version of `normalize_batch_in_training`.
2878
2879  Args:
2880      x: Input tensor or variable.
2881      gamma: Tensor by which to scale the input.
2882      beta: Tensor with which to center the input.
2883      reduction_axes: iterable of integers,
2884          axes over which to normalize.
2885      epsilon: Fuzz factor.
2886
2887  Returns:
2888      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2889  """
2890  mean, var = nn.moments(x, reduction_axes, None, None, False)
2891  normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2892  return normed, mean, var
2893
2894
2895def _broadcast_normalize_batch_in_training(x,
2896                                           gamma,
2897                                           beta,
2898                                           reduction_axes,
2899                                           epsilon=1e-3):
2900  """Non-fused, broadcast version of `normalize_batch_in_training`.
2901
2902  Args:
2903      x: Input tensor or variable.
2904      gamma: Tensor by which to scale the input.
2905      beta: Tensor with which to center the input.
2906      reduction_axes: iterable of integers,
2907          axes over which to normalize.
2908      epsilon: Fuzz factor.
2909
2910  Returns:
2911      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2912  """
2913  mean, var = nn.moments(x, reduction_axes, None, None, False)
2914  target_shape = []
2915  for axis in range(ndim(x)):
2916    if axis in reduction_axes:
2917      target_shape.append(1)
2918    else:
2919      target_shape.append(array_ops.shape(x)[axis])
2920  target_shape = array_ops.stack(target_shape)
2921
2922  broadcast_mean = array_ops.reshape(mean, target_shape)
2923  broadcast_var = array_ops.reshape(var, target_shape)
2924  if gamma is None:
2925    broadcast_gamma = None
2926  else:
2927    broadcast_gamma = array_ops.reshape(gamma, target_shape)
2928  if beta is None:
2929    broadcast_beta = None
2930  else:
2931    broadcast_beta = array_ops.reshape(beta, target_shape)
2932
2933  normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
2934                                  broadcast_beta, broadcast_gamma, epsilon)
2935  return normed, mean, var
2936
2937
2938def _fused_normalize_batch_in_training(x,
2939                                       gamma,
2940                                       beta,
2941                                       reduction_axes,
2942                                       epsilon=1e-3):
2943  """Fused version of `normalize_batch_in_training`.
2944
2945  Args:
2946      x: Input tensor or variable.
2947      gamma: Tensor by which to scale the input.
2948      beta: Tensor with which to center the input.
2949      reduction_axes: iterable of integers,
2950          axes over which to normalize.
2951      epsilon: Fuzz factor.
2952
2953  Returns:
2954      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2955  """
2956  if list(reduction_axes) == [0, 1, 2]:
2957    normalization_axis = 3
2958    tf_data_format = 'NHWC'
2959  else:
2960    normalization_axis = 1
2961    tf_data_format = 'NCHW'
2962
2963  if gamma is None:
2964    gamma = constant_op.constant(
2965        1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2966  if beta is None:
2967    beta = constant_op.constant(
2968        0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2969
2970  return nn.fused_batch_norm(
2971      x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
2972
2973
2974@keras_export('keras.backend.normalize_batch_in_training')
2975@doc_controls.do_not_generate_docs
2976def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
2977  """Computes mean and std for batch then apply batch_normalization on batch.
2978
2979  Args:
2980      x: Input tensor or variable.
2981      gamma: Tensor by which to scale the input.
2982      beta: Tensor with which to center the input.
2983      reduction_axes: iterable of integers,
2984          axes over which to normalize.
2985      epsilon: Fuzz factor.
2986
2987  Returns:
2988      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2989  """
2990  if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
2991    if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
2992      return _broadcast_normalize_batch_in_training(
2993          x, gamma, beta, reduction_axes, epsilon=epsilon)
2994    return _fused_normalize_batch_in_training(
2995        x, gamma, beta, reduction_axes, epsilon=epsilon)
2996  else:
2997    if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
2998      return _regular_normalize_batch_in_training(
2999          x, gamma, beta, reduction_axes, epsilon=epsilon)
3000    else:
3001      return _broadcast_normalize_batch_in_training(
3002          x, gamma, beta, reduction_axes, epsilon=epsilon)
3003
3004
3005@keras_export('keras.backend.batch_normalization')
3006@dispatch.add_dispatch_support
3007@doc_controls.do_not_generate_docs
3008def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
3009  """Applies batch normalization on x given mean, var, beta and gamma.
3010
3011  I.e. returns:
3012  `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
3013
3014  Args:
3015      x: Input tensor or variable.
3016      mean: Mean of batch.
3017      var: Variance of batch.
3018      beta: Tensor with which to center the input.
3019      gamma: Tensor by which to scale the input.
3020      axis: Integer, the axis that should be normalized.
3021          (typically the features axis).
3022      epsilon: Fuzz factor.
3023
3024  Returns:
3025      A tensor.
3026  """
3027  if ndim(x) == 4:
3028    # The CPU implementation of `fused_batch_norm` only supports NHWC
3029    if axis == 1 or axis == -3:
3030      tf_data_format = 'NCHW'
3031    elif axis == 3 or axis == -1:
3032      tf_data_format = 'NHWC'
3033    else:
3034      tf_data_format = None
3035
3036    if (tf_data_format == 'NHWC' or
3037        tf_data_format == 'NCHW' and _has_nchw_support()):
3038      # The mean / var / beta / gamma tensors may be broadcasted
3039      # so they may have extra axes of size 1, which should be squeezed.
3040      if ndim(mean) > 1:
3041        mean = array_ops.reshape(mean, [-1])
3042      if ndim(var) > 1:
3043        var = array_ops.reshape(var, [-1])
3044      if beta is None:
3045        beta = zeros_like(mean)
3046      elif ndim(beta) > 1:
3047        beta = array_ops.reshape(beta, [-1])
3048      if gamma is None:
3049        gamma = ones_like(mean)
3050      elif ndim(gamma) > 1:
3051        gamma = array_ops.reshape(gamma, [-1])
3052    y, _, _ = nn.fused_batch_norm(
3053        x,
3054        gamma,
3055        beta,
3056        epsilon=epsilon,
3057        mean=mean,
3058        variance=var,
3059        data_format=tf_data_format,
3060        is_training=False
3061    )
3062    return y
3063  return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
3064
3065
3066# SHAPE OPERATIONS
3067
3068
3069@keras_export('keras.backend.concatenate')
3070@dispatch.add_dispatch_support
3071@doc_controls.do_not_generate_docs
3072def concatenate(tensors, axis=-1):
3073  """Concatenates a list of tensors alongside the specified axis.
3074
3075  Args:
3076      tensors: list of tensors to concatenate.
3077      axis: concatenation axis.
3078
3079  Returns:
3080      A tensor.
3081
3082  Example:
3083
3084      >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
3085      >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
3086      >>> tf.keras.backend.concatenate((a, b), axis=-1)
3087      <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
3088      array([[ 1,  2,  3, 10, 20, 30],
3089             [ 4,  5,  6, 40, 50, 60],
3090             [ 7,  8,  9, 70, 80, 90]], dtype=int32)>
3091
3092  """
3093  if axis < 0:
3094    rank = ndim(tensors[0])
3095    if rank:
3096      axis %= rank
3097    else:
3098      axis = 0
3099
3100  if py_all(is_sparse(x) for x in tensors):
3101    return sparse_ops.sparse_concat(axis, tensors)
3102  elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
3103    return array_ops.concat(tensors, axis)
3104  else:
3105    return array_ops.concat([to_dense(x) for x in tensors], axis)
3106
3107
3108@keras_export('keras.backend.reshape')
3109@dispatch.add_dispatch_support
3110@doc_controls.do_not_generate_docs
3111def reshape(x, shape):
3112  """Reshapes a tensor to the specified shape.
3113
3114  Args:
3115      x: Tensor or variable.
3116      shape: Target shape tuple.
3117
3118  Returns:
3119      A tensor.
3120
3121  Example:
3122
3123    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3124    >>> a
3125    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3126    array([[ 1,  2,  3],
3127           [ 4,  5,  6],
3128           [ 7,  8,  9],
3129           [10, 11, 12]], dtype=int32)>
3130    >>> tf.keras.backend.reshape(a, shape=(2, 6))
3131    <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
3132    array([[ 1,  2,  3,  4,  5,  6],
3133           [ 7,  8,  9, 10, 11, 12]], dtype=int32)>
3134
3135  """
3136  return array_ops.reshape(x, shape)
3137
3138
3139@keras_export('keras.backend.permute_dimensions')
3140@dispatch.add_dispatch_support
3141@doc_controls.do_not_generate_docs
3142def permute_dimensions(x, pattern):
3143  """Permutes axes in a tensor.
3144
3145  Args:
3146      x: Tensor or variable.
3147      pattern: A tuple of
3148          dimension indices, e.g. `(0, 2, 1)`.
3149
3150  Returns:
3151      A tensor.
3152
3153  Example:
3154
3155    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3156    >>> a
3157    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3158    array([[ 1,  2,  3],
3159           [ 4,  5,  6],
3160           [ 7,  8,  9],
3161           [10, 11, 12]], dtype=int32)>
3162    >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
3163    <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3164    array([[ 1,  4,  7, 10],
3165           [ 2,  5,  8, 11],
3166           [ 3,  6,  9, 12]], dtype=int32)>
3167
3168  """
3169  return array_ops.transpose(x, perm=pattern)
3170
3171
3172@keras_export('keras.backend.resize_images')
3173@dispatch.add_dispatch_support
3174@doc_controls.do_not_generate_docs
3175def resize_images(x, height_factor, width_factor, data_format,
3176                  interpolation='nearest'):
3177  """Resizes the images contained in a 4D tensor.
3178
3179  Args:
3180      x: Tensor or variable to resize.
3181      height_factor: Positive integer.
3182      width_factor: Positive integer.
3183      data_format: One of `"channels_first"`, `"channels_last"`.
3184      interpolation: A string, one of `nearest` or `bilinear`.
3185
3186  Returns:
3187      A tensor.
3188
3189  Raises:
3190      ValueError: in case of incorrect value for
3191        `data_format` or `interpolation`.
3192  """
3193  if data_format == 'channels_first':
3194    rows, cols = 2, 3
3195  elif data_format == 'channels_last':
3196    rows, cols = 1, 2
3197  else:
3198    raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
3199
3200  new_shape = x.shape[rows:cols + 1]
3201  if new_shape.is_fully_defined():
3202    new_shape = constant_op.constant(new_shape.as_list(), dtype='int32')
3203  else:
3204    new_shape = array_ops.shape_v2(x)[rows:cols + 1]
3205  new_shape *= constant_op.constant(
3206      np.array([height_factor, width_factor], dtype='int32'))
3207
3208  if data_format == 'channels_first':
3209    x = permute_dimensions(x, [0, 2, 3, 1])
3210  if interpolation == 'nearest':
3211    x = image_ops.resize_images_v2(
3212        x, new_shape, method=image_ops.ResizeMethod.NEAREST_NEIGHBOR)
3213  elif interpolation == 'bilinear':
3214    x = image_ops.resize_images_v2(x, new_shape,
3215                                   method=image_ops.ResizeMethod.BILINEAR)
3216  else:
3217    raise ValueError('interpolation should be one '
3218                     'of "nearest" or "bilinear".')
3219  if data_format == 'channels_first':
3220    x = permute_dimensions(x, [0, 3, 1, 2])
3221
3222  return x
3223
3224
3225@keras_export('keras.backend.resize_volumes')
3226@dispatch.add_dispatch_support
3227@doc_controls.do_not_generate_docs
3228def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
3229  """Resizes the volume contained in a 5D tensor.
3230
3231  Args:
3232      x: Tensor or variable to resize.
3233      depth_factor: Positive integer.
3234      height_factor: Positive integer.
3235      width_factor: Positive integer.
3236      data_format: One of `"channels_first"`, `"channels_last"`.
3237
3238  Returns:
3239      A tensor.
3240
3241  Raises:
3242      ValueError: if `data_format` is neither
3243          `channels_last` or `channels_first`.
3244  """
3245  if data_format == 'channels_first':
3246    output = repeat_elements(x, depth_factor, axis=2)
3247    output = repeat_elements(output, height_factor, axis=3)
3248    output = repeat_elements(output, width_factor, axis=4)
3249    return output
3250  elif data_format == 'channels_last':
3251    output = repeat_elements(x, depth_factor, axis=1)
3252    output = repeat_elements(output, height_factor, axis=2)
3253    output = repeat_elements(output, width_factor, axis=3)
3254    return output
3255  else:
3256    raise ValueError('Invalid data_format: ' + str(data_format))
3257
3258
3259@keras_export('keras.backend.repeat_elements')
3260@dispatch.add_dispatch_support
3261@doc_controls.do_not_generate_docs
3262def repeat_elements(x, rep, axis):
3263  """Repeats the elements of a tensor along an axis, like `np.repeat`.
3264
3265  If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
3266  will have shape `(s1, s2 * rep, s3)`.
3267
3268  Args:
3269      x: Tensor or variable.
3270      rep: Python integer, number of times to repeat.
3271      axis: Axis along which to repeat.
3272
3273  Returns:
3274      A tensor.
3275
3276  Example:
3277
3278      >>> b = tf.constant([1, 2, 3])
3279      >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
3280      <tf.Tensor: shape=(6,), dtype=int32,
3281          numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
3282
3283  """
3284  x_shape = x.shape.as_list()
3285  # For static axis
3286  if x_shape[axis] is not None:
3287    # slices along the repeat axis
3288    splits = array_ops.split(value=x,
3289                             num_or_size_splits=x_shape[axis],
3290                             axis=axis)
3291    # repeat each slice the given number of reps
3292    x_rep = [s for s in splits for _ in range(rep)]
3293    return concatenate(x_rep, axis)
3294
3295  # Here we use tf.tile to mimic behavior of np.repeat so that
3296  # we can handle dynamic shapes (that include None).
3297  # To do that, we need an auxiliary axis to repeat elements along
3298  # it and then merge them along the desired axis.
3299
3300  # Repeating
3301  auxiliary_axis = axis + 1
3302  x_shape = array_ops.shape(x)
3303  x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
3304  reps = np.ones(len(x.shape) + 1)
3305  reps[auxiliary_axis] = rep
3306  x_rep = array_ops.tile(x_rep, reps)
3307
3308  # Merging
3309  reps = np.delete(reps, auxiliary_axis)
3310  reps[axis] = rep
3311  reps = array_ops.constant(reps, dtype='int32')
3312  x_shape *= reps
3313  x_rep = array_ops.reshape(x_rep, x_shape)
3314
3315  # Fix shape representation
3316  x_shape = x.shape.as_list()
3317  x_rep.set_shape(x_shape)
3318  x_rep._keras_shape = tuple(x_shape)
3319  return x_rep
3320
3321
3322@keras_export('keras.backend.repeat')
3323@dispatch.add_dispatch_support
3324@doc_controls.do_not_generate_docs
3325def repeat(x, n):
3326  """Repeats a 2D tensor.
3327
3328  if `x` has shape (samples, dim) and `n` is `2`,
3329  the output will have shape `(samples, 2, dim)`.
3330
3331  Args:
3332      x: Tensor or variable.
3333      n: Python integer, number of times to repeat.
3334
3335  Returns:
3336      A tensor.
3337
3338  Example:
3339
3340      >>> b = tf.constant([[1, 2], [3, 4]])
3341      >>> b
3342      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3343      array([[1, 2],
3344             [3, 4]], dtype=int32)>
3345      >>> tf.keras.backend.repeat(b, n=2)
3346      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3347      array([[[1, 2],
3348              [1, 2]],
3349             [[3, 4],
3350              [3, 4]]], dtype=int32)>
3351
3352  """
3353  assert ndim(x) == 2
3354  x = array_ops.expand_dims(x, 1)
3355  pattern = array_ops.stack([1, n, 1])
3356  return array_ops.tile(x, pattern)
3357
3358
3359@keras_export('keras.backend.arange')
3360@dispatch.add_dispatch_support
3361@doc_controls.do_not_generate_docs
3362def arange(start, stop=None, step=1, dtype='int32'):
3363  """Creates a 1D tensor containing a sequence of integers.
3364
3365  The function arguments use the same convention as
3366  Theano's arange: if only one argument is provided,
3367  it is in fact the "stop" argument and "start" is 0.
3368
3369  The default type of the returned tensor is `'int32'` to
3370  match TensorFlow's default.
3371
3372  Args:
3373      start: Start value.
3374      stop: Stop value.
3375      step: Difference between two successive values.
3376      dtype: Integer dtype to use.
3377
3378  Returns:
3379      An integer tensor.
3380
3381  Example:
3382
3383      >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
3384      <tf.Tensor: shape=(7,), dtype=float32,
3385          numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
3386
3387
3388
3389  """
3390  # Match the behavior of numpy and Theano by returning an empty sequence.
3391  if stop is None and start < 0:
3392    start = 0
3393  result = math_ops.range(start, limit=stop, delta=step, name='arange')
3394  if dtype != 'int32':
3395    result = cast(result, dtype)
3396  return result
3397
3398
3399@keras_export('keras.backend.tile')
3400@dispatch.add_dispatch_support
3401@doc_controls.do_not_generate_docs
3402def tile(x, n):
3403  """Creates a tensor by tiling `x` by `n`.
3404
3405  Args:
3406      x: A tensor or variable
3407      n: A list of integer. The length must be the same as the number of
3408          dimensions in `x`.
3409
3410  Returns:
3411      A tiled tensor.
3412  """
3413  if isinstance(n, int):
3414    n = [n]
3415  return array_ops.tile(x, n)
3416
3417
3418@keras_export('keras.backend.flatten')
3419@dispatch.add_dispatch_support
3420@doc_controls.do_not_generate_docs
3421def flatten(x):
3422  """Flatten a tensor.
3423
3424  Args:
3425      x: A tensor or variable.
3426
3427  Returns:
3428      A tensor, reshaped into 1-D
3429
3430  Example:
3431
3432      >>> b = tf.constant([[1, 2], [3, 4]])
3433      >>> b
3434      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3435      array([[1, 2],
3436             [3, 4]], dtype=int32)>
3437      >>> tf.keras.backend.flatten(b)
3438      <tf.Tensor: shape=(4,), dtype=int32,
3439          numpy=array([1, 2, 3, 4], dtype=int32)>
3440
3441  """
3442  return array_ops.reshape(x, [-1])
3443
3444
3445@keras_export('keras.backend.batch_flatten')
3446@dispatch.add_dispatch_support
3447@doc_controls.do_not_generate_docs
3448def batch_flatten(x):
3449  """Turn a nD tensor into a 2D tensor with same 0th dimension.
3450
3451  In other words, it flattens each data samples of a batch.
3452
3453  Args:
3454      x: A tensor or variable.
3455
3456  Returns:
3457      A tensor.
3458
3459  Examples:
3460    Flattening a 3D tensor to 2D by collapsing the last dimension.
3461
3462  >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5))
3463  >>> x_batch_flatten = batch_flatten(x_batch)
3464  >>> tf.keras.backend.int_shape(x_batch_flatten)
3465  (2, 60)
3466
3467  """
3468  x = array_ops.reshape(x, array_ops.stack([-1, prod(shape(x)[1:])]))
3469  return x
3470
3471
3472@keras_export('keras.backend.expand_dims')
3473@dispatch.add_dispatch_support
3474@doc_controls.do_not_generate_docs
3475def expand_dims(x, axis=-1):
3476  """Adds a 1-sized dimension at index "axis".
3477
3478  Args:
3479      x: A tensor or variable.
3480      axis: Position where to add a new axis.
3481
3482  Returns:
3483      A tensor with expanded dimensions.
3484  """
3485  return array_ops.expand_dims(x, axis)
3486
3487
3488@keras_export('keras.backend.squeeze')
3489@dispatch.add_dispatch_support
3490@doc_controls.do_not_generate_docs
3491def squeeze(x, axis):
3492  """Removes a 1-dimension from the tensor at index "axis".
3493
3494  Args:
3495      x: A tensor or variable.
3496      axis: Axis to drop.
3497
3498  Returns:
3499      A tensor with the same data as `x` but reduced dimensions.
3500  """
3501  return array_ops.squeeze(x, [axis])
3502
3503
3504@keras_export('keras.backend.temporal_padding')
3505@dispatch.add_dispatch_support
3506@doc_controls.do_not_generate_docs
3507def temporal_padding(x, padding=(1, 1)):
3508  """Pads the middle dimension of a 3D tensor.
3509
3510  Args:
3511      x: Tensor or variable.
3512      padding: Tuple of 2 integers, how many zeros to
3513          add at the start and end of dim 1.
3514
3515  Returns:
3516      A padded 3D tensor.
3517  """
3518  assert len(padding) == 2
3519  pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
3520  return array_ops.pad(x, pattern)
3521
3522
3523@keras_export('keras.backend.spatial_2d_padding')
3524@dispatch.add_dispatch_support
3525@doc_controls.do_not_generate_docs
3526def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
3527  """Pads the 2nd and 3rd dimensions of a 4D tensor.
3528
3529  Args:
3530      x: Tensor or variable.
3531      padding: Tuple of 2 tuples, padding pattern.
3532      data_format: One of `channels_last` or `channels_first`.
3533
3534  Returns:
3535      A padded 4D tensor.
3536
3537  Raises:
3538      ValueError: if `data_format` is neither
3539          `channels_last` or `channels_first`.
3540  """
3541  assert len(padding) == 2
3542  assert len(padding[0]) == 2
3543  assert len(padding[1]) == 2
3544  if data_format is None:
3545    data_format = image_data_format()
3546  if data_format not in {'channels_first', 'channels_last'}:
3547    raise ValueError('Unknown data_format: ' + str(data_format))
3548
3549  if data_format == 'channels_first':
3550    pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
3551  else:
3552    pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
3553  return array_ops.pad(x, pattern)
3554
3555
3556@keras_export('keras.backend.spatial_3d_padding')
3557@dispatch.add_dispatch_support
3558@doc_controls.do_not_generate_docs
3559def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
3560  """Pads 5D tensor with zeros along the depth, height, width dimensions.
3561
3562  Pads these dimensions with respectively
3563  "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
3564
3565  For 'channels_last' data_format,
3566  the 2nd, 3rd and 4th dimension will be padded.
3567  For 'channels_first' data_format,
3568  the 3rd, 4th and 5th dimension will be padded.
3569
3570  Args:
3571      x: Tensor or variable.
3572      padding: Tuple of 3 tuples, padding pattern.
3573      data_format: One of `channels_last` or `channels_first`.
3574
3575  Returns:
3576      A padded 5D tensor.
3577
3578  Raises:
3579      ValueError: if `data_format` is neither
3580          `channels_last` or `channels_first`.
3581
3582  """
3583  assert len(padding) == 3
3584  assert len(padding[0]) == 2
3585  assert len(padding[1]) == 2
3586  assert len(padding[2]) == 2
3587  if data_format is None:
3588    data_format = image_data_format()
3589  if data_format not in {'channels_first', 'channels_last'}:
3590    raise ValueError('Unknown data_format: ' + str(data_format))
3591
3592  if data_format == 'channels_first':
3593    pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
3594               [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
3595  else:
3596    pattern = [[0, 0], [padding[0][0], padding[0][1]],
3597               [padding[1][0], padding[1][1]], [padding[2][0],
3598                                                padding[2][1]], [0, 0]]
3599  return array_ops.pad(x, pattern)
3600
3601
3602@keras_export('keras.backend.stack')
3603@dispatch.add_dispatch_support
3604@doc_controls.do_not_generate_docs
3605def stack(x, axis=0):
3606  """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
3607
3608  Args:
3609      x: List of tensors.
3610      axis: Axis along which to perform stacking.
3611
3612  Returns:
3613      A tensor.
3614
3615  Example:
3616
3617      >>> a = tf.constant([[1, 2],[3, 4]])
3618      >>> b = tf.constant([[10, 20],[30, 40]])
3619      >>> tf.keras.backend.stack((a, b))
3620      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3621      array([[[ 1,  2],
3622              [ 3,  4]],
3623             [[10, 20],
3624              [30, 40]]], dtype=int32)>
3625
3626  """
3627  return array_ops.stack(x, axis=axis)
3628
3629
3630@keras_export('keras.backend.one_hot')
3631@dispatch.add_dispatch_support
3632@doc_controls.do_not_generate_docs
3633def one_hot(indices, num_classes):
3634  """Computes the one-hot representation of an integer tensor.
3635
3636  Args:
3637      indices: nD integer tensor of shape
3638          `(batch_size, dim1, dim2, ... dim(n-1))`
3639      num_classes: Integer, number of classes to consider.
3640
3641  Returns:
3642      (n + 1)D one hot representation of the input
3643      with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
3644
3645  Returns:
3646      The one-hot tensor.
3647  """
3648  return array_ops.one_hot(indices, depth=num_classes, axis=-1)
3649
3650
3651@keras_export('keras.backend.reverse')
3652@dispatch.add_dispatch_support
3653@doc_controls.do_not_generate_docs
3654def reverse(x, axes):
3655  """Reverse a tensor along the specified axes.
3656
3657  Args:
3658      x: Tensor to reverse.
3659      axes: Integer or iterable of integers.
3660          Axes to reverse.
3661
3662  Returns:
3663      A tensor.
3664  """
3665  if isinstance(axes, int):
3666    axes = [axes]
3667  return array_ops.reverse(x, axes)
3668
3669
3670# VALUE MANIPULATION
3671_VALUE_SET_CODE_STRING = """
3672  >>> K = tf.keras.backend  # Common keras convention
3673  >>> v = K.variable(1.)
3674
3675  >>> # reassign
3676  >>> K.set_value(v, 2.)
3677  >>> print(K.get_value(v))
3678  2.0
3679
3680  >>> # increment
3681  >>> K.set_value(v, K.get_value(v) + 1)
3682  >>> print(K.get_value(v))
3683  3.0
3684
3685  Variable semantics in TensorFlow 2 are eager execution friendly. The above
3686  code is roughly equivalent to:
3687
3688  >>> v = tf.Variable(1.)
3689
3690  >>> v.assign(2.)
3691  >>> print(v.numpy())
3692  2.0
3693
3694  >>> v.assign_add(1.)
3695  >>> print(v.numpy())
3696  3.0"""[3:]  # Prune first newline and indent to match the docstring template.
3697
3698
3699@keras_export('keras.backend.get_value')
3700@doc_controls.do_not_generate_docs
3701def get_value(x):
3702  """Returns the value of a variable.
3703
3704  `backend.get_value` is the complement of `backend.set_value`, and provides
3705  a generic interface for reading from variables while abstracting away the
3706  differences between TensorFlow 1.x and 2.x semantics.
3707
3708  {snippet}
3709
3710  Args:
3711      x: input variable.
3712
3713  Returns:
3714      A Numpy array.
3715  """
3716  if not tensor_util.is_tf_type(x):
3717    return x
3718  if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
3719    return x.numpy()
3720  if not getattr(x, '_in_graph_mode', True):
3721    # This is a variable which was created in an eager context, but is being
3722    # evaluated from a Graph.
3723    with context.eager_mode():
3724      return x.numpy()
3725
3726  if ops.executing_eagerly_outside_functions():
3727    # This method of evaluating works inside the Keras FuncGraph.
3728    with ops.init_scope():
3729      return x.numpy()
3730
3731  with x.graph.as_default():
3732    return x.eval(session=get_session((x,)))
3733
3734
3735@keras_export('keras.backend.batch_get_value')
3736@dispatch.add_dispatch_support
3737@doc_controls.do_not_generate_docs
3738def batch_get_value(tensors):
3739  """Returns the value of more than one tensor variable.
3740
3741  Args:
3742      tensors: list of ops to run.
3743
3744  Returns:
3745      A list of Numpy arrays.
3746
3747  Raises:
3748      RuntimeError: If this method is called inside defun.
3749  """
3750  if context.executing_eagerly():
3751    return [x.numpy() for x in tensors]
3752  elif ops.inside_function():  # pylint: disable=protected-access
3753    raise RuntimeError('Cannot get value inside Tensorflow graph function.')
3754  if tensors:
3755    return get_session(tensors).run(tensors)
3756  else:
3757    return []
3758
3759
3760@keras_export('keras.backend.set_value')
3761@doc_controls.do_not_generate_docs
3762def set_value(x, value):
3763  """Sets the value of a variable, from a Numpy array.
3764
3765  `backend.set_value` is the complement of `backend.get_value`, and provides
3766  a generic interface for assigning to variables while abstracting away the
3767  differences between TensorFlow 1.x and 2.x semantics.
3768
3769  {snippet}
3770
3771  Args:
3772      x: Variable to set to a new value.
3773      value: Value to set the tensor to, as a Numpy array
3774          (of the same shape).
3775  """
3776  value = np.asarray(value, dtype=dtype_numpy(x))
3777  if ops.executing_eagerly_outside_functions():
3778    x.assign(value)
3779  else:
3780    with get_graph().as_default():
3781      tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3782      if hasattr(x, '_assign_placeholder'):
3783        assign_placeholder = x._assign_placeholder
3784        assign_op = x._assign_op
3785      else:
3786        # In order to support assigning weights to resizable variables in
3787        # Keras, we make a placeholder with the correct number of dimensions
3788        # but with None in each dimension. This way, we can assign weights
3789        # of any size (as long as they have the correct dimensionality).
3790        placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3791        assign_placeholder = array_ops.placeholder(
3792            tf_dtype, shape=placeholder_shape)
3793        assign_op = x.assign(assign_placeholder)
3794        x._assign_placeholder = assign_placeholder
3795        x._assign_op = assign_op
3796      get_session().run(assign_op, feed_dict={assign_placeholder: value})
3797
3798
3799@keras_export('keras.backend.batch_set_value')
3800@dispatch.add_dispatch_support
3801@doc_controls.do_not_generate_docs
3802def batch_set_value(tuples):
3803  """Sets the values of many tensor variables at once.
3804
3805  Args:
3806      tuples: a list of tuples `(tensor, value)`.
3807          `value` should be a Numpy array.
3808  """
3809  if context.executing_eagerly() or ops.inside_function():
3810    for x, value in tuples:
3811      x.assign(np.asarray(value, dtype=dtype_numpy(x)))
3812  else:
3813    with get_graph().as_default():
3814      if tuples:
3815        assign_ops = []
3816        feed_dict = {}
3817        for x, value in tuples:
3818          value = np.asarray(value, dtype=dtype_numpy(x))
3819          tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3820          if hasattr(x, '_assign_placeholder'):
3821            assign_placeholder = x._assign_placeholder
3822            assign_op = x._assign_op
3823          else:
3824            # In order to support assigning weights to resizable variables in
3825            # Keras, we make a placeholder with the correct number of dimensions
3826            # but with None in each dimension. This way, we can assign weights
3827            # of any size (as long as they have the correct dimensionality).
3828            placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3829            assign_placeholder = array_ops.placeholder(
3830                tf_dtype, shape=placeholder_shape)
3831            assign_op = x.assign(assign_placeholder)
3832            x._assign_placeholder = assign_placeholder
3833            x._assign_op = assign_op
3834          assign_ops.append(assign_op)
3835          feed_dict[assign_placeholder] = value
3836        get_session().run(assign_ops, feed_dict=feed_dict)
3837
3838
3839get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3840set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3841
3842
3843@keras_export('keras.backend.print_tensor')
3844@dispatch.add_dispatch_support
3845@doc_controls.do_not_generate_docs
3846def print_tensor(x, message='', summarize=3):
3847  """Prints `message` and the tensor value when evaluated.
3848
3849  Note that `print_tensor` returns a new tensor identical to `x`
3850  which should be used in the following code. Otherwise the
3851  print operation is not taken into account during evaluation.
3852
3853  Example:
3854
3855  >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
3856  >>> _ = tf.keras.backend.print_tensor(x)
3857  [[1 2]
3858   [3 4]]
3859
3860  Args:
3861      x: Tensor to print.
3862      message: Message to print jointly with the tensor.
3863      summarize: The first and last `summarize` elements within each dimension
3864          are recursively printed per Tensor. If None, then the first 3 and last
3865          3 elements of each dimension are printed for each tensor. If set to
3866          -1, it will print all elements of every tensor.
3867
3868  Returns:
3869      The same tensor `x`, unchanged.
3870  """
3871  if isinstance(x, ops.Tensor) and hasattr(x, 'graph'):
3872    with get_graph().as_default():
3873      op = logging_ops.print_v2(
3874          message, x, output_stream=sys.stdout, summarize=summarize)
3875      with ops.control_dependencies([op]):
3876        return array_ops.identity(x)
3877  else:
3878    logging_ops.print_v2(
3879        message, x, output_stream=sys.stdout, summarize=summarize)
3880    return x
3881
3882# GRAPH MANIPULATION
3883
3884
3885class GraphExecutionFunction:
3886  """Runs a computation graph.
3887
3888  It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
3889  In particular additional operations via `fetches` argument and additional
3890  tensor substitutions via `feed_dict` arguments. Note that given
3891  substitutions are merged with substitutions from `inputs`. Even though
3892  `feed_dict` is passed once in the constructor (called in `model.compile()`)
3893  we can modify the values in the dictionary. Through this feed_dict we can
3894  provide additional substitutions besides Keras inputs.
3895
3896  Args:
3897      inputs: Feed placeholders to the computation graph.
3898      outputs: Output tensors to fetch.
3899      updates: Additional update ops to be run at function call.
3900      name: A name to help users identify what this function does.
3901      session_kwargs: Arguments to `tf.Session.run()`:
3902                      `fetches`, `feed_dict`, `options`, `run_metadata`.
3903  """
3904
3905  def __init__(self, inputs, outputs, updates=None, name=None,
3906               **session_kwargs):
3907    updates = updates or []
3908    if not isinstance(updates, (list, tuple)):
3909      raise TypeError('`updates` in a Keras backend function '
3910                      'should be a list or tuple.')
3911
3912    self._inputs_structure = inputs
3913    self.inputs = nest.flatten(inputs, expand_composites=True)
3914    self._outputs_structure = outputs
3915    self.outputs = cast_variables_to_tensor(
3916        nest.flatten(outputs, expand_composites=True))
3917    # TODO(b/127668432): Consider using autograph to generate these
3918    # dependencies in call.
3919    # Index 0 = total loss or model output for `predict`.
3920    with ops.control_dependencies([self.outputs[0]]):
3921      updates_ops = []
3922      for update in updates:
3923        if isinstance(update, tuple):
3924          p, new_p = update
3925          updates_ops.append(state_ops.assign(p, new_p))
3926        else:
3927          # assumed already an op
3928          updates_ops.append(update)
3929      self.updates_op = control_flow_ops.group(*updates_ops)
3930    self.name = name
3931    # additional tensor substitutions
3932    self.feed_dict = session_kwargs.pop('feed_dict', None)
3933    # additional operations
3934    self.fetches = session_kwargs.pop('fetches', [])
3935    if not isinstance(self.fetches, list):
3936      self.fetches = [self.fetches]
3937    self.run_options = session_kwargs.pop('options', None)
3938    self.run_metadata = session_kwargs.pop('run_metadata', None)
3939    # The main use case of `fetches` being passed to a model is the ability
3940    # to run custom updates
3941    # This requires us to wrap fetches in `identity` ops.
3942    self.fetches = [array_ops.identity(x) for x in self.fetches]
3943    self.session_kwargs = session_kwargs
3944    # This mapping keeps track of the function that should receive the
3945    # output from a fetch in `fetches`: { fetch: function(fetch_output) }
3946    # A Callback can use this to register a function with access to the
3947    # output values for a fetch it added.
3948    self.fetch_callbacks = {}
3949
3950    if session_kwargs:
3951      raise ValueError('Some keys in session_kwargs are not supported at this '
3952                       'time: %s' % (session_kwargs.keys(),))
3953
3954    self._callable_fn = None
3955    self._feed_arrays = None
3956    self._feed_symbols = None
3957    self._symbol_vals = None
3958    self._fetches = None
3959    self._session = None
3960
3961  def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
3962    """Generates a callable that runs the graph.
3963
3964    Args:
3965      feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
3966      feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
3967      symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
3968      session: Session to use to generate the callable.
3969
3970    Returns:
3971      Function that runs the graph according to the above options.
3972    """
3973    # Prepare callable options.
3974    callable_opts = config_pb2.CallableOptions()
3975    # Handle external-data feed.
3976    for x in feed_arrays:
3977      callable_opts.feed.append(x.name)
3978    if self.feed_dict:
3979      for key in sorted(self.feed_dict.keys()):
3980        callable_opts.feed.append(key.name)
3981    # Handle symbolic feed.
3982    for x, y in zip(feed_symbols, symbol_vals):
3983      connection = callable_opts.tensor_connection.add()
3984      if x.dtype != y.dtype:
3985        y = math_ops.cast(y, dtype=x.dtype)
3986      from_tensor = _as_graph_element(y)
3987      if from_tensor is None:
3988        from_tensor = y
3989      connection.from_tensor = from_tensor.name  # Data tensor
3990      connection.to_tensor = x.name  # Placeholder
3991    # Handle fetches.
3992    for x in self.outputs + self.fetches:
3993      callable_opts.fetch.append(x.name)
3994    # Handle updates.
3995    callable_opts.target.append(self.updates_op.name)
3996    # Handle run_options.
3997    if self.run_options:
3998      callable_opts.run_options.CopyFrom(self.run_options)
3999    # Create callable.
4000    callable_fn = session._make_callable_from_options(callable_opts)
4001    # Cache parameters corresponding to the generated callable, so that
4002    # we can detect future mismatches and refresh the callable.
4003    self._callable_fn = callable_fn
4004    self._feed_arrays = feed_arrays
4005    self._feed_symbols = feed_symbols
4006    self._symbol_vals = symbol_vals
4007    self._fetches = list(self.fetches)
4008    self._session = session
4009
4010  def _call_fetch_callbacks(self, fetches_output):
4011    for fetch, output in zip(self._fetches, fetches_output):
4012      if fetch in self.fetch_callbacks:
4013        self.fetch_callbacks[fetch](output)
4014
4015  def _eval_if_composite(self, tensor):
4016    """Helper method which evaluates any CompositeTensors passed to it."""
4017    # We need to evaluate any composite tensor objects that have been
4018    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4019    # actual CompositeTensor objects instead of the value(s) contained in the
4020    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4021    # this ensures that we return its value as a SparseTensorValue rather than
4022    # a SparseTensor.
4023    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
4024    if tf_utils.is_extension_type(tensor):
4025      return self._session.run(tensor)
4026    else:
4027      return tensor
4028
4029  def __call__(self, inputs):
4030    inputs = nest.flatten(inputs, expand_composites=True)
4031
4032    session = get_session(inputs)
4033    feed_arrays = []
4034    array_vals = []
4035    feed_symbols = []
4036    symbol_vals = []
4037    for tensor, value in zip(self.inputs, inputs):
4038      if value is None:
4039        continue
4040
4041      if tensor_util.is_tf_type(value):
4042        # Case: feeding symbolic tensor.
4043        feed_symbols.append(tensor)
4044        symbol_vals.append(value)
4045      else:
4046        # Case: feeding Numpy array.
4047        feed_arrays.append(tensor)
4048        # We need to do array conversion and type casting at this level, since
4049        # `callable_fn` only supports exact matches.
4050        tensor_type = dtypes_module.as_dtype(tensor.dtype)
4051        array_vals.append(np.asarray(value,
4052                                     dtype=tensor_type.as_numpy_dtype))
4053
4054    if self.feed_dict:
4055      for key in sorted(self.feed_dict.keys()):
4056        array_vals.append(
4057            np.asarray(self.feed_dict[key], dtype=key.dtype.as_numpy_dtype))
4058
4059    # Refresh callable if anything has changed.
4060    if (self._callable_fn is None or feed_arrays != self._feed_arrays or
4061        symbol_vals != self._symbol_vals or
4062        feed_symbols != self._feed_symbols or self.fetches != self._fetches or
4063        session != self._session):
4064      self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
4065
4066    fetched = self._callable_fn(*array_vals,
4067                                run_metadata=self.run_metadata)
4068    self._call_fetch_callbacks(fetched[-len(self._fetches):])
4069    output_structure = nest.pack_sequence_as(
4070        self._outputs_structure,
4071        fetched[:len(self.outputs)],
4072        expand_composites=True)
4073    # We need to evaluate any composite tensor objects that have been
4074    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4075    # actual CompositeTensor objects instead of the value(s) contained in the
4076    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4077    # this ensures that we return its value as a SparseTensorValue rather than
4078    # a SparseTensor.
4079    return nest.map_structure(self._eval_if_composite, output_structure)
4080
4081
4082@keras_export('keras.backend.function')
4083@doc_controls.do_not_generate_docs
4084def function(inputs, outputs, updates=None, name=None, **kwargs):
4085  """Instantiates a Keras function.
4086
4087  Args:
4088      inputs: List of placeholder tensors.
4089      outputs: List of output tensors.
4090      updates: List of update ops.
4091      name: String, name of function.
4092      **kwargs: Passed to `tf.Session.run`.
4093
4094  Returns:
4095      Output values as Numpy arrays.
4096
4097  Raises:
4098      ValueError: if invalid kwargs are passed in or if in eager execution.
4099  """
4100  if ops.executing_eagerly_outside_functions():
4101    if kwargs:
4102      raise ValueError('Session keyword arguments are not supported during '
4103                       'eager execution. You passed: %s' % (kwargs,))
4104    if updates:
4105      raise ValueError('`updates` argument is not supported during '
4106                       'eager execution. You passed: %s' % (updates,))
4107    from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
4108    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
4109    model = models.Model(inputs=inputs, outputs=outputs)
4110
4111    wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
4112    def func(model_inputs):
4113      outs = model(model_inputs)
4114      if wrap_outputs:
4115        outs = [outs]
4116      return tf_utils.sync_to_numpy_or_python_type(outs)
4117
4118    return func
4119
4120  if kwargs:
4121    for key in kwargs:
4122      if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
4123          and key not in ['inputs', 'outputs', 'updates', 'name']):
4124        msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
4125               'backend') % key
4126        raise ValueError(msg)
4127  return GraphExecutionFunction(
4128      inputs, outputs, updates=updates, name=name, **kwargs)
4129
4130
4131@keras_export('keras.backend.gradients')
4132@doc_controls.do_not_generate_docs
4133def gradients(loss, variables):
4134  """Returns the gradients of `loss` w.r.t. `variables`.
4135
4136  Args:
4137      loss: Scalar tensor to minimize.
4138      variables: List of variables.
4139
4140  Returns:
4141      A gradients tensor.
4142  """
4143  return gradients_module.gradients(
4144      loss, variables, colocate_gradients_with_ops=True)
4145
4146
4147@keras_export('keras.backend.stop_gradient')
4148@dispatch.add_dispatch_support
4149@doc_controls.do_not_generate_docs
4150def stop_gradient(variables):
4151  """Returns `variables` but with zero gradient w.r.t. every other variable.
4152
4153  Args:
4154      variables: Tensor or list of tensors to consider constant with respect
4155        to any other variable.
4156
4157
4158  Returns:
4159      A single tensor or a list of tensors (depending on the passed argument)
4160      that has no gradient with respect to any other variable.
4161  """
4162  if isinstance(variables, (list, tuple)):
4163    return map(array_ops.stop_gradient, variables)
4164  return array_ops.stop_gradient(variables)
4165
4166
4167# CONTROL FLOW
4168
4169
4170@keras_export('keras.backend.rnn')
4171@dispatch.add_dispatch_support
4172def rnn(step_function,
4173        inputs,
4174        initial_states,
4175        go_backwards=False,
4176        mask=None,
4177        constants=None,
4178        unroll=False,
4179        input_length=None,
4180        time_major=False,
4181        zero_output_for_mask=False):
4182  """Iterates over the time dimension of a tensor.
4183
4184  Args:
4185      step_function: RNN step function.
4186          Args;
4187              input; Tensor with shape `(samples, ...)` (no time dimension),
4188                  representing input for the batch of samples at a certain
4189                  time step.
4190              states; List of tensors.
4191          Returns;
4192              output; Tensor with shape `(samples, output_dim)`
4193                  (no time dimension).
4194              new_states; List of tensors, same length and shapes
4195                  as 'states'. The first state in the list must be the
4196                  output tensor at the previous timestep.
4197      inputs: Tensor of temporal data of shape `(samples, time, ...)`
4198          (at least 3D), or nested tensors, and each of which has shape
4199          `(samples, time, ...)`.
4200      initial_states: Tensor with shape `(samples, state_size)`
4201          (no time dimension), containing the initial values for the states used
4202          in the step function. In the case that state_size is in a nested
4203          shape, the shape of initial_states will also follow the nested
4204          structure.
4205      go_backwards: Boolean. If True, do the iteration over the time
4206          dimension in reverse order and return the reversed sequence.
4207      mask: Binary tensor with shape `(samples, time, 1)`,
4208          with a zero for every element that is masked.
4209      constants: List of constant values passed at each step.
4210      unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
4211      input_length: An integer or a 1-D Tensor, depending on whether
4212          the time dimension is fixed-length or not. In case of variable length
4213          input, it is used for masking in case there's no mask specified.
4214      time_major: Boolean. If true, the inputs and outputs will be in shape
4215          `(timesteps, batch, ...)`, whereas in the False case, it will be
4216          `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
4217          efficient because it avoids transposes at the beginning and end of the
4218          RNN calculation. However, most TensorFlow data is batch-major, so by
4219          default this function accepts input and emits output in batch-major
4220          form.
4221      zero_output_for_mask: Boolean. If True, the output for masked timestep
4222          will be zeros, whereas in the False case, output from previous
4223          timestep is returned.
4224
4225  Returns:
4226      A tuple, `(last_output, outputs, new_states)`.
4227          last_output: the latest output of the rnn, of shape `(samples, ...)`
4228          outputs: tensor with shape `(samples, time, ...)` where each
4229              entry `outputs[s, t]` is the output of the step function
4230              at time `t` for sample `s`.
4231          new_states: list of tensors, latest states returned by
4232              the step function, of shape `(samples, ...)`.
4233
4234  Raises:
4235      ValueError: if input dimension is less than 3.
4236      ValueError: if `unroll` is `True` but input timestep is not a fixed
4237      number.
4238      ValueError: if `mask` is provided (not `None`) but states is not provided
4239          (`len(states)` == 0).
4240  """
4241
4242  def swap_batch_timestep(input_t):
4243    # Swap the batch and timestep dim for the incoming tensor.
4244    axes = list(range(len(input_t.shape)))
4245    axes[0], axes[1] = 1, 0
4246    return array_ops.transpose(input_t, axes)
4247
4248  if not time_major:
4249    inputs = nest.map_structure(swap_batch_timestep, inputs)
4250
4251  flatted_inputs = nest.flatten(inputs)
4252  time_steps = flatted_inputs[0].shape[0]
4253  batch = flatted_inputs[0].shape[1]
4254  time_steps_t = array_ops.shape(flatted_inputs[0])[0]
4255
4256  for input_ in flatted_inputs:
4257    input_.shape.with_rank_at_least(3)
4258
4259  if mask is not None:
4260    if mask.dtype != dtypes_module.bool:
4261      mask = math_ops.cast(mask, dtypes_module.bool)
4262    if len(mask.shape) == 2:
4263      mask = expand_dims(mask)
4264    if not time_major:
4265      mask = swap_batch_timestep(mask)
4266
4267  if constants is None:
4268    constants = []
4269
4270  # tf.where needs its condition tensor to be the same shape as its two
4271  # result tensors, but in our case the condition (mask) tensor is
4272  # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
4273  # So we need to broadcast the mask to match the shape of inputs.
4274  # That's what the tile call does, it just repeats the mask along its
4275  # second dimension n times.
4276  def _expand_mask(mask_t, input_t, fixed_dim=1):
4277    if nest.is_nested(mask_t):
4278      raise ValueError('mask_t is expected to be tensor, but got %s' % mask_t)
4279    if nest.is_nested(input_t):
4280      raise ValueError('input_t is expected to be tensor, but got %s' % input_t)
4281    rank_diff = len(input_t.shape) - len(mask_t.shape)
4282    for _ in range(rank_diff):
4283      mask_t = array_ops.expand_dims(mask_t, -1)
4284    multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
4285    return array_ops.tile(mask_t, multiples)
4286
4287  if unroll:
4288    if not time_steps:
4289      raise ValueError('Unrolling requires a fixed number of timesteps.')
4290    states = tuple(initial_states)
4291    successive_states = []
4292    successive_outputs = []
4293
4294    # Process the input tensors. The input tensor need to be split on the
4295    # time_step dim, and reverse if go_backwards is True. In the case of nested
4296    # input, the input is flattened and then transformed individually.
4297    # The result of this will be a tuple of lists, each of the item in tuple is
4298    # list of the tensor with shape (batch, feature)
4299    def _process_single_input_t(input_t):
4300      input_t = array_ops.unstack(input_t)  # unstack for time_step dim
4301      if go_backwards:
4302        input_t.reverse()
4303      return input_t
4304
4305    if nest.is_nested(inputs):
4306      processed_input = nest.map_structure(_process_single_input_t, inputs)
4307    else:
4308      processed_input = (_process_single_input_t(inputs),)
4309
4310    def _get_input_tensor(time):
4311      inp = [t_[time] for t_ in processed_input]
4312      return nest.pack_sequence_as(inputs, inp)
4313
4314    if mask is not None:
4315      mask_list = array_ops.unstack(mask)
4316      if go_backwards:
4317        mask_list.reverse()
4318
4319      for i in range(time_steps):
4320        inp = _get_input_tensor(i)
4321        mask_t = mask_list[i]
4322        output, new_states = step_function(inp,
4323                                           tuple(states) + tuple(constants))
4324        tiled_mask_t = _expand_mask(mask_t, output)
4325
4326        if not successive_outputs:
4327          prev_output = zeros_like(output)
4328        else:
4329          prev_output = successive_outputs[-1]
4330
4331        output = array_ops.where_v2(tiled_mask_t, output, prev_output)
4332
4333        flat_states = nest.flatten(states)
4334        flat_new_states = nest.flatten(new_states)
4335        tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_states)
4336        flat_final_states = tuple(
4337            array_ops.where_v2(m, s, ps)
4338            for m, s, ps in zip(tiled_mask_t, flat_new_states, flat_states))
4339        states = nest.pack_sequence_as(states, flat_final_states)
4340
4341        successive_outputs.append(output)
4342        successive_states.append(states)
4343      last_output = successive_outputs[-1]
4344      new_states = successive_states[-1]
4345      outputs = array_ops.stack(successive_outputs)
4346
4347      if zero_output_for_mask:
4348        last_output = array_ops.where_v2(
4349            _expand_mask(mask_list[-1], last_output), last_output,
4350            zeros_like(last_output))
4351        outputs = array_ops.where_v2(
4352            _expand_mask(mask, outputs, fixed_dim=2), outputs,
4353            zeros_like(outputs))
4354
4355    else:  # mask is None
4356      for i in range(time_steps):
4357        inp = _get_input_tensor(i)
4358        output, states = step_function(inp, tuple(states) + tuple(constants))
4359        successive_outputs.append(output)
4360        successive_states.append(states)
4361      last_output = successive_outputs[-1]
4362      new_states = successive_states[-1]
4363      outputs = array_ops.stack(successive_outputs)
4364
4365  else:  # Unroll == False
4366    states = tuple(initial_states)
4367
4368    # Create input tensor array, if the inputs is nested tensors, then it will
4369    # be flattened first, and tensor array will be created one per flattened
4370    # tensor.
4371    input_ta = tuple(
4372        tensor_array_ops.TensorArray(
4373            dtype=inp.dtype,
4374            size=time_steps_t,
4375            tensor_array_name='input_ta_%s' % i)
4376        for i, inp in enumerate(flatted_inputs))
4377    input_ta = tuple(
4378        ta.unstack(input_) if not go_backwards else ta
4379        .unstack(reverse(input_, 0))
4380        for ta, input_ in zip(input_ta, flatted_inputs))
4381
4382    # Get the time(0) input and compute the output for that, the output will be
4383    # used to determine the dtype of output tensor array. Don't read from
4384    # input_ta due to TensorArray clear_after_read default to True.
4385    input_time_zero = nest.pack_sequence_as(inputs,
4386                                            [inp[0] for inp in flatted_inputs])
4387    # output_time_zero is used to determine the cell output shape and its dtype.
4388    # the value is discarded.
4389    output_time_zero, _ = step_function(
4390        input_time_zero, tuple(initial_states) + tuple(constants))
4391    output_ta = tuple(
4392        tensor_array_ops.TensorArray(
4393            dtype=out.dtype,
4394            size=time_steps_t,
4395            element_shape=out.shape,
4396            tensor_array_name='output_ta_%s' % i)
4397        for i, out in enumerate(nest.flatten(output_time_zero)))
4398
4399    time = constant_op.constant(0, dtype='int32', name='time')
4400
4401    # We only specify the 'maximum_iterations' when building for XLA since that
4402    # causes slowdowns on GPU in TF.
4403    if (not context.executing_eagerly() and
4404        control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
4405      max_iterations = math_ops.reduce_max(input_length)
4406    else:
4407      max_iterations = None
4408
4409    while_loop_kwargs = {
4410        'cond': lambda time, *_: time < time_steps_t,
4411        'maximum_iterations': max_iterations,
4412        'parallel_iterations': 32,
4413        'swap_memory': True,
4414    }
4415    if mask is not None:
4416      if go_backwards:
4417        mask = reverse(mask, 0)
4418
4419      mask_ta = tensor_array_ops.TensorArray(
4420          dtype=dtypes_module.bool,
4421          size=time_steps_t,
4422          tensor_array_name='mask_ta')
4423      mask_ta = mask_ta.unstack(mask)
4424
4425      def masking_fn(time):
4426        return mask_ta.read(time)
4427
4428      def compute_masked_output(mask_t, flat_out, flat_mask):
4429        tiled_mask_t = tuple(
4430            _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
4431            for o in flat_out)
4432        return tuple(
4433            array_ops.where_v2(m, o, fm)
4434            for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask))
4435    elif isinstance(input_length, ops.Tensor):
4436      if go_backwards:
4437        max_len = math_ops.reduce_max(input_length, axis=0)
4438        rev_input_length = math_ops.subtract(max_len - 1, input_length)
4439
4440        def masking_fn(time):
4441          return math_ops.less(rev_input_length, time)
4442      else:
4443
4444        def masking_fn(time):
4445          return math_ops.greater(input_length, time)
4446
4447      def compute_masked_output(mask_t, flat_out, flat_mask):
4448        return tuple(
4449            array_ops.where(mask_t, o, zo)
4450            for (o, zo) in zip(flat_out, flat_mask))
4451    else:
4452      masking_fn = None
4453
4454    if masking_fn is not None:
4455      # Mask for the T output will be base on the output of T - 1. In the case
4456      # T = 0, a zero filled tensor will be used.
4457      flat_zero_output = tuple(array_ops.zeros_like(o)
4458                               for o in nest.flatten(output_time_zero))
4459      def _step(time, output_ta_t, prev_output, *states):
4460        """RNN step function.
4461
4462        Args:
4463            time: Current timestep value.
4464            output_ta_t: TensorArray.
4465            prev_output: tuple of outputs from time - 1.
4466            *states: List of states.
4467
4468        Returns:
4469            Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
4470        """
4471        current_input = tuple(ta.read(time) for ta in input_ta)
4472        # maybe set shape.
4473        current_input = nest.pack_sequence_as(inputs, current_input)
4474        mask_t = masking_fn(time)
4475        output, new_states = step_function(current_input,
4476                                           tuple(states) + tuple(constants))
4477        # mask output
4478        flat_output = nest.flatten(output)
4479        flat_mask_output = (flat_zero_output if zero_output_for_mask
4480                            else nest.flatten(prev_output))
4481        flat_new_output = compute_masked_output(mask_t, flat_output,
4482                                                flat_mask_output)
4483
4484        # mask states
4485        flat_state = nest.flatten(states)
4486        flat_new_state = nest.flatten(new_states)
4487        for state, new_state in zip(flat_state, flat_new_state):
4488          if isinstance(new_state, ops.Tensor):
4489            new_state.set_shape(state.shape)
4490        flat_final_state = compute_masked_output(mask_t, flat_new_state,
4491                                                 flat_state)
4492        new_states = nest.pack_sequence_as(new_states, flat_final_state)
4493
4494        output_ta_t = tuple(
4495            ta.write(time, out)
4496            for ta, out in zip(output_ta_t, flat_new_output))
4497        return (time + 1, output_ta_t,
4498                tuple(flat_new_output)) + tuple(new_states)
4499
4500      final_outputs = control_flow_ops.while_loop(
4501          body=_step,
4502          loop_vars=(time, output_ta, flat_zero_output) + states,
4503          **while_loop_kwargs)
4504      # Skip final_outputs[2] which is the output for final timestep.
4505      new_states = final_outputs[3:]
4506    else:
4507      def _step(time, output_ta_t, *states):
4508        """RNN step function.
4509
4510        Args:
4511            time: Current timestep value.
4512            output_ta_t: TensorArray.
4513            *states: List of states.
4514
4515        Returns:
4516            Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
4517        """
4518        current_input = tuple(ta.read(time) for ta in input_ta)
4519        current_input = nest.pack_sequence_as(inputs, current_input)
4520        output, new_states = step_function(current_input,
4521                                           tuple(states) + tuple(constants))
4522        flat_state = nest.flatten(states)
4523        flat_new_state = nest.flatten(new_states)
4524        for state, new_state in zip(flat_state, flat_new_state):
4525          if isinstance(new_state, ops.Tensor):
4526            new_state.set_shape(state.shape)
4527
4528        flat_output = nest.flatten(output)
4529        output_ta_t = tuple(
4530            ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
4531        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
4532        return (time + 1, output_ta_t) + tuple(new_states)
4533
4534      final_outputs = control_flow_ops.while_loop(
4535          body=_step,
4536          loop_vars=(time, output_ta) + states,
4537          **while_loop_kwargs)
4538      new_states = final_outputs[2:]
4539
4540    output_ta = final_outputs[1]
4541
4542    outputs = tuple(o.stack() for o in output_ta)
4543    last_output = tuple(o[-1] for o in outputs)
4544
4545    outputs = nest.pack_sequence_as(output_time_zero, outputs)
4546    last_output = nest.pack_sequence_as(output_time_zero, last_output)
4547
4548  # static shape inference
4549  def set_shape(output_):
4550    if isinstance(output_, ops.Tensor):
4551      shape = output_.shape.as_list()
4552      shape[0] = time_steps
4553      shape[1] = batch
4554      output_.set_shape(shape)
4555    return output_
4556
4557  outputs = nest.map_structure(set_shape, outputs)
4558
4559  if not time_major:
4560    outputs = nest.map_structure(swap_batch_timestep, outputs)
4561
4562  return last_output, outputs, new_states
4563
4564
4565@keras_export('keras.backend.switch')
4566@dispatch.add_dispatch_support
4567@doc_controls.do_not_generate_docs
4568def switch(condition, then_expression, else_expression):
4569  """Switches between two operations depending on a scalar value.
4570
4571  Note that both `then_expression` and `else_expression`
4572  should be symbolic tensors of the *same shape*.
4573
4574  Args:
4575      condition: tensor (`int` or `bool`).
4576      then_expression: either a tensor, or a callable that returns a tensor.
4577      else_expression: either a tensor, or a callable that returns a tensor.
4578
4579  Returns:
4580      The selected tensor.
4581
4582  Raises:
4583      ValueError: If rank of `condition` is greater than rank of expressions.
4584  """
4585  if condition.dtype != dtypes_module.bool:
4586    condition = math_ops.cast(condition, 'bool')
4587  cond_ndim = ndim(condition)
4588  if not cond_ndim:
4589    if not callable(then_expression):
4590
4591      def then_expression_fn():
4592        return then_expression
4593    else:
4594      then_expression_fn = then_expression
4595    if not callable(else_expression):
4596
4597      def else_expression_fn():
4598        return else_expression
4599    else:
4600      else_expression_fn = else_expression
4601    x = control_flow_ops.cond(condition, then_expression_fn, else_expression_fn)
4602  else:
4603    # tf.where needs its condition tensor
4604    # to be the same shape as its two
4605    # result tensors
4606    if callable(then_expression):
4607      then_expression = then_expression()
4608    if callable(else_expression):
4609      else_expression = else_expression()
4610    expr_ndim = ndim(then_expression)
4611    if cond_ndim > expr_ndim:
4612      raise ValueError('Rank of `condition` should be less than or'
4613                       ' equal to rank of `then_expression` and '
4614                       '`else_expression`. ndim(condition)=' + str(cond_ndim) +
4615                       ', ndim(then_expression)'
4616                       '=' + str(expr_ndim))
4617    if cond_ndim > 1:
4618      ndim_diff = expr_ndim - cond_ndim
4619      cond_shape = array_ops.concat(
4620          [array_ops.shape(condition), [1] * ndim_diff], axis=0)
4621      condition = array_ops.reshape(condition, cond_shape)
4622      expr_shape = array_ops.shape(then_expression)
4623      shape_diff = expr_shape - cond_shape
4624      tile_shape = array_ops.where_v2(shape_diff > 0, expr_shape,
4625                                      array_ops.ones_like(expr_shape))
4626      condition = array_ops.tile(condition, tile_shape)
4627    x = array_ops.where_v2(condition, then_expression, else_expression)
4628  return x
4629
4630
4631@keras_export('keras.backend.in_train_phase')
4632@doc_controls.do_not_generate_docs
4633def in_train_phase(x, alt, training=None):
4634  """Selects `x` in train phase, and `alt` otherwise.
4635
4636  Note that `alt` should have the *same shape* as `x`.
4637
4638  Args:
4639      x: What to return in train phase
4640          (tensor or callable that returns a tensor).
4641      alt: What to return otherwise
4642          (tensor or callable that returns a tensor).
4643      training: Optional scalar tensor
4644          (or Python boolean, or Python integer)
4645          specifying the learning phase.
4646
4647  Returns:
4648      Either `x` or `alt` based on the `training` flag.
4649      the `training` flag defaults to `K.learning_phase()`.
4650  """
4651  from tensorflow.python.keras.engine import base_layer_utils  # pylint: disable=g-import-not-at-top
4652  if training is None:
4653    training = base_layer_utils.call_context().training
4654
4655  if training is None:
4656    training = learning_phase()
4657
4658  # TODO(b/138862903): Handle the case when training is tensor.
4659  if not tensor_util.is_tf_type(training):
4660    if training == 1 or training is True:
4661      if callable(x):
4662        return x()
4663      else:
4664        return x
4665
4666    elif training == 0 or training is False:
4667      if callable(alt):
4668        return alt()
4669      else:
4670        return alt
4671
4672  # else: assume learning phase is a placeholder tensor.
4673  x = switch(training, x, alt)
4674  return x
4675
4676
4677@keras_export('keras.backend.in_test_phase')
4678@doc_controls.do_not_generate_docs
4679def in_test_phase(x, alt, training=None):
4680  """Selects `x` in test phase, and `alt` otherwise.
4681
4682  Note that `alt` should have the *same shape* as `x`.
4683
4684  Args:
4685      x: What to return in test phase
4686          (tensor or callable that returns a tensor).
4687      alt: What to return otherwise
4688          (tensor or callable that returns a tensor).
4689      training: Optional scalar tensor
4690          (or Python boolean, or Python integer)
4691          specifying the learning phase.
4692
4693  Returns:
4694      Either `x` or `alt` based on `K.learning_phase`.
4695  """
4696  return in_train_phase(alt, x, training=training)
4697
4698
4699# NN OPERATIONS
4700
4701
4702@keras_export('keras.backend.relu')
4703@dispatch.add_dispatch_support
4704@doc_controls.do_not_generate_docs
4705def relu(x, alpha=0., max_value=None, threshold=0):
4706  """Rectified linear unit.
4707
4708  With default values, it returns element-wise `max(x, 0)`.
4709
4710  Otherwise, it follows:
4711  `f(x) = max_value` for `x >= max_value`,
4712  `f(x) = x` for `threshold <= x < max_value`,
4713  `f(x) = alpha * (x - threshold)` otherwise.
4714
4715  Args:
4716      x: A tensor or variable.
4717      alpha: A scalar, slope of negative section (default=`0.`).
4718      max_value: float. Saturation threshold.
4719      threshold: float. Threshold value for thresholded activation.
4720
4721  Returns:
4722      A tensor.
4723  """
4724  # While x can be a tensor or variable, we also see cases where
4725  # numpy arrays, lists, tuples are passed as well.
4726  # lists, tuples do not have 'dtype' attribute.
4727  dtype = getattr(x, 'dtype', floatx())
4728  if alpha != 0.:
4729    if max_value is None and threshold == 0:
4730      return nn.leaky_relu(x, alpha=alpha)
4731
4732    if threshold != 0:
4733      negative_part = nn.relu(-x + threshold)
4734    else:
4735      negative_part = nn.relu(-x)
4736
4737  clip_max = max_value is not None
4738
4739  if threshold != 0:
4740    # computes x for x > threshold else 0
4741    x = x * math_ops.cast(math_ops.greater(x, threshold), dtype=dtype)
4742  elif max_value == 6:
4743    # if no threshold, then can use nn.relu6 native TF op for performance
4744    x = nn.relu6(x)
4745    clip_max = False
4746  else:
4747    x = nn.relu(x)
4748
4749  if clip_max:
4750    max_value = _constant_to_tensor(max_value, x.dtype.base_dtype)
4751    zero = _constant_to_tensor(0, x.dtype.base_dtype)
4752    x = clip_ops.clip_by_value(x, zero, max_value)
4753
4754  if alpha != 0.:
4755    alpha = _to_tensor(alpha, x.dtype.base_dtype)
4756    x -= alpha * negative_part
4757  return x
4758
4759
4760@keras_export('keras.backend.elu')
4761@dispatch.add_dispatch_support
4762@doc_controls.do_not_generate_docs
4763def elu(x, alpha=1.):
4764  """Exponential linear unit.
4765
4766  Args:
4767      x: A tensor or variable to compute the activation function for.
4768      alpha: A scalar, slope of negative section.
4769
4770  Returns:
4771      A tensor.
4772  """
4773  res = nn.elu(x)
4774  if alpha == 1:
4775    return res
4776  else:
4777    return array_ops.where_v2(x > 0, res, alpha * res)
4778
4779
4780@keras_export('keras.backend.softmax')
4781@dispatch.add_dispatch_support
4782@doc_controls.do_not_generate_docs
4783def softmax(x, axis=-1):
4784  """Softmax of a tensor.
4785
4786  Args:
4787      x: A tensor or variable.
4788      axis: The dimension softmax would be performed on.
4789          The default is -1 which indicates the last dimension.
4790
4791  Returns:
4792      A tensor.
4793  """
4794  return nn.softmax(x, axis=axis)
4795
4796
4797@keras_export('keras.backend.softplus')
4798@dispatch.add_dispatch_support
4799@doc_controls.do_not_generate_docs
4800def softplus(x):
4801  """Softplus of a tensor.
4802
4803  Args:
4804      x: A tensor or variable.
4805
4806  Returns:
4807      A tensor.
4808  """
4809  return math_ops.softplus(x)
4810
4811
4812@keras_export('keras.backend.softsign')
4813@dispatch.add_dispatch_support
4814@doc_controls.do_not_generate_docs
4815def softsign(x):
4816  """Softsign of a tensor.
4817
4818  Args:
4819      x: A tensor or variable.
4820
4821  Returns:
4822      A tensor.
4823  """
4824  return nn.softsign(x)
4825
4826
4827@keras_export('keras.backend.categorical_crossentropy')
4828@dispatch.add_dispatch_support
4829@doc_controls.do_not_generate_docs
4830def categorical_crossentropy(target, output, from_logits=False, axis=-1):
4831  """Categorical crossentropy between an output tensor and a target tensor.
4832
4833  Args:
4834      target: A tensor of the same shape as `output`.
4835      output: A tensor resulting from a softmax
4836          (unless `from_logits` is True, in which
4837          case `output` is expected to be the logits).
4838      from_logits: Boolean, whether `output` is the
4839          result of a softmax, or is a tensor of logits.
4840      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4841          format `channels_last`, and `axis=1` corresponds to data format
4842          `channels_first`.
4843
4844  Returns:
4845      Output tensor.
4846
4847  Raises:
4848      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4849
4850  Example:
4851
4852  >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
4853  >>> print(a)
4854  tf.Tensor(
4855    [[1. 0. 0.]
4856     [0. 1. 0.]
4857     [0. 0. 1.]], shape=(3, 3), dtype=float32)
4858  >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94], shape=[3,3])
4859  >>> print(b)
4860  tf.Tensor(
4861    [[0.9  0.05 0.05]
4862     [0.05 0.89 0.06]
4863     [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
4864  >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
4865  >>> print(np.around(loss, 5))
4866  [0.10536 0.11653 0.06188]
4867  >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
4868  >>> print(np.around(loss, 5))
4869  [0. 0. 0.]
4870
4871  """
4872  target = ops.convert_to_tensor_v2_with_dispatch(target)
4873  output = ops.convert_to_tensor_v2_with_dispatch(output)
4874  target.shape.assert_is_compatible_with(output.shape)
4875
4876  # Use logits whenever they are available. `softmax` and `sigmoid`
4877  # activations cache logits on the `output` Tensor.
4878  if hasattr(output, '_keras_logits'):
4879    output = output._keras_logits  # pylint: disable=protected-access
4880    if from_logits:
4881      warnings.warn(
4882          '"`categorical_crossentropy` received `from_logits=True`, but '
4883          'the `output` argument was produced by a sigmoid or softmax '
4884          'activation and thus does not represent logits. Was this intended?"')
4885    from_logits = True
4886
4887  if from_logits:
4888    return nn.softmax_cross_entropy_with_logits_v2(
4889        labels=target, logits=output, axis=axis)
4890
4891  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
4892      output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
4893    # When softmax activation function is used for output operation, we
4894    # use logits from the softmax function directly to compute loss in order
4895    # to prevent collapsing zero when training.
4896    # See b/117284466
4897    assert len(output.op.inputs) == 1
4898    output = output.op.inputs[0]
4899    return nn.softmax_cross_entropy_with_logits_v2(
4900        labels=target, logits=output, axis=axis)
4901
4902  # scale preds so that the class probas of each sample sum to 1
4903  output = output / math_ops.reduce_sum(output, axis, True)
4904  # Compute cross entropy from probabilities.
4905  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4906  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4907  return -math_ops.reduce_sum(target * math_ops.log(output), axis)
4908
4909
4910@keras_export('keras.backend.sparse_categorical_crossentropy')
4911@dispatch.add_dispatch_support
4912@doc_controls.do_not_generate_docs
4913def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
4914  """Categorical crossentropy with integer targets.
4915
4916  Args:
4917      target: An integer tensor.
4918      output: A tensor resulting from a softmax
4919          (unless `from_logits` is True, in which
4920          case `output` is expected to be the logits).
4921      from_logits: Boolean, whether `output` is the
4922          result of a softmax, or is a tensor of logits.
4923      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4924          format `channels_last`, and `axis=1` corresponds to data format
4925          `channels_first`.
4926
4927  Returns:
4928      Output tensor.
4929
4930  Raises:
4931      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4932  """
4933  target = ops.convert_to_tensor_v2_with_dispatch(target)
4934  output = ops.convert_to_tensor_v2_with_dispatch(output)
4935
4936  # Use logits whenever they are available. `softmax` and `sigmoid`
4937  # activations cache logits on the `output` Tensor.
4938  if hasattr(output, '_keras_logits'):
4939    output = output._keras_logits  # pylint: disable=protected-access
4940    if from_logits:
4941      warnings.warn(
4942          '"`sparse_categorical_crossentropy` received `from_logits=True`, but '
4943          'the `output` argument was produced by a sigmoid or softmax '
4944          'activation and thus does not represent logits. Was this intended?"')
4945    from_logits = True
4946  elif (not from_logits and
4947        not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
4948        output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
4949    # When softmax activation function is used for output operation, we
4950    # use logits from the softmax function directly to compute loss in order
4951    # to prevent collapsing zero when training.
4952    # See b/117284466
4953    assert len(output.op.inputs) == 1
4954    output = output.op.inputs[0]
4955    from_logits = True
4956  elif not from_logits:
4957    epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4958    output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
4959    output = math_ops.log(output)
4960
4961  if isinstance(output.shape, (tuple, list)):
4962    output_rank = len(output.shape)
4963  else:
4964    output_rank = output.shape.ndims
4965  if output_rank is not None:
4966    axis %= output_rank
4967    if axis != output_rank - 1:
4968      permutation = list(
4969          itertools.chain(range(axis), range(axis + 1, output_rank), [axis]))
4970      output = array_ops.transpose(output, perm=permutation)
4971  elif axis != -1:
4972    raise ValueError(
4973        'Cannot compute sparse categorical crossentropy with `axis={}` on an '
4974        'output tensor with unknown rank'.format(axis))
4975
4976  target = cast(target, 'int64')
4977
4978  # Try to adjust the shape so that rank of labels = rank of logits - 1.
4979  output_shape = array_ops.shape_v2(output)
4980  target_rank = target.shape.ndims
4981
4982  update_shape = (
4983      target_rank is not None and output_rank is not None and
4984      target_rank != output_rank - 1)
4985  if update_shape:
4986    target = flatten(target)
4987    output = array_ops.reshape(output, [-1, output_shape[-1]])
4988
4989  if py_any(_is_symbolic_tensor(v) for v in [target, output]):
4990    with get_graph().as_default():
4991      res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4992          labels=target, logits=output)
4993  else:
4994    res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4995        labels=target, logits=output)
4996
4997  if update_shape and output_rank >= 3:
4998    # If our output includes timesteps or spatial dimensions we need to reshape
4999    return array_ops.reshape(res, output_shape[:-1])
5000  else:
5001    return res
5002
5003
5004@keras_export('keras.backend.binary_crossentropy')
5005@dispatch.add_dispatch_support
5006@doc_controls.do_not_generate_docs
5007def binary_crossentropy(target, output, from_logits=False):
5008  """Binary crossentropy between an output tensor and a target tensor.
5009
5010  Args:
5011      target: A tensor with the same shape as `output`.
5012      output: A tensor.
5013      from_logits: Whether `output` is expected to be a logits tensor.
5014          By default, we consider that `output`
5015          encodes a probability distribution.
5016
5017  Returns:
5018      A tensor.
5019  """
5020  target = ops.convert_to_tensor_v2_with_dispatch(target)
5021  output = ops.convert_to_tensor_v2_with_dispatch(output)
5022
5023  # Use logits whenever they are available. `softmax` and `sigmoid`
5024  # activations cache logits on the `output` Tensor.
5025  if hasattr(output, '_keras_logits'):
5026    output = output._keras_logits  # pylint: disable=protected-access
5027    if from_logits:
5028      warnings.warn(
5029          '"`binary_crossentropy` received `from_logits=True`, but the `output`'
5030          ' argument was produced by a sigmoid or softmax activation and thus '
5031          'does not represent logits. Was this intended?"')
5032    from_logits = True
5033
5034  if from_logits:
5035    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5036
5037  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
5038      output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
5039    # When sigmoid activation function is used for output operation, we
5040    # use logits from the sigmoid function directly to compute loss in order
5041    # to prevent collapsing zero when training.
5042    assert len(output.op.inputs) == 1
5043    output = output.op.inputs[0]
5044    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5045
5046  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5047  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
5048
5049  # Compute cross entropy from probabilities.
5050  bce = target * math_ops.log(output + epsilon())
5051  bce += (1 - target) * math_ops.log(1 - output + epsilon())
5052  return -bce
5053
5054
5055@keras_export('keras.backend.sigmoid')
5056@dispatch.add_dispatch_support
5057@doc_controls.do_not_generate_docs
5058def sigmoid(x):
5059  """Element-wise sigmoid.
5060
5061  Args:
5062      x: A tensor or variable.
5063
5064  Returns:
5065      A tensor.
5066  """
5067  return nn.sigmoid(x)
5068
5069
5070@keras_export('keras.backend.hard_sigmoid')
5071@dispatch.add_dispatch_support
5072@doc_controls.do_not_generate_docs
5073def hard_sigmoid(x):
5074  """Segment-wise linear approximation of sigmoid.
5075
5076  Faster than sigmoid.
5077  Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
5078  In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
5079
5080  Args:
5081      x: A tensor or variable.
5082
5083  Returns:
5084      A tensor.
5085  """
5086  point_two = _constant_to_tensor(0.2, x.dtype.base_dtype)
5087  point_five = _constant_to_tensor(0.5, x.dtype.base_dtype)
5088  x = math_ops.multiply(x, point_two)
5089  x = math_ops.add(x, point_five)
5090  x = clip_ops.clip_by_value(x, 0., 1.)
5091  return x
5092
5093
5094@keras_export('keras.backend.tanh')
5095@dispatch.add_dispatch_support
5096@doc_controls.do_not_generate_docs
5097def tanh(x):
5098  """Element-wise tanh.
5099
5100  Args:
5101      x: A tensor or variable.
5102
5103  Returns:
5104      A tensor.
5105  """
5106  return nn.tanh(x)
5107
5108
5109@keras_export('keras.backend.dropout')
5110@dispatch.add_dispatch_support
5111@doc_controls.do_not_generate_docs
5112def dropout(x, level, noise_shape=None, seed=None):
5113  """Sets entries in `x` to zero at random, while scaling the entire tensor.
5114
5115  Args:
5116      x: tensor
5117      level: fraction of the entries in the tensor
5118          that will be set to 0.
5119      noise_shape: shape for randomly generated keep/drop flags,
5120          must be broadcastable to the shape of `x`
5121      seed: random seed to ensure determinism.
5122
5123  Returns:
5124      A tensor.
5125  """
5126  if seed is None:
5127    seed = np.random.randint(10e6)
5128  return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
5129
5130
5131@keras_export('keras.backend.l2_normalize')
5132@dispatch.add_dispatch_support
5133@doc_controls.do_not_generate_docs
5134def l2_normalize(x, axis=None):
5135  """Normalizes a tensor wrt the L2 norm alongside the specified axis.
5136
5137  Args:
5138      x: Tensor or variable.
5139      axis: axis along which to perform normalization.
5140
5141  Returns:
5142      A tensor.
5143  """
5144  return nn.l2_normalize(x, axis=axis)
5145
5146
5147@keras_export('keras.backend.in_top_k')
5148@dispatch.add_dispatch_support
5149@doc_controls.do_not_generate_docs
5150def in_top_k(predictions, targets, k):
5151  """Returns whether the `targets` are in the top `k` `predictions`.
5152
5153  Args:
5154      predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
5155      targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
5156      k: An `int`, number of top elements to consider.
5157
5158  Returns:
5159      A 1D tensor of length `batch_size` and type `bool`.
5160      `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
5161      values of `predictions[i]`.
5162  """
5163  return nn.in_top_k(predictions, targets, k)
5164
5165
5166# CONVOLUTIONS
5167
5168
5169def _preprocess_conv1d_input(x, data_format):
5170  """Transpose and cast the input before the conv1d.
5171
5172  Args:
5173      x: input tensor.
5174      data_format: string, `"channels_last"` or `"channels_first"`.
5175
5176  Returns:
5177      A tensor.
5178  """
5179  tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
5180  if data_format == 'channels_first':
5181    if not _has_nchw_support():
5182      x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
5183    else:
5184      tf_data_format = 'NCW'
5185  return x, tf_data_format
5186
5187
5188def _preprocess_conv2d_input(x, data_format, force_transpose=False):
5189  """Transpose and cast the input before the conv2d.
5190
5191  Args:
5192      x: input tensor.
5193      data_format: string, `"channels_last"` or `"channels_first"`.
5194      force_transpose: Boolean. If True, the input will always be transposed
5195          from NCHW to NHWC if `data_format` is `"channels_first"`.
5196          If False, the transposition only occurs on CPU (GPU ops are
5197          assumed to support NCHW).
5198
5199  Returns:
5200      A tensor.
5201  """
5202  tf_data_format = 'NHWC'
5203  if data_format == 'channels_first':
5204    if not _has_nchw_support() or force_transpose:
5205      x = array_ops.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC
5206    else:
5207      tf_data_format = 'NCHW'
5208  return x, tf_data_format
5209
5210
5211def _preprocess_conv3d_input(x, data_format):
5212  """Transpose and cast the input before the conv3d.
5213
5214  Args:
5215      x: input tensor.
5216      data_format: string, `"channels_last"` or `"channels_first"`.
5217
5218  Returns:
5219      A tensor.
5220  """
5221  tf_data_format = 'NDHWC'
5222  if data_format == 'channels_first':
5223    if not _has_nchw_support():
5224      x = array_ops.transpose(x, (0, 2, 3, 4, 1))
5225    else:
5226      tf_data_format = 'NCDHW'
5227  return x, tf_data_format
5228
5229
5230def _preprocess_padding(padding):
5231  """Convert keras' padding to TensorFlow's padding.
5232
5233  Args:
5234      padding: string, one of 'same' , 'valid'
5235
5236  Returns:
5237      a string, one of 'SAME', 'VALID'.
5238
5239  Raises:
5240      ValueError: if invalid `padding'`
5241  """
5242  if padding == 'same':
5243    padding = 'SAME'
5244  elif padding == 'valid':
5245    padding = 'VALID'
5246  else:
5247    raise ValueError('Invalid padding: ' + str(padding))
5248  return padding
5249
5250
5251@keras_export('keras.backend.conv1d')
5252@dispatch.add_dispatch_support
5253@doc_controls.do_not_generate_docs
5254def conv1d(x,
5255           kernel,
5256           strides=1,
5257           padding='valid',
5258           data_format=None,
5259           dilation_rate=1):
5260  """1D convolution.
5261
5262  Args:
5263      x: Tensor or variable.
5264      kernel: kernel tensor.
5265      strides: stride integer.
5266      padding: string, `"same"`, `"causal"` or `"valid"`.
5267      data_format: string, one of "channels_last", "channels_first".
5268      dilation_rate: integer dilate rate.
5269
5270  Returns:
5271      A tensor, result of 1D convolution.
5272
5273  Raises:
5274      ValueError: if `data_format` is neither `channels_last` or
5275      `channels_first`.
5276  """
5277  if data_format is None:
5278    data_format = image_data_format()
5279  if data_format not in {'channels_first', 'channels_last'}:
5280    raise ValueError('Unknown data_format: ' + str(data_format))
5281
5282  kernel_shape = kernel.shape.as_list()
5283  if padding == 'causal':
5284    # causal (dilated) convolution:
5285    left_pad = dilation_rate * (kernel_shape[0] - 1)
5286    x = temporal_padding(x, (left_pad, 0))
5287    padding = 'valid'
5288  padding = _preprocess_padding(padding)
5289
5290  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5291  x = nn.convolution(
5292      input=x,
5293      filter=kernel,
5294      dilation_rate=dilation_rate,
5295      strides=strides,
5296      padding=padding,
5297      data_format=tf_data_format)
5298  if data_format == 'channels_first' and tf_data_format == 'NWC':
5299    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5300  return x
5301
5302
5303@keras_export('keras.backend.conv2d')
5304@dispatch.add_dispatch_support
5305@doc_controls.do_not_generate_docs
5306def conv2d(x,
5307           kernel,
5308           strides=(1, 1),
5309           padding='valid',
5310           data_format=None,
5311           dilation_rate=(1, 1)):
5312  """2D convolution.
5313
5314  Args:
5315      x: Tensor or variable.
5316      kernel: kernel tensor.
5317      strides: strides tuple.
5318      padding: string, `"same"` or `"valid"`.
5319      data_format: `"channels_last"` or `"channels_first"`.
5320      dilation_rate: tuple of 2 integers.
5321
5322  Returns:
5323      A tensor, result of 2D convolution.
5324
5325  Raises:
5326      ValueError: if `data_format` is neither `channels_last` or
5327      `channels_first`.
5328  """
5329  if data_format is None:
5330    data_format = image_data_format()
5331  if data_format not in {'channels_first', 'channels_last'}:
5332    raise ValueError('Unknown data_format: ' + str(data_format))
5333
5334  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5335  padding = _preprocess_padding(padding)
5336  x = nn.convolution(
5337      input=x,
5338      filter=kernel,
5339      dilation_rate=dilation_rate,
5340      strides=strides,
5341      padding=padding,
5342      data_format=tf_data_format)
5343  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5344    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5345  return x
5346
5347
5348@keras_export('keras.backend.conv2d_transpose')
5349@dispatch.add_dispatch_support
5350@doc_controls.do_not_generate_docs
5351def conv2d_transpose(x,
5352                     kernel,
5353                     output_shape,
5354                     strides=(1, 1),
5355                     padding='valid',
5356                     data_format=None,
5357                     dilation_rate=(1, 1)):
5358  """2D deconvolution (i.e.
5359
5360  transposed convolution).
5361
5362  Args:
5363      x: Tensor or variable.
5364      kernel: kernel tensor.
5365      output_shape: 1D int tensor for the output shape.
5366      strides: strides tuple.
5367      padding: string, `"same"` or `"valid"`.
5368      data_format: string, `"channels_last"` or `"channels_first"`.
5369      dilation_rate: Tuple of 2 integers.
5370
5371  Returns:
5372      A tensor, result of transposed 2D convolution.
5373
5374  Raises:
5375      ValueError: if `data_format` is neither `channels_last` or
5376      `channels_first`.
5377  """
5378  if data_format is None:
5379    data_format = image_data_format()
5380  if data_format not in {'channels_first', 'channels_last'}:
5381    raise ValueError('Unknown data_format: ' + str(data_format))
5382
5383  # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
5384  if data_format == 'channels_first' and dilation_rate != (1, 1):
5385    force_transpose = True
5386  else:
5387    force_transpose = False
5388
5389  x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
5390
5391  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5392    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5393                    output_shape[1])
5394  if output_shape[0] is None:
5395    output_shape = (shape(x)[0],) + tuple(output_shape[1:])
5396
5397  if isinstance(output_shape, (tuple, list)):
5398    output_shape = array_ops.stack(list(output_shape))
5399
5400  padding = _preprocess_padding(padding)
5401  if tf_data_format == 'NHWC':
5402    strides = (1,) + strides + (1,)
5403  else:
5404    strides = (1, 1) + strides
5405
5406  if dilation_rate == (1, 1):
5407    x = nn.conv2d_transpose(x, kernel, output_shape, strides,
5408                            padding=padding,
5409                            data_format=tf_data_format)
5410  else:
5411    assert dilation_rate[0] == dilation_rate[1]
5412    x = nn.atrous_conv2d_transpose(
5413        x,
5414        kernel,
5415        output_shape,
5416        rate=dilation_rate[0],
5417        padding=padding)
5418  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5419    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5420  return x
5421
5422
5423def separable_conv1d(x,
5424                     depthwise_kernel,
5425                     pointwise_kernel,
5426                     strides=1,
5427                     padding='valid',
5428                     data_format=None,
5429                     dilation_rate=1):
5430  """1D convolution with separable filters.
5431
5432  Args:
5433      x: input tensor
5434      depthwise_kernel: convolution kernel for the depthwise convolution.
5435      pointwise_kernel: kernel for the 1x1 convolution.
5436      strides: stride integer.
5437      padding: string, `"same"` or `"valid"`.
5438      data_format: string, `"channels_last"` or `"channels_first"`.
5439      dilation_rate: integer dilation rate.
5440
5441  Returns:
5442      Output tensor.
5443
5444  Raises:
5445      ValueError: if `data_format` is neither `channels_last` or
5446      `channels_first`.
5447  """
5448  if data_format is None:
5449    data_format = image_data_format()
5450  if data_format not in {'channels_first', 'channels_last'}:
5451    raise ValueError('Unknown data_format: ' + str(data_format))
5452
5453  if isinstance(strides, int):
5454    strides = (strides,)
5455  if isinstance(dilation_rate, int):
5456    dilation_rate = (dilation_rate,)
5457
5458  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5459  padding = _preprocess_padding(padding)
5460  if not isinstance(strides, tuple):
5461    strides = tuple(strides)
5462  if tf_data_format == 'NWC':
5463    spatial_start_dim = 1
5464    strides = (1,) + strides * 2 + (1,)
5465  else:
5466    spatial_start_dim = 2
5467    strides = (1, 1) + strides * 2
5468  x = array_ops.expand_dims(x, spatial_start_dim)
5469  depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
5470  pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
5471  dilation_rate = (1,) + dilation_rate
5472
5473  x = nn.separable_conv2d(
5474      x,
5475      depthwise_kernel,
5476      pointwise_kernel,
5477      strides=strides,
5478      padding=padding,
5479      rate=dilation_rate,
5480      data_format=tf_data_format)
5481
5482  x = array_ops.squeeze(x, [spatial_start_dim])
5483
5484  if data_format == 'channels_first' and tf_data_format == 'NWC':
5485    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5486
5487  return x
5488
5489
5490@keras_export('keras.backend.separable_conv2d')
5491@dispatch.add_dispatch_support
5492@doc_controls.do_not_generate_docs
5493def separable_conv2d(x,
5494                     depthwise_kernel,
5495                     pointwise_kernel,
5496                     strides=(1, 1),
5497                     padding='valid',
5498                     data_format=None,
5499                     dilation_rate=(1, 1)):
5500  """2D convolution with separable filters.
5501
5502  Args:
5503      x: input tensor
5504      depthwise_kernel: convolution kernel for the depthwise convolution.
5505      pointwise_kernel: kernel for the 1x1 convolution.
5506      strides: strides tuple (length 2).
5507      padding: string, `"same"` or `"valid"`.
5508      data_format: string, `"channels_last"` or `"channels_first"`.
5509      dilation_rate: tuple of integers,
5510          dilation rates for the separable convolution.
5511
5512  Returns:
5513      Output tensor.
5514
5515  Raises:
5516      ValueError: if `data_format` is neither `channels_last` or
5517      `channels_first`.
5518      ValueError: if `strides` is not a tuple of 2 integers.
5519  """
5520  if data_format is None:
5521    data_format = image_data_format()
5522  if data_format not in {'channels_first', 'channels_last'}:
5523    raise ValueError('Unknown data_format: ' + str(data_format))
5524  if len(strides) != 2:
5525    raise ValueError('`strides` must be a tuple of 2 integers.')
5526
5527  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5528  padding = _preprocess_padding(padding)
5529  if not isinstance(strides, tuple):
5530    strides = tuple(strides)
5531  if tf_data_format == 'NHWC':
5532    strides = (1,) + strides + (1,)
5533  else:
5534    strides = (1, 1) + strides
5535
5536  x = nn.separable_conv2d(
5537      x,
5538      depthwise_kernel,
5539      pointwise_kernel,
5540      strides=strides,
5541      padding=padding,
5542      rate=dilation_rate,
5543      data_format=tf_data_format)
5544  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5545    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5546  return x
5547
5548
5549@keras_export('keras.backend.depthwise_conv2d')
5550@dispatch.add_dispatch_support
5551@doc_controls.do_not_generate_docs
5552def depthwise_conv2d(x,
5553                     depthwise_kernel,
5554                     strides=(1, 1),
5555                     padding='valid',
5556                     data_format=None,
5557                     dilation_rate=(1, 1)):
5558  """2D convolution with separable filters.
5559
5560  Args:
5561      x: input tensor
5562      depthwise_kernel: convolution kernel for the depthwise convolution.
5563      strides: strides tuple (length 2).
5564      padding: string, `"same"` or `"valid"`.
5565      data_format: string, `"channels_last"` or `"channels_first"`.
5566      dilation_rate: tuple of integers,
5567          dilation rates for the separable convolution.
5568
5569  Returns:
5570      Output tensor.
5571
5572  Raises:
5573      ValueError: if `data_format` is neither `channels_last` or
5574      `channels_first`.
5575  """
5576  if data_format is None:
5577    data_format = image_data_format()
5578  if data_format not in {'channels_first', 'channels_last'}:
5579    raise ValueError('Unknown data_format: ' + str(data_format))
5580
5581  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5582  padding = _preprocess_padding(padding)
5583  if tf_data_format == 'NHWC':
5584    strides = (1,) + strides + (1,)
5585  else:
5586    strides = (1, 1) + strides
5587
5588  x = nn.depthwise_conv2d(
5589      x,
5590      depthwise_kernel,
5591      strides=strides,
5592      padding=padding,
5593      rate=dilation_rate,
5594      data_format=tf_data_format)
5595  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5596    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5597  return x
5598
5599
5600@keras_export('keras.backend.conv3d')
5601@dispatch.add_dispatch_support
5602@doc_controls.do_not_generate_docs
5603def conv3d(x,
5604           kernel,
5605           strides=(1, 1, 1),
5606           padding='valid',
5607           data_format=None,
5608           dilation_rate=(1, 1, 1)):
5609  """3D convolution.
5610
5611  Args:
5612      x: Tensor or variable.
5613      kernel: kernel tensor.
5614      strides: strides tuple.
5615      padding: string, `"same"` or `"valid"`.
5616      data_format: string, `"channels_last"` or `"channels_first"`.
5617      dilation_rate: tuple of 3 integers.
5618
5619  Returns:
5620      A tensor, result of 3D convolution.
5621
5622  Raises:
5623      ValueError: if `data_format` is neither `channels_last` or
5624      `channels_first`.
5625  """
5626  if data_format is None:
5627    data_format = image_data_format()
5628  if data_format not in {'channels_first', 'channels_last'}:
5629    raise ValueError('Unknown data_format: ' + str(data_format))
5630
5631  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5632  padding = _preprocess_padding(padding)
5633  x = nn.convolution(
5634      input=x,
5635      filter=kernel,
5636      dilation_rate=dilation_rate,
5637      strides=strides,
5638      padding=padding,
5639      data_format=tf_data_format)
5640  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5641    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5642  return x
5643
5644
5645def conv3d_transpose(x,
5646                     kernel,
5647                     output_shape,
5648                     strides=(1, 1, 1),
5649                     padding='valid',
5650                     data_format=None):
5651  """3D deconvolution (i.e.
5652
5653  transposed convolution).
5654
5655  Args:
5656      x: input tensor.
5657      kernel: kernel tensor.
5658      output_shape: 1D int tensor for the output shape.
5659      strides: strides tuple.
5660      padding: string, "same" or "valid".
5661      data_format: string, `"channels_last"` or `"channels_first"`.
5662
5663  Returns:
5664      A tensor, result of transposed 3D convolution.
5665
5666  Raises:
5667      ValueError: if `data_format` is neither `channels_last` or
5668      `channels_first`.
5669  """
5670  if data_format is None:
5671    data_format = image_data_format()
5672  if data_format not in {'channels_first', 'channels_last'}:
5673    raise ValueError('Unknown data_format: ' + str(data_format))
5674  if isinstance(output_shape, (tuple, list)):
5675    output_shape = array_ops.stack(output_shape)
5676
5677  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5678
5679  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5680    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5681                    output_shape[4], output_shape[1])
5682  if output_shape[0] is None:
5683    output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
5684    output_shape = array_ops.stack(list(output_shape))
5685
5686  padding = _preprocess_padding(padding)
5687  if tf_data_format == 'NDHWC':
5688    strides = (1,) + strides + (1,)
5689  else:
5690    strides = (1, 1) + strides
5691
5692  x = nn.conv3d_transpose(
5693      x,
5694      kernel,
5695      output_shape,
5696      strides,
5697      padding=padding,
5698      data_format=tf_data_format)
5699  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5700    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5701  return x
5702
5703
5704@keras_export('keras.backend.pool2d')
5705@dispatch.add_dispatch_support
5706@doc_controls.do_not_generate_docs
5707def pool2d(x,
5708           pool_size,
5709           strides=(1, 1),
5710           padding='valid',
5711           data_format=None,
5712           pool_mode='max'):
5713  """2D Pooling.
5714
5715  Args:
5716      x: Tensor or variable.
5717      pool_size: tuple of 2 integers.
5718      strides: tuple of 2 integers.
5719      padding: string, `"same"` or `"valid"`.
5720      data_format: string, `"channels_last"` or `"channels_first"`.
5721      pool_mode: string, `"max"` or `"avg"`.
5722
5723  Returns:
5724      A tensor, result of 2D pooling.
5725
5726  Raises:
5727      ValueError: if `data_format` is neither `"channels_last"` or
5728      `"channels_first"`.
5729      ValueError: if `pool_size` is not a tuple of 2 integers.
5730      ValueError: if `strides` is not a tuple of 2 integers.
5731      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5732  """
5733  if data_format is None:
5734    data_format = image_data_format()
5735  if data_format not in {'channels_first', 'channels_last'}:
5736    raise ValueError('Unknown data_format: ' + str(data_format))
5737  if len(pool_size) != 2:
5738    raise ValueError('`pool_size` must be a tuple of 2 integers.')
5739  if len(strides) != 2:
5740    raise ValueError('`strides` must be a tuple of 2 integers.')
5741
5742  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5743  padding = _preprocess_padding(padding)
5744  if tf_data_format == 'NHWC':
5745    strides = (1,) + strides + (1,)
5746    pool_size = (1,) + pool_size + (1,)
5747  else:
5748    strides = (1, 1) + strides
5749    pool_size = (1, 1) + pool_size
5750
5751  if pool_mode == 'max':
5752    x = nn.max_pool(
5753        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5754  elif pool_mode == 'avg':
5755    x = nn.avg_pool(
5756        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5757  else:
5758    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5759
5760  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5761    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5762  return x
5763
5764
5765@keras_export('keras.backend.pool3d')
5766@dispatch.add_dispatch_support
5767@doc_controls.do_not_generate_docs
5768def pool3d(x,
5769           pool_size,
5770           strides=(1, 1, 1),
5771           padding='valid',
5772           data_format=None,
5773           pool_mode='max'):
5774  """3D Pooling.
5775
5776  Args:
5777      x: Tensor or variable.
5778      pool_size: tuple of 3 integers.
5779      strides: tuple of 3 integers.
5780      padding: string, `"same"` or `"valid"`.
5781      data_format: string, `"channels_last"` or `"channels_first"`.
5782      pool_mode: string, `"max"` or `"avg"`.
5783
5784  Returns:
5785      A tensor, result of 3D pooling.
5786
5787  Raises:
5788      ValueError: if `data_format` is neither `"channels_last"` or
5789      `"channels_first"`.
5790      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5791  """
5792  if data_format is None:
5793    data_format = image_data_format()
5794  if data_format not in {'channels_first', 'channels_last'}:
5795    raise ValueError('Unknown data_format: ' + str(data_format))
5796
5797  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5798  padding = _preprocess_padding(padding)
5799  if tf_data_format == 'NDHWC':
5800    strides = (1,) + strides + (1,)
5801    pool_size = (1,) + pool_size + (1,)
5802  else:
5803    strides = (1, 1) + strides
5804    pool_size = (1, 1) + pool_size
5805
5806  if pool_mode == 'max':
5807    x = nn.max_pool3d(
5808        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5809  elif pool_mode == 'avg':
5810    x = nn.avg_pool3d(
5811        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5812  else:
5813    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5814
5815  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5816    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5817  return x
5818
5819
5820def local_conv(inputs,
5821               kernel,
5822               kernel_size,
5823               strides,
5824               output_shape,
5825               data_format=None):
5826  """Apply N-D convolution with un-shared weights.
5827
5828  Args:
5829      inputs: (N+2)-D tensor with shape
5830          (batch_size, channels_in, d_in1, ..., d_inN)
5831          if data_format='channels_first', or
5832          (batch_size, d_in1, ..., d_inN, channels_in)
5833          if data_format='channels_last'.
5834      kernel: the unshared weight for N-D convolution,
5835          with shape (output_items, feature_dim, channels_out), where
5836          feature_dim = np.prod(kernel_size) * channels_in,
5837          output_items = np.prod(output_shape).
5838      kernel_size: a tuple of N integers, specifying the
5839          spatial dimensions of the N-D convolution window.
5840      strides: a tuple of N integers, specifying the strides
5841          of the convolution along the spatial dimensions.
5842      output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
5843          dimensionality of the output.
5844      data_format: string, "channels_first" or "channels_last".
5845
5846  Returns:
5847      An (N+2)-D tensor with shape:
5848      (batch_size, channels_out) + output_shape
5849      if data_format='channels_first', or:
5850      (batch_size,) + output_shape + (channels_out,)
5851      if data_format='channels_last'.
5852
5853  Raises:
5854      ValueError: if `data_format` is neither
5855      `channels_last` nor `channels_first`.
5856  """
5857  if data_format is None:
5858    data_format = image_data_format()
5859  if data_format not in {'channels_first', 'channels_last'}:
5860    raise ValueError('Unknown data_format: ' + str(data_format))
5861
5862  kernel_shape = int_shape(kernel)
5863  feature_dim = kernel_shape[1]
5864  channels_out = kernel_shape[-1]
5865  ndims = len(output_shape)
5866  spatial_dimensions = list(range(ndims))
5867
5868  xs = []
5869  output_axes_ticks = [range(axis_max) for axis_max in output_shape]
5870  for position in itertools.product(*output_axes_ticks):
5871    slices = [slice(None)]
5872
5873    if data_format == 'channels_first':
5874      slices.append(slice(None))
5875
5876    slices.extend(
5877        slice(position[d] * strides[d], position[d] * strides[d] +
5878              kernel_size[d]) for d in spatial_dimensions)
5879
5880    if data_format == 'channels_last':
5881      slices.append(slice(None))
5882
5883    xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
5884
5885  x_aggregate = concatenate(xs, axis=0)
5886  output = batch_dot(x_aggregate, kernel)
5887  output = reshape(output, output_shape + (-1, channels_out))
5888
5889  if data_format == 'channels_first':
5890    permutation = [ndims, ndims + 1] + spatial_dimensions
5891  else:
5892    permutation = [ndims] + spatial_dimensions + [ndims + 1]
5893
5894  return permute_dimensions(output, permutation)
5895
5896
5897@keras_export('keras.backend.local_conv1d')
5898@dispatch.add_dispatch_support
5899@doc_controls.do_not_generate_docs
5900def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
5901  """Apply 1D conv with un-shared weights.
5902
5903  Args:
5904      inputs: 3D tensor with shape:
5905          (batch_size, steps, input_dim)
5906          if data_format is "channels_last" or
5907          (batch_size, input_dim, steps)
5908          if data_format is "channels_first".
5909      kernel: the unshared weight for convolution,
5910          with shape (output_length, feature_dim, filters).
5911      kernel_size: a tuple of a single integer,
5912          specifying the length of the 1D convolution window.
5913      strides: a tuple of a single integer,
5914          specifying the stride length of the convolution.
5915      data_format: the data format, channels_first or channels_last.
5916
5917  Returns:
5918      A 3d tensor with shape:
5919      (batch_size, output_length, filters)
5920      if data_format='channels_first'
5921      or 3D tensor with shape:
5922      (batch_size, filters, output_length)
5923      if data_format='channels_last'.
5924  """
5925  output_shape = (kernel.shape[0],)
5926  return local_conv(inputs,
5927                    kernel,
5928                    kernel_size,
5929                    strides,
5930                    output_shape,
5931                    data_format)
5932
5933
5934@keras_export('keras.backend.local_conv2d')
5935@dispatch.add_dispatch_support
5936@doc_controls.do_not_generate_docs
5937def local_conv2d(inputs,
5938                 kernel,
5939                 kernel_size,
5940                 strides,
5941                 output_shape,
5942                 data_format=None):
5943  """Apply 2D conv with un-shared weights.
5944
5945  Args:
5946      inputs: 4D tensor with shape:
5947          (batch_size, filters, new_rows, new_cols)
5948          if data_format='channels_first'
5949          or 4D tensor with shape:
5950          (batch_size, new_rows, new_cols, filters)
5951          if data_format='channels_last'.
5952      kernel: the unshared weight for convolution,
5953          with shape (output_items, feature_dim, filters).
5954      kernel_size: a tuple of 2 integers, specifying the
5955          width and height of the 2D convolution window.
5956      strides: a tuple of 2 integers, specifying the strides
5957          of the convolution along the width and height.
5958      output_shape: a tuple with (output_row, output_col).
5959      data_format: the data format, channels_first or channels_last.
5960
5961  Returns:
5962      A 4D tensor with shape:
5963      (batch_size, filters, new_rows, new_cols)
5964      if data_format='channels_first'
5965      or 4D tensor with shape:
5966      (batch_size, new_rows, new_cols, filters)
5967      if data_format='channels_last'.
5968  """
5969  return local_conv(inputs,
5970                    kernel,
5971                    kernel_size,
5972                    strides,
5973                    output_shape,
5974                    data_format)
5975
5976
5977@keras_export('keras.backend.bias_add')
5978@dispatch.add_dispatch_support
5979@doc_controls.do_not_generate_docs
5980def bias_add(x, bias, data_format=None):
5981  """Adds a bias vector to a tensor.
5982
5983  Args:
5984      x: Tensor or variable.
5985      bias: Bias tensor to add.
5986      data_format: string, `"channels_last"` or `"channels_first"`.
5987
5988  Returns:
5989      Output tensor.
5990
5991  Raises:
5992      ValueError: In one of the two cases below:
5993                  1. invalid `data_format` argument.
5994                  2. invalid bias shape.
5995                     the bias should be either a vector or
5996                     a tensor with ndim(x) - 1 dimension
5997  """
5998  if data_format is None:
5999    data_format = image_data_format()
6000  if data_format not in {'channels_first', 'channels_last'}:
6001    raise ValueError('Unknown data_format: ' + str(data_format))
6002  bias_shape = int_shape(bias)
6003  if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
6004    raise ValueError(
6005        'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
6006        (len(bias_shape), ndim(x) - 1))
6007
6008  if len(bias_shape) == 1:
6009    if data_format == 'channels_first':
6010      return nn.bias_add(x, bias, data_format='NCHW')
6011    return nn.bias_add(x, bias, data_format='NHWC')
6012  if ndim(x) in (3, 4, 5):
6013    if data_format == 'channels_first':
6014      bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
6015      return x + reshape(bias, bias_reshape_axis)
6016    return x + reshape(bias, (1,) + bias_shape)
6017  return nn.bias_add(x, bias)
6018
6019
6020# RANDOMNESS
6021
6022
6023@keras_export('keras.backend.random_normal')
6024@dispatch.add_dispatch_support
6025@doc_controls.do_not_generate_docs
6026def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6027  """Returns a tensor with normal distribution of values.
6028
6029  It is an alias to `tf.random.normal`.
6030
6031  Args:
6032      shape: A tuple of integers, the shape of tensor to create.
6033      mean: A float, the mean value of the normal distribution to draw samples.
6034        Default to 0.0.
6035      stddev: A float, the standard deviation of the normal distribution
6036        to draw samples. Default to 1.0.
6037      dtype: `tf.dtypes.DType`, dtype of returned tensor. Default to use Keras
6038        backend dtype which is float32.
6039      seed: Integer, random seed. Will use a random numpy integer when not
6040        specified.
6041
6042  Returns:
6043      A tensor with normal distribution of values.
6044
6045  Example:
6046
6047  >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3),
6048  ... mean=0.0, stddev=1.0)
6049  >>> random_normal_tensor
6050  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6051  dtype=float32)>
6052  """
6053  if dtype is None:
6054    dtype = floatx()
6055  if seed is None:
6056    seed = np.random.randint(10e6)
6057  return random_ops.random_normal(
6058      shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
6059
6060
6061@keras_export('keras.backend.random_uniform')
6062@dispatch.add_dispatch_support
6063@doc_controls.do_not_generate_docs
6064def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
6065  """Returns a tensor with uniform distribution of values.
6066
6067  Args:
6068      shape: A tuple of integers, the shape of tensor to create.
6069      minval: A float, lower boundary of the uniform distribution
6070          to draw samples.
6071      maxval: A float, upper boundary of the uniform distribution
6072          to draw samples.
6073      dtype: String, dtype of returned tensor.
6074      seed: Integer, random seed.
6075
6076  Returns:
6077      A tensor.
6078
6079  Example:
6080
6081  >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3),
6082  ... minval=0.0, maxval=1.0)
6083  >>> random_uniform_tensor
6084  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6085  dtype=float32)>
6086  """
6087  if dtype is None:
6088    dtype = floatx()
6089  if seed is None:
6090    seed = np.random.randint(10e6)
6091  return random_ops.random_uniform(
6092      shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
6093
6094
6095@keras_export('keras.backend.random_binomial')
6096@dispatch.add_dispatch_support
6097@doc_controls.do_not_generate_docs
6098def random_binomial(shape, p=0.0, dtype=None, seed=None):
6099  """Returns a tensor with random binomial distribution of values.
6100
6101  DEPRECATED, use `tf.keras.backend.random_bernoulli` instead.
6102
6103  The binomial distribution with parameters `n` and `p` is the probability
6104  distribution of the number of successful Bernoulli process. Only supports
6105  `n` = 1 for now.
6106
6107  Args:
6108      shape: A tuple of integers, the shape of tensor to create.
6109      p: A float, `0. <= p <= 1`, probability of binomial distribution.
6110      dtype: String, dtype of returned tensor.
6111      seed: Integer, random seed.
6112
6113  Returns:
6114      A tensor.
6115
6116  Example:
6117
6118  >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3),
6119  ... p=0.5)
6120  >>> random_binomial_tensor
6121  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6122  dtype=float32)>
6123  """
6124  warnings.warn('`tf.keras.backend.random_binomial` is deprecated, '
6125                'and will be removed in a future version.'
6126                'Please use `tf.keras.backend.random_bernoulli` instead.')
6127  return random_bernoulli(shape, p, dtype, seed)
6128
6129
6130@keras_export('keras.backend.random_bernoulli')
6131@dispatch.add_dispatch_support
6132@doc_controls.do_not_generate_docs
6133def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
6134  """Returns a tensor with random bernoulli distribution of values.
6135
6136  Args:
6137      shape: A tuple of integers, the shape of tensor to create.
6138      p: A float, `0. <= p <= 1`, probability of bernoulli distribution.
6139      dtype: String, dtype of returned tensor.
6140      seed: Integer, random seed.
6141
6142  Returns:
6143      A tensor.
6144  """
6145  if dtype is None:
6146    dtype = floatx()
6147  if seed is None:
6148    seed = np.random.randint(10e6)
6149  return array_ops.where_v2(
6150      random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
6151      array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
6152
6153
6154@keras_export('keras.backend.truncated_normal')
6155@dispatch.add_dispatch_support
6156@doc_controls.do_not_generate_docs
6157def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6158  """Returns a tensor with truncated random normal distribution of values.
6159
6160  The generated values follow a normal distribution
6161  with specified mean and standard deviation,
6162  except that values whose magnitude is more than
6163  two standard deviations from the mean are dropped and re-picked.
6164
6165  Args:
6166      shape: A tuple of integers, the shape of tensor to create.
6167      mean: Mean of the values.
6168      stddev: Standard deviation of the values.
6169      dtype: String, dtype of returned tensor.
6170      seed: Integer, random seed.
6171
6172  Returns:
6173      A tensor.
6174  """
6175  if dtype is None:
6176    dtype = floatx()
6177  if seed is None:
6178    seed = np.random.randint(10e6)
6179  return random_ops.truncated_normal(
6180      shape, mean, stddev, dtype=dtype, seed=seed)
6181
6182
6183# CTC
6184# TensorFlow has a native implementation, but it uses sparse tensors
6185# and therefore requires a wrapper for Keras. The functions below convert
6186# dense to sparse tensors and also wraps up the beam search code that is
6187# in TensorFlow's CTC implementation
6188
6189
6190@keras_export('keras.backend.ctc_label_dense_to_sparse')
6191@dispatch.add_dispatch_support
6192@doc_controls.do_not_generate_docs
6193def ctc_label_dense_to_sparse(labels, label_lengths):
6194  """Converts CTC labels from dense to sparse.
6195
6196  Args:
6197      labels: dense CTC labels.
6198      label_lengths: length of the labels.
6199
6200  Returns:
6201      A sparse tensor representation of the labels.
6202  """
6203  label_shape = array_ops.shape(labels)
6204  num_batches_tns = array_ops.stack([label_shape[0]])
6205  max_num_labels_tns = array_ops.stack([label_shape[1]])
6206
6207  def range_less_than(old_input, current_input):
6208    return array_ops.expand_dims(
6209        math_ops.range(array_ops.shape(old_input)[1]), 0) < array_ops.fill(
6210            max_num_labels_tns, current_input)
6211
6212  init = math_ops.cast(
6213      array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
6214  dense_mask = functional_ops.scan(
6215      range_less_than, label_lengths, initializer=init, parallel_iterations=1)
6216  dense_mask = dense_mask[:, 0, :]
6217
6218  label_array = array_ops.reshape(
6219      array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
6220      label_shape)
6221  label_ind = array_ops.boolean_mask(label_array, dense_mask)
6222
6223  batch_array = array_ops.transpose(
6224      array_ops.reshape(
6225          array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
6226          reverse(label_shape, 0)))
6227  batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
6228  indices = array_ops.transpose(
6229      array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
6230
6231  vals_sparse = array_ops.gather_nd(labels, indices)
6232
6233  return sparse_tensor.SparseTensor(
6234      math_ops.cast(indices, dtypes_module.int64), vals_sparse,
6235      math_ops.cast(label_shape, dtypes_module.int64))
6236
6237
6238@keras_export('keras.backend.ctc_batch_cost')
6239@dispatch.add_dispatch_support
6240@doc_controls.do_not_generate_docs
6241def ctc_batch_cost(y_true, y_pred, input_length, label_length):
6242  """Runs CTC loss algorithm on each batch element.
6243
6244  Args:
6245      y_true: tensor `(samples, max_string_length)`
6246          containing the truth labels.
6247      y_pred: tensor `(samples, time_steps, num_categories)`
6248          containing the prediction, or output of the softmax.
6249      input_length: tensor `(samples, 1)` containing the sequence length for
6250          each batch item in `y_pred`.
6251      label_length: tensor `(samples, 1)` containing the sequence length for
6252          each batch item in `y_true`.
6253
6254  Returns:
6255      Tensor with shape (samples,1) containing the
6256          CTC loss of each element.
6257  """
6258  label_length = math_ops.cast(
6259      array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
6260  input_length = math_ops.cast(
6261      array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
6262  sparse_labels = math_ops.cast(
6263      ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
6264
6265  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6266
6267  return array_ops.expand_dims(
6268      ctc.ctc_loss(
6269          inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
6270
6271
6272@keras_export('keras.backend.ctc_decode')
6273@dispatch.add_dispatch_support
6274@doc_controls.do_not_generate_docs
6275def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
6276  """Decodes the output of a softmax.
6277
6278  Can use either greedy search (also known as best path)
6279  or a constrained dictionary search.
6280
6281  Args:
6282      y_pred: tensor `(samples, time_steps, num_categories)`
6283          containing the prediction, or output of the softmax.
6284      input_length: tensor `(samples, )` containing the sequence length for
6285          each batch item in `y_pred`.
6286      greedy: perform much faster best-path search if `true`.
6287          This does not use a dictionary.
6288      beam_width: if `greedy` is `false`: a beam search decoder will be used
6289          with a beam of this width.
6290      top_paths: if `greedy` is `false`,
6291          how many of the most probable paths will be returned.
6292
6293  Returns:
6294      Tuple:
6295          List: if `greedy` is `true`, returns a list of one element that
6296              contains the decoded sequence.
6297              If `false`, returns the `top_paths` most probable
6298              decoded sequences.
6299              Each decoded sequence has shape (samples, time_steps).
6300              Important: blank labels are returned as `-1`.
6301          Tensor `(top_paths, )` that contains
6302              the log probability of each decoded sequence.
6303  """
6304  input_shape = shape(y_pred)
6305  num_samples, num_steps = input_shape[0], input_shape[1]
6306  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6307  input_length = math_ops.cast(input_length, dtypes_module.int32)
6308
6309  if greedy:
6310    (decoded, log_prob) = ctc.ctc_greedy_decoder(
6311        inputs=y_pred, sequence_length=input_length)
6312  else:
6313    (decoded, log_prob) = ctc.ctc_beam_search_decoder(
6314        inputs=y_pred,
6315        sequence_length=input_length,
6316        beam_width=beam_width,
6317        top_paths=top_paths)
6318  decoded_dense = []
6319  for st in decoded:
6320    st = sparse_tensor.SparseTensor(
6321        st.indices, st.values, (num_samples, num_steps))
6322    decoded_dense.append(
6323        sparse_ops.sparse_tensor_to_dense(sp_input=st, default_value=-1))
6324  return (decoded_dense, log_prob)
6325
6326
6327# HIGH ORDER FUNCTIONS
6328
6329
6330@keras_export('keras.backend.map_fn')
6331@doc_controls.do_not_generate_docs
6332def map_fn(fn, elems, name=None, dtype=None):
6333  """Map the function fn over the elements elems and return the outputs.
6334
6335  Args:
6336      fn: Callable that will be called upon each element in elems
6337      elems: tensor
6338      name: A string name for the map node in the graph
6339      dtype: Output data type.
6340
6341  Returns:
6342      Tensor with dtype `dtype`.
6343  """
6344  return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
6345
6346
6347@keras_export('keras.backend.foldl')
6348@doc_controls.do_not_generate_docs
6349def foldl(fn, elems, initializer=None, name=None):
6350  """Reduce elems using fn to combine them from left to right.
6351
6352  Args:
6353      fn: Callable that will be called upon each element in elems and an
6354          accumulator, for instance `lambda acc, x: acc + x`
6355      elems: tensor
6356      initializer: The first value used (`elems[0]` in case of None)
6357      name: A string name for the foldl node in the graph
6358
6359  Returns:
6360      Tensor with same type and shape as `initializer`.
6361  """
6362  return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
6363
6364
6365@keras_export('keras.backend.foldr')
6366@doc_controls.do_not_generate_docs
6367def foldr(fn, elems, initializer=None, name=None):
6368  """Reduce elems using fn to combine them from right to left.
6369
6370  Args:
6371      fn: Callable that will be called upon each element in elems and an
6372          accumulator, for instance `lambda acc, x: acc + x`
6373      elems: tensor
6374      initializer: The first value used (`elems[-1]` in case of None)
6375      name: A string name for the foldr node in the graph
6376
6377  Returns:
6378      Same type and shape as initializer
6379  """
6380  return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
6381
6382# Load Keras default configuration from config file if present.
6383# Set Keras base dir path given KERAS_HOME env variable, if applicable.
6384# Otherwise either ~/.keras or /tmp.
6385if 'KERAS_HOME' in os.environ:
6386  _keras_dir = os.environ.get('KERAS_HOME')
6387else:
6388  _keras_base_dir = os.path.expanduser('~')
6389  _keras_dir = os.path.join(_keras_base_dir, '.keras')
6390_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
6391if os.path.exists(_config_path):
6392  try:
6393    with open(_config_path) as fh:
6394      _config = json.load(fh)
6395  except ValueError:
6396    _config = {}
6397  _floatx = _config.get('floatx', floatx())
6398  assert _floatx in {'float16', 'float32', 'float64'}
6399  _epsilon = _config.get('epsilon', epsilon())
6400  assert isinstance(_epsilon, float)
6401  _image_data_format = _config.get('image_data_format', image_data_format())
6402  assert _image_data_format in {'channels_last', 'channels_first'}
6403  set_floatx(_floatx)
6404  set_epsilon(_epsilon)
6405  set_image_data_format(_image_data_format)
6406
6407# Save config file.
6408if not os.path.exists(_keras_dir):
6409  try:
6410    os.makedirs(_keras_dir)
6411  except OSError:
6412    # Except permission denied and potential race conditions
6413    # in multi-threaded environments.
6414    pass
6415
6416if not os.path.exists(_config_path):
6417  _config = {
6418      'floatx': floatx(),
6419      'epsilon': epsilon(),
6420      'backend': 'tensorflow',
6421      'image_data_format': image_data_format()
6422  }
6423  try:
6424    with open(_config_path, 'w') as f:
6425      f.write(json.dumps(_config, indent=4))
6426  except IOError:
6427    # Except permission denied.
6428    pass
6429
6430
6431def configure_and_create_distributed_session(distribution_strategy):
6432  """Configure session config and create a session with it."""
6433
6434  def _create_session(distribution_strategy):
6435    """Create the Distributed Strategy session."""
6436    session_config = get_default_session_config()
6437
6438    # If a session already exists, merge in its config; in the case there is a
6439    # conflict, take values of the existing config.
6440    global _SESSION
6441    if getattr(_SESSION, 'session', None) and _SESSION.session._config:
6442      session_config.MergeFrom(_SESSION.session._config)
6443
6444    if is_tpu_strategy(distribution_strategy):
6445      # TODO(priyag, yuefengz): Remove this workaround when Distribute
6446      # Coordinator is integrated with keras and we can create a session from
6447      # there.
6448      distribution_strategy.configure(session_config)
6449      master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
6450      session = session_module.Session(config=session_config, target=master)
6451    else:
6452      worker_context = dc.get_current_worker_context()
6453      if worker_context:
6454        dc_session_config = worker_context.session_config
6455        # Merge the default session config to the one from distribute
6456        # coordinator, which is fine for now since they don't have
6457        # conflicting configurations.
6458        dc_session_config.MergeFrom(session_config)
6459        session = session_module.Session(
6460            config=dc_session_config, target=worker_context.master_target)
6461      else:
6462        distribution_strategy.configure(session_config)
6463        session = session_module.Session(config=session_config)
6464
6465    set_session(session)
6466
6467  if distribution_strategy.extended._in_multi_worker_mode():
6468    dc.run_distribute_coordinator(
6469        _create_session,
6470        distribution_strategy)
6471  else:
6472    _create_session(distribution_strategy)
6473
6474
6475def _is_tpu_strategy_class(clz):
6476  is_tpu_strat = lambda k: k.__name__.startswith('TPUStrategy')
6477  if is_tpu_strat(clz):
6478    return True
6479  return py_any(map(_is_tpu_strategy_class, clz.__bases__))
6480
6481
6482def is_tpu_strategy(strategy):
6483  """Returns whether input is a TPUStrategy instance or subclass instance."""
6484  return _is_tpu_strategy_class(strategy.__class__)
6485
6486
6487def cast_variables_to_tensor(tensors):
6488
6489  def _cast_variables_to_tensor(tensor):
6490    if isinstance(tensor, variables_module.Variable):
6491      return array_ops.identity(tensor)
6492    return tensor
6493
6494  return nest.map_structure(_cast_variables_to_tensor, tensors)
6495
6496
6497def _is_symbolic_tensor(x):
6498  return tensor_util.is_tf_type(x) and not isinstance(x, ops.EagerTensor)
6499
6500
6501def convert_inputs_if_ragged(inputs):
6502  """Converts any ragged tensors to dense."""
6503
6504  def _convert_ragged_input(inputs):
6505    if isinstance(inputs, ragged_tensor.RaggedTensor):
6506      return inputs.to_tensor()
6507    return inputs
6508
6509  flat_inputs = nest.flatten(inputs)
6510  contains_ragged = py_any(
6511      isinstance(i, ragged_tensor.RaggedTensor) for i in flat_inputs)
6512
6513  if not contains_ragged:
6514    return inputs, None
6515
6516  inputs = nest.map_structure(_convert_ragged_input, inputs)
6517  # Multiple mask are not yet supported, so one mask is used on all inputs.
6518  # We approach this similarly when using row lengths to ignore steps.
6519  nested_row_lengths = math_ops.cast(flat_inputs[0].nested_row_lengths()[0],
6520                                     'int32')
6521  return inputs, nested_row_lengths
6522
6523
6524def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths,
6525                            go_backwards=False):
6526  """Converts any ragged input back to its initial structure."""
6527  if not is_ragged_input:
6528    return output
6529
6530  if go_backwards:
6531    # Reverse based on the timestep dim, so that nested_row_lengths will mask
6532    # from the correct direction. Return the reverse ragged tensor.
6533    output = reverse(output, [1])
6534    ragged = ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6535    return reverse(ragged, [1])
6536  else:
6537    return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6538
6539
6540class ContextValueCache(weakref.WeakKeyDictionary):
6541  """Container that caches (possibly tensor) values based on the context.
6542
6543  This class is similar to defaultdict, where values may be produced by the
6544  default factory specified during initialization. This class also has a default
6545  value for the key (when key is `None`) -- the key is set to the current graph
6546  or eager context. The default factories for key and value are only used in
6547  `__getitem__` and `setdefault`. The `.get()` behavior remains the same.
6548
6549  This object will return the value of the current graph or closest parent graph
6550  if the current graph is a function. This is to reflect the fact that if a
6551  tensor is created in eager/graph, child functions may capture that tensor.
6552
6553  The default factory method may accept keyword arguments (unlike defaultdict,
6554  which only accepts callables with 0 arguments). To pass keyword arguments to
6555  `default_factory`, use the `setdefault` method instead of `__getitem__`.
6556
6557  An example of how this class can be used in different contexts:
6558
6559  ```
6560  cache = ContextValueCache(int)
6561
6562  # Eager mode
6563  cache[None] += 2
6564  cache[None] += 4
6565  assert cache[None] == 6
6566
6567  # Graph mode
6568  with tf.Graph().as_default() as g:
6569    cache[None] += 5
6570    cache[g] += 3
6571  assert cache[g] == 8
6572  ```
6573
6574  Example of a default factory with arguments:
6575
6576  ```
6577  cache = ContextValueCache(lambda x: x + 1)
6578  g = tf.get_default_graph()
6579
6580  # Example with keyword argument.
6581  value = cache.setdefault(key=g, kwargs={'x': 3})
6582  assert cache[g] == 4
6583  ```
6584  """
6585
6586  def __init__(self, default_factory):
6587    self.default_factory = default_factory
6588    weakref.WeakKeyDictionary.__init__(self)
6589
6590  def _key(self):
6591    if context.executing_eagerly():
6592      return _DUMMY_EAGER_GRAPH.key
6593    else:
6594      return ops.get_default_graph()
6595
6596  def _get_parent_graph(self, graph):
6597    """Returns the parent graph or dummy eager object."""
6598    # TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the
6599    # outer graph. This results in outer_graph always being a Graph,
6600    # even in eager mode (get_default_graph will create a new Graph if there
6601    # isn't a default graph). Because of this bug, we have to specially set the
6602    # key when eager execution is enabled.
6603    parent_graph = graph.outer_graph
6604    if (not isinstance(parent_graph, func_graph.FuncGraph) and
6605        ops.executing_eagerly_outside_functions()):
6606      return _DUMMY_EAGER_GRAPH.key
6607    return parent_graph
6608
6609  def _get_recursive(self, key):
6610    """Gets the value at key or the closest parent graph."""
6611    value = self.get(key)
6612    if value is not None:
6613      return value
6614
6615    # Since FuncGraphs are able to capture tensors and variables from their
6616    # parent graphs, recursively search to see if there is a value stored for
6617    # one of the parent graphs.
6618    if isinstance(key, func_graph.FuncGraph):
6619      return self._get_recursive(self._get_parent_graph(key))
6620    return None
6621
6622  def __getitem__(self, key):
6623    """Gets the value at key (or current context), or sets default value.
6624
6625    Args:
6626      key: May be `None` or `Graph`object. When `None`, the key is set to the
6627        current context.
6628
6629    Returns:
6630      Either the cached or default value.
6631    """
6632    if key is None:
6633      key = self._key()
6634
6635    value = self._get_recursive(key)
6636    if value is None:
6637      value = self[key] = self.default_factory()  # pylint:disable=not-callable
6638    return value
6639
6640  def setdefault(self, key=None, default=None, kwargs=None):
6641    """Sets the default value if key is not in dict, and returns the value."""
6642    if key is None:
6643      key = self._key()
6644    kwargs = kwargs or {}
6645
6646    if default is None and key not in self:
6647      default = self.default_factory(**kwargs)
6648    return weakref.WeakKeyDictionary.setdefault(self, key, default)
6649
6650# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
6651# dummy object is used.
6652# A learning phase is a bool tensor used to run Keras models in
6653# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
6654_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
6655
6656# This dictionary holds a mapping between a graph and variables to initialize
6657# in the graph.
6658_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
6659
6660# This dictionary holds a mapping between a graph and TF optimizers created in
6661# the graph.
6662_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
6663