xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.ops."""
16
17import gc
18import os
19import threading
20import weakref
21
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import full_type_pb2
27from tensorflow.core.framework import tensor_shape_pb2
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.autograph.core import ag_ctx
30from tensorflow.python.client import session
31from tensorflow.python.data.ops import dataset_ops
32from tensorflow.python.eager import backprop
33from tensorflow.python.eager import context
34from tensorflow.python.eager import def_function
35from tensorflow.python.eager import function as eager_function
36from tensorflow.python.eager import wrap_function
37from tensorflow.python.framework import composite_tensor
38from tensorflow.python.framework import config
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import device as pydev
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import function
44from tensorflow.python.framework import indexed_slices
45from tensorflow.python.framework import ops
46from tensorflow.python.framework import sparse_tensor
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import tensor_util
50from tensorflow.python.framework import test_ops
51from tensorflow.python.framework import test_util
52from tensorflow.python.framework import type_spec
53from tensorflow.python.framework import versions
54from tensorflow.python.ops import array_ops
55from tensorflow.python.ops import control_flow_ops
56from tensorflow.python.ops import math_ops
57from tensorflow.python.ops import resource_variable_ops
58from tensorflow.python.ops import resources
59from tensorflow.python.ops import special_math_ops
60from tensorflow.python.ops import variable_scope
61from tensorflow.python.ops import variables
62import tensorflow.python.ops.gradients  # pylint: disable=unused-import
63from tensorflow.python.platform import googletest
64from tensorflow.python.util import compat
65
66
67class ResourceTest(test_util.TensorFlowTestCase):
68
69  @test_util.run_deprecated_v1
70  def testBuildGraph(self):
71    with self.cached_session():
72      pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
73      test_ops.resource_create_op(pt).run()
74
75  @test_util.run_deprecated_v1
76  def testInitialize(self):
77    with self.cached_session():
78      handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
79      resources.register_resource(
80          handle=handle,
81          create_op=test_ops.resource_create_op(handle),
82          is_initialized_op=test_ops.resource_initialized_op(handle))
83      self.assertEqual(
84          len(
85              resources.report_uninitialized_resources(
86                  resources.shared_resources()).eval()), 1)
87      resources.initialize_resources(resources.shared_resources()).run()
88      self.assertEqual(
89          len(
90              resources.report_uninitialized_resources(
91                  resources.shared_resources()).eval()), 0)
92
93
94class TensorAndShapeTest(test_util.TensorFlowTestCase):
95
96  def testShape(self):
97    op = ops.Operation(
98        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
99    t = op.outputs[0]
100    self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
101    t.set_shape([1, 2, 3])
102    self.assertEqual([1, 2, 3], t.get_shape())
103
104  def testIterable(self):
105    if not context.executing_eagerly():
106      self.skipTest("Eager-mode test")
107    op = ops.Operation(
108        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
109    t = op.outputs[0]
110    with self.assertRaisesRegex(TypeError, "Cannot iterate"):
111      iter(t)
112
113  def testIterableGraph(self):
114    if context.executing_eagerly():
115      self.skipTest("Graph-mode test")
116
117    op = ops.Operation(
118        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
119    t = op.outputs[0]
120    with self.assertRaisesRegex(TypeError, "Iterating.*not allowed in Graph"):
121      next(iter(t))
122    with self.assertRaisesRegex(TypeError, "Iterating.*AutoGraph did convert"):
123      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
124        next(iter(t))
125    with self.assertRaisesRegex(TypeError, "Iterating.*AutoGraph is disabled"):
126      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
127        next(iter(t))
128
129  def testImplicitBool(self):
130    op = ops.Operation(
131        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool])
132    t = op.outputs[0]
133    with self.assertRaisesRegex(TypeError,
134                                "Using.*as a.*bool.*not allowed in Graph"):
135      bool(t)
136    with self.assertRaisesRegex(TypeError,
137                                "Using.*as a.*bool.*AutoGraph did convert"):
138      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
139        bool(t)
140    with self.assertRaisesRegex(TypeError,
141                                "Using.*as a.*bool.*AutoGraph is disabled"):
142      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
143        bool(t)
144
145  def testAddShape(self):
146    with self.cached_session():
147      a = array_ops.zeros([2, 3])
148      b = array_ops.ones([1, 3])
149      c = a + b
150      self.assertEqual([2, 3], c.shape)
151
152  @test_util.run_deprecated_v1
153  def testUnknownDim(self):
154    with self.cached_session():
155      a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
156      b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
157      c = a + b
158      self.assertEqual([2, None, 3], c.shape.as_list())
159
160  @test_util.run_deprecated_v1
161  def testUnknownShape(self):
162    with self.cached_session():
163      a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
164      b = array_ops.ones([1, 3])
165      c = a + b
166      self.assertEqual(tensor_shape.unknown_shape(), c.shape)
167
168  @test_util.run_deprecated_v1
169  def testScalarShape(self):
170    with self.cached_session():
171      a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
172      b = array_ops.ones([])
173      c = a + b
174      self.assertEqual(tensor_shape.TensorShape([]), c.shape)
175
176  @test_util.run_deprecated_v1
177  def testShapeFunctionError(self):
178    with self.cached_session():
179      a = array_ops.ones([1, 2, 3])
180      b = array_ops.ones([4, 5, 6])
181      with self.assertRaisesRegex(
182          ValueError, r"Dimensions must be equal, but are 2 and 5 for .*add"
183          r".*Add(V2)?.* with input shapes: \[1,2,3\], \[4,5,6\]."):
184        _ = a + b
185
186  def testNumpyArray(self):
187    with ops.Graph().as_default():
188      x = array_ops.ones((3, 4), name="test_ones")
189
190    with self.assertRaisesRegex(NotImplementedError,
191                                r"Cannot convert a symbolic.+test_ones"):
192      np.array(x)
193
194    with self.assertRaisesRegex(TypeError, "not well defined.+test_ones"):
195      len(x)
196
197    # EagerTensors should still behave as numpy arrays.
198    with context.eager_mode():
199      x = array_ops.ones((3, 4))
200
201    self.assertAllEqual(x, np.ones((3, 4)))
202    self.assertAllEqual(np.array(x), np.ones((3, 4)))
203    self.assertLen(x, 3)
204
205  def testConstructor(self):
206    a = array_ops.ones([])
207    for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size",
208                 "tolist", "data"]:
209      with self.assertRaisesRegex(
210          AttributeError, r"If you are looking for numpy-related methods"):
211        getattr(a, name)
212    with self.assertRaisesRegex(
213        AttributeError, r"object has no attribute"):
214      a.foo_bar()
215
216  def testRef(self):
217    x1 = constant_op.constant(3)
218    x2 = x1
219    y = constant_op.constant(3)
220    z = constant_op.constant([6, 10])
221    w = variables.Variable(5)
222
223    self.assertEqual(x1.ref(), x1.ref())
224    self.assertEqual(x2.ref(), x2.ref())
225    self.assertEqual(x1.ref(), x2.ref())
226    self.assertEqual(y.ref(), y.ref())
227    self.assertEqual(z.ref(), z.ref())
228    self.assertEqual(w.ref(), w.ref())
229
230    self.assertNotEqual(x1.ref(), y.ref())
231    self.assertNotEqual(x1.ref(), z.ref())
232    self.assertNotEqual(x1.ref(), w.ref())
233    self.assertNotEqual(y.ref(), z.ref())
234    self.assertNotEqual(y.ref(), w.ref())
235    self.assertNotEqual(z.ref(), w.ref())
236
237  def testRefDeref(self):
238    x1 = constant_op.constant(3)
239    x2 = x1
240    y = constant_op.constant(3)
241    z = constant_op.constant([6, 10])
242    w = variables.Variable(5)
243
244    self.assertIs(x1, x1.ref().deref())
245    self.assertIs(x2, x2.ref().deref())
246    self.assertIs(x1, x2.ref().deref())
247    self.assertIs(x2, x1.ref().deref())
248    self.assertIs(y, y.ref().deref())
249    self.assertIs(z, z.ref().deref())
250
251    self.assertIsNot(x1, y.ref().deref())
252    self.assertIsNot(x1, z.ref().deref())
253    self.assertIsNot(x1, w.ref().deref())
254    self.assertIsNot(y, z.ref().deref())
255    self.assertIsNot(y, w.ref().deref())
256    self.assertIsNot(z, w.ref().deref())
257
258  def testRefInSet(self):
259    x1 = constant_op.constant(3)
260    x2 = x1
261    y = constant_op.constant(3)
262    z = constant_op.constant([6, 10])
263    w = variables.Variable(5)
264
265    self.assertEqual(x1.ref(), x2.ref())
266
267    tensor_set = {
268        x1.ref(),
269        x2.ref(),
270        y.ref(),
271        z.ref(),
272        w.ref(),
273    }
274
275    self.assertLen(tensor_set, 4)
276    self.assertIn(x1.ref(), tensor_set)
277    self.assertIn(x2.ref(), tensor_set)
278    self.assertIn(y.ref(), tensor_set)
279    self.assertIn(z.ref(), tensor_set)
280    self.assertIn(w.ref(), tensor_set)
281
282  def testRefInDict(self):
283    x1 = constant_op.constant(3)
284    x2 = x1
285    y = constant_op.constant(3)
286    z = constant_op.constant([6, 10])
287    w = variables.Variable(5)
288
289    self.assertEqual(x1.ref(), x2.ref())
290
291    tensor_dict = {
292        x1.ref(): "x1",
293        y.ref(): "y",
294        z.ref(): "z",
295        w.ref(): "w",
296    }
297
298    self.assertLen(tensor_dict, 4)
299
300    # Overwriting x1
301    tensor_dict[x2.ref()] = "x2"
302    self.assertLen(tensor_dict, 4)
303
304    self.assertEqual(tensor_dict[x1.ref()], "x2")
305    self.assertEqual(tensor_dict[x2.ref()], "x2")
306    self.assertEqual(tensor_dict[y.ref()], "y")
307    self.assertEqual(tensor_dict[z.ref()], "z")
308    self.assertEqual(tensor_dict[w.ref()], "w")
309
310  def testTensorRefStrong(self):
311    x = constant_op.constant(1.)
312    x_ref = x.ref()
313    del x
314    self.assertIsNotNone(x_ref.deref())
315
316  def testVariableRefStrong(self):
317    x = variables.Variable(1.)
318    x_ref = x.ref()
319    del x
320    self.assertIsNotNone(x_ref.deref())
321
322  @test_util.run_in_graph_and_eager_modes
323  def testBitwiseAndNumeric(self):
324    x = constant_op.constant([0, 1, 3])
325    y = constant_op.constant([1, 1, 1])
326
327    z = x & y
328
329    self.assertAllEqual(z, [0, 1, 1])
330
331  @test_util.run_in_graph_and_eager_modes
332  def testBitwiseAndBool(self):
333    x = constant_op.constant([False, False, True, True])
334    y = constant_op.constant([False, True, False, True])
335
336    z = x & y
337
338    self.assertAllEqual(z, [False, False, False, True])
339
340  @test_util.run_in_graph_and_eager_modes
341  def testBitwiseAndErrors(self):
342    x_int = constant_op.constant(0)
343    x_bool = constant_op.constant(True)
344
345    if context.executing_eagerly():  # :(
346      expected_errtype = errors.InvalidArgumentError
347    else:
348      expected_errtype = TypeError
349
350    with self.assertRaises(expected_errtype):
351      _ = x_int & x_bool
352    with self.assertRaises(expected_errtype):
353      _ = x_int & constant_op.constant("a")
354
355    with self.assertRaises(expected_errtype):
356      _ = x_bool & x_int
357    with self.assertRaises(expected_errtype):
358      _ = x_bool & constant_op.constant("a")
359
360    with self.assertRaises(expected_errtype):
361      _ = constant_op.constant("a") & constant_op.constant("b")
362
363  @test_util.run_in_graph_and_eager_modes
364  def testBitwiseOrNumeric(self):
365    x = constant_op.constant([0, 1, 2])
366    y = constant_op.constant([1, 1, 1])
367
368    z = x | y
369
370    self.assertAllEqual(z, [1, 1, 3])
371
372  @test_util.run_in_graph_and_eager_modes
373  def testBitwiseOrBool(self):
374    x = constant_op.constant([False, False, True, True])
375    y = constant_op.constant([False, True, False, True])
376
377    z = x | y
378
379    self.assertAllEqual(z, [False, True, True, True])
380
381  @test_util.run_in_graph_and_eager_modes
382  def testBitwiseOrErrors(self):
383    x_int = constant_op.constant(0)
384    x_bool = constant_op.constant(True)
385
386    if context.executing_eagerly():  # :(
387      expected_errtype = errors.InvalidArgumentError
388    else:
389      expected_errtype = TypeError
390
391    with self.assertRaises(expected_errtype):
392      _ = x_int | x_bool
393    with self.assertRaises(expected_errtype):
394      _ = x_int | constant_op.constant("a")
395
396    with self.assertRaises(expected_errtype):
397      _ = x_bool | x_int
398    with self.assertRaises(expected_errtype):
399      _ = x_bool | constant_op.constant("a")
400
401    with self.assertRaises(expected_errtype):
402      _ = constant_op.constant("a") | constant_op.constant("b")
403
404  @test_util.run_in_graph_and_eager_modes
405  def testBitwiseXorNumeric(self):
406    x = constant_op.constant([0, 1, 3])
407    y = constant_op.constant([1, 1, 1])
408
409    z = x ^ y
410
411    self.assertAllEqual(z, [1, 0, 2])
412
413  @test_util.run_in_graph_and_eager_modes
414  def testBitwiseXorBool(self):
415    x = constant_op.constant([False, False, True, True])
416    y = constant_op.constant([False, True, False, True])
417
418    z = x ^ y
419
420    self.assertAllEqual(z, [False, True, True, False])
421
422  @test_util.run_in_graph_and_eager_modes
423  def testBitwiseXorErrors(self):
424    x_int = constant_op.constant(0)
425    x_bool = constant_op.constant(True)
426
427    if context.executing_eagerly():  # :(
428      expected_errtype = errors.InvalidArgumentError
429    else:
430      expected_errtype = TypeError
431
432    with self.assertRaises(expected_errtype):
433      _ = x_int ^ x_bool
434    with self.assertRaises(expected_errtype):
435      _ = x_int ^ constant_op.constant("a")
436
437    with self.assertRaises(expected_errtype):
438      _ = x_bool ^ x_int
439    with self.assertRaises(expected_errtype):
440      _ = x_bool ^ constant_op.constant("a")
441
442    with self.assertRaises(expected_errtype):
443      _ = constant_op.constant("a") ^ constant_op.constant("b")
444
445  @test_util.run_in_graph_and_eager_modes
446  def testBitwiseNotNumeric(self):
447    x = constant_op.constant([0, dtypes.int32.min, 1])
448
449    # pylint: disable=invalid-unary-operand-type
450    y = ~x
451
452    self.assertAllEqual(y, [-1, dtypes.int32.max, -2])
453
454  @test_util.run_in_graph_and_eager_modes
455  def testBitwiseNotBool(self):
456    x = constant_op.constant([False, True])
457
458    # pylint: disable=invalid-unary-operand-type
459    y = ~x
460
461    self.assertAllEqual(y, [True, False])
462
463  @test_util.run_in_graph_and_eager_modes
464  def testBitwiseNotErrors(self):
465    if context.executing_eagerly():  # :(
466      expected_errtype = errors.InvalidArgumentError
467    else:
468      expected_errtype = TypeError
469
470    # pylint: disable=invalid-unary-operand-type
471    with self.assertRaises(expected_errtype):
472      _ = ~constant_op.constant("a")
473
474
475@test_util.run_all_in_graph_and_eager_modes
476class IndexedSlicesTest(test_util.TensorFlowTestCase):
477
478  def testToTensor(self):
479    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
480    indices = constant_op.constant([0, 2])
481    x = indexed_slices.IndexedSlices(values, indices)
482    with self.assertRaises(ValueError):
483      tensor = ops.convert_to_tensor(x, name="tensor")
484    self.assertEqual(tensor_shape.TensorShape(None), x.shape)
485
486    dense_shape = constant_op.constant([3, 2])
487    y = indexed_slices.IndexedSlices(values, indices, dense_shape)
488    tensor = ops.convert_to_tensor(y, name="tensor")
489    self.assertAllEqual(tensor.shape, y.shape)
490    self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])
491
492  @test_util.run_gpu_only
493  def testEagerCopy(self):
494    with context.eager_mode():
495      var = variables.Variable([[0.0], [0.0], [0.0], [0.0]], name="tensor")
496      with backprop.GradientTape() as tape:
497        a = array_ops.gather(array_ops.gather(var, [0, 1]), [0, 1])
498        b = array_ops.gather(array_ops.gather(var, [2, 3]), [0, 1])
499        r = special_math_ops.einsum("ij,ij->i", a, b)
500      g = tape.gradient(r, [var])[0]
501      values = g.values if isinstance(g, indexed_slices.IndexedSlices) else g
502      self.assertAllEqual(values.get_shape(), [4, 1])
503
504  def testNegation(self):
505    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
506    indices = constant_op.constant([0, 2])
507    x = -indexed_slices.IndexedSlices(values, indices)
508    self.assertAllEqual(x.values, [[-2, -3], [-5, -7]])
509    self.assertAllEqual(x.indices, [0, 2])
510
511  def testScalarMul(self):
512    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
513    indices = constant_op.constant([0, 2])
514    x = math_ops.scalar_mul(-2, indexed_slices.IndexedSlices(values, indices))
515    self.assertAllEqual(x.values, [[-4, -6], [-10, -14]])
516    self.assertAllEqual(x.indices, [0, 2])
517
518
519@test_util.run_all_in_graph_and_eager_modes
520class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
521                            parameterized.TestCase):
522
523  def assertAllTensorsEqual(self, list1, list2):
524    self.assertLen(list1, len(list2))
525    for (t1, t2) in zip(list1, list2):
526      self.assertAllEqual(t1, t2)
527
528  def testConstruction(self):
529    spec1 = indexed_slices.IndexedSlicesSpec()
530    self.assertIsNone(spec1._shape.rank)
531    self.assertEqual(spec1._values_dtype, dtypes.float32)
532    self.assertEqual(spec1._indices_dtype, dtypes.int64)
533    self.assertIsNone(spec1._dense_shape_dtype)
534    self.assertEqual(spec1._indices_shape.as_list(), [None])
535
536    spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string,
537                                             dtypes.int32, dtypes.int64, [10])
538    self.assertEqual(spec2._shape.as_list(), [None, None])
539    self.assertEqual(spec2._values_dtype, dtypes.string)
540    self.assertEqual(spec2._indices_dtype, dtypes.int32)
541    self.assertEqual(spec2._dense_shape_dtype, dtypes.int64)
542    self.assertEqual(spec2._indices_shape.as_list(), [10])
543
544  def testValueType(self):
545    spec1 = indexed_slices.IndexedSlicesSpec()
546    self.assertEqual(spec1.value_type, indexed_slices.IndexedSlices)
547
548  @parameterized.parameters([
549      (indexed_slices.IndexedSlicesSpec(),
550       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
551        tensor_shape.TensorShape([None]))),
552      (indexed_slices.IndexedSlicesSpec(shape=[5, None, None]),
553       (tensor_shape.TensorShape([5, None, None]), dtypes.float32,
554        dtypes.int64, None, tensor_shape.TensorShape([None]))),
555      (indexed_slices.IndexedSlicesSpec(
556          dtype=dtypes.int32, dense_shape_dtype=dtypes.int64),
557       (tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64,
558        dtypes.int64, tensor_shape.TensorShape([None]))),
559      (indexed_slices.IndexedSlicesSpec(indices_shape=[100]),
560       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
561        tensor_shape.TensorShape([100]))),
562  ])  # pyformat: disable
563  def testSerialize(self, spec, expected):
564    serialization = spec._serialize()
565    # TensorShape has an unconventional definition of equality, so we can't use
566    # assertEqual directly here.  But repr() is deterministic and lossless for
567    # the expected values, so we can use that instead.
568    self.assertEqual(repr(serialization), repr(expected))
569
570  @parameterized.parameters([
571      (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), (
572          tensor_spec.TensorSpec(None, dtypes.string),
573          tensor_spec.TensorSpec([None], dtypes.int64),
574      )),
575      (indexed_slices.IndexedSlicesSpec(
576          dtype=dtypes.string, dense_shape_dtype=dtypes.int32), (
577              tensor_spec.TensorSpec(None, dtypes.string),
578              tensor_spec.TensorSpec([None], dtypes.int64),
579              tensor_spec.TensorSpec([None], dtypes.int32),
580          )),
581      (indexed_slices.IndexedSlicesSpec(
582          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), (
583              tensor_spec.TensorSpec([None, 10, 15], dtypes.float32),
584              tensor_spec.TensorSpec([None], dtypes.int64),
585              tensor_spec.TensorSpec([3], dtypes.int32),
586          )),
587      (indexed_slices.IndexedSlicesSpec(
588          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32,
589          indices_shape=[20]), (
590              tensor_spec.TensorSpec([20, 10, 15], dtypes.float32),
591              tensor_spec.TensorSpec([20], dtypes.int64),
592              tensor_spec.TensorSpec([3], dtypes.int32),
593          )),
594  ])
595  def testComponentSpecs(self, spec, expected):
596    self.assertEqual(spec._component_specs, expected)
597
598  @parameterized.parameters([
599      {
600          "spec": indexed_slices.IndexedSlicesSpec(),
601          "values": [3.0, 5.0],
602          "indices": [5, 10]
603      },
604      {
605          "spec":
606              indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32),
607          "values": [3.0, 5.0],
608          "indices": [5, 10],
609          "dense_shape": [100]
610      },
611  ])
612  def testToFromComponents(self, spec, indices, values, dense_shape=None):
613    x = indexed_slices.IndexedSlices(indices, values, dense_shape)
614    actual_components = spec._to_components(x)
615    if dense_shape is None:
616      self.assertAllTensorsEqual(actual_components, [indices, values])
617    else:
618      self.assertAllTensorsEqual(actual_components,
619                                 [indices, values, dense_shape])
620    st_reconstructed = spec._from_components(actual_components)
621    self.assertAllEqual(x.indices, st_reconstructed.indices)
622    self.assertAllEqual(x.values, st_reconstructed.values)
623    if dense_shape is None:
624      self.assertIsNone(st_reconstructed.dense_shape)
625    else:
626      self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)
627
628  @test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
629  def testFromNumpyComponents(self):
630    indices = np.array([3, 8])
631    values = np.array([1.0, 9.0])
632    dense_shape = np.array([100])
633
634    spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
635    st1 = spec1._from_components((values, indices, dense_shape))
636    self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
637    self.assertAllEqual(st1.indices, indices)
638    self.assertAllEqual(st1.values, values)
639    self.assertAllEqual(st1.dense_shape, dense_shape)
640
641    spec2 = indexed_slices.IndexedSlicesSpec()
642    st2 = spec2._from_components((values, indices))
643    self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
644    self.assertAllEqual(st2.indices, indices)
645    self.assertAllEqual(st2.values, values)
646    self.assertIsNone(st2.dense_shape)
647
648
649class NodeDefConstructorTest(test_util.TensorFlowTestCase):
650
651  def testNoArgs(self):
652    nodedef = ops._NodeDef("None", "bar")
653    self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
654
655
656def _apply_op(g, *args, **kwargs):
657  op = g.create_op(*args, **kwargs)
658  if len(op.outputs) == 1:
659    return op.outputs[0]
660  else:
661    return op.outputs
662
663
664class OperationTest(test_util.TensorFlowTestCase):
665
666  def testTraceback(self):
667    g = ops.Graph()
668    op1 = ops.Operation(
669        ops._NodeDef("None", "op1"), g, [],
670        [dtypes.float32_ref, dtypes.float32])
671    self.assertIn("testTraceback", op1.traceback[-1])
672
673  @test_util.run_deprecated_v1
674  def testNoInputs(self):
675    op = test_ops.float_output_string_output(name="myop").a.op
676    self.assertEqual(2, len(op.values()))
677    self.assertEqual(0, len(op.inputs))
678    self.assertEqual("myop", op.name)
679
680    float_t, label_str_t = op.values()
681    self.assertEqual(dtypes.float32, float_t.dtype)
682    self.assertEqual(op, float_t.op)
683    self.assertEqual(0, float_t._value_index)
684    self.assertEqual(0, len(float_t.consumers()))
685    self.assertEqual("myop", float_t._as_node_def_input())
686
687    self.assertEqual(dtypes.string, label_str_t.dtype)
688    self.assertEqual(op, label_str_t.op)
689    self.assertEqual(1, label_str_t._value_index)
690    self.assertEqual(0, len(label_str_t.consumers()))
691    self.assertEqual("myop:1", label_str_t._as_node_def_input())
692
693    self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
694                           op.node_def)
695
696  @test_util.run_deprecated_v1
697  def testNoOutputs(self):
698    op1 = test_ops.float_output(name="myop1").op
699    float_t, = op1.values()
700    op2 = test_ops.float_input(float_t, name="myop2")
701    self.assertEqual(0, len(op2.values()))
702    self.assertEqual(1, len(op2.inputs))
703    self.assertIs(float_t, op2.inputs[0])
704
705    self.assertEqual(1, len(float_t.consumers()))
706    self.assertEqual(op2, float_t.consumers()[0])
707
708    self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
709    self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
710                           op2.node_def)
711
712  @test_util.run_deprecated_v1
713  def testInputsAndOutputs(self):
714    op1 = test_ops.float_output(name="myop1").op
715    self.assertEqual(1, len(op1.values()))
716    float1_t, = op1.values()
717
718    op2 = test_ops.float_output_string_output(name="myop2").a.op
719    self.assertEqual(2, len(op2.values()))
720    float2_t, label2_str_t = op2.values()
721
722    # Note that we consume label2_str_t twice here.
723    op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
724    self.assertEqual(2, len(op3.values()))
725
726    self.assertEqual(1, len(float1_t.consumers()))
727    self.assertEqual(op3, float1_t.consumers()[0])
728
729    self.assertEqual(0, len(float2_t.consumers()))
730
731    self.assertEqual(2, len(label2_str_t.consumers()))
732    self.assertEqual(op3, label2_str_t.consumers()[0])
733    self.assertEqual(op3, label2_str_t.consumers()[1])
734
735    self.assertProtoEquals("""
736    op:'Foo2' name:'myop3'
737    input:'myop1' input:'myop2:1' input:'myop2:1'
738    """, op3.node_def)
739
740  def testDeviceObject(self):
741    op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
742    op._set_device("/job:goo/device:GPU:0")
743    self.assertProtoEquals(
744        "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
745    op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
746    op._set_device(
747        pydev.DeviceSpec(
748            job="muu", device_type="CPU", device_index=0))
749    self.assertProtoEquals(
750        "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
751
752  def testReferenceInput(self):
753    g = ops.Graph()
754    op1 = ops.Operation(
755        ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
756        [dtypes.float32_ref, dtypes.float32])
757    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
758    self.assertEqual([], list(op1.inputs))
759    ref_t, nonref_t = op1.values()
760    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
761    op2 = ops.Operation(
762        ops._NodeDef("RefInputFloatInput", "op2"),
763        g, [ref_t, nonref_t], [],
764        input_types=[dtypes.float32_ref, dtypes.float32])
765    self.assertProtoEquals(
766        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
767        op2.node_def)
768    self.assertEqual([ref_t, nonref_t], list(op2.inputs))
769    op3 = ops.Operation(
770        ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
771    self.assertProtoEquals(
772        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
773        op3.node_def)
774
775  def testInvalidNames(self):
776    g = ops.Graph()
777    with self.assertRaises(ValueError):
778      ops.Operation(ops._NodeDef("op", ""), g)
779    with self.assertRaises(ValueError):
780      ops.Operation(ops._NodeDef("op", "_invalid"), g)
781    with self.assertRaises(ValueError):
782      ops.Operation(ops._NodeDef("op", "-invalid"), g)
783    with self.assertRaises(ValueError):
784      ops.Operation(ops._NodeDef("op", "/invalid"), g)
785    with self.assertRaises(ValueError):
786      ops.Operation(ops._NodeDef("op", "invalid:0"), g)
787
788  @test_util.run_deprecated_v1
789  def testNoShapeFunction(self):
790    op = test_ops.a()
791    self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
792
793  @test_util.run_in_graph_and_eager_modes
794  def testConvertToTensorNestedArray(self):
795    values = [[2], [3], [5], [7]]
796    tensor = ops.convert_to_tensor(values)
797    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
798    self.assertAllEqual(values, self.evaluate(tensor))
799
800  def testShapeTuple(self):
801    with self.cached_session():
802      c = constant_op.constant(1)
803      self.assertEqual(c._shape_tuple(), ())  # pylint: disable=protected-access
804
805  def testConvertToTensorEager(self):
806    with context.eager_mode():
807      t = constant_op.constant(1)
808      self.assertTrue(isinstance(t, ops.EagerTensor))
809      converted = ops.convert_to_tensor(t)
810      self.assertTrue(isinstance(converted, ops.EagerTensor))
811      converted = ops.convert_to_tensor(1)
812      self.assertTrue(isinstance(converted, ops.EagerTensor))
813
814  @test_util.run_in_graph_and_eager_modes
815  def testConvertToTensorNestedTuple(self):
816    values = ((2,), (3,), (5,), (7,))
817    tensor = ops.convert_to_tensor(values)
818    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
819    self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))
820
821  @test_util.run_in_graph_and_eager_modes
822  def testConvertToTensorNestedTensors(self):
823    values = ((2,), (3,), (5,), (7,))
824    tensor = ops.convert_to_tensor(
825        [constant_op.constant(row) for row in values])
826    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
827    self.assertAllEqual(values, self.evaluate(tensor))
828    tensor = ops.convert_to_tensor(
829        [[constant_op.constant(v) for v in row] for row in values])
830    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
831    self.assertAllEqual(values, self.evaluate(tensor))
832
833  @test_util.run_in_graph_and_eager_modes
834  def testConvertToTensorNestedMix(self):
835    values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
836    tensor = ops.convert_to_tensor(values)
837    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
838    self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))
839
840  @test_util.run_in_graph_and_eager_modes
841  def testConvertToTensorPreferred(self):
842    values = [2, 3, 5, 7]
843    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
844    self.assertEqual(dtypes.float32, tensor.dtype)
845
846    # Convert empty tensor to anything.
847    values = []
848    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
849    self.assertEqual(dtypes.int64, tensor.dtype)
850
851    # The preferred dtype is a type error and will convert to
852    # float32 instead.
853    values = [1.23]
854    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
855    self.assertEqual(dtypes.float32, tensor.dtype)
856
857  @test_util.run_in_graph_and_eager_modes
858  def testConvertToInvalidTensorType(self):
859    with self.assertRaises(TypeError):
860      # Forcing an invalid dtype should fail with a type error.
861      values = [1.23]
862      ops.convert_to_tensor(values, dtype=dtypes.int64)
863
864  @test_util.run_in_graph_and_eager_modes
865  def testConvertToLongLongTensorType(self):
866    tensor = ops.convert_to_tensor(
867        # Get a numpy array of dtype NPY_LONGLONG
868        np.prod(constant_op.constant([1])._shape_tuple()),
869        dtype=dtypes.int64)
870    self.assertEqual(dtypes.int64, tensor.dtype)
871
872  @test_util.run_in_graph_and_eager_modes
873  def testConvertToTensorFromInvalidTensor(self):
874    tensor = constant_op.constant(42.0, dtype=dtypes.float32)
875    with self.assertRaises(ValueError):
876      ops.convert_to_tensor(tensor, dtype=dtypes.int32)
877
878  @test_util.run_in_graph_and_eager_modes
879  def testConvertToTensorProtocol(self):
880    class TensorCompatible:
881
882      def __tf_tensor__(self, dtype=None, name=None):
883        return constant_op.constant((1, 2, 3), dtype=dtype, name=name)
884
885    tc = TensorCompatible()
886
887    tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
888    self.assertEqual(tensor.dtype, dtypes.int32)
889    self.assertAllEqual((1, 2, 3), self.evaluate(tensor))
890
891  @test_util.run_deprecated_v1
892  def testNoConvert(self):
893    # Operation cannot be converted to Tensor.
894    op = control_flow_ops.no_op()
895    with self.assertRaisesRegex(TypeError,
896                                "can't convert Operation '.+' to Tensor"):
897      ops.convert_to_tensor(op)
898
899  def testStr(self):
900    node_def = ops._NodeDef("None", "op1")
901    op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
902    self.assertEqual(str(node_def), str(op))
903
904  def testRepr(self):
905    op = ops.Operation(
906        ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
907    self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
908
909  @test_util.run_deprecated_v1
910  def testGetAttr(self):
911    op = test_ops.default_attrs()
912    self.assertEqual(op.get_attr("string_val"), b"abc")
913    self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
914    self.assertEqual(op.get_attr("int_val"), 123)
915    self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
916    self.assertEqual(op.get_attr("float_val"), 10.0)
917    self.assertEqual(op.get_attr("float_list_val"), [10.0])
918    self.assertEqual(op.get_attr("bool_val"), True)
919    self.assertEqual(op.get_attr("bool_list_val"), [True, False])
920    self.assertEqual(op.get_attr("shape_val"),
921                     tensor_shape.as_shape([2, 1]).as_proto())
922    self.assertEqual(op.get_attr("shape_list_val"),
923                     [tensor_shape.as_shape([]).as_proto(),
924                      tensor_shape.as_shape([1]).as_proto()])
925    self.assertEqual(op.get_attr("tensor_val"),
926                     tensor_util.make_tensor_proto(1, dtypes.int32))
927    self.assertEqual(op.get_attr("tensor_list_val"),
928                     [tensor_util.make_tensor_proto(1, dtypes.int32)])
929
930    type_val = op.get_attr("type_val")
931    # First check that type_val is a DType, because the assertEqual will work
932    # no matter what since DType overrides __eq__
933    self.assertIsInstance(type_val, dtypes.DType)
934    self.assertEqual(type_val, dtypes.int32)
935
936    type_list_val = op.get_attr("type_list_val")
937    self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
938    self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
939
940    @function.Defun(dtypes.float32, func_name="MyFunc")
941    def func(x):
942      return x
943
944    op = test_ops.func_attr(func)
945    self.assertEqual(op.get_attr("f"),
946                     attr_value_pb2.NameAttrList(name="MyFunc"))
947
948    # Try fetching missing attr
949    with self.assertRaisesRegex(
950        ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
951      op.get_attr("FakeAttr")
952
953  # TODO(b/65162920): remove this test when users who are directly mutating the
954  # node_def have been updated to proper usage.
955  @test_util.run_deprecated_v1
956  def testSetAttr(self):
957    op = test_ops.int_attr().op
958    op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
959    # TODO(skyewm): add node_def check
960    self.assertEqual(op.get_attr("foo"), 2)
961
962  @test_util.run_v2_only
963  def testSetFullType(self):
964    @def_function.function
965    def test_fn():
966      ds = dataset_ops.Dataset.range(3)._variant_tensor
967
968      ds.op.experimental_set_type(
969          full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_PRODUCT))
970
971      self.assertEqual(ds.op.node_def.experimental_type.type_id,
972                       full_type_pb2.TFT_PRODUCT)
973
974    test_fn()
975
976  # TODO(nolivia): test all error cases
977  def testAddControlInput(self):
978    with ops.Graph().as_default():
979      x = constant_op.constant(1).op
980      y = constant_op.constant(2).op
981      z = constant_op.constant(3).op
982    z._add_control_input(x)  # pylint: disable=protected-access
983    self.assertEqual(z.control_inputs, [x])
984    z._add_control_input(x)  # pylint: disable=protected-access
985    self.assertEqual(z.control_inputs, [x])
986    z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
987    self.assertEqual(z.control_inputs, [x, y])
988    self.assertEqual(x._control_outputs, [z])
989
990  @test_util.run_deprecated_v1
991  def testRemoveAllControlInputs(self):
992    a = constant_op.constant(1)
993    with ops.control_dependencies([a]):
994      b = constant_op.constant(2)
995    c = constant_op.constant(3)
996    d = constant_op.constant(4)
997    e = constant_op.constant(5)
998    with ops.control_dependencies([a, c]):
999      f = d + e
1000
1001    self.assertEqual(a.op.control_inputs, [])
1002    self.assertEqual(b.op.control_inputs, [a.op])
1003    self.assertEqual(f.op.control_inputs, [a.op, c.op])
1004
1005    a.op._remove_all_control_inputs()  # pylint: disable=protected-access
1006    self.assertEqual(a.op.control_inputs, [])
1007
1008    b.op._remove_all_control_inputs()  # pylint: disable=protected-access
1009    self.assertEqual(b.op.control_inputs, [])
1010
1011    f.op._remove_all_control_inputs()  # pylint: disable=protected-access
1012    self.assertEqual(f.op.control_inputs, [])
1013    self.assertEqual(list(f.op.inputs), [d, e])
1014
1015  @test_util.run_deprecated_v1
1016  def testControlInputCycle(self):
1017    graph = ops.Graph()
1018    with graph.as_default():
1019      z = constant_op.constant(0)
1020      x = constant_op.constant(1)
1021      y = constant_op.constant(2)
1022      y.op._add_control_input(z.op)  # pylint: disable=protected-access
1023      y.op._add_control_input(x.op)  # pylint: disable=protected-access
1024      x.op._add_control_input(y.op)  # pylint: disable=protected-access
1025    with self.session(graph=graph) as sess:
1026      with self.assertRaisesRegex(
1027          errors.InvalidArgumentError,
1028          "Graph is invalid, contains a cycle with 2 nodes"):
1029        self.evaluate(x)
1030
1031  def testUpdateInput(self):
1032    g = ops.Graph()
1033    with g.as_default():
1034      x = constant_op.constant(1)
1035      y = constant_op.constant(2)
1036      z = x + y
1037
1038    z.op._update_input(0, y)  # pylint: disable=protected-access
1039    self.assertEqual(list(z.op.inputs), [y, y])
1040    self.assertEqual(x.consumers(), [])
1041    self.assertEqual(y.consumers(), [z.op, z.op])
1042    with session.Session(graph=g) as sess:
1043      self.assertEqual(self.evaluate(z), 4)
1044
1045    z.op._update_input(0, x)  # pylint: disable=protected-access
1046    self.assertEqual(list(z.op.inputs), [x, y])
1047    self.assertEqual(x.consumers(), [z.op])
1048    self.assertEqual(y.consumers(), [z.op])
1049    with session.Session(graph=g) as sess:
1050      self.assertEqual(self.evaluate(z), 3)
1051
1052    z.op._update_input(1, y)  # pylint: disable=protected-access
1053    self.assertEqual(list(z.op.inputs), [x, y])
1054    self.assertEqual(x.consumers(), [z.op])
1055    self.assertEqual(y.consumers(), [z.op])
1056    with session.Session(graph=g) as sess:
1057      self.assertEqual(self.evaluate(z), 3)
1058
1059  def testUpdateInputGraphError(self):
1060    g_0 = ops.Graph()
1061    g_1 = ops.Graph()
1062    with g_0.as_default():
1063      x = constant_op.constant(1)
1064    with g_1.as_default():
1065      y = constant_op.constant(2)
1066      z = y * 2
1067      with self.assertRaisesRegex(ValueError, "must be from the same graph"):
1068        z.op._update_input(0, x)  # pylint: disable=protected-access
1069
1070  def testUpdateInputTypeError(self):
1071    g = ops.Graph()
1072    with g.as_default():
1073      w = constant_op.constant(0)
1074      x = constant_op.constant("")
1075      y = constant_op.constant(1)
1076      z = y + w
1077      z.op._update_input(0, x)  # pylint: disable=protected-access
1078    with session.Session(graph=g) as sess:
1079      with self.assertRaisesRegex(
1080          errors.InvalidArgumentError,
1081          "Input 0 of node add was passed string from Const_1:0 incompatible "
1082          "with expected int32"):
1083        self.evaluate(z)
1084
1085  def testUpdateInputShapeError(self):
1086    g = ops.Graph()
1087    with g.as_default():
1088      w = constant_op.constant(2, shape=[3, 1])
1089      x = constant_op.constant(0, shape=[3, 1])
1090      y = constant_op.constant(1, shape=[2, 2])
1091      z = w + x
1092    with self.assertRaisesRegex(
1093        errors.InvalidArgumentError,
1094        r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
1095      z.op._update_input(0, y)  # pylint: disable=protected-access
1096
1097  def testUpdateInputOutOfRange(self):
1098    g = ops.Graph()
1099    with g.as_default():
1100      x = constant_op.constant(1)
1101    with self.assertRaisesRegex(
1102        errors.OutOfRangeError,
1103        r"Cannot update edge. Input index \[1\] is greater than the number of "
1104        r"total inputs \[0\]."):
1105      x.op._update_input(1, x)  # pylint: disable=protected-access
1106
1107  @test_util.enable_control_flow_v2
1108  @test_util.run_v1_only("b/120545219")
1109  def testAddWhileInput(self):
1110
1111    @eager_function.defun
1112    def test():
1113      output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
1114                                           [1])
1115      while_op = output.op
1116      self.assertEqual(while_op.type, "StatelessWhile")
1117      orig_num_inputs = len(while_op.inputs)
1118
1119      # Make sure we can handle the while op having a control input.
1120      while_op._add_control_input(constant_op.constant(0).op)
1121
1122      new_input1 = constant_op.constant(1.0)
1123      new_input2 = constant_op.constant(True)
1124
1125      # Clear output shapes to bypass shape checking.
1126      while_op._set_shape_list_attr("output_shapes", [])
1127      while_op._set_type_list_attr("T", [t.dtype for t in while_op.inputs] +
1128                                   [new_input1.dtype, new_input2.dtype])
1129
1130      while_op._add_while_inputs([new_input1, new_input2])
1131      # Can't add an edge beyond what's specified by "T"
1132      with self.assertRaises(errors.OutOfRangeError):
1133        while_op._add_while_inputs([new_input2])
1134      self.assertLen(while_op.inputs, orig_num_inputs + 2)  # pylint: disable=g-deprecated-assert
1135
1136      test()
1137
1138  @test_util.run_deprecated_v1
1139  def testOpDef(self):
1140    x = constant_op.constant(0)
1141    y = constant_op.constant(1)
1142    z = x + y
1143
1144    self.assertEqual(x.op.op_def.name, "Const")
1145    self.assertLen(x.op.op_def.input_arg, 0)
1146    self.assertLen(x.op.op_def.output_arg, 1)
1147
1148    self.assertRegex(z.op.op_def.name, "Add(V2)?")
1149    self.assertLen(z.op.op_def.input_arg, 2)
1150    self.assertLen(z.op.op_def.output_arg, 1)
1151
1152  def testInputFromDifferentGraphError(self):
1153    g_0 = ops.Graph()
1154    g_1 = ops.Graph()
1155    with g_0.as_default():
1156      x = constant_op.constant(1)
1157    with g_1.as_default():
1158      y = constant_op.constant(2)
1159      with self.assertRaisesRegex(ValueError, "must be from the same graph"):
1160        y * x  # pylint: disable=pointless-statement
1161
1162  def testInputsAreImmutable(self):
1163    g = ops.Graph()
1164    with g.as_default():
1165      x = test_ops.int_output()
1166      op = test_ops.int_input_int_output(x, name="myop").op
1167    with self.assertRaisesRegex(AttributeError,
1168                                "'tuple' object has no attribute 'append'"):
1169      op.inputs.append(None)
1170
1171
1172class CreateOpTest(test_util.TensorFlowTestCase):
1173
1174  def testNodeDefArgs(self):
1175    g = ops.Graph()
1176    op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
1177    with g.device("/device:GPU:0"):
1178      op2 = g.create_op(
1179          "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
1180          name="myop2")
1181    op3 = g.create_op(
1182        "Foo3",
1183        [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
1184        [dtypes.float32, dtypes.int32],
1185        None,
1186        name="myop3")
1187    self.assertDeviceEqual(None, op1.device)
1188    self.assertDeviceEqual("/device:GPU:0", op2.device)
1189    self.assertDeviceEqual(None, op3.device)
1190    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
1191    self.assertProtoEquals(
1192        "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
1193        op2.node_def)
1194    self.assertProtoEquals(
1195        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
1196        op3.node_def)
1197
1198  def testReferenceInput(self):
1199    g = ops.Graph()
1200    op1 = g.create_op(
1201        "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
1202        name="op1")
1203    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
1204    ref_t, nonref_t = op1.values()
1205    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
1206    op2 = g.create_op(
1207        "RefInputFloatInput", [ref_t, nonref_t], [],
1208        input_types=[dtypes.float32_ref, dtypes.float32],
1209        name="op2")
1210    self.assertProtoEquals(
1211        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
1212        op2.node_def)
1213    op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
1214    self.assertProtoEquals(
1215        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
1216        op3.node_def)
1217
1218  def testFinalized(self):
1219    g = ops.Graph()
1220    g.finalize()
1221    with self.assertRaises(RuntimeError):
1222      g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
1223
1224    # Test unfinalize.
1225    g._unsafe_unfinalize()
1226    g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
1227
1228
1229# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
1230# method. Arguably we should only test the public APIs that depend on this
1231# method. However, this logic is complex and tricky, and it can be difficult to
1232# ascertain if we have adequate coverage (e.g. a graph may run successfully if
1233# the control flow context isn't set properly, but a more complicated use case
1234# that might not be obvious to test will fail). Thus we instead explicitly test
1235# the low-level behavior.
1236class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
1237
1238  @test_util.run_deprecated_v1
1239  def testBasic(self):
1240    g = ops.Graph()
1241    with g.as_default():
1242      x = test_ops.int_output()
1243      c_op = ops._create_c_op(
1244          g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
1245      op = g._create_op_from_tf_operation(c_op)
1246
1247    self.assertEqual(op.name, "myop")
1248    self.assertEqual(op.type, "IntInputIntOutput")
1249    self.assertLen(op.outputs, 1)
1250    self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
1251    self.assertEqual(list(op.inputs), [x])
1252    self.assertEqual(op.control_inputs, [])
1253    self.assertEqual(op.graph, g)
1254    self.assertEqual(x.consumers(), [op])
1255    self.assertIsNotNone(op.traceback)
1256    self.assertIn("testBasic", op.traceback[-1])
1257    self.assertEqual(g.get_operation_by_name("myop"), op)
1258    self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
1259
1260  def testShape(self):
1261    g = ops.Graph()
1262    with g.as_default():
1263      x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
1264      c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
1265      op = g._create_op_from_tf_operation(c_op)
1266
1267    self.assertEqual(op.name, "myop")
1268    self.assertEqual(op.type, "Identity")
1269    self.assertLen(op.outputs, 1)
1270    self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3]))
1271
1272  def testUniqueName(self):
1273    g = ops.Graph()
1274    with g.as_default():
1275      c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
1276      c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
1277      op = g._create_op_from_tf_operation(c_op)
1278      op2 = g._create_op_from_tf_operation(c_op2)
1279
1280      # Create ops with same names as op1 and op2. We expect the new names to be
1281      # uniquified.
1282      op3 = test_ops.int_output(name="myop").op
1283      op4 = test_ops.int_output(name="myop_1").op
1284
1285    self.assertEqual(op.name, "myop")
1286    self.assertEqual(op2.name, "myop_1")
1287    self.assertEqual(op3.name, "myop_2")
1288    self.assertEqual(op4.name, "myop_1_1")
1289
1290  @test_util.run_v1_only("b/120545219")
1291  def testCond(self):
1292    g = ops.Graph()
1293    with g.as_default():
1294      x = test_ops.int_output()
1295
1296      def true_fn():
1297        ops._create_c_op(ops.get_default_graph(),
1298                         ops._NodeDef("IntInput", "cond/myop"), [x], [])
1299        new_ops = g._add_new_tf_operations()
1300        self.assertLen(new_ops, 1)
1301        return x
1302
1303      control_flow_ops.cond(x < 10, true_fn, lambda: x)
1304
1305    op = g.get_operation_by_name("cond/myop")
1306    self.assertIsNotNone(op)
1307    self.assertEqual(op.name, "cond/myop")
1308    self.assertEqual(op.type, "IntInput")
1309    self.assertEqual(op.outputs, [])
1310    op_input = op.inputs[0].op
1311    self.assertEqual(op_input.type, "Switch")
1312    self.assertEqual(op_input.inputs[0], x)
1313    self.assertEqual(op.graph, g)
1314    # pylint: disable=protected-access
1315    self.assertIsNotNone(op._get_control_flow_context())
1316    self.assertEqual(op._get_control_flow_context().name,
1317                     "cond/cond_text")
1318    # pylint: enable=protected-access
1319
1320  @test_util.run_v1_only("b/120545219")
1321  def testWhileLoop(self):
1322    g = ops.Graph()
1323    with g.as_default():
1324      x = test_ops.int_output()
1325
1326      def body(i):
1327        ops._create_c_op(ops.get_default_graph(),
1328                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
1329        new_ops = g._add_new_tf_operations()
1330        self.assertLen(new_ops, 1)
1331        return i
1332
1333      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
1334
1335    op = g.get_operation_by_name("myloop/myop")
1336    self.assertIsNotNone(op)
1337    self.assertEqual(op.name, "myloop/myop")
1338    self.assertEqual(op.type, "IntInput")
1339    self.assertEqual(op.outputs, [])
1340    op_input = op.inputs[0].op
1341    self.assertEqual(op_input.type, "Enter")
1342    self.assertEqual(list(op_input.inputs), [x])
1343    self.assertEqual(op.graph, g)
1344    # pylint: disable=protected-access
1345    self.assertIsNotNone(op._get_control_flow_context())
1346    self.assertEqual(op._get_control_flow_context().name,
1347                     "myloop/while_context")
1348    # pylint: enable=protected-access
1349
1350  @test_util.run_v1_only("b/120545219")
1351  def testWhileLoopWithInternalControlDep(self):
1352    g = ops.Graph()
1353    with g.as_default():
1354      x = test_ops.int_output()
1355
1356      def body(i):
1357        c = constant_op.constant(1.0, name="c")
1358        ops._create_c_op(ops.get_default_graph(),
1359                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
1360        with ops.control_dependencies([c]):
1361          new_ops = g._add_new_tf_operations()
1362          self.assertLen(new_ops, 1)
1363        return i
1364
1365      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
1366
1367    op = g.get_operation_by_name("myloop/myop")
1368    self.assertIsNotNone(op)
1369    c = g.get_operation_by_name("myloop/c")
1370    self.assertIsNotNone(c)
1371    # Internal control dep is preserved
1372    self.assertEqual(op.control_inputs, [c])
1373
1374  @test_util.run_v1_only("b/120545219")
1375  def testWhileLoopWithExternalControlDep(self):
1376    g = ops.Graph()
1377    with g.as_default():
1378      x = test_ops.int_output()
1379      c = constant_op.constant(1.0)
1380
1381      def body(i):
1382        ops._create_c_op(ops.get_default_graph(),
1383                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
1384        with ops.control_dependencies([c]):
1385          new_ops = g._add_new_tf_operations()
1386          self.assertLen(new_ops, 1)
1387        return i
1388
1389      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
1390
1391    op = g.get_operation_by_name("myloop/myop")
1392    self.assertIsNotNone(op)
1393    # External control dep is removed and replaced with internal control dep
1394    self.assertNotEqual(op.control_inputs[0], c.op)
1395    self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
1396
1397
1398class ApplyOpTest(test_util.TensorFlowTestCase):
1399
1400  def testNodeDefArgs(self):
1401    g = ops.Graph()
1402    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
1403    with g.device("/device:GPU:0"):
1404      t2 = _apply_op(
1405          g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
1406    t3 = _apply_op(
1407        g,
1408        "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
1409        name="myop3")
1410    self.assertTrue(isinstance(t1, ops.Tensor))
1411    self.assertTrue(isinstance(t2, list))
1412    self.assertTrue(isinstance(t3, list))
1413    self.assertTrue(isinstance(t3[0], ops.Tensor))
1414    self.assertEqual("myop1", t1._as_node_def_input())
1415    self.assertEqual("myop2", t2[0]._as_node_def_input())
1416    self.assertEqual("myop2:1", t2[1]._as_node_def_input())
1417    self.assertEqual("myop3", t3[0]._as_node_def_input())
1418    # Validate that we got the right ops as well
1419    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
1420    self.assertProtoEquals(
1421        "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
1422        t2[0].op.node_def)
1423    self.assertProtoEquals(
1424        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
1425        t3[0].op.node_def)
1426
1427  def testReferenceInput(self):
1428    g = ops.Graph()
1429    ref_t, nonref_t = _apply_op(
1430        g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
1431        name="op1")
1432    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
1433                           ref_t.op.node_def)
1434    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
1435    out_2 = _apply_op(
1436        g,
1437        "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
1438        input_types=[dtypes.float32_ref, dtypes.float32],
1439        name="op2")
1440    self.assertProtoEquals(
1441        "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
1442        out_2.op.node_def)
1443    out_3 = _apply_op(
1444        g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
1445        name="op3")
1446    self.assertProtoEquals(
1447        "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
1448        out_3.op.node_def)
1449
1450
1451class NameStackTest(test_util.TensorFlowTestCase):
1452
1453  def testBasics(self):
1454    g = ops.Graph()
1455    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
1456    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
1457    self.assertEqual("foo", g.unique_name("foo"))
1458    self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
1459    self.assertEqual("foo_1", g.unique_name("foo"))
1460    self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
1461    self.assertEqual("foo_2", g.unique_name("foo"))
1462    self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
1463    self.assertEqual("foo_1_1", g.unique_name("foo_1"))
1464    self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
1465    self.assertEqual("foo_1_2", g.unique_name("foo_1"))
1466    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
1467    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
1468    with g.name_scope("bar"):
1469      self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
1470      self.assertEqual("bar/foo", g.unique_name("foo"))
1471      self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
1472      self.assertEqual("bar/foo_1", g.unique_name("foo"))
1473      with g.name_scope(None):
1474        self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
1475        self.assertEqual("foo_3", g.unique_name("foo"))
1476      with g.name_scope("baz"):
1477        self.assertEqual(
1478            "bar/baz/foo", g.unique_name(
1479                "foo", mark_as_used=False))
1480        self.assertEqual("bar/baz/foo", g.unique_name("foo"))
1481        self.assertEqual(
1482            "bar/baz/foo_1", g.unique_name(
1483                "foo", mark_as_used=False))
1484        self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
1485      with g.name_scope("baz"):
1486        self.assertEqual(
1487            "bar/baz_1/foo", g.unique_name(
1488                "foo", mark_as_used=False))
1489        self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
1490        self.assertEqual(
1491            "bar/baz_1/foo_1", g.unique_name(
1492                "foo", mark_as_used=False))
1493        self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
1494    with g.name_scope("quux"):
1495      self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
1496      self.assertEqual("quux/foo", g.unique_name("foo"))
1497    with g.name_scope("bar"):
1498      with g.name_scope("baz"):
1499        self.assertEqual(
1500            "bar_1/baz/foo", g.unique_name(
1501                "foo", mark_as_used=False))
1502        self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
1503    self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
1504    self.assertEqual("foo_4", g.unique_name("foo"))
1505    self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
1506    self.assertEqual("bar_2", g.unique_name("bar"))
1507
1508  def testBackslashAndDashRegex(self):
1509    # GitHub issue 39019, all should pass
1510    g = ops.Graph()
1511    with g.name_scope("n_CatCntc-campaign\\c_campaign"):
1512      pass
1513    with g.name_scope("foo"):
1514      with g.name_scope("n_CatCntc-campaign\\c_campaign"):
1515        pass
1516    with g.name_scope("n_CatCntc-campaign\\c_campaign"):
1517      with g.name_scope("foo"):
1518        pass
1519
1520  @test_util.run_deprecated_v1
1521  def testNameAndVariableScope(self):
1522    with self.cached_session() as sess:
1523      with sess.graph.name_scope("l0"):
1524        with variable_scope.variable_scope("l1"):
1525          with sess.graph.name_scope("l1") as scope:
1526            self.assertEqual("l0/l1/l1/", scope)
1527            self.assertEqual(
1528                "l0/l1/l1/foo",
1529                sess.graph.unique_name(
1530                    "foo", mark_as_used=False))
1531            self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
1532          with sess.graph.name_scope("l2") as scope:
1533            self.assertEqual("l0/l1/l2/", scope)
1534            self.assertEqual(
1535                "l0/l1/l2/foo",
1536                sess.graph.unique_name(
1537                    "foo", mark_as_used=False))
1538            self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
1539
1540  def testOutOfOrderUniqueName(self):
1541    g = ops.Graph()
1542    self.assertEqual("foo_2", g.unique_name("foo_2"))
1543    self.assertEqual("foo", g.unique_name("foo"))
1544    self.assertEqual("foo_1", g.unique_name("foo"))
1545    self.assertEqual("foo_3", g.unique_name("foo"))
1546
1547  def testUniqueNameCaseInsensitivity(self):
1548    g = ops.Graph()
1549    self.assertEqual("foo", g.unique_name("foo"))
1550    self.assertEqual("Foo_1", g.unique_name("Foo"))
1551    with g.name_scope("bar"):
1552      self.assertEqual("bar/foo", g.unique_name("foo"))
1553    with g.name_scope("Bar"):
1554      self.assertEqual("Bar_1/foo", g.unique_name("foo"))
1555
1556  def testInvalidNameRaisesError(self):
1557    g = ops.Graph()
1558    with g.name_scope(""):  # Should not raise
1559      pass
1560    with g.name_scope("foo/"):  # Should not raise
1561      with g.name_scope("_bar"):  # Should not raise
1562        pass
1563    with self.assertRaises(ValueError):
1564      with g.name_scope("foo:0"):
1565        pass
1566    with self.assertRaises(ValueError):
1567      with g.name_scope("_bar"):
1568        pass
1569
1570  def testEmptyScopeEdgeCases(self):
1571    g = ops.Graph()
1572    self.assertEqual("", g.get_name_scope())
1573    with g.name_scope("") as scope:
1574      self.assertEqual("", scope)
1575      self.assertEqual("", g.get_name_scope())
1576    with g.name_scope(None) as scope:
1577      self.assertEqual("", scope)
1578      self.assertEqual("", g.get_name_scope())
1579    with g.name_scope("foo") as scope:
1580      self.assertEqual("foo/", scope)
1581      self.assertEqual("foo", g.get_name_scope())
1582      with g.name_scope("") as scope:
1583        self.assertEqual("", scope)
1584        self.assertEqual("", g.get_name_scope())
1585      with g.name_scope(None) as scope:
1586        self.assertEqual("", scope)
1587        self.assertEqual("", g.get_name_scope())
1588
1589
1590class NameTest(test_util.TensorFlowTestCase):
1591
1592  def testGenerateName(self):
1593    g = ops.Graph()
1594    op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
1595    self.assertEqual("TwoFloatOutputs", op0.name)
1596    self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
1597    self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
1598
1599    op1 = g.create_op("FloatOutput", [], [dtypes.float32])
1600    self.assertEqual("FloatOutput", op1.name)
1601    self.assertEqual("FloatOutput:0", op1.outputs[0].name)
1602
1603    op2 = g.create_op("FloatOutput", [], [dtypes.float32])
1604    self.assertEqual("FloatOutput_1", op2.name)
1605    self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
1606
1607    op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
1608    self.assertEqual("my_op", op3.name)
1609    self.assertEqual("my_op:0", op3.outputs[0].name)
1610
1611  def testNameScope(self):
1612    g = ops.Graph()
1613
1614    with g.name_scope("foo") as foo:
1615      self.assertEqual("foo/", foo)
1616      with g.name_scope("foo2") as foo2:
1617        self.assertEqual("foo/foo2/", foo2)
1618      with g.name_scope(None) as empty1:
1619        self.assertEqual("", empty1)
1620        with g.name_scope("foo3") as foo3:
1621          self.assertEqual("foo3/", foo3)
1622      with g.name_scope("") as empty2:
1623        self.assertEqual("", empty2)
1624
1625    self.assertEqual("FloatOutput",
1626                     g.create_op("FloatOutput", [], [dtypes.float32]).name)
1627    with g.name_scope("bar") as scope:
1628      self.assertEqual("bar/FloatOutput",
1629                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1630      self.assertEqual("bar/FloatOutput_1",
1631                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1632      # If you use the value from "with .. as", that values is used as-is.
1633      self.assertEqual(
1634          "bar", g.create_op(
1635              "FloatOutput", [], [dtypes.float32], name=scope).name)
1636    with g.name_scope("baz") as scope:
1637      with g.name_scope("quux"):
1638        self.assertEqual("baz/quux/FloatOutput",
1639                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
1640      # If you use the value from the enclosing "with .. as", nothing is pushed.
1641      with g.name_scope(scope):
1642        self.assertEqual("baz/FloatOutput",
1643                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
1644        self.assertEqual(
1645            "baz", g.create_op(
1646                "FloatOutput", [], [dtypes.float32], name=scope).name)
1647        self.assertEqual(
1648            "trailing",
1649            g.create_op(
1650                "FloatOutput", [], [dtypes.float32], name="trailing/").name)
1651    with g.name_scope("bar"):
1652      self.assertEqual("bar_1/FloatOutput",
1653                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1654    with g.name_scope("bar/"):
1655      self.assertEqual("bar/FloatOutput_2",
1656                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1657
1658
1659class DeviceTest(test_util.TensorFlowTestCase):
1660
1661  def testNoDevice(self):
1662    g = ops.Graph()
1663    op = g.create_op("FloatOutput", [], [dtypes.float32])
1664    self.assertDeviceEqual(None, op.device)
1665    gd = g.as_graph_def()
1666    self.assertProtoEqualsVersion("""
1667      node { name: "FloatOutput" op: "FloatOutput" }
1668    """, gd)
1669
1670  def testEagerBackingDevice(self):
1671    with context.eager_mode():
1672      with ops.device("/device:CPU:0"):
1673        t = constant_op.constant(1.0)
1674        self.assertRegex(t.device, "/device:CPU:0")
1675        self.assertRegex(t.backing_device, "/device:CPU:0")
1676
1677  def testDevicePartialString(self):
1678    g = ops.Graph()
1679    with g.device("/job:worker/replica:2"):
1680      g.create_op("FloatOutput", [], [dtypes.float32])
1681    gd = g.as_graph_def()
1682    self.assertProtoEqualsVersion("""
1683      node { name: "FloatOutput" op: "FloatOutput"
1684             device: "/job:worker/replica:2" }
1685    """, gd)
1686
1687  def testDeviceFull(self):
1688    g = ops.Graph()
1689    with g.device(
1690        pydev.DeviceSpec(
1691            job="worker", replica=2, task=0, device_type="CPU",
1692            device_index=3)):
1693      g.create_op("FloatOutput", [], [dtypes.float32])
1694    gd = g.as_graph_def()
1695    self.assertProtoEqualsVersion("""
1696      node { name: "FloatOutput" op: "FloatOutput"
1697             device: "/job:worker/replica:2/task:0/device:CPU:3" }
1698    """, gd)
1699
1700  def testNesting(self):
1701    g = ops.Graph()
1702    with g.device("/job:worker/replica:2"):
1703      g.create_op("FloatOutput", [], [dtypes.float32])
1704      with g.device("/job:worker/replica:3/task:0"):
1705        g.create_op("FloatOutput", [], [dtypes.float32])
1706      g.create_op("FloatOutput", [], [dtypes.float32])
1707    gd = g.as_graph_def()
1708    self.assertProtoEqualsVersion("""
1709      node { name: "FloatOutput" op: "FloatOutput"
1710             device: "/job:worker/replica:2" }
1711      node { name: "FloatOutput_1" op: "FloatOutput"
1712             device: "/job:worker/replica:3/task:0" }
1713      node { name: "FloatOutput_2" op: "FloatOutput"
1714             device: "/job:worker/replica:2" }
1715    """, gd)
1716
1717  def testNestingString(self):
1718    g = ops.Graph()
1719    with g.device("/job:worker/replica:2"):
1720      g.create_op("FloatOutput", [], [dtypes.float32])
1721      with g.device("/job:worker/replica:3/task:0"):
1722        g.create_op("FloatOutput", [], [dtypes.float32])
1723      g.create_op("FloatOutput", [], [dtypes.float32])
1724    gd = g.as_graph_def()
1725    self.assertProtoEqualsVersion("""
1726      node { name: "FloatOutput" op: "FloatOutput"
1727             device: "/job:worker/replica:2" }
1728      node { name: "FloatOutput_1" op: "FloatOutput"
1729             device: "/job:worker/replica:3/task:0" }
1730      node { name: "FloatOutput_2" op: "FloatOutput"
1731             device: "/job:worker/replica:2" }
1732    """, gd)
1733
1734  def testNestingOverrideGpuCpu(self):
1735    g = ops.Graph()
1736    with g.device("/job:worker/replica:2/device:CPU:1"):
1737      g.create_op("FloatOutput", [], [dtypes.float32])
1738      with g.device("/job:worker/replica:2/device:GPU:2"):
1739        g.create_op("FloatOutput", [], [dtypes.float32])
1740      g.create_op("FloatOutput", [], [dtypes.float32])
1741    gd = g.as_graph_def()
1742    self.assertProtoEqualsVersion("""
1743      node { name: "FloatOutput" op: "FloatOutput"
1744             device: "/job:worker/replica:2/device:CPU:1"  }
1745      node { name: "FloatOutput_1" op: "FloatOutput"
1746             device: "/job:worker/replica:2/device:GPU:2" }
1747      node { name: "FloatOutput_2" op: "FloatOutput"
1748             device: "/job:worker/replica:2/device:CPU:1" }
1749    """, gd)
1750
1751  def testNestingWithMergeDeviceFunction(self):
1752    g = ops.Graph()
1753
1754    with g.device(pydev.merge_device("/device:GPU:0")):
1755      g.create_op("FloatOutput", [], [dtypes.float32])
1756      with g.device(pydev.merge_device("/job:worker")):
1757        g.create_op("FloatOutput", [], [dtypes.float32])
1758        with g.device(pydev.merge_device("/device:CPU:0")):
1759          g.create_op("FloatOutput", [], [dtypes.float32])
1760          with g.device(pydev.merge_device("/job:ps")):
1761            g.create_op("FloatOutput", [], [dtypes.float32])
1762            with g.device(pydev.merge_device(None)):
1763              g.create_op("FloatOutput", [], [dtypes.float32])
1764
1765    gd = g.as_graph_def()
1766    self.assertProtoEqualsVersion("""
1767      node { name: "FloatOutput" op: "FloatOutput"
1768             device: "/device:GPU:0" }
1769      node { name: "FloatOutput_1" op: "FloatOutput"
1770             device: "/job:worker/device:GPU:0" }
1771      node { name: "FloatOutput_2" op: "FloatOutput"
1772             device: "/job:worker/device:CPU:0" }
1773      node { name: "FloatOutput_3" op: "FloatOutput"
1774             device: "/job:ps/device:CPU:0" }
1775      node { name: "FloatOutput_4" op: "FloatOutput"
1776             device: "/job:ps/device:CPU:0" }
1777    """, gd)
1778
1779  def testNestingWithDeviceStrings(self):
1780    g = ops.Graph()
1781
1782    with g.device("/device:GPU:0"):
1783      g.create_op("FloatOutput", [], [dtypes.float32])
1784      with g.device("/job:worker"):
1785        g.create_op("FloatOutput", [], [dtypes.float32])
1786        with g.device("/device:CPU:0"):
1787          g.create_op("FloatOutput", [], [dtypes.float32])
1788          with g.device("/job:ps"):
1789            g.create_op("FloatOutput", [], [dtypes.float32])
1790            with g.device(""):
1791              g.create_op("FloatOutput", [], [dtypes.float32])
1792
1793    gd = g.as_graph_def()
1794    self.assertProtoEqualsVersion("""
1795      node { name: "FloatOutput" op: "FloatOutput"
1796             device: "/device:GPU:0" }
1797      node { name: "FloatOutput_1" op: "FloatOutput"
1798             device: "/job:worker/device:GPU:0" }
1799      node { name: "FloatOutput_2" op: "FloatOutput"
1800             device: "/job:worker/device:CPU:0" }
1801      node { name: "FloatOutput_3" op: "FloatOutput"
1802             device: "/job:ps/device:CPU:0" }
1803      node { name: "FloatOutput_4" op: "FloatOutput"
1804             device: "/job:ps/device:CPU:0" }
1805    """, gd)
1806
1807  def testNestingWithDeviceStringWildcard(self):
1808    g = ops.Graph()
1809
1810    with g.device("/device:GPU:7"):
1811      g.create_op("FloatOutput", [], [dtypes.float32])
1812      with g.device("/device:GPU:*"):
1813        g.create_op("FloatOutput", [], [dtypes.float32])
1814
1815    with g.device("/device:CPU:*"):
1816      g.create_op("FloatOutput", [], [dtypes.float32])
1817      with g.device("/device:CPU:5"):
1818        g.create_op("FloatOutput", [], [dtypes.float32])
1819
1820    gd = g.as_graph_def()
1821    self.assertProtoEqualsVersion("""
1822      node { name: "FloatOutput" op: "FloatOutput"
1823             device: "/device:GPU:7" }
1824      node { name: "FloatOutput_1" op: "FloatOutput"
1825             device: "/device:GPU:7" }
1826      node { name: "FloatOutput_2" op: "FloatOutput"
1827             device: "/device:CPU:*" }
1828      node { name: "FloatOutput_3" op: "FloatOutput"
1829             device: "/device:CPU:5" }
1830    """, gd)
1831
1832  def testNestingErrorGraph(self):
1833    g = ops.Graph()
1834    scope = g.device("/device:GPU:8")
1835    scope.__enter__()
1836    with g.device("/device:GPU:9"):
1837      with self.assertRaises(RuntimeError):
1838        scope.__exit__(None, None, None)
1839
1840  def testNestingErrorEager(self):
1841    with context.eager_mode():
1842      scope = ops.device("/device:CPU:0")
1843      scope.__enter__()
1844      with ops.device(None):
1845        with self.assertRaises(RuntimeError):
1846          scope.__exit__(None, None, None)
1847
1848  def testNoneClearsDefault(self):
1849    g = ops.Graph()
1850    with g.device("/job:worker/replica:2/device:CPU:1"):
1851      g.create_op("FloatOutput", [], [dtypes.float32])
1852      with g.device(None):
1853        g.create_op("FloatOutput", [], [dtypes.float32])
1854      g.create_op("FloatOutput", [], [dtypes.float32])
1855    gd = g.as_graph_def()
1856    self.assertProtoEqualsVersion("""
1857      node { name: "FloatOutput" op: "FloatOutput"
1858             device: "/job:worker/replica:2/device:CPU:1" }
1859      node { name: "FloatOutput_1" op: "FloatOutput" }
1860      node { name: "FloatOutput_2" op: "FloatOutput"
1861             device: "/job:worker/replica:2/device:CPU:1" }
1862    """, gd)
1863
1864  def testNoneIgnoresOuterDeviceFunction(self):
1865    g = ops.Graph()
1866    with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
1867      g.create_op("FloatOutput", [], [dtypes.float32])
1868      with g.device(None):
1869        g.create_op("FloatOutput", [], [dtypes.float32])
1870      g.create_op("FloatOutput", [], [dtypes.float32])
1871    gd = g.as_graph_def()
1872    self.assertProtoEqualsVersion("""
1873      node { name: "FloatOutput" op: "FloatOutput"
1874             device: "/job:worker/replica:2/device:CPU:1" }
1875      node { name: "FloatOutput_1" op: "FloatOutput" }
1876      node { name: "FloatOutput_2" op: "FloatOutput"
1877             device: "/job:worker/replica:2/device:CPU:1" }
1878    """, gd)
1879
1880  def _overwritingDeviceFunction(self, unused_op):
1881    # This device function unconditionally overwrites the device of ops.
1882    #
1883    # NOTE(mrry): Writing device functions like this is not
1884    # recommended. Instead, in most cases you should use
1885    # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
1886    # argument to `tf.device()` and the device component will be merged in.
1887    return "/job:overwrite"
1888
1889  def testOverwritingBehavior(self):
1890    g = ops.Graph()
1891    with g.device(self._overwritingDeviceFunction):
1892      g.create_op("FloatOutput", [], [dtypes.float32])
1893      with g.device("/job:ps"):  # Will be overwritten.
1894        g.create_op("FloatOutput", [], [dtypes.float32])
1895      with g.device(pydev.merge_device("/job:ps")):  # Will be overwritten.
1896        g.create_op("FloatOutput", [], [dtypes.float32])
1897      with g.device(None):  # Disables overwriting device function
1898        with g.device("/job:ps"):
1899          g.create_op("FloatOutput", [], [dtypes.float32])
1900      with g.device(None):  # Disables overwriting device function
1901        with g.device(pydev.merge_device("/job:ps")):
1902          g.create_op("FloatOutput", [], [dtypes.float32])
1903    gd = g.as_graph_def()
1904    self.assertProtoEqualsVersion("""
1905      node { name: "FloatOutput" op: "FloatOutput"
1906             device: "/job:overwrite" }
1907      node { name: "FloatOutput_1" op: "FloatOutput"
1908             device: "/job:overwrite" }
1909      node { name: "FloatOutput_2" op: "FloatOutput"
1910             device: "/job:overwrite" }
1911      node { name: "FloatOutput_3" op: "FloatOutput"
1912             device: "/job:ps" }
1913      node { name: "FloatOutput_4" op: "FloatOutput"
1914             device: "/job:ps" }
1915    """, gd)
1916
1917
1918class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
1919
1920  class TestThread(threading.Thread):
1921
1922    def __init__(self, graph, replica_id):
1923      super(MultithreadedGraphStateTest.TestThread, self).__init__()
1924      self._graph = graph
1925      self._replica_id = replica_id
1926      # This thread sets this event when it mutated the graph.  The caller can
1927      # wait for that.
1928      self.has_mutated_graph = threading.Event()
1929      # This thread waits for when it should continue.  The caller can set this
1930      # event.
1931      self.should_continue = threading.Event()
1932
1933    def run(self):
1934      # Mutate a graph's stack, then set `has_mutated_graph`, then wait for
1935      # `should_continue`, then add an op to the graph affected by the graph's
1936      # stack.
1937      raise NotImplementedError("must be implemented in descendants")
1938
1939  def testDeviceFunctionStack(self):
1940
1941    class DeviceSettingThread(self.TestThread):
1942
1943      def run(self):
1944        with g.device("/job:worker/replica:{}".format(self._replica_id)):
1945          self.has_mutated_graph.set()
1946          self.should_continue.wait()
1947          self.should_continue.clear()
1948          g.create_op(
1949              "FloatOutput", [], [dtypes.float32],
1950              name="FloatOutput_{}".format(self._replica_id))
1951
1952    g = ops.Graph()
1953    # If `switch_to_thread` isn't called, then device placement of the ops
1954    # below is not deterministic.
1955    g.switch_to_thread_local()
1956    threads = [DeviceSettingThread(g, i) for i in range(3)]
1957    for t in threads:
1958      t.start()
1959      t.has_mutated_graph.wait()
1960      t.has_mutated_graph.clear()
1961    for t in threads:
1962      t.should_continue.set()
1963      t.join()
1964
1965    gd = g.as_graph_def()
1966    self.assertProtoEqualsVersion("""
1967      node { name: "FloatOutput_0" op: "FloatOutput"
1968             device: "/job:worker/replica:0" }
1969      node { name: "FloatOutput_1" op: "FloatOutput"
1970             device: "/job:worker/replica:1" }
1971      node { name: "FloatOutput_2" op: "FloatOutput"
1972             device: "/job:worker/replica:2" }
1973    """, gd)
1974
1975  def testColocateWith(self):
1976
1977    class ColocatingThread(self.TestThread):
1978
1979      def __init__(self, graph, replica_id, op_to_colocate_with):
1980        super(ColocatingThread, self).__init__(graph, replica_id)
1981        self._op_to_colocate_with = op_to_colocate_with
1982
1983      def run(self):
1984        with g.colocate_with(self._op_to_colocate_with):
1985          self.has_mutated_graph.set()
1986          self.should_continue.wait()
1987          self.should_continue.clear()
1988          g.create_op(
1989              "FloatOutput", [], [dtypes.float32],
1990              name="FloatOutput_{}".format(self._replica_id))
1991
1992    g = ops.Graph()
1993    ops_to_colocate_with = []
1994    for i in range(3):
1995      with g.device("/job:worker/replica:{}".format(i)):
1996        ops_to_colocate_with.append(
1997            g.create_op(
1998                "FloatOutput", [], [dtypes.float32],
1999                name="ColocateWithMe_{}".format(i)))
2000
2001    # If `switch_to_thread` isn't called, then `device` and `attr` values for
2002    # the ops below are not deterministic.
2003    g.switch_to_thread_local()
2004    threads = [
2005        ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
2006    ]
2007    for t in threads:
2008      t.start()
2009      t.has_mutated_graph.wait()
2010      t.has_mutated_graph.clear()
2011    for t in threads:
2012      t.should_continue.set()
2013      t.join()
2014
2015    gd = g.as_graph_def()
2016    self.assertProtoEqualsVersion("""
2017      node { name: "ColocateWithMe_0" op: "FloatOutput"
2018             device: "/job:worker/replica:0" }
2019      node { name: "ColocateWithMe_1" op: "FloatOutput"
2020             device: "/job:worker/replica:1" }
2021      node { name: "ColocateWithMe_2" op: "FloatOutput"
2022             device: "/job:worker/replica:2" }
2023      node { name: "FloatOutput_0" op: "FloatOutput"
2024             device: "/job:worker/replica:0"
2025             attr { key: "_class"
2026               value { list {
2027                 s: "loc:@ColocateWithMe_0"}}}}
2028      node { name: "FloatOutput_1" op: "FloatOutput"
2029             device: "/job:worker/replica:1"
2030             attr { key: "_class"
2031               value { list {
2032                 s: "loc:@ColocateWithMe_1"}}}}
2033      node { name: "FloatOutput_2" op: "FloatOutput"
2034             device: "/job:worker/replica:2"
2035             attr { key: "_class"
2036               value { list {
2037                 s: "loc:@ColocateWithMe_2"}}}}
2038    """, gd)
2039
2040  def testControlDependencies(self):
2041
2042    class DependingThread(self.TestThread):
2043
2044      def __init__(self, graph, replica_id, dependency_op):
2045        super(DependingThread, self).__init__(graph, replica_id)
2046        self._dependency_op = dependency_op
2047
2048      def run(self):
2049        with g.control_dependencies([self._dependency_op]):
2050          self.has_mutated_graph.set()
2051          self.should_continue.wait()
2052          self.should_continue.clear()
2053          g.create_op(
2054              "FloatOutput", [], [dtypes.float32],
2055              name="FloatOutput_{}".format(self._replica_id))
2056
2057    g = ops.Graph()
2058    dependency_ops = []
2059    for i in range(3):
2060      dependency_ops.append(
2061          g.create_op(
2062              "FloatOutput", [], [dtypes.float32],
2063              name="ColocateWithMe_{}".format(i)))
2064
2065    # If `switch_to_thread` isn't called, then `input` values for the ops below
2066    # are not deterministic.
2067    g.switch_to_thread_local()
2068    threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
2069    for t in threads:
2070      t.start()
2071      t.has_mutated_graph.wait()
2072      t.has_mutated_graph.clear()
2073    for t in threads:
2074      t.should_continue.set()
2075      t.join()
2076
2077    gd = g.as_graph_def()
2078    self.assertProtoEqualsVersion(
2079        """
2080      node { name: "ColocateWithMe_0" op: "FloatOutput"
2081             attr { key: "_has_manual_control_dependencies"
2082                    value { b: true } } }
2083      node { name: "ColocateWithMe_1" op: "FloatOutput"
2084             attr { key: "_has_manual_control_dependencies"
2085                    value { b: true } } }
2086      node { name: "ColocateWithMe_2" op: "FloatOutput"
2087             attr { key: "_has_manual_control_dependencies"
2088                    value { b: true } } }
2089      node { name: "FloatOutput_0" op: "FloatOutput"
2090             input: "^ColocateWithMe_0" }
2091      node { name: "FloatOutput_1" op: "FloatOutput"
2092             input: "^ColocateWithMe_1" }
2093      node { name: "FloatOutput_2" op: "FloatOutput"
2094             input: "^ColocateWithMe_2" }
2095    """, gd)
2096
2097  def testNameStack(self):
2098
2099    class NameSettingThread(self.TestThread):
2100
2101      def run(self):
2102        with g.name_scope("foo"):
2103          op1 = g.create_op("FloatOutput", [], [dtypes.float32])
2104          self.has_mutated_graph.set()
2105          self.should_continue.wait()
2106          self.should_continue.clear()
2107          op2 = g.create_op("FloatOutput", [], [dtypes.float32])
2108          self.result = (op1, op2)
2109
2110    g = ops.Graph()
2111    threads = [NameSettingThread(g, i) for i in range(3)]
2112    for t in threads:
2113      t.start()
2114      t.has_mutated_graph.wait()
2115      t.has_mutated_graph.clear()
2116
2117    for t in threads:
2118      t.should_continue.set()
2119      t.join()
2120
2121    suffixes = ["", "_1", "_2"]
2122    for t, s in zip(threads, suffixes):
2123      self.assertEqual("foo" + s + "/FloatOutput", t.result[0].name)
2124      self.assertEqual("foo" + s + "/FloatOutput_1", t.result[1].name)
2125
2126
2127class ObjectWithName(object):
2128
2129  def __init__(self, name):
2130    self._name = name
2131
2132  @property
2133  def name(self):
2134    return self._name
2135
2136
2137class CollectionTest(test_util.TensorFlowTestCase):
2138
2139  def test_get_collections(self):
2140    g = ops.Graph()
2141    self.assertSequenceEqual(g.collections, [])
2142    g.add_to_collection("key", 12)
2143    g.add_to_collection("key", 15)
2144    self.assertSequenceEqual(g.collections, ["key"])
2145    g.add_to_collection("other", "foo")
2146    self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
2147    self.assertSequenceEqual(
2148        sorted(g.get_all_collection_keys()), ["key", "other"])
2149
2150  def test_add_to_collection(self):
2151    g = ops.Graph()
2152    g.add_to_collection("key", 12)
2153    g.add_to_collection("other", "foo")
2154    g.add_to_collection("key", 34)
2155
2156    # Note that only blank1 is returned.
2157    g.add_to_collection("blah", 27)
2158    blank1 = ObjectWithName("prefix/foo")
2159    g.add_to_collection("blah", blank1)
2160    blank2 = ObjectWithName("junk/foo")
2161    g.add_to_collection("blah", blank2)
2162
2163    self.assertEqual([12, 34], g.get_collection("key"))
2164    self.assertEqual([], g.get_collection("nothing"))
2165    self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
2166    self.assertEqual([blank1], g.get_collection("blah", "prefix"))
2167    self.assertEqual([blank1], g.get_collection("blah", ".*x"))
2168
2169    # Make sure that get_collection() returns a first-level
2170    # copy of the collection, while get_collection_ref() returns
2171    # the original list.
2172    other_collection_snapshot = g.get_collection("other")
2173    other_collection_ref = g.get_collection_ref("other")
2174    self.assertEqual(["foo"], other_collection_snapshot)
2175    self.assertEqual(["foo"], other_collection_ref)
2176    g.add_to_collection("other", "bar")
2177    self.assertEqual(["foo"], other_collection_snapshot)
2178    self.assertEqual(["foo", "bar"], other_collection_ref)
2179    self.assertEqual(["foo", "bar"], g.get_collection("other"))
2180    self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
2181
2182    # Verify that getting an empty collection ref returns a modifiable list.
2183    empty_coll_ref = g.get_collection_ref("empty")
2184    self.assertEqual([], empty_coll_ref)
2185    empty_coll = g.get_collection("empty")
2186    self.assertEqual([], empty_coll)
2187    self.assertFalse(empty_coll is empty_coll_ref)
2188    empty_coll_ref2 = g.get_collection_ref("empty")
2189    self.assertTrue(empty_coll_ref2 is empty_coll_ref)
2190    # Add to the collection.
2191    empty_coll_ref.append("something")
2192    self.assertEqual(["something"], empty_coll_ref)
2193    self.assertEqual(["something"], empty_coll_ref2)
2194    self.assertEqual([], empty_coll)
2195    self.assertEqual(["something"], g.get_collection("empty"))
2196    empty_coll_ref3 = g.get_collection_ref("empty")
2197    self.assertTrue(empty_coll_ref3 is empty_coll_ref)
2198
2199  def test_add_to_collections_uniquify(self):
2200    g = ops.Graph()
2201    g.add_to_collections([1, 2, 1], "key")
2202    # Make sure "key" is not added twice
2203    self.assertEqual(["key"], g.get_collection(1))
2204
2205  def test_add_to_collections_from_list(self):
2206    g = ops.Graph()
2207    g.add_to_collections(["abc", "123"], "key")
2208    self.assertEqual(["key"], g.get_collection("abc"))
2209    self.assertEqual(["key"], g.get_collection("123"))
2210
2211  def test_add_to_collections_from_tuple(self):
2212    g = ops.Graph()
2213    g.add_to_collections(("abc", "123"), "key")
2214    self.assertEqual(["key"], g.get_collection("abc"))
2215    self.assertEqual(["key"], g.get_collection("123"))
2216
2217  def test_add_to_collections_from_generator(self):
2218    g = ops.Graph()
2219
2220    def generator():
2221      yield "abc"
2222      yield "123"
2223
2224    g.add_to_collections(generator(), "key")
2225    self.assertEqual(["key"], g.get_collection("abc"))
2226    self.assertEqual(["key"], g.get_collection("123"))
2227
2228  def test_add_to_collections_from_set(self):
2229    g = ops.Graph()
2230    g.add_to_collections(set(["abc", "123"]), "key")
2231    self.assertEqual(["key"], g.get_collection("abc"))
2232    self.assertEqual(["key"], g.get_collection("123"))
2233
2234  def test_add_to_collections_from_string(self):
2235    g = ops.Graph()
2236    g.add_to_collections("abc", "key")
2237    self.assertEqual(["key"], g.get_collection("abc"))
2238
2239  def test_default_graph(self):
2240    with ops.Graph().as_default():
2241      ops.add_to_collection("key", 90)
2242      ops.add_to_collection("key", 100)
2243      # Collections are ordered.
2244      self.assertEqual([90, 100], ops.get_collection("key"))
2245
2246  def test_defun(self):
2247    with context.eager_mode():
2248
2249      @eager_function.defun
2250      def defun():
2251        ops.add_to_collection("int", 1)
2252        ops.add_to_collection("tensor", constant_op.constant(2))
2253
2254        @eager_function.defun
2255        def inner_defun():
2256          self.assertEqual(ops.get_collection("int"), [1])
2257          three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
2258          ops.add_to_collection("int", 2)
2259          self.assertEqual(ops.get_collection("int"), [1, 2])
2260          ops.add_to_collection("foo", "bar")
2261          self.assertEqual(ops.get_collection("foo"), ["bar"])
2262          return three
2263
2264        self.assertEqual(ops.get_collection("int"), [1])
2265        three = inner_defun()
2266        self.assertEqual(ops.get_collection("int"), [1])
2267        self.assertEqual(ops.get_collection("foo"), [])
2268        return three
2269
2270      three = defun()
2271      self.assertEqual(three.numpy(), 3)
2272
2273
2274ops.NotDifferentiable("FloatOutput")
2275
2276
2277@ops.RegisterGradient("CopyOp")
2278def _CopyGrad(op, x_grad):  # pylint: disable=invalid-name
2279  _ = op
2280  return x_grad
2281
2282
2283@ops.RegisterGradient("copy_override")
2284def _CopyOverrideGrad(op, x_grad):  # pylint: disable=invalid-name
2285  _ = op
2286  return x_grad
2287
2288
2289class RegistrationTest(test_util.TensorFlowTestCase):
2290
2291  @test_util.run_deprecated_v1
2292  def testRegisterGradients(self):
2293    x = test_ops.float_output()
2294    y = test_ops.copy_op(x)
2295    fn = ops.get_gradient_function(y.op)
2296    self.assertEqual(_CopyGrad, fn)
2297
2298  def testOverrideGradients(self):
2299    g = ops.Graph()
2300    with g.as_default():
2301      x = test_ops.float_output()
2302      with g.gradient_override_map({"CopyOp": "copy_override"}):
2303        y = test_ops.copy_op(x)
2304      fn = ops.get_gradient_function(y.op)
2305      self.assertEqual(_CopyOverrideGrad, fn)
2306
2307  def testNonExistentOverride(self):
2308    g = ops.Graph()
2309    with g.as_default():
2310      x = test_ops.float_output()
2311      with g.gradient_override_map({"CopyOp": "unknown_override"}):
2312        y = test_ops.copy_op(x)
2313      with self.assertRaisesRegex(LookupError, "unknown_override"):
2314        ops.get_gradient_function(y.op)
2315
2316
2317class ComparisonTest(test_util.TensorFlowTestCase):
2318
2319  def testMembershipAllowed(self):
2320    g = ops.Graph()
2321    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
2322    t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
2323    self.assertTrue(isinstance(t1, ops.Tensor))
2324    self.assertTrue(isinstance(t2, ops.Tensor))
2325    self.assertTrue(t1 in [t1])
2326    self.assertTrue(t1 not in [t2])
2327
2328
2329class ControlDependenciesTest(test_util.TensorFlowTestCase):
2330
2331  @test_util.run_deprecated_v1
2332  def testBasic(self):
2333    g = ops.Graph()
2334    with g.as_default():
2335      # Creating unregistered ops with _apply_op() doesn't work with the C API
2336      # TODO(skyewm): address this more consistently. Possible solutions are
2337      # to use registered ops in all tests, create a way to register ops in
2338      # Python tests, or conditionally disable the op registration check in
2339      # the C API.
2340      a = constant_op.constant(1.0)
2341      b = constant_op.constant(1.0)
2342      with g.control_dependencies([a]):
2343        c = constant_op.constant(1.0)
2344        d = array_ops.identity(b)
2345        e = array_ops.identity(c)
2346
2347    self.assertEqual(c.op.control_inputs, [a.op])
2348    self.assertEqual(d.op.control_inputs, [a.op])
2349    # e should be dominated by c.
2350    self.assertEqual(e.op.control_inputs, [])
2351
2352  @test_util.run_in_graph_and_eager_modes
2353  def testEager(self):
2354    def future():
2355      future.calls += 1
2356      return constant_op.constant(2.0)
2357    future.calls = 0
2358
2359    if context.executing_eagerly():
2360      a = constant_op.constant(1.0)
2361      b = future
2362      with ops.control_dependencies([a, b]):
2363        c = constant_op.constant(3.0)
2364      self.assertEqual(future.calls, 1)
2365    else:
2366      g = ops.Graph()
2367      with g.as_default():
2368        a = constant_op.constant(1.0)
2369        b = future()
2370        with g.control_dependencies([a, b]):
2371          c = constant_op.constant(3.0)
2372      self.assertEqual(c.op.control_inputs, [a.op, b.op])
2373      self.assertEqual(future.calls, 1)
2374
2375  def testBasicWithConversion(self):
2376    g = ops.Graph()
2377    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2378
2379    class ConvertibleObj(object):
2380
2381      def _as_graph_element(self):
2382        return a
2383
2384    with g.control_dependencies([ConvertibleObj()]):
2385      c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2386
2387    self.assertEqual(c.op.control_inputs, [a.op])
2388
2389  def testNested(self):
2390    g = ops.Graph()
2391    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2392    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2393    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2394    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2395
2396    with g.control_dependencies([a_1, a_2, a_3, a_4]):
2397      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2398
2399    with g.control_dependencies([a_1]):
2400      with g.control_dependencies([a_2]):
2401        with g.control_dependencies([a_3]):
2402          with g.control_dependencies([a_4]):
2403            b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2404
2405    self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
2406                          b_1.op.control_inputs)
2407    self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
2408
2409  def testClear(self):
2410    g = ops.Graph()
2411    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2412    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2413    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2414    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2415
2416    with g.control_dependencies([a_1]):
2417      with g.control_dependencies([a_2]):
2418        with g.control_dependencies(None):
2419          with g.control_dependencies([a_3]):
2420            with g.control_dependencies([a_4]):
2421              # deps [a_3, a_4]
2422              b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2423            # deps = [a_3]
2424            b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2425          # deps back to None
2426          b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2427        # deps back to [a_1, a_2]
2428        b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2429      # deps back to [a_1]
2430      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2431      with g.control_dependencies(None):
2432        # deps are None again
2433        b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2434
2435    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
2436    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
2437    self.assertItemsEqual([], b_none.op.control_inputs)
2438    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
2439    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
2440    self.assertItemsEqual([], b_none2.op.control_inputs)
2441
2442  def testComplex(self):
2443    g = ops.Graph()
2444
2445    # Usage pattern:
2446    # * Nodes a_i are constants defined at the outermost scope, and are used
2447    #   as control inputs for the ith nested scope.
2448    # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
2449    # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
2450    # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
2451    # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
2452
2453    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2454    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2455    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2456    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2457
2458    with g.control_dependencies([a_1]):
2459      b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2460                      [dtypes.float32])
2461      c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2462                      [dtypes.float32])
2463      d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
2464                      [dtypes.float32])
2465      e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2466      with g.control_dependencies([a_2]):
2467        b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2468                        [dtypes.float32])
2469        c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2470                        [dtypes.float32])
2471        d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
2472                        [dtypes.float32])
2473        e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
2474                        [dtypes.float32])
2475        with g.control_dependencies([a_3]):
2476          b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2477                          [dtypes.float32])
2478          c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2479                          [dtypes.float32])
2480          d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
2481                          [dtypes.float32])
2482          e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
2483                          [dtypes.float32])
2484          with g.control_dependencies([a_4]):
2485            b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2486                            [dtypes.float32])
2487            c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2488                            [dtypes.float32])
2489            d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
2490                            [dtypes.float32])
2491            e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
2492                            [dtypes.float32])
2493
2494    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
2495    self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
2496    self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
2497    self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
2498
2499    self.assertItemsEqual([], c_1.op.control_inputs)
2500    self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
2501    self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
2502    self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
2503
2504    self.assertItemsEqual([], d_1.op.control_inputs)
2505    self.assertItemsEqual([], d_2.op.control_inputs)
2506    self.assertItemsEqual([], d_3.op.control_inputs)
2507    self.assertItemsEqual([], d_4.op.control_inputs)
2508
2509    self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
2510    self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
2511    self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
2512    self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
2513
2514  def testRepeatedDependency(self):
2515    g = ops.Graph()
2516    a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
2517    a_0, a_1 = a.outputs
2518    with g.control_dependencies([a_0]):
2519      b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2520      with g.control_dependencies([a_1]):
2521        c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2522
2523    self.assertEqual(b.op.control_inputs, [a])
2524    self.assertEqual(c.op.control_inputs, [a])
2525
2526  def testNoControlDependencyWithDataDependency(self):
2527    g = ops.Graph()
2528    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2529    with g.control_dependencies([a]):
2530      b = _apply_op(g, "Identity", [a], [dtypes.float32])
2531
2532    self.assertEqual(b.op.control_inputs, [])
2533
2534  def testMonitoringAttributeAddedWhenUsingManualControlDep(self):
2535    g = ops.Graph()
2536    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2537    b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2538    with g.control_dependencies([a]):
2539      c = _apply_op(g, "Identity", [b], [dtypes.float32])
2540
2541    with g.control_dependencies([b]):
2542      d = _apply_op(g, "Identity", [b], [dtypes.float32])
2543
2544    # Validate that the monitoring attribute is set to track usage of the
2545    # `control_dependencies(...)` API.
2546    self.assertEqual(c.op.control_inputs, [a.op])
2547    with self.assertRaises(ValueError):
2548      c.op.get_attr("_has_manual_control_dependencies")
2549    self.assertEqual(a.op.get_attr("_has_manual_control_dependencies"), True)
2550
2551    # Validate that the monitoring attribute is set to track usage of the
2552    # `control_dependencies(...)` API even when the manual control deps actually
2553    # happened to be pruned at runtime.
2554    self.assertEqual(d.op.control_inputs, [])
2555    with self.assertRaises(ValueError):
2556      d.op.get_attr("_has_manual_control_dependencies")
2557    self.assertEqual(b.op.get_attr("_has_manual_control_dependencies"), True)
2558
2559
2560class OpScopeTest(test_util.TensorFlowTestCase):
2561
2562  @test_util.run_in_graph_and_eager_modes
2563  def testNames(self):
2564    with ops.name_scope("foo", skip_on_eager=False) as foo:
2565      self.assertEqual("foo/", foo)
2566      with ops.name_scope("foo2", skip_on_eager=False) as foo2:
2567        self.assertEqual("foo/foo2/", foo2)
2568      with ops.name_scope(None, skip_on_eager=False) as empty1:
2569        self.assertEqual("", empty1)
2570        with ops.name_scope("foo3", skip_on_eager=False) as foo3:
2571          self.assertEqual("foo3/", foo3)
2572      with ops.name_scope("", skip_on_eager=False) as empty2:
2573        self.assertEqual("", empty2)
2574    with ops.name_scope("foo/", skip_on_eager=False) as outer_foo:
2575      self.assertEqual("foo/", outer_foo)
2576      with ops.name_scope("", skip_on_eager=False) as empty3:
2577        self.assertEqual("", empty3)
2578      with ops.name_scope("foo4", skip_on_eager=False) as foo4:
2579        self.assertEqual("foo/foo4/", foo4)
2580      with ops.name_scope("foo5//", skip_on_eager=False) as foo5:
2581        self.assertEqual("foo5//", foo5)
2582        with ops.name_scope("foo6", skip_on_eager=False) as foo6:
2583          self.assertEqual("foo5//foo6/", foo6)
2584      with ops.name_scope("/", skip_on_eager=False) as foo7:
2585        self.assertEqual("/", foo7)
2586      with ops.name_scope("//", skip_on_eager=False) as foo8:
2587        self.assertEqual("//", foo8)
2588      with ops.name_scope("a//b/c", skip_on_eager=False) as foo9:
2589        self.assertEqual("foo/a//b/c/", foo9)
2590    with ops.name_scope("a//b/c", skip_on_eager=False) as foo10:
2591      self.assertEqual("a//b/c/", foo10)
2592
2593  @test_util.run_in_graph_and_eager_modes
2594  def testEagerDefaultScopeName(self):
2595    with ops.name_scope(None, "default", skip_on_eager=False) as scope:
2596      self.assertEqual(scope, "default/")
2597      with ops.name_scope(None, "default2", skip_on_eager=False) as scope2:
2598        self.assertEqual(scope2, "default/default2/")
2599
2600  @test_util.run_in_graph_and_eager_modes
2601  def testNameScopeV2IsReEntrant(self):
2602    foo = ops.name_scope_v2("foo")
2603    bar = ops.name_scope_v2("bar")
2604    with foo as scope_name:
2605      self.assertEqual("foo/", scope_name)
2606      with foo as scope_name:
2607        self.assertEqual("foo/foo/", scope_name)
2608      with bar as scope_name:
2609        self.assertEqual("foo/bar/", scope_name)
2610        with foo as scope_name:
2611          self.assertEqual("foo/bar/foo/", scope_name)
2612    with bar as scope_name:
2613      self.assertEqual("bar/", scope_name)
2614
2615  @test_util.run_deprecated_v1
2616  def testNoScopeName(self):
2617    g0 = ops.Graph()
2618    values = [
2619        g0.create_op("A", [], [dtypes.float32]),
2620        g0.create_op("B", [], [dtypes.float32])
2621    ]
2622    with self.assertRaises(ValueError):
2623      with ops.name_scope(None, values=values):
2624        pass
2625    with self.assertRaises(ValueError):
2626      with ops.name_scope(None, None, values):
2627        pass
2628
2629  @test_util.run_deprecated_v1
2630  def testEmptyScopeName(self):
2631    g0 = ops.Graph()
2632    a = g0.create_op("A", [], [dtypes.float32])
2633    b = g0.create_op("B", [], [dtypes.float32])
2634    with ops.name_scope("", values=[a, b]) as scope:
2635      self.assertEqual("", scope)
2636      self.assertEqual(g0, ops.get_default_graph())
2637    with ops.name_scope("", "my_default_scope", [a, b]) as scope:
2638      self.assertEqual("", scope)
2639      self.assertEqual(g0, ops.get_default_graph())
2640
2641  @test_util.run_deprecated_v1
2642  def testDefaultScopeName(self):
2643    g0 = ops.Graph()
2644    a = g0.create_op("A", [], [dtypes.float32])
2645    b = g0.create_op("B", [], [dtypes.float32])
2646    scope_name = "my_scope"
2647    default_scope_name = "my_default_scope"
2648    with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
2649      self.assertEqual("%s/" % scope_name, scope)
2650      self.assertEqual(g0, ops.get_default_graph())
2651    with ops.name_scope(None, default_scope_name, [a, b]) as scope:
2652      self.assertEqual("%s/" % default_scope_name, scope)
2653      self.assertEqual(g0, ops.get_default_graph())
2654    with self.assertRaises(TypeError):
2655      with ops.name_scope(scope_name, [a, b]):
2656        pass
2657
2658  def _testGraphElements(self, graph_elements):
2659    scope_name = "my_scope"
2660    with ops.name_scope(scope_name, values=graph_elements) as scope:
2661      self.assertEqual("%s/" % scope_name, scope)
2662      self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
2663    g1 = ops.Graph()
2664    a = g1.create_op("A", [], [dtypes.float32])
2665    with self.assertRaises(ValueError):
2666      with ops.name_scope(scope_name, values=graph_elements + [a]):
2667        pass
2668
2669  @test_util.run_in_graph_and_eager_modes
2670  def testGetCurrentNameScope(self):
2671    self.assertEqual(ops.get_current_name_scope(), "")
2672    with ops.name_scope_v2("aaa"):
2673      self.assertEqual(ops.get_current_name_scope(), "aaa")
2674      with ops.name_scope_v2("bbb"):
2675        self.assertEqual(ops.get_current_name_scope(), "aaa/bbb")
2676      self.assertEqual(ops.get_current_name_scope(), "aaa")
2677    self.assertEqual(ops.get_current_name_scope(), "")
2678
2679  @test_util.run_deprecated_v1
2680  def testTensor(self):
2681    g0 = ops.Graph()
2682    a = g0.create_op("A", [], [dtypes.float32])
2683    b = g0.create_op("B", [], [dtypes.float32])
2684    self._testGraphElements([a, b])
2685
2686  @test_util.run_deprecated_v1
2687  def testSparseTensor(self):
2688    g0 = ops.Graph()
2689    a = g0.create_op("A", [], [dtypes.float32])
2690    b = g0.create_op("B", [], [dtypes.float32])
2691    sparse = sparse_tensor.SparseTensor(
2692        _apply_op(g0, "Int64Output", [], [dtypes.int64]),
2693        _apply_op(g0, "FloatOutput", [], [dtypes.float32]),
2694        _apply_op(g0, "Int64Output", [], [dtypes.int64]))
2695    self._testGraphElements([a, sparse, b])
2696
2697  @test_util.run_deprecated_v1
2698  def testVariable(self):
2699    g0 = ops.Graph()
2700    with g0.as_default():
2701      variable = variables.Variable([1.0])
2702    a = g0.create_op("A", [], [dtypes.float32])
2703    b = g0.create_op("B", [], [dtypes.float32])
2704    self._testGraphElements([a, variable, b])
2705
2706
2707class InitScopeTest(test_util.TensorFlowTestCase):
2708
2709  def testClearsControlDependencies(self):
2710    g = ops.Graph()
2711    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2712    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2713    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2714    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2715
2716    with g.as_default():
2717      with g.control_dependencies([a_1]):
2718        with g.control_dependencies([a_2]):
2719          with ops.init_scope():
2720            with g.control_dependencies([a_3]):
2721              with g.control_dependencies([a_4]):
2722                # deps [a_3, a_4]
2723                b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2724              # deps = [a_3]
2725              b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2726            # deps back to None
2727            b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2728          # deps back to [a_1, a_2]
2729          b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2730        # deps back to [a_1]
2731        b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2732        with ops.init_scope():
2733          # deps are None again
2734          b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2735
2736    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
2737    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
2738    self.assertItemsEqual([], b_none.op.control_inputs)
2739    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
2740    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
2741    self.assertItemsEqual([], b_none2.op.control_inputs)
2742
2743  def testLiftsOpsFromFunctions(self):
2744    g0 = ops.Graph()
2745    g1 = ops.Graph()
2746    g1._building_function = True  # pylint: disable=protected-access
2747    g2 = ops.Graph()
2748    g2._building_function = True  # pylint: disable=protected-access
2749
2750    with g0.as_default():
2751      with g1.as_default():
2752        with g2.as_default():
2753          with ops.init_scope():
2754            _ = constant_op.constant(1.0)
2755
2756    self.assertLen(g2.get_operations(), 0)
2757    self.assertLen(g1.get_operations(), 0)
2758    self.assertLen(g0.get_operations(), 1)
2759
2760  def testPreservesDevices(self):
2761    g0 = ops.Graph()
2762    with g0.as_default(), ops.device("CPU:0"):
2763      g1 = ops.Graph()
2764      g1._building_function = True  # pylint: disable=protected-access
2765      with g1.as_default():
2766        with ops.device("GPU:0"):
2767          with ops.init_scope():
2768            # init_scope should preserve device set under `g1`.
2769            on_gpu = constant_op.constant(1.0)
2770            self.assertEqual(on_gpu.device, "/device:GPU:0")
2771          still_on_gpu = constant_op.constant(1.0)
2772          self.assertEqual(still_on_gpu.device, "/device:GPU:0")
2773        blank = constant_op.constant(1.0)
2774        self.assertEqual(blank.device, "")
2775        with ops.init_scope():
2776          now_on_cpu = constant_op.constant(1.0)
2777          self.assertEqual(now_on_cpu.device, "/device:CPU:0")
2778      on_cpu = constant_op.constant(1.0)
2779      self.assertEqual(on_cpu.device, "/device:CPU:0")
2780
2781  def testComposes(self):
2782    g0 = ops.Graph()
2783    g1 = ops.Graph()
2784    g1._building_function = True  # pylint: disable=protected-access
2785    g2 = ops.Graph()
2786    g2._building_function = True  # pylint: disable=protected-access
2787    g3 = ops.Graph()
2788    g3._building_function = False  # pylint: disable=protected-access
2789
2790    with g0.as_default():
2791      with g1.as_default():
2792        with ops.init_scope():
2793          # This op should be lifted into g0.
2794          _ = constant_op.constant(1.0)
2795          self.assertIs(g0, ops.get_default_graph())
2796          self.assertLen(g2.get_operations(), 0)
2797          self.assertLen(g1.get_operations(), 0)
2798          self.assertLen(g0.get_operations(), 1)
2799        with g2.as_default():
2800          with ops.init_scope():
2801            # This op should be lifted into g0.
2802            _ = constant_op.constant(1.0)
2803            self.assertIs(g0, ops.get_default_graph())
2804            with g3.as_default():
2805              with ops.init_scope():
2806                # This op should be lifted into g3, because g3 is not building a
2807                # function.
2808                _ = constant_op.constant(1.0)
2809                self.assertIs(g3, ops.get_default_graph())
2810
2811    self.assertLen(g3.get_operations(), 1)
2812    self.assertLen(g2.get_operations(), 0)
2813    self.assertLen(g1.get_operations(), 0)
2814    self.assertLen(g0.get_operations(), 2)
2815
2816  def testEscapesToEagerContext(self):
2817    g = ops.Graph()
2818    g._building_function = True  # pylint: disable=protected-access
2819    with context.eager_mode():
2820      with context.graph_mode():
2821        with g.as_default():
2822          with ops.init_scope():
2823            # Because g is building a function, init_scope should
2824            # escape out to the eager context.
2825            self.assertTrue(context.executing_eagerly())
2826          # g should be reinstated as the default graph, and the
2827          # graph context should be re-entered.
2828          self.assertIs(g, ops.get_default_graph())
2829          self.assertFalse(context.executing_eagerly())
2830
2831  def testStaysInEagerWhenOnlyEagerContextActive(self):
2832    with context.eager_mode():
2833      with ops.init_scope():
2834        self.assertTrue(context.eager_mode())
2835      self.assertTrue(context.eager_mode())
2836
2837  def testEscapesDefunWhenInEagerMode(self):
2838
2839    def function_with_variables():
2840      with ops.init_scope():
2841        self.v = resource_variable_ops.ResourceVariable(3)
2842      return self.v.assign_add(1)
2843
2844    with context.eager_mode():
2845      # Each invocation of function_with_variables recreates a variable.
2846      self.assertEqual(4, int(function_with_variables()))
2847      self.assertEqual(4, int(function_with_variables()))
2848
2849      compiled = eager_function.defun(function_with_variables)
2850      # The init_scope in function_with_variables lifts the variable out
2851      # of the graph function constructed by defun; hence,
2852      # compiled now appears to be stateful.
2853      self.assertEqual(4, int(compiled()))
2854      self.assertEqual(5, int(compiled()))
2855
2856  def testEscapesDefunWhenInGraphMode(self):
2857    def function_with_variables(name):
2858      with ops.init_scope():
2859        _ = variable_scope.get_variable(name, shape=(1,))
2860
2861    g = ops.Graph()
2862    with g.as_default():
2863      with self.cached_session():
2864        # First ensure that graphs that are not building functions are
2865        # not escaped.
2866        function_with_variables("foo")
2867        with self.assertRaisesRegex(ValueError,
2868                                    r"Variable foo already exists.*"):
2869          # This will fail because reuse is not set to True.
2870          function_with_variables("foo")
2871
2872        compiled = eager_function.defun(function_with_variables)
2873        compiled("bar")
2874        self.assertEqual(
2875            len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
2876
2877        # The second call to `compiled` should not create variables: the
2878        # init_scope has lifted the variable creation code out of the defun.
2879        compiled("bar")
2880        self.assertEqual(
2881            len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
2882
2883  def testEscapesNestedDefun(self):
2884
2885    def inner_function():
2886      with ops.init_scope():
2887        self.v = resource_variable_ops.ResourceVariable(1)
2888      return self.v.assign_add(2)
2889
2890    def outer_function(inner=None):
2891      with ops.init_scope():
2892        self.v0 = resource_variable_ops.ResourceVariable(0)
2893      return self.v0.assign_add(1) + inner()
2894
2895    with context.eager_mode():
2896      # Each invocation of outer_function recreates variables.
2897      self.assertEqual(4, int(outer_function(inner=inner_function)))
2898      self.assertEqual(4, int(outer_function(inner=inner_function)))
2899
2900      compiled_inner = eager_function.defun(inner_function)
2901      compiled_outer = eager_function.defun(outer_function)
2902      # The init_scope lifts variables out of the graph functions
2903      # constructed by defun; hence, compiled_outer should now appear to be
2904      # stateful.
2905      self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
2906      self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
2907
2908  @test_util.run_v1_only("b/120545219")
2909  def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
2910    with context.graph_mode():
2911      ops.reset_default_graph()
2912      # This doesn't push anything onto the graph stack, but it does
2913      # set the stack's global graph.
2914      global_graph = ops.get_default_graph()
2915      fn_graph = ops.Graph()
2916
2917      # pylint: disable=protected-access
2918      fn_graph._building_function = True
2919      self.assertLen(ops._default_graph_stack.stack, 0)
2920      with fn_graph.as_default():
2921        self.assertLen(ops._default_graph_stack.stack, 1)
2922        with ops.init_scope():
2923          self.assertGreater(len(ops._default_graph_stack.stack), 1)
2924          dummy = constant_op.constant(1.0)
2925        self.assertLen(ops._default_graph_stack.stack, 1)
2926      # Note that the global graph is _not_ on the graph stack.
2927      self.assertLen(ops._default_graph_stack.stack, 0)
2928      # Ensure that `dummy` was added to the global graph.
2929      self.assertEqual(global_graph, dummy.graph)
2930      # pylint: enable=protected-access
2931
2932  def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
2933    with context.graph_mode():
2934      # pylint: disable=protected-access
2935      self.assertLen(ops._default_graph_stack.stack, 0)
2936      with ops.init_scope():
2937        self.assertGreater(len(ops._default_graph_stack.stack), 0)
2938      self.assertLen(ops._default_graph_stack.stack, 0)
2939      # pylint: enable=protected-access
2940
2941  def testPreservesNameScopeInGraphConstruction(self):
2942    with ops.Graph().as_default():
2943      function_graph = ops.Graph()
2944      with function_graph.as_default():
2945        with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
2946          self.assertEqual(ops.get_name_scope(), "inner")
2947      self.assertEqual(ops.get_name_scope(), "")
2948
2949  def testEnteringGraphFromEagerIsSticky(self):
2950    with context.eager_mode():
2951      g = ops.Graph()
2952      with g.as_default():
2953        with ops.init_scope():
2954          self.assertFalse(context.executing_eagerly())
2955          self.assertEqual(g, ops.get_default_graph())
2956
2957  def testMixGraphEager(self):
2958    with context.eager_mode():
2959      c = constant_op.constant(1.0)
2960      with ops.Graph().as_default():
2961        with self.assertRaisesRegex(RuntimeError,
2962                                    "Attempting to capture an EagerTensor"):
2963          math_ops.add(c, c)
2964        c2 = constant_op.constant(2.0)
2965      with self.assertRaises(TypeError):
2966        math_ops.add(c2, c2)
2967
2968  def testPreservesNameScopeInEagerExecution(self):
2969    with context.eager_mode():
2970      def foo():
2971        with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
2972          if context.executing_eagerly():
2973            # A trailing slash is always appended when eager execution is
2974            # enabled.
2975            self.assertEqual(context.context().scope_name, "inner/")
2976          else:
2977            self.assertEqual(ops.get_name_scope(), "inner")
2978
2979      foo()
2980      self.assertEqual(ops.get_name_scope(), "")
2981      foo_compiled = eager_function.defun(foo)
2982      foo_compiled()
2983      self.assertEqual(ops.get_name_scope(), "")
2984
2985  def testExecutingEagerlyOutsideFunctions(self):
2986
2987    @def_function.function
2988    def f():
2989      return ops.executing_eagerly_outside_functions()
2990
2991    with context.graph_mode():
2992      self.assertFalse(ops.executing_eagerly_outside_functions())
2993      with session.Session():
2994        # Need self.evaluate for these as the return type of functions is
2995        # tensors.
2996        self.assertFalse(self.evaluate(f()))
2997
2998    with context.eager_mode():
2999      self.assertTrue(ops.executing_eagerly_outside_functions())
3000      self.assertTrue(f())
3001
3002      with ops.Graph().as_default():
3003        self.assertFalse(ops.executing_eagerly_outside_functions())
3004        with session.Session():
3005          self.assertFalse(self.evaluate(f()))
3006
3007
3008class GraphTest(test_util.TensorFlowTestCase):
3009
3010  def setUp(self):
3011    ops.reset_default_graph()
3012
3013  def _AssertDefault(self, expected):
3014    self.assertIs(expected, ops.get_default_graph())
3015
3016  def testResetDefaultGraphNesting(self):
3017    g0 = ops.Graph()
3018    with self.assertRaises(AssertionError):
3019      with g0.as_default():
3020        ops.reset_default_graph()
3021
3022  def testGraphContextManagerCancelsEager(self):
3023    with context.eager_mode():
3024      with ops.Graph().as_default():
3025        self.assertFalse(context.executing_eagerly())
3026
3027  def testGraphContextManager(self):
3028    g0 = ops.Graph()
3029    with g0.as_default() as g1:
3030      self.assertIs(g0, g1)
3031
3032  def testDefaultGraph(self):
3033    orig = ops.get_default_graph()
3034    self.assertFalse(ops.has_default_graph())
3035    self._AssertDefault(orig)
3036    g0 = ops.Graph()
3037    self.assertFalse(ops.has_default_graph())
3038    self._AssertDefault(orig)
3039    context_manager_0 = g0.as_default()
3040    self.assertFalse(ops.has_default_graph())
3041    self._AssertDefault(orig)
3042    with context_manager_0 as g0:
3043      self._AssertDefault(g0)
3044      with ops.Graph().as_default() as g1:
3045        self.assertTrue(ops.has_default_graph())
3046        self._AssertDefault(g1)
3047      self._AssertDefault(g0)
3048    self._AssertDefault(orig)
3049    self.assertFalse(ops.has_default_graph())
3050
3051  def testPreventFeeding(self):
3052    g = ops.Graph()
3053    a = constant_op.constant(2.0)
3054    self.assertTrue(g.is_feedable(a))
3055    g.prevent_feeding(a)
3056    self.assertFalse(g.is_feedable(a))
3057
3058  @test_util.run_deprecated_v1
3059  def testPreventFetching(self):
3060    g = ops.Graph()
3061    a = constant_op.constant(2.0)
3062    self.assertTrue(g.is_fetchable(a))
3063    g.prevent_fetching(a.op)
3064    self.assertFalse(g.is_fetchable(a))
3065
3066  def testAsGraphElementConversions(self):
3067
3068    class ConvertibleObj(object):
3069
3070      def _as_graph_element(self):
3071        return "FloatOutput:0"
3072
3073    class NonConvertibleObj(object):
3074
3075      pass
3076
3077    g = ops.Graph()
3078    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
3079    self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
3080    with self.assertRaises(TypeError):
3081      g.as_graph_element(NonConvertibleObj())
3082
3083  # Regression test against creating custom __del__ functions in classes
3084  # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
3085  # cycles that require calling a __del__ method, because the __del__ method can
3086  # theoretically increase the object's refcount to "save" it from gc, and any
3087  # already-deleted objects in the cycle would have be to restored.)
3088  def testGarbageCollected(self):
3089    # Create a graph we can delete and a weak reference to monitor if it's gc'd
3090    g = ops.Graph()
3091    g_ref = weakref.ref(g)
3092    # Create some ops
3093    with g.as_default():
3094      a = constant_op.constant(2.0)
3095      b = constant_op.constant(3.0)
3096      c = math_ops.add(a, b)
3097    # Create a session we can delete
3098    with session.Session(graph=g) as sess:
3099      self.evaluate(c)
3100    # Delete all references and trigger gc
3101    del g
3102    del a
3103    del b
3104    del c
3105    del sess
3106    gc.collect()
3107    self.assertIsNone(g_ref())
3108
3109  def testRunnableAfterInvalidShape(self):
3110    with ops.Graph().as_default():
3111      with self.assertRaises(ValueError):
3112        math_ops.add([1, 2], [1, 2, 3])
3113      a = constant_op.constant(1)
3114      with session.Session() as sess:
3115        self.evaluate(a)
3116
3117  def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
3118    g = ops.Graph()
3119    with g.as_default():
3120      with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
3121        with self.assertRaises(ValueError):
3122          test_ops.kernel_label_required(1)
3123      a = constant_op.constant(1)
3124      with session.Session() as sess:
3125        self.evaluate(a)
3126
3127
3128class AttrScopeTest(test_util.TensorFlowTestCase):
3129
3130  def _get_test_attrs(self):
3131    x = control_flow_ops.no_op()
3132    try:
3133      a = compat.as_text(x.get_attr("_A"))
3134    except ValueError:
3135      a = None
3136    try:
3137      b = compat.as_text(x.get_attr("_B"))
3138    except ValueError:
3139      b = None
3140    return (a, b)
3141
3142  @test_util.run_deprecated_v1
3143  def testNoLabel(self):
3144    with self.cached_session():
3145      self.assertAllEqual((None, None), self._get_test_attrs())
3146
3147  @test_util.run_deprecated_v1
3148  def testLabelMap(self):
3149    with self.cached_session() as sess:
3150      a1 = self._get_test_attrs()
3151      with sess.graph._attr_scope({
3152          "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
3153      }):
3154        a2 = self._get_test_attrs()
3155        with sess.graph._attr_scope({
3156            "_A": None,
3157            "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
3158        }):
3159          a3 = self._get_test_attrs()
3160          with sess.graph._attr_scope({
3161              "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
3162          }):
3163            a4 = self._get_test_attrs()
3164          a5 = self._get_test_attrs()
3165        a6 = self._get_test_attrs()
3166      a7 = self._get_test_attrs()
3167
3168      self.assertAllEqual((None, None), a1)
3169      self.assertAllEqual(("foo", None), a2)
3170      self.assertAllEqual((None, "bar"), a3)
3171      self.assertAllEqual(("baz", "bar"), a4)
3172      self.assertAllEqual((None, "bar"), a5)
3173      self.assertAllEqual(("foo", None), a6)
3174      self.assertAllEqual((None, None), a7)
3175
3176
3177class KernelLabelTest(test_util.TensorFlowTestCase):
3178
3179  @test_util.run_deprecated_v1
3180  def testNoLabel(self):
3181    with self.cached_session():
3182      self.assertAllEqual(b"My label is: default",
3183                          test_ops.kernel_label().eval())
3184
3185  @test_util.run_deprecated_v1
3186  def testLabelMap(self):
3187    with self.cached_session() as sess:
3188      default_1 = test_ops.kernel_label()
3189      # pylint: disable=protected-access
3190      with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
3191        overload_1_1 = test_ops.kernel_label()
3192        with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
3193          overload_2 = test_ops.kernel_label()
3194          with sess.graph._kernel_label_map({"KernelLabel": ""}):
3195            default_2 = test_ops.kernel_label()
3196        overload_1_2 = test_ops.kernel_label()
3197      # pylint: enable=protected-access
3198      default_3 = test_ops.kernel_label()
3199
3200      self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
3201      self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
3202      self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
3203      self.assertAllEqual(b"My label is: overload_1",
3204                          self.evaluate(overload_1_1))
3205      self.assertAllEqual(b"My label is: overload_1",
3206                          self.evaluate(overload_1_2))
3207      self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))
3208
3209
3210class AsGraphDefTest(test_util.TensorFlowTestCase):
3211
3212  def testGraphDefVersion(self):
3213    """Test that the graphdef version is plumbed through to kernels."""
3214    with ops.Graph().as_default() as g:
3215      version = g.graph_def_versions.producer
3216      with self.session(graph=g):
3217        v = test_ops.graph_def_version().eval()
3218        self.assertEqual(version, v)
3219
3220  def testAddShapes(self):
3221    with ops.Graph().as_default() as g:
3222      t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
3223                                     [dtypes.float32] * 5)
3224      t1.set_shape(None)
3225      t2.set_shape([])
3226      t3.set_shape([None])
3227      t4.set_shape([43, 37])
3228      t5.set_shape([43, None])
3229
3230      b = constant_op.constant(1.0)  # pylint: disable=unused-variable
3231
3232      gd = g.as_graph_def(add_shapes=True)
3233      self.assertProtoEqualsVersion("""
3234      node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
3235        attr {
3236          key: "_output_shapes"
3237          value {
3238            list {
3239              shape { unknown_rank: true }
3240              shape { }
3241              shape { dim { size: -1 } }
3242              shape { dim { size: 43 } dim { size: 37 } }
3243              shape { dim { size: 43 } dim { size: -1 } }
3244            }
3245          }
3246        }
3247      }
3248    node { name: "Const" op: "Const"
3249      attr {
3250        key: "_output_shapes"
3251        value {
3252          list {
3253            shape { }
3254          }
3255        }
3256      }
3257      attr {
3258        key: "dtype"
3259        value { type: DT_FLOAT }
3260      }
3261      attr {
3262        key: "value"
3263        value {
3264          tensor {
3265            dtype: DT_FLOAT
3266            tensor_shape { }
3267         float_val: 1.0  } } } }
3268      """, gd)
3269
3270
3271@ops.RegisterStatistics("a", "flops")
3272def _calc_a_forward_flops(unused_graph, unused_node):
3273  return ops.OpStats("flops", 20)
3274
3275
3276class StatisticsTest(test_util.TensorFlowTestCase):
3277
3278  def testRegisteredNode(self):
3279    graph = ops.Graph()
3280    node = ops._NodeDef("a", "an_a")
3281    flops = ops.get_stats_for_node_def(graph, node, "flops")
3282    self.assertEqual(20, flops.value)
3283    missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
3284    self.assertEqual(None, missing_stat.value)
3285
3286  def testUnregisteredNode(self):
3287    graph = ops.Graph()
3288    node = ops._NodeDef("b", "a_b")
3289    weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
3290    self.assertEqual(None, weight_params.value)
3291
3292  def testAccumulateStatistics(self):
3293    flops_total = ops.OpStats("flops")
3294    self.assertEqual(None, flops_total.value)
3295    second_flops = ops.OpStats("flops", 3)
3296    flops_total += second_flops
3297    self.assertEqual(3, flops_total.value)
3298
3299
3300class DeviceStackTest(test_util.TensorFlowTestCase):
3301
3302  @test_util.run_deprecated_v1
3303  def testBasicDeviceAssignmentMetadata(self):
3304
3305    def device_func(unused_op):
3306      return "/cpu:*"
3307
3308    const_zero = constant_op.constant([0.0], name="zero")
3309    with ops.device("/cpu"):
3310      const_one = constant_op.constant([1.0], name="one")
3311      with ops.device("/cpu:0"):
3312        const_two = constant_op.constant([2.0], name="two")
3313    with ops.device(device_func):
3314      const_three = constant_op.constant(3.0, name="three")
3315
3316    self.assertEqual(0, len(const_zero.op._device_assignments))
3317
3318    one_list = const_one.op._device_assignments
3319    self.assertEqual(1, len(one_list))
3320    self.assertEqual("/cpu", one_list[0].obj)
3321    self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
3322
3323    two_list = const_two.op._device_assignments
3324    self.assertEqual(2, len(two_list))
3325    devices = [t.obj for t in two_list]
3326    self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
3327
3328    three_list = const_three.op._device_assignments
3329    self.assertEqual(1, len(three_list))
3330    func_description = three_list[0].obj
3331    expected_regex = r"device_func<.*ops_test.py, [0-9]+"
3332    self.assertRegex(func_description, expected_regex)
3333
3334  @test_util.run_deprecated_v1
3335  def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
3336
3337    with ops.device("/cpu"):
3338      const_one = constant_op.constant([1.0], name="one")
3339    with ops.get_default_graph().device("/cpu"):
3340      const_two = constant_op.constant([2.0], name="two")
3341
3342    one_metadata = const_one.op._device_assignments[0]
3343    two_metadata = const_two.op._device_assignments[0]
3344
3345    # Verify both types of device assignment return the right stack info.
3346    self.assertRegex("ops_test.py", os.path.basename(one_metadata.filename))
3347    self.assertEqual(one_metadata.filename, two_metadata.filename)
3348    self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
3349
3350
3351class ColocationGroupTest(test_util.TensorFlowTestCase):
3352
3353  @test_util.run_deprecated_v1
3354  def testBasic(self):
3355    a = constant_op.constant([2.0], name="a")
3356    with ops.colocate_with(a.op):
3357      b = constant_op.constant(3.0)
3358    c = constant_op.constant(4.0)
3359    self.assertEqual([b"loc:@a"], a.op.colocation_groups())
3360    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3361    with self.assertRaises(ValueError):
3362      c.op.get_attr("_class")
3363
3364  @test_util.run_deprecated_v1
3365  def testBasicColocationMetadata(self):
3366    const_two = constant_op.constant([2.0], name="two")
3367    with ops.colocate_with(const_two.op):
3368      const_three = constant_op.constant(3.0, name="three")
3369    locations_dict = const_three.op._colocation_dict
3370    self.assertIn("two", locations_dict)
3371    metadata = locations_dict["two"]
3372    self.assertIsNone(metadata.obj)
3373    # Check that this test's filename is recorded as the file containing the
3374    # colocation statement.
3375    self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
3376
3377  @test_util.run_deprecated_v1
3378  def testColocationDeviceInteraction(self):
3379    with ops.device("/cpu:0"):
3380      with ops.device("/device:GPU:0"):
3381        a = constant_op.constant([2.0], name="a")
3382      with ops.colocate_with(a.op):
3383        # 'b' is created in the scope of /cpu:0, but it is
3384        # colocated with 'a', which is on '/device:GPU:0'.  colocate_with
3385        # overrides devices because it is a stronger constraint.
3386        b = constant_op.constant(3.0)
3387    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3388    self.assertEqual(a.op.device, b.op.device)
3389
3390  @test_util.run_deprecated_v1
3391  def testColocationCanonicalization(self):
3392    with ops.device("/device:GPU:0"):
3393      _ = constant_op.constant(2.0)
3394    with ops.device(lambda op: "/device:GPU:0"):
3395      b = constant_op.constant(3.0)
3396    with ops.get_default_graph().colocate_with(b):
3397      with ops.device("/device:GPU:0"):
3398        c = constant_op.constant(4.0)
3399
3400    # A's device will be /device:GPU:0
3401    # B's device will be /device:GPU:0
3402    # C's device will be /device:GPU:0 because it
3403    # inherits B's device name, after canonicalizing the names.
3404    self.assertEqual(b.op.device, c.op.device)
3405
3406  @test_util.run_deprecated_v1
3407  def testLocationOverrides(self):
3408    with ops.device("/cpu:0"):
3409      with ops.device("/device:GPU:0"):
3410        a = constant_op.constant([2.0], name="a")
3411        # Note that this colocation is "redundant", since we are
3412        # within the scope of "/device:GPU:0".  However, we would like to
3413        # preserve in the GraphDef that these two ops should be
3414        # colocated in a portable way.
3415        with ops.colocate_with(a.op):
3416          b = constant_op.constant(3.0)
3417        c = constant_op.constant(4.0)
3418      d = constant_op.constant(5.0)
3419
3420    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3421    self.assertEqual("/device:GPU:0", a.op.device)
3422    self.assertEqual(a.op.device, b.op.device)
3423
3424    # Test that device function stack is restored.
3425    self.assertEqual("/device:GPU:0", c.op.device)
3426    self.assertEqual("/device:CPU:0", d.op.device)
3427
3428  @test_util.run_deprecated_v1
3429  def testNestedColocateWith(self):
3430    a = constant_op.constant([2.0], name="a")
3431    with ops.colocate_with(a.op):
3432      b = constant_op.constant(3.0)
3433      with ops.colocate_with(b.op):
3434        c = constant_op.constant(4.0)
3435    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3436    self.assertEqual([b"loc:@a"], c.op.colocation_groups())
3437
3438  @test_util.run_deprecated_v1
3439  def testMultiColocationGroups(self):
3440    a = constant_op.constant([2.0], name="a")
3441    b = constant_op.constant(3.0, name="b")
3442    with ops.colocate_with(a.op):
3443      with ops.colocate_with(b.op):
3444        c = constant_op.constant(4.0)
3445    self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
3446
3447  @test_util.run_deprecated_v1
3448  def testColocationIgnoreStack(self):
3449    a = constant_op.constant([2.0], name="a")
3450    b = constant_op.constant(3.0, name="b")
3451    with ops.colocate_with(a.op):
3452      with ops.colocate_with(b.op, ignore_existing=True):
3453        c = constant_op.constant(4.0)
3454    self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
3455
3456  @test_util.run_deprecated_v1
3457  def testColocateWithReset(self):
3458    a = constant_op.constant([2.0], name="a")
3459    with ops.colocate_with(a.op):
3460      b = constant_op.constant(3.0, name="b")
3461      with ops.colocate_with(None, ignore_existing=True):
3462        c = constant_op.constant(4.0, name="c")
3463    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3464    self.assertEqual([b"loc:@c"], c.op.colocation_groups())
3465
3466  @test_util.run_deprecated_v1
3467  def testColocateWithInitialNoneThenNested(self):
3468    a = constant_op.constant([2.0], name="a")
3469    with ops.colocate_with(a.op):
3470      with ops.colocate_with(None, ignore_existing=True):
3471        b = constant_op.constant(3.0, name="b")
3472        with ops.colocate_with(b.op):
3473          c = constant_op.constant(4.0, name="c")
3474    self.assertEqual([b"loc:@b"], b.op.colocation_groups())
3475    self.assertEqual([b"loc:@b"], c.op.colocation_groups())
3476
3477  @test_util.run_deprecated_v1
3478  def testColocateVariables(self):
3479    a = variables.Variable([2.0], name="a")
3480    with ops.colocate_with(a.op):
3481      b = variables.Variable([3.0], name="b")
3482    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3483
3484  @test_util.run_deprecated_v1
3485  def testColocateResourceVariablesInFunction(self):
3486    with ops.device("/device:CPU:0"):
3487      a = resource_variable_ops.ResourceVariable(1.0)
3488
3489    @def_function.function
3490    def f():
3491      with ops.colocate_with(a):
3492        b = array_ops.ones([], name="output")
3493        self.assertEqual("/device:CPU:0", b.op.device)
3494    f()
3495
3496  def testColocateWithVariableInFunction(self):
3497    v = variables.Variable(1.)
3498
3499    @def_function.function
3500    def f():
3501      with ops.colocate_with(v):
3502        return array_ops.ones([], name="output")
3503
3504    f()
3505    graph_def = f.get_concrete_function().graph.as_graph_def()
3506    wrap_function.function_from_graph_def(graph_def, [], ["output"])
3507
3508
3509class DeadlineTest(test_util.TensorFlowTestCase):
3510
3511  def testNoDeadlineSet(self):
3512    with ops.Graph().as_default() as g:
3513      get_deadline = test_ops.get_deadline()
3514      with self.session(graph=g) as sess:
3515        run_options = config_pb2.RunOptions()
3516        with self.assertRaises(errors.InvalidArgumentError):
3517          sess.run(get_deadline, options=run_options)
3518
3519  def testDeadlineSetTimesOut(self):
3520    with ops.Graph().as_default() as g:
3521      sleep_op = test_ops.sleep_op(10)
3522      with self.session(graph=g) as sess:
3523        run_options = config_pb2.RunOptions(timeout_in_ms=3_000)
3524        with self.assertRaises(errors.DeadlineExceededError):
3525          sess.run(sleep_op, options=run_options)
3526
3527
3528class DeprecatedTest(test_util.TensorFlowTestCase):
3529
3530  def testSuccess(self):
3531    with ops.Graph().as_default() as g:
3532      test_util.set_producer_version(g, 7)
3533      old = test_ops.old()
3534      with self.session(graph=g):
3535        old.run()
3536
3537  def _error(self):
3538    return ((r"Op Old is not available in GraphDef version %d\. "
3539             r"It has been removed in version 8\. For reasons\.") %
3540            versions.GRAPH_DEF_VERSION)
3541
3542  def testGraphConstructionFail(self):
3543    with ops.Graph().as_default():
3544      with self.assertRaisesRegex(NotImplementedError, self._error()):
3545        test_ops.old()
3546
3547
3548class NameScopeTest(test_util.TensorFlowTestCase):
3549
3550  def testStripAndPrependScope(self):
3551    strs = [
3552        "hidden1/hidden1/weights",  # Same prefix. Should strip.
3553        "hidden1///hidden1/weights",  # Extra "/". Should strip.
3554        "^hidden1/hidden1/weights",  # Same prefix. Should strip.
3555        "loc:@hidden1/hidden1/weights",  # Same prefix. Should strip.
3556        "hhidden1/hidden1/weights",  # Different prefix. Should keep.
3557        "hidden1"
3558    ]  # Not a prefix. Should keep.
3559    expected_striped = [
3560        "hidden1/weights", "hidden1/weights", "^hidden1/weights",
3561        "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
3562    ]
3563    expected_prepended = [
3564        "hidden2/hidden1/weights", "hidden2/hidden1/weights",
3565        "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
3566        "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
3567    ]
3568    name_scope_to_strip = "hidden1"
3569    name_scope_to_add = "hidden2"
3570    for es, ep, s in zip(expected_striped, expected_prepended, strs):
3571      striped = ops.strip_name_scope(s, name_scope_to_strip)
3572      self.assertEqual(es, striped)
3573      self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
3574
3575  def testGetNameScope(self):
3576    with ops.Graph().as_default() as g:
3577      with ops.name_scope("scope1"):
3578        with ops.name_scope("scope2"):
3579          with ops.name_scope("scope3"):
3580            self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
3581          self.assertEqual("scope1/scope2", g.get_name_scope())
3582        self.assertEqual("scope1", g.get_name_scope())
3583      self.assertEqual("", g.get_name_scope())
3584
3585  def testTwoGraphs(self):
3586
3587    def f():
3588      g1 = ops.Graph()
3589      g2 = ops.Graph()
3590      with g1.as_default():
3591        with g2.as_default():
3592          with ops.name_scope("_"):
3593            pass
3594
3595    self.assertRaisesRegex(ValueError,
3596                           "'_' is not a valid (?:root )?scope name", f)
3597
3598
3599class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
3600
3601  @test_util.run_v1_only("b/120545219")
3602  def testBadArgumentsToEnableEagerExecution(self):
3603    with self.assertRaisesRegex(TypeError, "config must be a tf.ConfigProto"):
3604      ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
3605    with self.assertRaisesRegex(ValueError, "device_policy must be one of"):
3606      c = config_pb2.ConfigProto()
3607      ops.enable_eager_execution(c, c)
3608    with self.assertRaisesRegex(ValueError, "execution_mode must be one of"):
3609      c = config_pb2.ConfigProto()
3610      ops.enable_eager_execution(c, execution_mode=c)
3611
3612
3613class _TupleTensor(composite_tensor.CompositeTensor):
3614  """`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading."""
3615
3616  def __init__(self, components):
3617    super(_TupleTensor, self).__init__()
3618    self._components = tuple(ops.convert_to_tensor(c) for c in components)
3619
3620  @property
3621  def _type_spec(self):
3622    return _TupleTensorSpec(type_spec.from_value(c) for c in self._components)
3623
3624  def __getitem__(self, key):
3625    return self._components[key]
3626
3627  def __len__(self):
3628    return len(self._components)
3629
3630  def __iter__(self):
3631    return iter(self._components)
3632
3633
3634class _TupleTensorSpec(type_spec.TypeSpec):
3635
3636  def __init__(self, specs):
3637    self._specs = specs
3638
3639  value_type = property(lambda self: _TupleTensor)
3640  _component_specs = property(lambda self: self._specs)
3641
3642  def _to_components(self, value):
3643    return value._components
3644
3645  def _from_components(self, components):
3646    return _TupleTensor(*components)
3647
3648  def _serialize(self):
3649    return (self._specs,)
3650
3651
3652class _MyTuple(object):
3653  """Pretend user-side class for `ConvertToCompositeTensorTest ."""
3654
3655  def __init__(self, components):
3656    super(_MyTuple, self).__init__()
3657    self._components = tuple(components)
3658
3659  def __getitem__(self, key):
3660    return self._components[key]
3661
3662  def __len__(self):
3663    return len(self._components)
3664
3665  def __iter__(self):
3666    return iter(self._components)
3667
3668
3669ops.register_tensor_conversion_function(
3670    _MyTuple, conversion_func=lambda x, *_, **__: _TupleTensor(x))
3671
3672
3673class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
3674
3675  @test_util.disable_tfrt("TODO(kkb): This makes Kokoro tests fail.")
3676  def testCompositeTensorConversion(self):
3677    """Tests that a user can register a CompositeTensor converter."""
3678    x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
3679    y = ops.convert_to_tensor_or_composite(x)
3680    self.assertFalse(tensor_util.is_tf_type(y))
3681    self.assertIsInstance(y, _TupleTensor)
3682    self.assertLen(y, len(x))
3683    for x_, y_ in zip(x, y):
3684      self.assertIsInstance(y_, ops.Tensor)
3685      self.assertTrue(tensor_util.is_tf_type(y_))
3686      self.assertAllEqual(x_, tensor_util.constant_value(y_))
3687
3688
3689@test_util.disable_tfrt("Packing EagerTensors is not supported yet.")
3690class PackEagerTensorTest(test_util.TensorFlowTestCase):
3691
3692  def setUp(self):
3693    super(PackEagerTensorTest, self).setUp()
3694    context._reset_context()
3695    cpus = config.list_physical_devices("CPU")
3696    # Set 2 virtual CPUs
3697    config.set_logical_device_configuration(cpus[0], [
3698        context.LogicalDeviceConfiguration(),
3699        context.LogicalDeviceConfiguration(),
3700    ])
3701
3702  def testPack(self):
3703    with context.eager_mode():
3704      with ops.device("CPU:0"):
3705        var0 = resource_variable_ops.ResourceVariable(1.0)
3706        c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
3707      with ops.device("CPU:1"):
3708        var1 = resource_variable_ops.ResourceVariable(2.0)
3709        var2 = resource_variable_ops.ResourceVariable([3.0])
3710        c1 = constant_op.constant([9.0])
3711
3712      packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle])
3713      self.assertTrue(packed_var0.is_packed)
3714      self.assertEqual(packed_var0.dtype, var0.handle.dtype)
3715      self.assertEqual(packed_var0.shape, var0.handle.shape)
3716      self.assertEqual(packed_var0._handle_data, var0.handle._handle_data)
3717      self.assertIn("COMPOSITE:0", packed_var0.device)
3718      self.assertIn("COMPOSITE:0", packed_var0.backing_device)
3719      with self.assertRaises(errors.InvalidArgumentError):
3720        packed_var0.numpy()
3721
3722      # Different dtypes
3723      with self.assertRaises(ValueError):
3724        ops.pack_eager_tensors([var0.handle, c1])
3725
3726      # Different shapes
3727      with self.assertRaises(ValueError):
3728        ops.pack_eager_tensors([c0, c1])
3729
3730      # Different handle data
3731      with self.assertRaises(ValueError):
3732        ops.pack_eager_tensors([var0.handle, var2.handle])
3733
3734
3735class GraphDefInputShapesTest(test_util.TensorFlowTestCase):
3736
3737  def setUpInputShapes(self, pre_add_input_shapes):
3738
3739    test_tensor_shape = [None, 1, 1, 1]
3740
3741    @def_function.function(input_signature=[
3742        tensor_spec.TensorSpec(shape=test_tensor_shape, dtype=dtypes.float32)
3743    ])
3744    def f(x):
3745      return array_ops.identity(x, name="output")
3746
3747    x = array_ops.ones([2, 1, 1, 1], dtype=dtypes.float32)
3748    f(x)
3749
3750    tensor_shape_proto = tensor_shape_pb2.TensorShapeProto(dim=[
3751        tensor_shape_pb2.TensorShapeProto.Dim(size=-1 if d is None else d)
3752        for d in test_tensor_shape
3753    ])
3754    list_proto = attr_value_pb2.AttrValue.ListValue(shape=[tensor_shape_proto])
3755    concrete_function = f.get_concrete_function()
3756    if pre_add_input_shapes:
3757      attr_value = attr_value_pb2.AttrValue(list=list_proto)
3758      concrete_function = eager_function.ConcreteFunction(
3759          concrete_function.graph,
3760          attrs={"_input_shapes": attr_value},
3761          spec=concrete_function._pre_initialized_function_spec)
3762
3763    test_graph = ops.Graph()
3764    with test_graph.as_default():
3765      concrete_function.add_to_graph(g=test_graph)
3766    graph_def = test_graph.as_graph_def(add_shapes=True)
3767    self.assertLen(graph_def.library.function, 1)
3768    function_def = graph_def.library.function[0]
3769    input_shapes = function_def.attr["_input_shapes"]
3770    return input_shapes
3771
3772  def testGraphDefInputShapes(self):
3773    pre_added_input_shapes = self.setUpInputShapes(pre_add_input_shapes=True)
3774    post_added_input_shapes = self.setUpInputShapes(pre_add_input_shapes=False)
3775    self.assertProtoEquals(pre_added_input_shapes, post_added_input_shapes)
3776
3777
3778class TensorTest(test_util.TensorFlowTestCase):
3779
3780  def testToArrayEagerMode(self):
3781
3782    with context.eager_mode():
3783      a = np.array(constant_op.constant(32), dtype=np.float32)
3784      b = np.array(constant_op.constant(32, dtype=dtypes.int64))
3785
3786      self.assertEqual(a.dtype, np.dtype(np.float32))
3787      self.assertEqual(b.dtype, np.dtype(np.int64))
3788
3789  def testToArrayFunctionMode(self):
3790
3791    @def_function.function
3792    def f():
3793      # Raises during trace compilation.
3794      return np.array(constant_op.constant(32), dtype=np.int32)
3795
3796    @def_function.function
3797    def g():
3798      # Raises during trace compilation.
3799      return np.array(constant_op.constant(32))
3800
3801    with self.assertRaisesRegex(NotImplementedError,
3802                                "Cannot convert a symbolic tf.Tensor"):
3803      f()
3804
3805    with self.assertRaisesRegex(NotImplementedError,
3806                                "Cannot convert a symbolic tf.Tensor"):
3807      g()
3808
3809
3810if __name__ == "__main__":
3811  googletest.main()
3812