1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.ops.""" 16 17import gc 18import os 19import threading 20import weakref 21 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import full_type_pb2 27from tensorflow.core.framework import tensor_shape_pb2 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.autograph.core import ag_ctx 30from tensorflow.python.client import session 31from tensorflow.python.data.ops import dataset_ops 32from tensorflow.python.eager import backprop 33from tensorflow.python.eager import context 34from tensorflow.python.eager import def_function 35from tensorflow.python.eager import function as eager_function 36from tensorflow.python.eager import wrap_function 37from tensorflow.python.framework import composite_tensor 38from tensorflow.python.framework import config 39from tensorflow.python.framework import constant_op 40from tensorflow.python.framework import device as pydev 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import errors 43from tensorflow.python.framework import function 44from tensorflow.python.framework import indexed_slices 45from tensorflow.python.framework import ops 46from tensorflow.python.framework import sparse_tensor 47from tensorflow.python.framework import tensor_shape 48from tensorflow.python.framework import tensor_spec 49from tensorflow.python.framework import tensor_util 50from tensorflow.python.framework import test_ops 51from tensorflow.python.framework import test_util 52from tensorflow.python.framework import type_spec 53from tensorflow.python.framework import versions 54from tensorflow.python.ops import array_ops 55from tensorflow.python.ops import control_flow_ops 56from tensorflow.python.ops import math_ops 57from tensorflow.python.ops import resource_variable_ops 58from tensorflow.python.ops import resources 59from tensorflow.python.ops import special_math_ops 60from tensorflow.python.ops import variable_scope 61from tensorflow.python.ops import variables 62import tensorflow.python.ops.gradients # pylint: disable=unused-import 63from tensorflow.python.platform import googletest 64from tensorflow.python.util import compat 65 66 67class ResourceTest(test_util.TensorFlowTestCase): 68 69 @test_util.run_deprecated_v1 70 def testBuildGraph(self): 71 with self.cached_session(): 72 pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") 73 test_ops.resource_create_op(pt).run() 74 75 @test_util.run_deprecated_v1 76 def testInitialize(self): 77 with self.cached_session(): 78 handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") 79 resources.register_resource( 80 handle=handle, 81 create_op=test_ops.resource_create_op(handle), 82 is_initialized_op=test_ops.resource_initialized_op(handle)) 83 self.assertEqual( 84 len( 85 resources.report_uninitialized_resources( 86 resources.shared_resources()).eval()), 1) 87 resources.initialize_resources(resources.shared_resources()).run() 88 self.assertEqual( 89 len( 90 resources.report_uninitialized_resources( 91 resources.shared_resources()).eval()), 0) 92 93 94class TensorAndShapeTest(test_util.TensorFlowTestCase): 95 96 def testShape(self): 97 op = ops.Operation( 98 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 99 t = op.outputs[0] 100 self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) 101 t.set_shape([1, 2, 3]) 102 self.assertEqual([1, 2, 3], t.get_shape()) 103 104 def testIterable(self): 105 if not context.executing_eagerly(): 106 self.skipTest("Eager-mode test") 107 op = ops.Operation( 108 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 109 t = op.outputs[0] 110 with self.assertRaisesRegex(TypeError, "Cannot iterate"): 111 iter(t) 112 113 def testIterableGraph(self): 114 if context.executing_eagerly(): 115 self.skipTest("Graph-mode test") 116 117 op = ops.Operation( 118 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 119 t = op.outputs[0] 120 with self.assertRaisesRegex(TypeError, "Iterating.*not allowed in Graph"): 121 next(iter(t)) 122 with self.assertRaisesRegex(TypeError, "Iterating.*AutoGraph did convert"): 123 with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED): 124 next(iter(t)) 125 with self.assertRaisesRegex(TypeError, "Iterating.*AutoGraph is disabled"): 126 with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED): 127 next(iter(t)) 128 129 def testImplicitBool(self): 130 op = ops.Operation( 131 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool]) 132 t = op.outputs[0] 133 with self.assertRaisesRegex(TypeError, 134 "Using.*as a.*bool.*not allowed in Graph"): 135 bool(t) 136 with self.assertRaisesRegex(TypeError, 137 "Using.*as a.*bool.*AutoGraph did convert"): 138 with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED): 139 bool(t) 140 with self.assertRaisesRegex(TypeError, 141 "Using.*as a.*bool.*AutoGraph is disabled"): 142 with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED): 143 bool(t) 144 145 def testAddShape(self): 146 with self.cached_session(): 147 a = array_ops.zeros([2, 3]) 148 b = array_ops.ones([1, 3]) 149 c = a + b 150 self.assertEqual([2, 3], c.shape) 151 152 @test_util.run_deprecated_v1 153 def testUnknownDim(self): 154 with self.cached_session(): 155 a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 156 b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 157 c = a + b 158 self.assertEqual([2, None, 3], c.shape.as_list()) 159 160 @test_util.run_deprecated_v1 161 def testUnknownShape(self): 162 with self.cached_session(): 163 a = array_ops.placeholder(dtype=dtypes.float32, shape=None) 164 b = array_ops.ones([1, 3]) 165 c = a + b 166 self.assertEqual(tensor_shape.unknown_shape(), c.shape) 167 168 @test_util.run_deprecated_v1 169 def testScalarShape(self): 170 with self.cached_session(): 171 a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 172 b = array_ops.ones([]) 173 c = a + b 174 self.assertEqual(tensor_shape.TensorShape([]), c.shape) 175 176 @test_util.run_deprecated_v1 177 def testShapeFunctionError(self): 178 with self.cached_session(): 179 a = array_ops.ones([1, 2, 3]) 180 b = array_ops.ones([4, 5, 6]) 181 with self.assertRaisesRegex( 182 ValueError, r"Dimensions must be equal, but are 2 and 5 for .*add" 183 r".*Add(V2)?.* with input shapes: \[1,2,3\], \[4,5,6\]."): 184 _ = a + b 185 186 def testNumpyArray(self): 187 with ops.Graph().as_default(): 188 x = array_ops.ones((3, 4), name="test_ones") 189 190 with self.assertRaisesRegex(NotImplementedError, 191 r"Cannot convert a symbolic.+test_ones"): 192 np.array(x) 193 194 with self.assertRaisesRegex(TypeError, "not well defined.+test_ones"): 195 len(x) 196 197 # EagerTensors should still behave as numpy arrays. 198 with context.eager_mode(): 199 x = array_ops.ones((3, 4)) 200 201 self.assertAllEqual(x, np.ones((3, 4))) 202 self.assertAllEqual(np.array(x), np.ones((3, 4))) 203 self.assertLen(x, 3) 204 205 def testConstructor(self): 206 a = array_ops.ones([]) 207 for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size", 208 "tolist", "data"]: 209 with self.assertRaisesRegex( 210 AttributeError, r"If you are looking for numpy-related methods"): 211 getattr(a, name) 212 with self.assertRaisesRegex( 213 AttributeError, r"object has no attribute"): 214 a.foo_bar() 215 216 def testRef(self): 217 x1 = constant_op.constant(3) 218 x2 = x1 219 y = constant_op.constant(3) 220 z = constant_op.constant([6, 10]) 221 w = variables.Variable(5) 222 223 self.assertEqual(x1.ref(), x1.ref()) 224 self.assertEqual(x2.ref(), x2.ref()) 225 self.assertEqual(x1.ref(), x2.ref()) 226 self.assertEqual(y.ref(), y.ref()) 227 self.assertEqual(z.ref(), z.ref()) 228 self.assertEqual(w.ref(), w.ref()) 229 230 self.assertNotEqual(x1.ref(), y.ref()) 231 self.assertNotEqual(x1.ref(), z.ref()) 232 self.assertNotEqual(x1.ref(), w.ref()) 233 self.assertNotEqual(y.ref(), z.ref()) 234 self.assertNotEqual(y.ref(), w.ref()) 235 self.assertNotEqual(z.ref(), w.ref()) 236 237 def testRefDeref(self): 238 x1 = constant_op.constant(3) 239 x2 = x1 240 y = constant_op.constant(3) 241 z = constant_op.constant([6, 10]) 242 w = variables.Variable(5) 243 244 self.assertIs(x1, x1.ref().deref()) 245 self.assertIs(x2, x2.ref().deref()) 246 self.assertIs(x1, x2.ref().deref()) 247 self.assertIs(x2, x1.ref().deref()) 248 self.assertIs(y, y.ref().deref()) 249 self.assertIs(z, z.ref().deref()) 250 251 self.assertIsNot(x1, y.ref().deref()) 252 self.assertIsNot(x1, z.ref().deref()) 253 self.assertIsNot(x1, w.ref().deref()) 254 self.assertIsNot(y, z.ref().deref()) 255 self.assertIsNot(y, w.ref().deref()) 256 self.assertIsNot(z, w.ref().deref()) 257 258 def testRefInSet(self): 259 x1 = constant_op.constant(3) 260 x2 = x1 261 y = constant_op.constant(3) 262 z = constant_op.constant([6, 10]) 263 w = variables.Variable(5) 264 265 self.assertEqual(x1.ref(), x2.ref()) 266 267 tensor_set = { 268 x1.ref(), 269 x2.ref(), 270 y.ref(), 271 z.ref(), 272 w.ref(), 273 } 274 275 self.assertLen(tensor_set, 4) 276 self.assertIn(x1.ref(), tensor_set) 277 self.assertIn(x2.ref(), tensor_set) 278 self.assertIn(y.ref(), tensor_set) 279 self.assertIn(z.ref(), tensor_set) 280 self.assertIn(w.ref(), tensor_set) 281 282 def testRefInDict(self): 283 x1 = constant_op.constant(3) 284 x2 = x1 285 y = constant_op.constant(3) 286 z = constant_op.constant([6, 10]) 287 w = variables.Variable(5) 288 289 self.assertEqual(x1.ref(), x2.ref()) 290 291 tensor_dict = { 292 x1.ref(): "x1", 293 y.ref(): "y", 294 z.ref(): "z", 295 w.ref(): "w", 296 } 297 298 self.assertLen(tensor_dict, 4) 299 300 # Overwriting x1 301 tensor_dict[x2.ref()] = "x2" 302 self.assertLen(tensor_dict, 4) 303 304 self.assertEqual(tensor_dict[x1.ref()], "x2") 305 self.assertEqual(tensor_dict[x2.ref()], "x2") 306 self.assertEqual(tensor_dict[y.ref()], "y") 307 self.assertEqual(tensor_dict[z.ref()], "z") 308 self.assertEqual(tensor_dict[w.ref()], "w") 309 310 def testTensorRefStrong(self): 311 x = constant_op.constant(1.) 312 x_ref = x.ref() 313 del x 314 self.assertIsNotNone(x_ref.deref()) 315 316 def testVariableRefStrong(self): 317 x = variables.Variable(1.) 318 x_ref = x.ref() 319 del x 320 self.assertIsNotNone(x_ref.deref()) 321 322 @test_util.run_in_graph_and_eager_modes 323 def testBitwiseAndNumeric(self): 324 x = constant_op.constant([0, 1, 3]) 325 y = constant_op.constant([1, 1, 1]) 326 327 z = x & y 328 329 self.assertAllEqual(z, [0, 1, 1]) 330 331 @test_util.run_in_graph_and_eager_modes 332 def testBitwiseAndBool(self): 333 x = constant_op.constant([False, False, True, True]) 334 y = constant_op.constant([False, True, False, True]) 335 336 z = x & y 337 338 self.assertAllEqual(z, [False, False, False, True]) 339 340 @test_util.run_in_graph_and_eager_modes 341 def testBitwiseAndErrors(self): 342 x_int = constant_op.constant(0) 343 x_bool = constant_op.constant(True) 344 345 if context.executing_eagerly(): # :( 346 expected_errtype = errors.InvalidArgumentError 347 else: 348 expected_errtype = TypeError 349 350 with self.assertRaises(expected_errtype): 351 _ = x_int & x_bool 352 with self.assertRaises(expected_errtype): 353 _ = x_int & constant_op.constant("a") 354 355 with self.assertRaises(expected_errtype): 356 _ = x_bool & x_int 357 with self.assertRaises(expected_errtype): 358 _ = x_bool & constant_op.constant("a") 359 360 with self.assertRaises(expected_errtype): 361 _ = constant_op.constant("a") & constant_op.constant("b") 362 363 @test_util.run_in_graph_and_eager_modes 364 def testBitwiseOrNumeric(self): 365 x = constant_op.constant([0, 1, 2]) 366 y = constant_op.constant([1, 1, 1]) 367 368 z = x | y 369 370 self.assertAllEqual(z, [1, 1, 3]) 371 372 @test_util.run_in_graph_and_eager_modes 373 def testBitwiseOrBool(self): 374 x = constant_op.constant([False, False, True, True]) 375 y = constant_op.constant([False, True, False, True]) 376 377 z = x | y 378 379 self.assertAllEqual(z, [False, True, True, True]) 380 381 @test_util.run_in_graph_and_eager_modes 382 def testBitwiseOrErrors(self): 383 x_int = constant_op.constant(0) 384 x_bool = constant_op.constant(True) 385 386 if context.executing_eagerly(): # :( 387 expected_errtype = errors.InvalidArgumentError 388 else: 389 expected_errtype = TypeError 390 391 with self.assertRaises(expected_errtype): 392 _ = x_int | x_bool 393 with self.assertRaises(expected_errtype): 394 _ = x_int | constant_op.constant("a") 395 396 with self.assertRaises(expected_errtype): 397 _ = x_bool | x_int 398 with self.assertRaises(expected_errtype): 399 _ = x_bool | constant_op.constant("a") 400 401 with self.assertRaises(expected_errtype): 402 _ = constant_op.constant("a") | constant_op.constant("b") 403 404 @test_util.run_in_graph_and_eager_modes 405 def testBitwiseXorNumeric(self): 406 x = constant_op.constant([0, 1, 3]) 407 y = constant_op.constant([1, 1, 1]) 408 409 z = x ^ y 410 411 self.assertAllEqual(z, [1, 0, 2]) 412 413 @test_util.run_in_graph_and_eager_modes 414 def testBitwiseXorBool(self): 415 x = constant_op.constant([False, False, True, True]) 416 y = constant_op.constant([False, True, False, True]) 417 418 z = x ^ y 419 420 self.assertAllEqual(z, [False, True, True, False]) 421 422 @test_util.run_in_graph_and_eager_modes 423 def testBitwiseXorErrors(self): 424 x_int = constant_op.constant(0) 425 x_bool = constant_op.constant(True) 426 427 if context.executing_eagerly(): # :( 428 expected_errtype = errors.InvalidArgumentError 429 else: 430 expected_errtype = TypeError 431 432 with self.assertRaises(expected_errtype): 433 _ = x_int ^ x_bool 434 with self.assertRaises(expected_errtype): 435 _ = x_int ^ constant_op.constant("a") 436 437 with self.assertRaises(expected_errtype): 438 _ = x_bool ^ x_int 439 with self.assertRaises(expected_errtype): 440 _ = x_bool ^ constant_op.constant("a") 441 442 with self.assertRaises(expected_errtype): 443 _ = constant_op.constant("a") ^ constant_op.constant("b") 444 445 @test_util.run_in_graph_and_eager_modes 446 def testBitwiseNotNumeric(self): 447 x = constant_op.constant([0, dtypes.int32.min, 1]) 448 449 # pylint: disable=invalid-unary-operand-type 450 y = ~x 451 452 self.assertAllEqual(y, [-1, dtypes.int32.max, -2]) 453 454 @test_util.run_in_graph_and_eager_modes 455 def testBitwiseNotBool(self): 456 x = constant_op.constant([False, True]) 457 458 # pylint: disable=invalid-unary-operand-type 459 y = ~x 460 461 self.assertAllEqual(y, [True, False]) 462 463 @test_util.run_in_graph_and_eager_modes 464 def testBitwiseNotErrors(self): 465 if context.executing_eagerly(): # :( 466 expected_errtype = errors.InvalidArgumentError 467 else: 468 expected_errtype = TypeError 469 470 # pylint: disable=invalid-unary-operand-type 471 with self.assertRaises(expected_errtype): 472 _ = ~constant_op.constant("a") 473 474 475@test_util.run_all_in_graph_and_eager_modes 476class IndexedSlicesTest(test_util.TensorFlowTestCase): 477 478 def testToTensor(self): 479 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 480 indices = constant_op.constant([0, 2]) 481 x = indexed_slices.IndexedSlices(values, indices) 482 with self.assertRaises(ValueError): 483 tensor = ops.convert_to_tensor(x, name="tensor") 484 self.assertEqual(tensor_shape.TensorShape(None), x.shape) 485 486 dense_shape = constant_op.constant([3, 2]) 487 y = indexed_slices.IndexedSlices(values, indices, dense_shape) 488 tensor = ops.convert_to_tensor(y, name="tensor") 489 self.assertAllEqual(tensor.shape, y.shape) 490 self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]]) 491 492 @test_util.run_gpu_only 493 def testEagerCopy(self): 494 with context.eager_mode(): 495 var = variables.Variable([[0.0], [0.0], [0.0], [0.0]], name="tensor") 496 with backprop.GradientTape() as tape: 497 a = array_ops.gather(array_ops.gather(var, [0, 1]), [0, 1]) 498 b = array_ops.gather(array_ops.gather(var, [2, 3]), [0, 1]) 499 r = special_math_ops.einsum("ij,ij->i", a, b) 500 g = tape.gradient(r, [var])[0] 501 values = g.values if isinstance(g, indexed_slices.IndexedSlices) else g 502 self.assertAllEqual(values.get_shape(), [4, 1]) 503 504 def testNegation(self): 505 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 506 indices = constant_op.constant([0, 2]) 507 x = -indexed_slices.IndexedSlices(values, indices) 508 self.assertAllEqual(x.values, [[-2, -3], [-5, -7]]) 509 self.assertAllEqual(x.indices, [0, 2]) 510 511 def testScalarMul(self): 512 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 513 indices = constant_op.constant([0, 2]) 514 x = math_ops.scalar_mul(-2, indexed_slices.IndexedSlices(values, indices)) 515 self.assertAllEqual(x.values, [[-4, -6], [-10, -14]]) 516 self.assertAllEqual(x.indices, [0, 2]) 517 518 519@test_util.run_all_in_graph_and_eager_modes 520class IndexedSlicesSpecTest(test_util.TensorFlowTestCase, 521 parameterized.TestCase): 522 523 def assertAllTensorsEqual(self, list1, list2): 524 self.assertLen(list1, len(list2)) 525 for (t1, t2) in zip(list1, list2): 526 self.assertAllEqual(t1, t2) 527 528 def testConstruction(self): 529 spec1 = indexed_slices.IndexedSlicesSpec() 530 self.assertIsNone(spec1._shape.rank) 531 self.assertEqual(spec1._values_dtype, dtypes.float32) 532 self.assertEqual(spec1._indices_dtype, dtypes.int64) 533 self.assertIsNone(spec1._dense_shape_dtype) 534 self.assertEqual(spec1._indices_shape.as_list(), [None]) 535 536 spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string, 537 dtypes.int32, dtypes.int64, [10]) 538 self.assertEqual(spec2._shape.as_list(), [None, None]) 539 self.assertEqual(spec2._values_dtype, dtypes.string) 540 self.assertEqual(spec2._indices_dtype, dtypes.int32) 541 self.assertEqual(spec2._dense_shape_dtype, dtypes.int64) 542 self.assertEqual(spec2._indices_shape.as_list(), [10]) 543 544 def testValueType(self): 545 spec1 = indexed_slices.IndexedSlicesSpec() 546 self.assertEqual(spec1.value_type, indexed_slices.IndexedSlices) 547 548 @parameterized.parameters([ 549 (indexed_slices.IndexedSlicesSpec(), 550 (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None, 551 tensor_shape.TensorShape([None]))), 552 (indexed_slices.IndexedSlicesSpec(shape=[5, None, None]), 553 (tensor_shape.TensorShape([5, None, None]), dtypes.float32, 554 dtypes.int64, None, tensor_shape.TensorShape([None]))), 555 (indexed_slices.IndexedSlicesSpec( 556 dtype=dtypes.int32, dense_shape_dtype=dtypes.int64), 557 (tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64, 558 dtypes.int64, tensor_shape.TensorShape([None]))), 559 (indexed_slices.IndexedSlicesSpec(indices_shape=[100]), 560 (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None, 561 tensor_shape.TensorShape([100]))), 562 ]) # pyformat: disable 563 def testSerialize(self, spec, expected): 564 serialization = spec._serialize() 565 # TensorShape has an unconventional definition of equality, so we can't use 566 # assertEqual directly here. But repr() is deterministic and lossless for 567 # the expected values, so we can use that instead. 568 self.assertEqual(repr(serialization), repr(expected)) 569 570 @parameterized.parameters([ 571 (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), ( 572 tensor_spec.TensorSpec(None, dtypes.string), 573 tensor_spec.TensorSpec([None], dtypes.int64), 574 )), 575 (indexed_slices.IndexedSlicesSpec( 576 dtype=dtypes.string, dense_shape_dtype=dtypes.int32), ( 577 tensor_spec.TensorSpec(None, dtypes.string), 578 tensor_spec.TensorSpec([None], dtypes.int64), 579 tensor_spec.TensorSpec([None], dtypes.int32), 580 )), 581 (indexed_slices.IndexedSlicesSpec( 582 shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), ( 583 tensor_spec.TensorSpec([None, 10, 15], dtypes.float32), 584 tensor_spec.TensorSpec([None], dtypes.int64), 585 tensor_spec.TensorSpec([3], dtypes.int32), 586 )), 587 (indexed_slices.IndexedSlicesSpec( 588 shape=[5, 10, 15], dense_shape_dtype=dtypes.int32, 589 indices_shape=[20]), ( 590 tensor_spec.TensorSpec([20, 10, 15], dtypes.float32), 591 tensor_spec.TensorSpec([20], dtypes.int64), 592 tensor_spec.TensorSpec([3], dtypes.int32), 593 )), 594 ]) 595 def testComponentSpecs(self, spec, expected): 596 self.assertEqual(spec._component_specs, expected) 597 598 @parameterized.parameters([ 599 { 600 "spec": indexed_slices.IndexedSlicesSpec(), 601 "values": [3.0, 5.0], 602 "indices": [5, 10] 603 }, 604 { 605 "spec": 606 indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32), 607 "values": [3.0, 5.0], 608 "indices": [5, 10], 609 "dense_shape": [100] 610 }, 611 ]) 612 def testToFromComponents(self, spec, indices, values, dense_shape=None): 613 x = indexed_slices.IndexedSlices(indices, values, dense_shape) 614 actual_components = spec._to_components(x) 615 if dense_shape is None: 616 self.assertAllTensorsEqual(actual_components, [indices, values]) 617 else: 618 self.assertAllTensorsEqual(actual_components, 619 [indices, values, dense_shape]) 620 st_reconstructed = spec._from_components(actual_components) 621 self.assertAllEqual(x.indices, st_reconstructed.indices) 622 self.assertAllEqual(x.values, st_reconstructed.values) 623 if dense_shape is None: 624 self.assertIsNone(st_reconstructed.dense_shape) 625 else: 626 self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape) 627 628 @test_util.run_v1_only("IndexedSlicesValue is deprecated in v2") 629 def testFromNumpyComponents(self): 630 indices = np.array([3, 8]) 631 values = np.array([1.0, 9.0]) 632 dense_shape = np.array([100]) 633 634 spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32) 635 st1 = spec1._from_components((values, indices, dense_shape)) 636 self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue) 637 self.assertAllEqual(st1.indices, indices) 638 self.assertAllEqual(st1.values, values) 639 self.assertAllEqual(st1.dense_shape, dense_shape) 640 641 spec2 = indexed_slices.IndexedSlicesSpec() 642 st2 = spec2._from_components((values, indices)) 643 self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue) 644 self.assertAllEqual(st2.indices, indices) 645 self.assertAllEqual(st2.values, values) 646 self.assertIsNone(st2.dense_shape) 647 648 649class NodeDefConstructorTest(test_util.TensorFlowTestCase): 650 651 def testNoArgs(self): 652 nodedef = ops._NodeDef("None", "bar") 653 self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) 654 655 656def _apply_op(g, *args, **kwargs): 657 op = g.create_op(*args, **kwargs) 658 if len(op.outputs) == 1: 659 return op.outputs[0] 660 else: 661 return op.outputs 662 663 664class OperationTest(test_util.TensorFlowTestCase): 665 666 def testTraceback(self): 667 g = ops.Graph() 668 op1 = ops.Operation( 669 ops._NodeDef("None", "op1"), g, [], 670 [dtypes.float32_ref, dtypes.float32]) 671 self.assertIn("testTraceback", op1.traceback[-1]) 672 673 @test_util.run_deprecated_v1 674 def testNoInputs(self): 675 op = test_ops.float_output_string_output(name="myop").a.op 676 self.assertEqual(2, len(op.values())) 677 self.assertEqual(0, len(op.inputs)) 678 self.assertEqual("myop", op.name) 679 680 float_t, label_str_t = op.values() 681 self.assertEqual(dtypes.float32, float_t.dtype) 682 self.assertEqual(op, float_t.op) 683 self.assertEqual(0, float_t._value_index) 684 self.assertEqual(0, len(float_t.consumers())) 685 self.assertEqual("myop", float_t._as_node_def_input()) 686 687 self.assertEqual(dtypes.string, label_str_t.dtype) 688 self.assertEqual(op, label_str_t.op) 689 self.assertEqual(1, label_str_t._value_index) 690 self.assertEqual(0, len(label_str_t.consumers())) 691 self.assertEqual("myop:1", label_str_t._as_node_def_input()) 692 693 self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", 694 op.node_def) 695 696 @test_util.run_deprecated_v1 697 def testNoOutputs(self): 698 op1 = test_ops.float_output(name="myop1").op 699 float_t, = op1.values() 700 op2 = test_ops.float_input(float_t, name="myop2") 701 self.assertEqual(0, len(op2.values())) 702 self.assertEqual(1, len(op2.inputs)) 703 self.assertIs(float_t, op2.inputs[0]) 704 705 self.assertEqual(1, len(float_t.consumers())) 706 self.assertEqual(op2, float_t.consumers()[0]) 707 708 self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) 709 self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", 710 op2.node_def) 711 712 @test_util.run_deprecated_v1 713 def testInputsAndOutputs(self): 714 op1 = test_ops.float_output(name="myop1").op 715 self.assertEqual(1, len(op1.values())) 716 float1_t, = op1.values() 717 718 op2 = test_ops.float_output_string_output(name="myop2").a.op 719 self.assertEqual(2, len(op2.values())) 720 float2_t, label2_str_t = op2.values() 721 722 # Note that we consume label2_str_t twice here. 723 op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op 724 self.assertEqual(2, len(op3.values())) 725 726 self.assertEqual(1, len(float1_t.consumers())) 727 self.assertEqual(op3, float1_t.consumers()[0]) 728 729 self.assertEqual(0, len(float2_t.consumers())) 730 731 self.assertEqual(2, len(label2_str_t.consumers())) 732 self.assertEqual(op3, label2_str_t.consumers()[0]) 733 self.assertEqual(op3, label2_str_t.consumers()[1]) 734 735 self.assertProtoEquals(""" 736 op:'Foo2' name:'myop3' 737 input:'myop1' input:'myop2:1' input:'myop2:1' 738 """, op3.node_def) 739 740 def testDeviceObject(self): 741 op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) 742 op._set_device("/job:goo/device:GPU:0") 743 self.assertProtoEquals( 744 "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) 745 op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) 746 op._set_device( 747 pydev.DeviceSpec( 748 job="muu", device_type="CPU", device_index=0)) 749 self.assertProtoEquals( 750 "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) 751 752 def testReferenceInput(self): 753 g = ops.Graph() 754 op1 = ops.Operation( 755 ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], 756 [dtypes.float32_ref, dtypes.float32]) 757 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 758 self.assertEqual([], list(op1.inputs)) 759 ref_t, nonref_t = op1.values() 760 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 761 op2 = ops.Operation( 762 ops._NodeDef("RefInputFloatInput", "op2"), 763 g, [ref_t, nonref_t], [], 764 input_types=[dtypes.float32_ref, dtypes.float32]) 765 self.assertProtoEquals( 766 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 767 op2.node_def) 768 self.assertEqual([ref_t, nonref_t], list(op2.inputs)) 769 op3 = ops.Operation( 770 ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) 771 self.assertProtoEquals( 772 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 773 op3.node_def) 774 775 def testInvalidNames(self): 776 g = ops.Graph() 777 with self.assertRaises(ValueError): 778 ops.Operation(ops._NodeDef("op", ""), g) 779 with self.assertRaises(ValueError): 780 ops.Operation(ops._NodeDef("op", "_invalid"), g) 781 with self.assertRaises(ValueError): 782 ops.Operation(ops._NodeDef("op", "-invalid"), g) 783 with self.assertRaises(ValueError): 784 ops.Operation(ops._NodeDef("op", "/invalid"), g) 785 with self.assertRaises(ValueError): 786 ops.Operation(ops._NodeDef("op", "invalid:0"), g) 787 788 @test_util.run_deprecated_v1 789 def testNoShapeFunction(self): 790 op = test_ops.a() 791 self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) 792 793 @test_util.run_in_graph_and_eager_modes 794 def testConvertToTensorNestedArray(self): 795 values = [[2], [3], [5], [7]] 796 tensor = ops.convert_to_tensor(values) 797 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 798 self.assertAllEqual(values, self.evaluate(tensor)) 799 800 def testShapeTuple(self): 801 with self.cached_session(): 802 c = constant_op.constant(1) 803 self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access 804 805 def testConvertToTensorEager(self): 806 with context.eager_mode(): 807 t = constant_op.constant(1) 808 self.assertTrue(isinstance(t, ops.EagerTensor)) 809 converted = ops.convert_to_tensor(t) 810 self.assertTrue(isinstance(converted, ops.EagerTensor)) 811 converted = ops.convert_to_tensor(1) 812 self.assertTrue(isinstance(converted, ops.EagerTensor)) 813 814 @test_util.run_in_graph_and_eager_modes 815 def testConvertToTensorNestedTuple(self): 816 values = ((2,), (3,), (5,), (7,)) 817 tensor = ops.convert_to_tensor(values) 818 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 819 self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values))) 820 821 @test_util.run_in_graph_and_eager_modes 822 def testConvertToTensorNestedTensors(self): 823 values = ((2,), (3,), (5,), (7,)) 824 tensor = ops.convert_to_tensor( 825 [constant_op.constant(row) for row in values]) 826 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 827 self.assertAllEqual(values, self.evaluate(tensor)) 828 tensor = ops.convert_to_tensor( 829 [[constant_op.constant(v) for v in row] for row in values]) 830 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 831 self.assertAllEqual(values, self.evaluate(tensor)) 832 833 @test_util.run_in_graph_and_eager_modes 834 def testConvertToTensorNestedMix(self): 835 values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) 836 tensor = ops.convert_to_tensor(values) 837 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 838 self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor)) 839 840 @test_util.run_in_graph_and_eager_modes 841 def testConvertToTensorPreferred(self): 842 values = [2, 3, 5, 7] 843 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) 844 self.assertEqual(dtypes.float32, tensor.dtype) 845 846 # Convert empty tensor to anything. 847 values = [] 848 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 849 self.assertEqual(dtypes.int64, tensor.dtype) 850 851 # The preferred dtype is a type error and will convert to 852 # float32 instead. 853 values = [1.23] 854 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 855 self.assertEqual(dtypes.float32, tensor.dtype) 856 857 @test_util.run_in_graph_and_eager_modes 858 def testConvertToInvalidTensorType(self): 859 with self.assertRaises(TypeError): 860 # Forcing an invalid dtype should fail with a type error. 861 values = [1.23] 862 ops.convert_to_tensor(values, dtype=dtypes.int64) 863 864 @test_util.run_in_graph_and_eager_modes 865 def testConvertToLongLongTensorType(self): 866 tensor = ops.convert_to_tensor( 867 # Get a numpy array of dtype NPY_LONGLONG 868 np.prod(constant_op.constant([1])._shape_tuple()), 869 dtype=dtypes.int64) 870 self.assertEqual(dtypes.int64, tensor.dtype) 871 872 @test_util.run_in_graph_and_eager_modes 873 def testConvertToTensorFromInvalidTensor(self): 874 tensor = constant_op.constant(42.0, dtype=dtypes.float32) 875 with self.assertRaises(ValueError): 876 ops.convert_to_tensor(tensor, dtype=dtypes.int32) 877 878 @test_util.run_in_graph_and_eager_modes 879 def testConvertToTensorProtocol(self): 880 class TensorCompatible: 881 882 def __tf_tensor__(self, dtype=None, name=None): 883 return constant_op.constant((1, 2, 3), dtype=dtype, name=name) 884 885 tc = TensorCompatible() 886 887 tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32) 888 self.assertEqual(tensor.dtype, dtypes.int32) 889 self.assertAllEqual((1, 2, 3), self.evaluate(tensor)) 890 891 @test_util.run_deprecated_v1 892 def testNoConvert(self): 893 # Operation cannot be converted to Tensor. 894 op = control_flow_ops.no_op() 895 with self.assertRaisesRegex(TypeError, 896 "can't convert Operation '.+' to Tensor"): 897 ops.convert_to_tensor(op) 898 899 def testStr(self): 900 node_def = ops._NodeDef("None", "op1") 901 op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) 902 self.assertEqual(str(node_def), str(op)) 903 904 def testRepr(self): 905 op = ops.Operation( 906 ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) 907 self.assertEqual("<tf.Operation 'op1' type=None>", repr(op)) 908 909 @test_util.run_deprecated_v1 910 def testGetAttr(self): 911 op = test_ops.default_attrs() 912 self.assertEqual(op.get_attr("string_val"), b"abc") 913 self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) 914 self.assertEqual(op.get_attr("int_val"), 123) 915 self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) 916 self.assertEqual(op.get_attr("float_val"), 10.0) 917 self.assertEqual(op.get_attr("float_list_val"), [10.0]) 918 self.assertEqual(op.get_attr("bool_val"), True) 919 self.assertEqual(op.get_attr("bool_list_val"), [True, False]) 920 self.assertEqual(op.get_attr("shape_val"), 921 tensor_shape.as_shape([2, 1]).as_proto()) 922 self.assertEqual(op.get_attr("shape_list_val"), 923 [tensor_shape.as_shape([]).as_proto(), 924 tensor_shape.as_shape([1]).as_proto()]) 925 self.assertEqual(op.get_attr("tensor_val"), 926 tensor_util.make_tensor_proto(1, dtypes.int32)) 927 self.assertEqual(op.get_attr("tensor_list_val"), 928 [tensor_util.make_tensor_proto(1, dtypes.int32)]) 929 930 type_val = op.get_attr("type_val") 931 # First check that type_val is a DType, because the assertEqual will work 932 # no matter what since DType overrides __eq__ 933 self.assertIsInstance(type_val, dtypes.DType) 934 self.assertEqual(type_val, dtypes.int32) 935 936 type_list_val = op.get_attr("type_list_val") 937 self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) 938 self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) 939 940 @function.Defun(dtypes.float32, func_name="MyFunc") 941 def func(x): 942 return x 943 944 op = test_ops.func_attr(func) 945 self.assertEqual(op.get_attr("f"), 946 attr_value_pb2.NameAttrList(name="MyFunc")) 947 948 # Try fetching missing attr 949 with self.assertRaisesRegex( 950 ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."): 951 op.get_attr("FakeAttr") 952 953 # TODO(b/65162920): remove this test when users who are directly mutating the 954 # node_def have been updated to proper usage. 955 @test_util.run_deprecated_v1 956 def testSetAttr(self): 957 op = test_ops.int_attr().op 958 op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) 959 # TODO(skyewm): add node_def check 960 self.assertEqual(op.get_attr("foo"), 2) 961 962 @test_util.run_v2_only 963 def testSetFullType(self): 964 @def_function.function 965 def test_fn(): 966 ds = dataset_ops.Dataset.range(3)._variant_tensor 967 968 ds.op.experimental_set_type( 969 full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_PRODUCT)) 970 971 self.assertEqual(ds.op.node_def.experimental_type.type_id, 972 full_type_pb2.TFT_PRODUCT) 973 974 test_fn() 975 976 # TODO(nolivia): test all error cases 977 def testAddControlInput(self): 978 with ops.Graph().as_default(): 979 x = constant_op.constant(1).op 980 y = constant_op.constant(2).op 981 z = constant_op.constant(3).op 982 z._add_control_input(x) # pylint: disable=protected-access 983 self.assertEqual(z.control_inputs, [x]) 984 z._add_control_input(x) # pylint: disable=protected-access 985 self.assertEqual(z.control_inputs, [x]) 986 z._add_control_inputs([x, y, y]) # pylint: disable=protected-access 987 self.assertEqual(z.control_inputs, [x, y]) 988 self.assertEqual(x._control_outputs, [z]) 989 990 @test_util.run_deprecated_v1 991 def testRemoveAllControlInputs(self): 992 a = constant_op.constant(1) 993 with ops.control_dependencies([a]): 994 b = constant_op.constant(2) 995 c = constant_op.constant(3) 996 d = constant_op.constant(4) 997 e = constant_op.constant(5) 998 with ops.control_dependencies([a, c]): 999 f = d + e 1000 1001 self.assertEqual(a.op.control_inputs, []) 1002 self.assertEqual(b.op.control_inputs, [a.op]) 1003 self.assertEqual(f.op.control_inputs, [a.op, c.op]) 1004 1005 a.op._remove_all_control_inputs() # pylint: disable=protected-access 1006 self.assertEqual(a.op.control_inputs, []) 1007 1008 b.op._remove_all_control_inputs() # pylint: disable=protected-access 1009 self.assertEqual(b.op.control_inputs, []) 1010 1011 f.op._remove_all_control_inputs() # pylint: disable=protected-access 1012 self.assertEqual(f.op.control_inputs, []) 1013 self.assertEqual(list(f.op.inputs), [d, e]) 1014 1015 @test_util.run_deprecated_v1 1016 def testControlInputCycle(self): 1017 graph = ops.Graph() 1018 with graph.as_default(): 1019 z = constant_op.constant(0) 1020 x = constant_op.constant(1) 1021 y = constant_op.constant(2) 1022 y.op._add_control_input(z.op) # pylint: disable=protected-access 1023 y.op._add_control_input(x.op) # pylint: disable=protected-access 1024 x.op._add_control_input(y.op) # pylint: disable=protected-access 1025 with self.session(graph=graph) as sess: 1026 with self.assertRaisesRegex( 1027 errors.InvalidArgumentError, 1028 "Graph is invalid, contains a cycle with 2 nodes"): 1029 self.evaluate(x) 1030 1031 def testUpdateInput(self): 1032 g = ops.Graph() 1033 with g.as_default(): 1034 x = constant_op.constant(1) 1035 y = constant_op.constant(2) 1036 z = x + y 1037 1038 z.op._update_input(0, y) # pylint: disable=protected-access 1039 self.assertEqual(list(z.op.inputs), [y, y]) 1040 self.assertEqual(x.consumers(), []) 1041 self.assertEqual(y.consumers(), [z.op, z.op]) 1042 with session.Session(graph=g) as sess: 1043 self.assertEqual(self.evaluate(z), 4) 1044 1045 z.op._update_input(0, x) # pylint: disable=protected-access 1046 self.assertEqual(list(z.op.inputs), [x, y]) 1047 self.assertEqual(x.consumers(), [z.op]) 1048 self.assertEqual(y.consumers(), [z.op]) 1049 with session.Session(graph=g) as sess: 1050 self.assertEqual(self.evaluate(z), 3) 1051 1052 z.op._update_input(1, y) # pylint: disable=protected-access 1053 self.assertEqual(list(z.op.inputs), [x, y]) 1054 self.assertEqual(x.consumers(), [z.op]) 1055 self.assertEqual(y.consumers(), [z.op]) 1056 with session.Session(graph=g) as sess: 1057 self.assertEqual(self.evaluate(z), 3) 1058 1059 def testUpdateInputGraphError(self): 1060 g_0 = ops.Graph() 1061 g_1 = ops.Graph() 1062 with g_0.as_default(): 1063 x = constant_op.constant(1) 1064 with g_1.as_default(): 1065 y = constant_op.constant(2) 1066 z = y * 2 1067 with self.assertRaisesRegex(ValueError, "must be from the same graph"): 1068 z.op._update_input(0, x) # pylint: disable=protected-access 1069 1070 def testUpdateInputTypeError(self): 1071 g = ops.Graph() 1072 with g.as_default(): 1073 w = constant_op.constant(0) 1074 x = constant_op.constant("") 1075 y = constant_op.constant(1) 1076 z = y + w 1077 z.op._update_input(0, x) # pylint: disable=protected-access 1078 with session.Session(graph=g) as sess: 1079 with self.assertRaisesRegex( 1080 errors.InvalidArgumentError, 1081 "Input 0 of node add was passed string from Const_1:0 incompatible " 1082 "with expected int32"): 1083 self.evaluate(z) 1084 1085 def testUpdateInputShapeError(self): 1086 g = ops.Graph() 1087 with g.as_default(): 1088 w = constant_op.constant(2, shape=[3, 1]) 1089 x = constant_op.constant(0, shape=[3, 1]) 1090 y = constant_op.constant(1, shape=[2, 2]) 1091 z = w + x 1092 with self.assertRaisesRegex( 1093 errors.InvalidArgumentError, 1094 r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): 1095 z.op._update_input(0, y) # pylint: disable=protected-access 1096 1097 def testUpdateInputOutOfRange(self): 1098 g = ops.Graph() 1099 with g.as_default(): 1100 x = constant_op.constant(1) 1101 with self.assertRaisesRegex( 1102 errors.OutOfRangeError, 1103 r"Cannot update edge. Input index \[1\] is greater than the number of " 1104 r"total inputs \[0\]."): 1105 x.op._update_input(1, x) # pylint: disable=protected-access 1106 1107 @test_util.enable_control_flow_v2 1108 @test_util.run_v1_only("b/120545219") 1109 def testAddWhileInput(self): 1110 1111 @eager_function.defun 1112 def test(): 1113 output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1, 1114 [1]) 1115 while_op = output.op 1116 self.assertEqual(while_op.type, "StatelessWhile") 1117 orig_num_inputs = len(while_op.inputs) 1118 1119 # Make sure we can handle the while op having a control input. 1120 while_op._add_control_input(constant_op.constant(0).op) 1121 1122 new_input1 = constant_op.constant(1.0) 1123 new_input2 = constant_op.constant(True) 1124 1125 # Clear output shapes to bypass shape checking. 1126 while_op._set_shape_list_attr("output_shapes", []) 1127 while_op._set_type_list_attr("T", [t.dtype for t in while_op.inputs] + 1128 [new_input1.dtype, new_input2.dtype]) 1129 1130 while_op._add_while_inputs([new_input1, new_input2]) 1131 # Can't add an edge beyond what's specified by "T" 1132 with self.assertRaises(errors.OutOfRangeError): 1133 while_op._add_while_inputs([new_input2]) 1134 self.assertLen(while_op.inputs, orig_num_inputs + 2) # pylint: disable=g-deprecated-assert 1135 1136 test() 1137 1138 @test_util.run_deprecated_v1 1139 def testOpDef(self): 1140 x = constant_op.constant(0) 1141 y = constant_op.constant(1) 1142 z = x + y 1143 1144 self.assertEqual(x.op.op_def.name, "Const") 1145 self.assertLen(x.op.op_def.input_arg, 0) 1146 self.assertLen(x.op.op_def.output_arg, 1) 1147 1148 self.assertRegex(z.op.op_def.name, "Add(V2)?") 1149 self.assertLen(z.op.op_def.input_arg, 2) 1150 self.assertLen(z.op.op_def.output_arg, 1) 1151 1152 def testInputFromDifferentGraphError(self): 1153 g_0 = ops.Graph() 1154 g_1 = ops.Graph() 1155 with g_0.as_default(): 1156 x = constant_op.constant(1) 1157 with g_1.as_default(): 1158 y = constant_op.constant(2) 1159 with self.assertRaisesRegex(ValueError, "must be from the same graph"): 1160 y * x # pylint: disable=pointless-statement 1161 1162 def testInputsAreImmutable(self): 1163 g = ops.Graph() 1164 with g.as_default(): 1165 x = test_ops.int_output() 1166 op = test_ops.int_input_int_output(x, name="myop").op 1167 with self.assertRaisesRegex(AttributeError, 1168 "'tuple' object has no attribute 'append'"): 1169 op.inputs.append(None) 1170 1171 1172class CreateOpTest(test_util.TensorFlowTestCase): 1173 1174 def testNodeDefArgs(self): 1175 g = ops.Graph() 1176 op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 1177 with g.device("/device:GPU:0"): 1178 op2 = g.create_op( 1179 "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, 1180 name="myop2") 1181 op3 = g.create_op( 1182 "Foo3", 1183 [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], 1184 [dtypes.float32, dtypes.int32], 1185 None, 1186 name="myop3") 1187 self.assertDeviceEqual(None, op1.device) 1188 self.assertDeviceEqual("/device:GPU:0", op2.device) 1189 self.assertDeviceEqual(None, op3.device) 1190 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) 1191 self.assertProtoEquals( 1192 "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", 1193 op2.node_def) 1194 self.assertProtoEquals( 1195 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", 1196 op3.node_def) 1197 1198 def testReferenceInput(self): 1199 g = ops.Graph() 1200 op1 = g.create_op( 1201 "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 1202 name="op1") 1203 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 1204 ref_t, nonref_t = op1.values() 1205 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 1206 op2 = g.create_op( 1207 "RefInputFloatInput", [ref_t, nonref_t], [], 1208 input_types=[dtypes.float32_ref, dtypes.float32], 1209 name="op2") 1210 self.assertProtoEquals( 1211 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 1212 op2.node_def) 1213 op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") 1214 self.assertProtoEquals( 1215 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 1216 op3.node_def) 1217 1218 def testFinalized(self): 1219 g = ops.Graph() 1220 g.finalize() 1221 with self.assertRaises(RuntimeError): 1222 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 1223 1224 # Test unfinalize. 1225 g._unsafe_unfinalize() 1226 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 1227 1228 1229# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation 1230# method. Arguably we should only test the public APIs that depend on this 1231# method. However, this logic is complex and tricky, and it can be difficult to 1232# ascertain if we have adequate coverage (e.g. a graph may run successfully if 1233# the control flow context isn't set properly, but a more complicated use case 1234# that might not be obvious to test will fail). Thus we instead explicitly test 1235# the low-level behavior. 1236class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): 1237 1238 @test_util.run_deprecated_v1 1239 def testBasic(self): 1240 g = ops.Graph() 1241 with g.as_default(): 1242 x = test_ops.int_output() 1243 c_op = ops._create_c_op( 1244 g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) 1245 op = g._create_op_from_tf_operation(c_op) 1246 1247 self.assertEqual(op.name, "myop") 1248 self.assertEqual(op.type, "IntInputIntOutput") 1249 self.assertLen(op.outputs, 1) 1250 self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) 1251 self.assertEqual(list(op.inputs), [x]) 1252 self.assertEqual(op.control_inputs, []) 1253 self.assertEqual(op.graph, g) 1254 self.assertEqual(x.consumers(), [op]) 1255 self.assertIsNotNone(op.traceback) 1256 self.assertIn("testBasic", op.traceback[-1]) 1257 self.assertEqual(g.get_operation_by_name("myop"), op) 1258 self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) 1259 1260 def testShape(self): 1261 g = ops.Graph() 1262 with g.as_default(): 1263 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 1264 c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) 1265 op = g._create_op_from_tf_operation(c_op) 1266 1267 self.assertEqual(op.name, "myop") 1268 self.assertEqual(op.type, "Identity") 1269 self.assertLen(op.outputs, 1) 1270 self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3])) 1271 1272 def testUniqueName(self): 1273 g = ops.Graph() 1274 with g.as_default(): 1275 c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) 1276 c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) 1277 op = g._create_op_from_tf_operation(c_op) 1278 op2 = g._create_op_from_tf_operation(c_op2) 1279 1280 # Create ops with same names as op1 and op2. We expect the new names to be 1281 # uniquified. 1282 op3 = test_ops.int_output(name="myop").op 1283 op4 = test_ops.int_output(name="myop_1").op 1284 1285 self.assertEqual(op.name, "myop") 1286 self.assertEqual(op2.name, "myop_1") 1287 self.assertEqual(op3.name, "myop_2") 1288 self.assertEqual(op4.name, "myop_1_1") 1289 1290 @test_util.run_v1_only("b/120545219") 1291 def testCond(self): 1292 g = ops.Graph() 1293 with g.as_default(): 1294 x = test_ops.int_output() 1295 1296 def true_fn(): 1297 ops._create_c_op(ops.get_default_graph(), 1298 ops._NodeDef("IntInput", "cond/myop"), [x], []) 1299 new_ops = g._add_new_tf_operations() 1300 self.assertLen(new_ops, 1) 1301 return x 1302 1303 control_flow_ops.cond(x < 10, true_fn, lambda: x) 1304 1305 op = g.get_operation_by_name("cond/myop") 1306 self.assertIsNotNone(op) 1307 self.assertEqual(op.name, "cond/myop") 1308 self.assertEqual(op.type, "IntInput") 1309 self.assertEqual(op.outputs, []) 1310 op_input = op.inputs[0].op 1311 self.assertEqual(op_input.type, "Switch") 1312 self.assertEqual(op_input.inputs[0], x) 1313 self.assertEqual(op.graph, g) 1314 # pylint: disable=protected-access 1315 self.assertIsNotNone(op._get_control_flow_context()) 1316 self.assertEqual(op._get_control_flow_context().name, 1317 "cond/cond_text") 1318 # pylint: enable=protected-access 1319 1320 @test_util.run_v1_only("b/120545219") 1321 def testWhileLoop(self): 1322 g = ops.Graph() 1323 with g.as_default(): 1324 x = test_ops.int_output() 1325 1326 def body(i): 1327 ops._create_c_op(ops.get_default_graph(), 1328 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 1329 new_ops = g._add_new_tf_operations() 1330 self.assertLen(new_ops, 1) 1331 return i 1332 1333 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 1334 1335 op = g.get_operation_by_name("myloop/myop") 1336 self.assertIsNotNone(op) 1337 self.assertEqual(op.name, "myloop/myop") 1338 self.assertEqual(op.type, "IntInput") 1339 self.assertEqual(op.outputs, []) 1340 op_input = op.inputs[0].op 1341 self.assertEqual(op_input.type, "Enter") 1342 self.assertEqual(list(op_input.inputs), [x]) 1343 self.assertEqual(op.graph, g) 1344 # pylint: disable=protected-access 1345 self.assertIsNotNone(op._get_control_flow_context()) 1346 self.assertEqual(op._get_control_flow_context().name, 1347 "myloop/while_context") 1348 # pylint: enable=protected-access 1349 1350 @test_util.run_v1_only("b/120545219") 1351 def testWhileLoopWithInternalControlDep(self): 1352 g = ops.Graph() 1353 with g.as_default(): 1354 x = test_ops.int_output() 1355 1356 def body(i): 1357 c = constant_op.constant(1.0, name="c") 1358 ops._create_c_op(ops.get_default_graph(), 1359 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 1360 with ops.control_dependencies([c]): 1361 new_ops = g._add_new_tf_operations() 1362 self.assertLen(new_ops, 1) 1363 return i 1364 1365 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 1366 1367 op = g.get_operation_by_name("myloop/myop") 1368 self.assertIsNotNone(op) 1369 c = g.get_operation_by_name("myloop/c") 1370 self.assertIsNotNone(c) 1371 # Internal control dep is preserved 1372 self.assertEqual(op.control_inputs, [c]) 1373 1374 @test_util.run_v1_only("b/120545219") 1375 def testWhileLoopWithExternalControlDep(self): 1376 g = ops.Graph() 1377 with g.as_default(): 1378 x = test_ops.int_output() 1379 c = constant_op.constant(1.0) 1380 1381 def body(i): 1382 ops._create_c_op(ops.get_default_graph(), 1383 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 1384 with ops.control_dependencies([c]): 1385 new_ops = g._add_new_tf_operations() 1386 self.assertLen(new_ops, 1) 1387 return i 1388 1389 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 1390 1391 op = g.get_operation_by_name("myloop/myop") 1392 self.assertIsNotNone(op) 1393 # External control dep is removed and replaced with internal control dep 1394 self.assertNotEqual(op.control_inputs[0], c.op) 1395 self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) 1396 1397 1398class ApplyOpTest(test_util.TensorFlowTestCase): 1399 1400 def testNodeDefArgs(self): 1401 g = ops.Graph() 1402 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 1403 with g.device("/device:GPU:0"): 1404 t2 = _apply_op( 1405 g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") 1406 t3 = _apply_op( 1407 g, 1408 "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], 1409 name="myop3") 1410 self.assertTrue(isinstance(t1, ops.Tensor)) 1411 self.assertTrue(isinstance(t2, list)) 1412 self.assertTrue(isinstance(t3, list)) 1413 self.assertTrue(isinstance(t3[0], ops.Tensor)) 1414 self.assertEqual("myop1", t1._as_node_def_input()) 1415 self.assertEqual("myop2", t2[0]._as_node_def_input()) 1416 self.assertEqual("myop2:1", t2[1]._as_node_def_input()) 1417 self.assertEqual("myop3", t3[0]._as_node_def_input()) 1418 # Validate that we got the right ops as well 1419 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) 1420 self.assertProtoEquals( 1421 "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", 1422 t2[0].op.node_def) 1423 self.assertProtoEquals( 1424 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", 1425 t3[0].op.node_def) 1426 1427 def testReferenceInput(self): 1428 g = ops.Graph() 1429 ref_t, nonref_t = _apply_op( 1430 g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 1431 name="op1") 1432 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", 1433 ref_t.op.node_def) 1434 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 1435 out_2 = _apply_op( 1436 g, 1437 "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], 1438 input_types=[dtypes.float32_ref, dtypes.float32], 1439 name="op2") 1440 self.assertProtoEquals( 1441 "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", 1442 out_2.op.node_def) 1443 out_3 = _apply_op( 1444 g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], 1445 name="op3") 1446 self.assertProtoEquals( 1447 "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", 1448 out_3.op.node_def) 1449 1450 1451class NameStackTest(test_util.TensorFlowTestCase): 1452 1453 def testBasics(self): 1454 g = ops.Graph() 1455 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 1456 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 1457 self.assertEqual("foo", g.unique_name("foo")) 1458 self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) 1459 self.assertEqual("foo_1", g.unique_name("foo")) 1460 self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) 1461 self.assertEqual("foo_2", g.unique_name("foo")) 1462 self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) 1463 self.assertEqual("foo_1_1", g.unique_name("foo_1")) 1464 self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) 1465 self.assertEqual("foo_1_2", g.unique_name("foo_1")) 1466 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) 1467 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) 1468 with g.name_scope("bar"): 1469 self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) 1470 self.assertEqual("bar/foo", g.unique_name("foo")) 1471 self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) 1472 self.assertEqual("bar/foo_1", g.unique_name("foo")) 1473 with g.name_scope(None): 1474 self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) 1475 self.assertEqual("foo_3", g.unique_name("foo")) 1476 with g.name_scope("baz"): 1477 self.assertEqual( 1478 "bar/baz/foo", g.unique_name( 1479 "foo", mark_as_used=False)) 1480 self.assertEqual("bar/baz/foo", g.unique_name("foo")) 1481 self.assertEqual( 1482 "bar/baz/foo_1", g.unique_name( 1483 "foo", mark_as_used=False)) 1484 self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) 1485 with g.name_scope("baz"): 1486 self.assertEqual( 1487 "bar/baz_1/foo", g.unique_name( 1488 "foo", mark_as_used=False)) 1489 self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) 1490 self.assertEqual( 1491 "bar/baz_1/foo_1", g.unique_name( 1492 "foo", mark_as_used=False)) 1493 self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) 1494 with g.name_scope("quux"): 1495 self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) 1496 self.assertEqual("quux/foo", g.unique_name("foo")) 1497 with g.name_scope("bar"): 1498 with g.name_scope("baz"): 1499 self.assertEqual( 1500 "bar_1/baz/foo", g.unique_name( 1501 "foo", mark_as_used=False)) 1502 self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) 1503 self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) 1504 self.assertEqual("foo_4", g.unique_name("foo")) 1505 self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) 1506 self.assertEqual("bar_2", g.unique_name("bar")) 1507 1508 def testBackslashAndDashRegex(self): 1509 # GitHub issue 39019, all should pass 1510 g = ops.Graph() 1511 with g.name_scope("n_CatCntc-campaign\\c_campaign"): 1512 pass 1513 with g.name_scope("foo"): 1514 with g.name_scope("n_CatCntc-campaign\\c_campaign"): 1515 pass 1516 with g.name_scope("n_CatCntc-campaign\\c_campaign"): 1517 with g.name_scope("foo"): 1518 pass 1519 1520 @test_util.run_deprecated_v1 1521 def testNameAndVariableScope(self): 1522 with self.cached_session() as sess: 1523 with sess.graph.name_scope("l0"): 1524 with variable_scope.variable_scope("l1"): 1525 with sess.graph.name_scope("l1") as scope: 1526 self.assertEqual("l0/l1/l1/", scope) 1527 self.assertEqual( 1528 "l0/l1/l1/foo", 1529 sess.graph.unique_name( 1530 "foo", mark_as_used=False)) 1531 self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) 1532 with sess.graph.name_scope("l2") as scope: 1533 self.assertEqual("l0/l1/l2/", scope) 1534 self.assertEqual( 1535 "l0/l1/l2/foo", 1536 sess.graph.unique_name( 1537 "foo", mark_as_used=False)) 1538 self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) 1539 1540 def testOutOfOrderUniqueName(self): 1541 g = ops.Graph() 1542 self.assertEqual("foo_2", g.unique_name("foo_2")) 1543 self.assertEqual("foo", g.unique_name("foo")) 1544 self.assertEqual("foo_1", g.unique_name("foo")) 1545 self.assertEqual("foo_3", g.unique_name("foo")) 1546 1547 def testUniqueNameCaseInsensitivity(self): 1548 g = ops.Graph() 1549 self.assertEqual("foo", g.unique_name("foo")) 1550 self.assertEqual("Foo_1", g.unique_name("Foo")) 1551 with g.name_scope("bar"): 1552 self.assertEqual("bar/foo", g.unique_name("foo")) 1553 with g.name_scope("Bar"): 1554 self.assertEqual("Bar_1/foo", g.unique_name("foo")) 1555 1556 def testInvalidNameRaisesError(self): 1557 g = ops.Graph() 1558 with g.name_scope(""): # Should not raise 1559 pass 1560 with g.name_scope("foo/"): # Should not raise 1561 with g.name_scope("_bar"): # Should not raise 1562 pass 1563 with self.assertRaises(ValueError): 1564 with g.name_scope("foo:0"): 1565 pass 1566 with self.assertRaises(ValueError): 1567 with g.name_scope("_bar"): 1568 pass 1569 1570 def testEmptyScopeEdgeCases(self): 1571 g = ops.Graph() 1572 self.assertEqual("", g.get_name_scope()) 1573 with g.name_scope("") as scope: 1574 self.assertEqual("", scope) 1575 self.assertEqual("", g.get_name_scope()) 1576 with g.name_scope(None) as scope: 1577 self.assertEqual("", scope) 1578 self.assertEqual("", g.get_name_scope()) 1579 with g.name_scope("foo") as scope: 1580 self.assertEqual("foo/", scope) 1581 self.assertEqual("foo", g.get_name_scope()) 1582 with g.name_scope("") as scope: 1583 self.assertEqual("", scope) 1584 self.assertEqual("", g.get_name_scope()) 1585 with g.name_scope(None) as scope: 1586 self.assertEqual("", scope) 1587 self.assertEqual("", g.get_name_scope()) 1588 1589 1590class NameTest(test_util.TensorFlowTestCase): 1591 1592 def testGenerateName(self): 1593 g = ops.Graph() 1594 op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 1595 self.assertEqual("TwoFloatOutputs", op0.name) 1596 self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) 1597 self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) 1598 1599 op1 = g.create_op("FloatOutput", [], [dtypes.float32]) 1600 self.assertEqual("FloatOutput", op1.name) 1601 self.assertEqual("FloatOutput:0", op1.outputs[0].name) 1602 1603 op2 = g.create_op("FloatOutput", [], [dtypes.float32]) 1604 self.assertEqual("FloatOutput_1", op2.name) 1605 self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) 1606 1607 op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") 1608 self.assertEqual("my_op", op3.name) 1609 self.assertEqual("my_op:0", op3.outputs[0].name) 1610 1611 def testNameScope(self): 1612 g = ops.Graph() 1613 1614 with g.name_scope("foo") as foo: 1615 self.assertEqual("foo/", foo) 1616 with g.name_scope("foo2") as foo2: 1617 self.assertEqual("foo/foo2/", foo2) 1618 with g.name_scope(None) as empty1: 1619 self.assertEqual("", empty1) 1620 with g.name_scope("foo3") as foo3: 1621 self.assertEqual("foo3/", foo3) 1622 with g.name_scope("") as empty2: 1623 self.assertEqual("", empty2) 1624 1625 self.assertEqual("FloatOutput", 1626 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1627 with g.name_scope("bar") as scope: 1628 self.assertEqual("bar/FloatOutput", 1629 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1630 self.assertEqual("bar/FloatOutput_1", 1631 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1632 # If you use the value from "with .. as", that values is used as-is. 1633 self.assertEqual( 1634 "bar", g.create_op( 1635 "FloatOutput", [], [dtypes.float32], name=scope).name) 1636 with g.name_scope("baz") as scope: 1637 with g.name_scope("quux"): 1638 self.assertEqual("baz/quux/FloatOutput", 1639 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1640 # If you use the value from the enclosing "with .. as", nothing is pushed. 1641 with g.name_scope(scope): 1642 self.assertEqual("baz/FloatOutput", 1643 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1644 self.assertEqual( 1645 "baz", g.create_op( 1646 "FloatOutput", [], [dtypes.float32], name=scope).name) 1647 self.assertEqual( 1648 "trailing", 1649 g.create_op( 1650 "FloatOutput", [], [dtypes.float32], name="trailing/").name) 1651 with g.name_scope("bar"): 1652 self.assertEqual("bar_1/FloatOutput", 1653 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1654 with g.name_scope("bar/"): 1655 self.assertEqual("bar/FloatOutput_2", 1656 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1657 1658 1659class DeviceTest(test_util.TensorFlowTestCase): 1660 1661 def testNoDevice(self): 1662 g = ops.Graph() 1663 op = g.create_op("FloatOutput", [], [dtypes.float32]) 1664 self.assertDeviceEqual(None, op.device) 1665 gd = g.as_graph_def() 1666 self.assertProtoEqualsVersion(""" 1667 node { name: "FloatOutput" op: "FloatOutput" } 1668 """, gd) 1669 1670 def testEagerBackingDevice(self): 1671 with context.eager_mode(): 1672 with ops.device("/device:CPU:0"): 1673 t = constant_op.constant(1.0) 1674 self.assertRegex(t.device, "/device:CPU:0") 1675 self.assertRegex(t.backing_device, "/device:CPU:0") 1676 1677 def testDevicePartialString(self): 1678 g = ops.Graph() 1679 with g.device("/job:worker/replica:2"): 1680 g.create_op("FloatOutput", [], [dtypes.float32]) 1681 gd = g.as_graph_def() 1682 self.assertProtoEqualsVersion(""" 1683 node { name: "FloatOutput" op: "FloatOutput" 1684 device: "/job:worker/replica:2" } 1685 """, gd) 1686 1687 def testDeviceFull(self): 1688 g = ops.Graph() 1689 with g.device( 1690 pydev.DeviceSpec( 1691 job="worker", replica=2, task=0, device_type="CPU", 1692 device_index=3)): 1693 g.create_op("FloatOutput", [], [dtypes.float32]) 1694 gd = g.as_graph_def() 1695 self.assertProtoEqualsVersion(""" 1696 node { name: "FloatOutput" op: "FloatOutput" 1697 device: "/job:worker/replica:2/task:0/device:CPU:3" } 1698 """, gd) 1699 1700 def testNesting(self): 1701 g = ops.Graph() 1702 with g.device("/job:worker/replica:2"): 1703 g.create_op("FloatOutput", [], [dtypes.float32]) 1704 with g.device("/job:worker/replica:3/task:0"): 1705 g.create_op("FloatOutput", [], [dtypes.float32]) 1706 g.create_op("FloatOutput", [], [dtypes.float32]) 1707 gd = g.as_graph_def() 1708 self.assertProtoEqualsVersion(""" 1709 node { name: "FloatOutput" op: "FloatOutput" 1710 device: "/job:worker/replica:2" } 1711 node { name: "FloatOutput_1" op: "FloatOutput" 1712 device: "/job:worker/replica:3/task:0" } 1713 node { name: "FloatOutput_2" op: "FloatOutput" 1714 device: "/job:worker/replica:2" } 1715 """, gd) 1716 1717 def testNestingString(self): 1718 g = ops.Graph() 1719 with g.device("/job:worker/replica:2"): 1720 g.create_op("FloatOutput", [], [dtypes.float32]) 1721 with g.device("/job:worker/replica:3/task:0"): 1722 g.create_op("FloatOutput", [], [dtypes.float32]) 1723 g.create_op("FloatOutput", [], [dtypes.float32]) 1724 gd = g.as_graph_def() 1725 self.assertProtoEqualsVersion(""" 1726 node { name: "FloatOutput" op: "FloatOutput" 1727 device: "/job:worker/replica:2" } 1728 node { name: "FloatOutput_1" op: "FloatOutput" 1729 device: "/job:worker/replica:3/task:0" } 1730 node { name: "FloatOutput_2" op: "FloatOutput" 1731 device: "/job:worker/replica:2" } 1732 """, gd) 1733 1734 def testNestingOverrideGpuCpu(self): 1735 g = ops.Graph() 1736 with g.device("/job:worker/replica:2/device:CPU:1"): 1737 g.create_op("FloatOutput", [], [dtypes.float32]) 1738 with g.device("/job:worker/replica:2/device:GPU:2"): 1739 g.create_op("FloatOutput", [], [dtypes.float32]) 1740 g.create_op("FloatOutput", [], [dtypes.float32]) 1741 gd = g.as_graph_def() 1742 self.assertProtoEqualsVersion(""" 1743 node { name: "FloatOutput" op: "FloatOutput" 1744 device: "/job:worker/replica:2/device:CPU:1" } 1745 node { name: "FloatOutput_1" op: "FloatOutput" 1746 device: "/job:worker/replica:2/device:GPU:2" } 1747 node { name: "FloatOutput_2" op: "FloatOutput" 1748 device: "/job:worker/replica:2/device:CPU:1" } 1749 """, gd) 1750 1751 def testNestingWithMergeDeviceFunction(self): 1752 g = ops.Graph() 1753 1754 with g.device(pydev.merge_device("/device:GPU:0")): 1755 g.create_op("FloatOutput", [], [dtypes.float32]) 1756 with g.device(pydev.merge_device("/job:worker")): 1757 g.create_op("FloatOutput", [], [dtypes.float32]) 1758 with g.device(pydev.merge_device("/device:CPU:0")): 1759 g.create_op("FloatOutput", [], [dtypes.float32]) 1760 with g.device(pydev.merge_device("/job:ps")): 1761 g.create_op("FloatOutput", [], [dtypes.float32]) 1762 with g.device(pydev.merge_device(None)): 1763 g.create_op("FloatOutput", [], [dtypes.float32]) 1764 1765 gd = g.as_graph_def() 1766 self.assertProtoEqualsVersion(""" 1767 node { name: "FloatOutput" op: "FloatOutput" 1768 device: "/device:GPU:0" } 1769 node { name: "FloatOutput_1" op: "FloatOutput" 1770 device: "/job:worker/device:GPU:0" } 1771 node { name: "FloatOutput_2" op: "FloatOutput" 1772 device: "/job:worker/device:CPU:0" } 1773 node { name: "FloatOutput_3" op: "FloatOutput" 1774 device: "/job:ps/device:CPU:0" } 1775 node { name: "FloatOutput_4" op: "FloatOutput" 1776 device: "/job:ps/device:CPU:0" } 1777 """, gd) 1778 1779 def testNestingWithDeviceStrings(self): 1780 g = ops.Graph() 1781 1782 with g.device("/device:GPU:0"): 1783 g.create_op("FloatOutput", [], [dtypes.float32]) 1784 with g.device("/job:worker"): 1785 g.create_op("FloatOutput", [], [dtypes.float32]) 1786 with g.device("/device:CPU:0"): 1787 g.create_op("FloatOutput", [], [dtypes.float32]) 1788 with g.device("/job:ps"): 1789 g.create_op("FloatOutput", [], [dtypes.float32]) 1790 with g.device(""): 1791 g.create_op("FloatOutput", [], [dtypes.float32]) 1792 1793 gd = g.as_graph_def() 1794 self.assertProtoEqualsVersion(""" 1795 node { name: "FloatOutput" op: "FloatOutput" 1796 device: "/device:GPU:0" } 1797 node { name: "FloatOutput_1" op: "FloatOutput" 1798 device: "/job:worker/device:GPU:0" } 1799 node { name: "FloatOutput_2" op: "FloatOutput" 1800 device: "/job:worker/device:CPU:0" } 1801 node { name: "FloatOutput_3" op: "FloatOutput" 1802 device: "/job:ps/device:CPU:0" } 1803 node { name: "FloatOutput_4" op: "FloatOutput" 1804 device: "/job:ps/device:CPU:0" } 1805 """, gd) 1806 1807 def testNestingWithDeviceStringWildcard(self): 1808 g = ops.Graph() 1809 1810 with g.device("/device:GPU:7"): 1811 g.create_op("FloatOutput", [], [dtypes.float32]) 1812 with g.device("/device:GPU:*"): 1813 g.create_op("FloatOutput", [], [dtypes.float32]) 1814 1815 with g.device("/device:CPU:*"): 1816 g.create_op("FloatOutput", [], [dtypes.float32]) 1817 with g.device("/device:CPU:5"): 1818 g.create_op("FloatOutput", [], [dtypes.float32]) 1819 1820 gd = g.as_graph_def() 1821 self.assertProtoEqualsVersion(""" 1822 node { name: "FloatOutput" op: "FloatOutput" 1823 device: "/device:GPU:7" } 1824 node { name: "FloatOutput_1" op: "FloatOutput" 1825 device: "/device:GPU:7" } 1826 node { name: "FloatOutput_2" op: "FloatOutput" 1827 device: "/device:CPU:*" } 1828 node { name: "FloatOutput_3" op: "FloatOutput" 1829 device: "/device:CPU:5" } 1830 """, gd) 1831 1832 def testNestingErrorGraph(self): 1833 g = ops.Graph() 1834 scope = g.device("/device:GPU:8") 1835 scope.__enter__() 1836 with g.device("/device:GPU:9"): 1837 with self.assertRaises(RuntimeError): 1838 scope.__exit__(None, None, None) 1839 1840 def testNestingErrorEager(self): 1841 with context.eager_mode(): 1842 scope = ops.device("/device:CPU:0") 1843 scope.__enter__() 1844 with ops.device(None): 1845 with self.assertRaises(RuntimeError): 1846 scope.__exit__(None, None, None) 1847 1848 def testNoneClearsDefault(self): 1849 g = ops.Graph() 1850 with g.device("/job:worker/replica:2/device:CPU:1"): 1851 g.create_op("FloatOutput", [], [dtypes.float32]) 1852 with g.device(None): 1853 g.create_op("FloatOutput", [], [dtypes.float32]) 1854 g.create_op("FloatOutput", [], [dtypes.float32]) 1855 gd = g.as_graph_def() 1856 self.assertProtoEqualsVersion(""" 1857 node { name: "FloatOutput" op: "FloatOutput" 1858 device: "/job:worker/replica:2/device:CPU:1" } 1859 node { name: "FloatOutput_1" op: "FloatOutput" } 1860 node { name: "FloatOutput_2" op: "FloatOutput" 1861 device: "/job:worker/replica:2/device:CPU:1" } 1862 """, gd) 1863 1864 def testNoneIgnoresOuterDeviceFunction(self): 1865 g = ops.Graph() 1866 with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): 1867 g.create_op("FloatOutput", [], [dtypes.float32]) 1868 with g.device(None): 1869 g.create_op("FloatOutput", [], [dtypes.float32]) 1870 g.create_op("FloatOutput", [], [dtypes.float32]) 1871 gd = g.as_graph_def() 1872 self.assertProtoEqualsVersion(""" 1873 node { name: "FloatOutput" op: "FloatOutput" 1874 device: "/job:worker/replica:2/device:CPU:1" } 1875 node { name: "FloatOutput_1" op: "FloatOutput" } 1876 node { name: "FloatOutput_2" op: "FloatOutput" 1877 device: "/job:worker/replica:2/device:CPU:1" } 1878 """, gd) 1879 1880 def _overwritingDeviceFunction(self, unused_op): 1881 # This device function unconditionally overwrites the device of ops. 1882 # 1883 # NOTE(mrry): Writing device functions like this is not 1884 # recommended. Instead, in most cases you should use 1885 # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the 1886 # argument to `tf.device()` and the device component will be merged in. 1887 return "/job:overwrite" 1888 1889 def testOverwritingBehavior(self): 1890 g = ops.Graph() 1891 with g.device(self._overwritingDeviceFunction): 1892 g.create_op("FloatOutput", [], [dtypes.float32]) 1893 with g.device("/job:ps"): # Will be overwritten. 1894 g.create_op("FloatOutput", [], [dtypes.float32]) 1895 with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. 1896 g.create_op("FloatOutput", [], [dtypes.float32]) 1897 with g.device(None): # Disables overwriting device function 1898 with g.device("/job:ps"): 1899 g.create_op("FloatOutput", [], [dtypes.float32]) 1900 with g.device(None): # Disables overwriting device function 1901 with g.device(pydev.merge_device("/job:ps")): 1902 g.create_op("FloatOutput", [], [dtypes.float32]) 1903 gd = g.as_graph_def() 1904 self.assertProtoEqualsVersion(""" 1905 node { name: "FloatOutput" op: "FloatOutput" 1906 device: "/job:overwrite" } 1907 node { name: "FloatOutput_1" op: "FloatOutput" 1908 device: "/job:overwrite" } 1909 node { name: "FloatOutput_2" op: "FloatOutput" 1910 device: "/job:overwrite" } 1911 node { name: "FloatOutput_3" op: "FloatOutput" 1912 device: "/job:ps" } 1913 node { name: "FloatOutput_4" op: "FloatOutput" 1914 device: "/job:ps" } 1915 """, gd) 1916 1917 1918class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): 1919 1920 class TestThread(threading.Thread): 1921 1922 def __init__(self, graph, replica_id): 1923 super(MultithreadedGraphStateTest.TestThread, self).__init__() 1924 self._graph = graph 1925 self._replica_id = replica_id 1926 # This thread sets this event when it mutated the graph. The caller can 1927 # wait for that. 1928 self.has_mutated_graph = threading.Event() 1929 # This thread waits for when it should continue. The caller can set this 1930 # event. 1931 self.should_continue = threading.Event() 1932 1933 def run(self): 1934 # Mutate a graph's stack, then set `has_mutated_graph`, then wait for 1935 # `should_continue`, then add an op to the graph affected by the graph's 1936 # stack. 1937 raise NotImplementedError("must be implemented in descendants") 1938 1939 def testDeviceFunctionStack(self): 1940 1941 class DeviceSettingThread(self.TestThread): 1942 1943 def run(self): 1944 with g.device("/job:worker/replica:{}".format(self._replica_id)): 1945 self.has_mutated_graph.set() 1946 self.should_continue.wait() 1947 self.should_continue.clear() 1948 g.create_op( 1949 "FloatOutput", [], [dtypes.float32], 1950 name="FloatOutput_{}".format(self._replica_id)) 1951 1952 g = ops.Graph() 1953 # If `switch_to_thread` isn't called, then device placement of the ops 1954 # below is not deterministic. 1955 g.switch_to_thread_local() 1956 threads = [DeviceSettingThread(g, i) for i in range(3)] 1957 for t in threads: 1958 t.start() 1959 t.has_mutated_graph.wait() 1960 t.has_mutated_graph.clear() 1961 for t in threads: 1962 t.should_continue.set() 1963 t.join() 1964 1965 gd = g.as_graph_def() 1966 self.assertProtoEqualsVersion(""" 1967 node { name: "FloatOutput_0" op: "FloatOutput" 1968 device: "/job:worker/replica:0" } 1969 node { name: "FloatOutput_1" op: "FloatOutput" 1970 device: "/job:worker/replica:1" } 1971 node { name: "FloatOutput_2" op: "FloatOutput" 1972 device: "/job:worker/replica:2" } 1973 """, gd) 1974 1975 def testColocateWith(self): 1976 1977 class ColocatingThread(self.TestThread): 1978 1979 def __init__(self, graph, replica_id, op_to_colocate_with): 1980 super(ColocatingThread, self).__init__(graph, replica_id) 1981 self._op_to_colocate_with = op_to_colocate_with 1982 1983 def run(self): 1984 with g.colocate_with(self._op_to_colocate_with): 1985 self.has_mutated_graph.set() 1986 self.should_continue.wait() 1987 self.should_continue.clear() 1988 g.create_op( 1989 "FloatOutput", [], [dtypes.float32], 1990 name="FloatOutput_{}".format(self._replica_id)) 1991 1992 g = ops.Graph() 1993 ops_to_colocate_with = [] 1994 for i in range(3): 1995 with g.device("/job:worker/replica:{}".format(i)): 1996 ops_to_colocate_with.append( 1997 g.create_op( 1998 "FloatOutput", [], [dtypes.float32], 1999 name="ColocateWithMe_{}".format(i))) 2000 2001 # If `switch_to_thread` isn't called, then `device` and `attr` values for 2002 # the ops below are not deterministic. 2003 g.switch_to_thread_local() 2004 threads = [ 2005 ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) 2006 ] 2007 for t in threads: 2008 t.start() 2009 t.has_mutated_graph.wait() 2010 t.has_mutated_graph.clear() 2011 for t in threads: 2012 t.should_continue.set() 2013 t.join() 2014 2015 gd = g.as_graph_def() 2016 self.assertProtoEqualsVersion(""" 2017 node { name: "ColocateWithMe_0" op: "FloatOutput" 2018 device: "/job:worker/replica:0" } 2019 node { name: "ColocateWithMe_1" op: "FloatOutput" 2020 device: "/job:worker/replica:1" } 2021 node { name: "ColocateWithMe_2" op: "FloatOutput" 2022 device: "/job:worker/replica:2" } 2023 node { name: "FloatOutput_0" op: "FloatOutput" 2024 device: "/job:worker/replica:0" 2025 attr { key: "_class" 2026 value { list { 2027 s: "loc:@ColocateWithMe_0"}}}} 2028 node { name: "FloatOutput_1" op: "FloatOutput" 2029 device: "/job:worker/replica:1" 2030 attr { key: "_class" 2031 value { list { 2032 s: "loc:@ColocateWithMe_1"}}}} 2033 node { name: "FloatOutput_2" op: "FloatOutput" 2034 device: "/job:worker/replica:2" 2035 attr { key: "_class" 2036 value { list { 2037 s: "loc:@ColocateWithMe_2"}}}} 2038 """, gd) 2039 2040 def testControlDependencies(self): 2041 2042 class DependingThread(self.TestThread): 2043 2044 def __init__(self, graph, replica_id, dependency_op): 2045 super(DependingThread, self).__init__(graph, replica_id) 2046 self._dependency_op = dependency_op 2047 2048 def run(self): 2049 with g.control_dependencies([self._dependency_op]): 2050 self.has_mutated_graph.set() 2051 self.should_continue.wait() 2052 self.should_continue.clear() 2053 g.create_op( 2054 "FloatOutput", [], [dtypes.float32], 2055 name="FloatOutput_{}".format(self._replica_id)) 2056 2057 g = ops.Graph() 2058 dependency_ops = [] 2059 for i in range(3): 2060 dependency_ops.append( 2061 g.create_op( 2062 "FloatOutput", [], [dtypes.float32], 2063 name="ColocateWithMe_{}".format(i))) 2064 2065 # If `switch_to_thread` isn't called, then `input` values for the ops below 2066 # are not deterministic. 2067 g.switch_to_thread_local() 2068 threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] 2069 for t in threads: 2070 t.start() 2071 t.has_mutated_graph.wait() 2072 t.has_mutated_graph.clear() 2073 for t in threads: 2074 t.should_continue.set() 2075 t.join() 2076 2077 gd = g.as_graph_def() 2078 self.assertProtoEqualsVersion( 2079 """ 2080 node { name: "ColocateWithMe_0" op: "FloatOutput" 2081 attr { key: "_has_manual_control_dependencies" 2082 value { b: true } } } 2083 node { name: "ColocateWithMe_1" op: "FloatOutput" 2084 attr { key: "_has_manual_control_dependencies" 2085 value { b: true } } } 2086 node { name: "ColocateWithMe_2" op: "FloatOutput" 2087 attr { key: "_has_manual_control_dependencies" 2088 value { b: true } } } 2089 node { name: "FloatOutput_0" op: "FloatOutput" 2090 input: "^ColocateWithMe_0" } 2091 node { name: "FloatOutput_1" op: "FloatOutput" 2092 input: "^ColocateWithMe_1" } 2093 node { name: "FloatOutput_2" op: "FloatOutput" 2094 input: "^ColocateWithMe_2" } 2095 """, gd) 2096 2097 def testNameStack(self): 2098 2099 class NameSettingThread(self.TestThread): 2100 2101 def run(self): 2102 with g.name_scope("foo"): 2103 op1 = g.create_op("FloatOutput", [], [dtypes.float32]) 2104 self.has_mutated_graph.set() 2105 self.should_continue.wait() 2106 self.should_continue.clear() 2107 op2 = g.create_op("FloatOutput", [], [dtypes.float32]) 2108 self.result = (op1, op2) 2109 2110 g = ops.Graph() 2111 threads = [NameSettingThread(g, i) for i in range(3)] 2112 for t in threads: 2113 t.start() 2114 t.has_mutated_graph.wait() 2115 t.has_mutated_graph.clear() 2116 2117 for t in threads: 2118 t.should_continue.set() 2119 t.join() 2120 2121 suffixes = ["", "_1", "_2"] 2122 for t, s in zip(threads, suffixes): 2123 self.assertEqual("foo" + s + "/FloatOutput", t.result[0].name) 2124 self.assertEqual("foo" + s + "/FloatOutput_1", t.result[1].name) 2125 2126 2127class ObjectWithName(object): 2128 2129 def __init__(self, name): 2130 self._name = name 2131 2132 @property 2133 def name(self): 2134 return self._name 2135 2136 2137class CollectionTest(test_util.TensorFlowTestCase): 2138 2139 def test_get_collections(self): 2140 g = ops.Graph() 2141 self.assertSequenceEqual(g.collections, []) 2142 g.add_to_collection("key", 12) 2143 g.add_to_collection("key", 15) 2144 self.assertSequenceEqual(g.collections, ["key"]) 2145 g.add_to_collection("other", "foo") 2146 self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) 2147 self.assertSequenceEqual( 2148 sorted(g.get_all_collection_keys()), ["key", "other"]) 2149 2150 def test_add_to_collection(self): 2151 g = ops.Graph() 2152 g.add_to_collection("key", 12) 2153 g.add_to_collection("other", "foo") 2154 g.add_to_collection("key", 34) 2155 2156 # Note that only blank1 is returned. 2157 g.add_to_collection("blah", 27) 2158 blank1 = ObjectWithName("prefix/foo") 2159 g.add_to_collection("blah", blank1) 2160 blank2 = ObjectWithName("junk/foo") 2161 g.add_to_collection("blah", blank2) 2162 2163 self.assertEqual([12, 34], g.get_collection("key")) 2164 self.assertEqual([], g.get_collection("nothing")) 2165 self.assertEqual([27, blank1, blank2], g.get_collection("blah")) 2166 self.assertEqual([blank1], g.get_collection("blah", "prefix")) 2167 self.assertEqual([blank1], g.get_collection("blah", ".*x")) 2168 2169 # Make sure that get_collection() returns a first-level 2170 # copy of the collection, while get_collection_ref() returns 2171 # the original list. 2172 other_collection_snapshot = g.get_collection("other") 2173 other_collection_ref = g.get_collection_ref("other") 2174 self.assertEqual(["foo"], other_collection_snapshot) 2175 self.assertEqual(["foo"], other_collection_ref) 2176 g.add_to_collection("other", "bar") 2177 self.assertEqual(["foo"], other_collection_snapshot) 2178 self.assertEqual(["foo", "bar"], other_collection_ref) 2179 self.assertEqual(["foo", "bar"], g.get_collection("other")) 2180 self.assertTrue(other_collection_ref is g.get_collection_ref("other")) 2181 2182 # Verify that getting an empty collection ref returns a modifiable list. 2183 empty_coll_ref = g.get_collection_ref("empty") 2184 self.assertEqual([], empty_coll_ref) 2185 empty_coll = g.get_collection("empty") 2186 self.assertEqual([], empty_coll) 2187 self.assertFalse(empty_coll is empty_coll_ref) 2188 empty_coll_ref2 = g.get_collection_ref("empty") 2189 self.assertTrue(empty_coll_ref2 is empty_coll_ref) 2190 # Add to the collection. 2191 empty_coll_ref.append("something") 2192 self.assertEqual(["something"], empty_coll_ref) 2193 self.assertEqual(["something"], empty_coll_ref2) 2194 self.assertEqual([], empty_coll) 2195 self.assertEqual(["something"], g.get_collection("empty")) 2196 empty_coll_ref3 = g.get_collection_ref("empty") 2197 self.assertTrue(empty_coll_ref3 is empty_coll_ref) 2198 2199 def test_add_to_collections_uniquify(self): 2200 g = ops.Graph() 2201 g.add_to_collections([1, 2, 1], "key") 2202 # Make sure "key" is not added twice 2203 self.assertEqual(["key"], g.get_collection(1)) 2204 2205 def test_add_to_collections_from_list(self): 2206 g = ops.Graph() 2207 g.add_to_collections(["abc", "123"], "key") 2208 self.assertEqual(["key"], g.get_collection("abc")) 2209 self.assertEqual(["key"], g.get_collection("123")) 2210 2211 def test_add_to_collections_from_tuple(self): 2212 g = ops.Graph() 2213 g.add_to_collections(("abc", "123"), "key") 2214 self.assertEqual(["key"], g.get_collection("abc")) 2215 self.assertEqual(["key"], g.get_collection("123")) 2216 2217 def test_add_to_collections_from_generator(self): 2218 g = ops.Graph() 2219 2220 def generator(): 2221 yield "abc" 2222 yield "123" 2223 2224 g.add_to_collections(generator(), "key") 2225 self.assertEqual(["key"], g.get_collection("abc")) 2226 self.assertEqual(["key"], g.get_collection("123")) 2227 2228 def test_add_to_collections_from_set(self): 2229 g = ops.Graph() 2230 g.add_to_collections(set(["abc", "123"]), "key") 2231 self.assertEqual(["key"], g.get_collection("abc")) 2232 self.assertEqual(["key"], g.get_collection("123")) 2233 2234 def test_add_to_collections_from_string(self): 2235 g = ops.Graph() 2236 g.add_to_collections("abc", "key") 2237 self.assertEqual(["key"], g.get_collection("abc")) 2238 2239 def test_default_graph(self): 2240 with ops.Graph().as_default(): 2241 ops.add_to_collection("key", 90) 2242 ops.add_to_collection("key", 100) 2243 # Collections are ordered. 2244 self.assertEqual([90, 100], ops.get_collection("key")) 2245 2246 def test_defun(self): 2247 with context.eager_mode(): 2248 2249 @eager_function.defun 2250 def defun(): 2251 ops.add_to_collection("int", 1) 2252 ops.add_to_collection("tensor", constant_op.constant(2)) 2253 2254 @eager_function.defun 2255 def inner_defun(): 2256 self.assertEqual(ops.get_collection("int"), [1]) 2257 three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] 2258 ops.add_to_collection("int", 2) 2259 self.assertEqual(ops.get_collection("int"), [1, 2]) 2260 ops.add_to_collection("foo", "bar") 2261 self.assertEqual(ops.get_collection("foo"), ["bar"]) 2262 return three 2263 2264 self.assertEqual(ops.get_collection("int"), [1]) 2265 three = inner_defun() 2266 self.assertEqual(ops.get_collection("int"), [1]) 2267 self.assertEqual(ops.get_collection("foo"), []) 2268 return three 2269 2270 three = defun() 2271 self.assertEqual(three.numpy(), 3) 2272 2273 2274ops.NotDifferentiable("FloatOutput") 2275 2276 2277@ops.RegisterGradient("CopyOp") 2278def _CopyGrad(op, x_grad): # pylint: disable=invalid-name 2279 _ = op 2280 return x_grad 2281 2282 2283@ops.RegisterGradient("copy_override") 2284def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name 2285 _ = op 2286 return x_grad 2287 2288 2289class RegistrationTest(test_util.TensorFlowTestCase): 2290 2291 @test_util.run_deprecated_v1 2292 def testRegisterGradients(self): 2293 x = test_ops.float_output() 2294 y = test_ops.copy_op(x) 2295 fn = ops.get_gradient_function(y.op) 2296 self.assertEqual(_CopyGrad, fn) 2297 2298 def testOverrideGradients(self): 2299 g = ops.Graph() 2300 with g.as_default(): 2301 x = test_ops.float_output() 2302 with g.gradient_override_map({"CopyOp": "copy_override"}): 2303 y = test_ops.copy_op(x) 2304 fn = ops.get_gradient_function(y.op) 2305 self.assertEqual(_CopyOverrideGrad, fn) 2306 2307 def testNonExistentOverride(self): 2308 g = ops.Graph() 2309 with g.as_default(): 2310 x = test_ops.float_output() 2311 with g.gradient_override_map({"CopyOp": "unknown_override"}): 2312 y = test_ops.copy_op(x) 2313 with self.assertRaisesRegex(LookupError, "unknown_override"): 2314 ops.get_gradient_function(y.op) 2315 2316 2317class ComparisonTest(test_util.TensorFlowTestCase): 2318 2319 def testMembershipAllowed(self): 2320 g = ops.Graph() 2321 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 2322 t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") 2323 self.assertTrue(isinstance(t1, ops.Tensor)) 2324 self.assertTrue(isinstance(t2, ops.Tensor)) 2325 self.assertTrue(t1 in [t1]) 2326 self.assertTrue(t1 not in [t2]) 2327 2328 2329class ControlDependenciesTest(test_util.TensorFlowTestCase): 2330 2331 @test_util.run_deprecated_v1 2332 def testBasic(self): 2333 g = ops.Graph() 2334 with g.as_default(): 2335 # Creating unregistered ops with _apply_op() doesn't work with the C API 2336 # TODO(skyewm): address this more consistently. Possible solutions are 2337 # to use registered ops in all tests, create a way to register ops in 2338 # Python tests, or conditionally disable the op registration check in 2339 # the C API. 2340 a = constant_op.constant(1.0) 2341 b = constant_op.constant(1.0) 2342 with g.control_dependencies([a]): 2343 c = constant_op.constant(1.0) 2344 d = array_ops.identity(b) 2345 e = array_ops.identity(c) 2346 2347 self.assertEqual(c.op.control_inputs, [a.op]) 2348 self.assertEqual(d.op.control_inputs, [a.op]) 2349 # e should be dominated by c. 2350 self.assertEqual(e.op.control_inputs, []) 2351 2352 @test_util.run_in_graph_and_eager_modes 2353 def testEager(self): 2354 def future(): 2355 future.calls += 1 2356 return constant_op.constant(2.0) 2357 future.calls = 0 2358 2359 if context.executing_eagerly(): 2360 a = constant_op.constant(1.0) 2361 b = future 2362 with ops.control_dependencies([a, b]): 2363 c = constant_op.constant(3.0) 2364 self.assertEqual(future.calls, 1) 2365 else: 2366 g = ops.Graph() 2367 with g.as_default(): 2368 a = constant_op.constant(1.0) 2369 b = future() 2370 with g.control_dependencies([a, b]): 2371 c = constant_op.constant(3.0) 2372 self.assertEqual(c.op.control_inputs, [a.op, b.op]) 2373 self.assertEqual(future.calls, 1) 2374 2375 def testBasicWithConversion(self): 2376 g = ops.Graph() 2377 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2378 2379 class ConvertibleObj(object): 2380 2381 def _as_graph_element(self): 2382 return a 2383 2384 with g.control_dependencies([ConvertibleObj()]): 2385 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2386 2387 self.assertEqual(c.op.control_inputs, [a.op]) 2388 2389 def testNested(self): 2390 g = ops.Graph() 2391 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2392 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2393 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2394 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2395 2396 with g.control_dependencies([a_1, a_2, a_3, a_4]): 2397 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2398 2399 with g.control_dependencies([a_1]): 2400 with g.control_dependencies([a_2]): 2401 with g.control_dependencies([a_3]): 2402 with g.control_dependencies([a_4]): 2403 b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2404 2405 self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], 2406 b_1.op.control_inputs) 2407 self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) 2408 2409 def testClear(self): 2410 g = ops.Graph() 2411 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2412 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2413 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2414 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2415 2416 with g.control_dependencies([a_1]): 2417 with g.control_dependencies([a_2]): 2418 with g.control_dependencies(None): 2419 with g.control_dependencies([a_3]): 2420 with g.control_dependencies([a_4]): 2421 # deps [a_3, a_4] 2422 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2423 # deps = [a_3] 2424 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2425 # deps back to None 2426 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2427 # deps back to [a_1, a_2] 2428 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2429 # deps back to [a_1] 2430 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2431 with g.control_dependencies(None): 2432 # deps are None again 2433 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2434 2435 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 2436 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 2437 self.assertItemsEqual([], b_none.op.control_inputs) 2438 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 2439 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 2440 self.assertItemsEqual([], b_none2.op.control_inputs) 2441 2442 def testComplex(self): 2443 g = ops.Graph() 2444 2445 # Usage pattern: 2446 # * Nodes a_i are constants defined at the outermost scope, and are used 2447 # as control inputs for the ith nested scope. 2448 # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. 2449 # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. 2450 # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. 2451 # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. 2452 2453 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2454 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2455 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2456 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2457 2458 with g.control_dependencies([a_1]): 2459 b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2460 [dtypes.float32]) 2461 c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2462 [dtypes.float32]) 2463 d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], 2464 [dtypes.float32]) 2465 e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2466 with g.control_dependencies([a_2]): 2467 b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2468 [dtypes.float32]) 2469 c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2470 [dtypes.float32]) 2471 d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], 2472 [dtypes.float32]) 2473 e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], 2474 [dtypes.float32]) 2475 with g.control_dependencies([a_3]): 2476 b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2477 [dtypes.float32]) 2478 c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2479 [dtypes.float32]) 2480 d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], 2481 [dtypes.float32]) 2482 e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], 2483 [dtypes.float32]) 2484 with g.control_dependencies([a_4]): 2485 b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2486 [dtypes.float32]) 2487 c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2488 [dtypes.float32]) 2489 d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], 2490 [dtypes.float32]) 2491 e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], 2492 [dtypes.float32]) 2493 2494 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 2495 self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) 2496 self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) 2497 self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) 2498 2499 self.assertItemsEqual([], c_1.op.control_inputs) 2500 self.assertItemsEqual([a_2.op], c_2.op.control_inputs) 2501 self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) 2502 self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) 2503 2504 self.assertItemsEqual([], d_1.op.control_inputs) 2505 self.assertItemsEqual([], d_2.op.control_inputs) 2506 self.assertItemsEqual([], d_3.op.control_inputs) 2507 self.assertItemsEqual([], d_4.op.control_inputs) 2508 2509 self.assertItemsEqual([a_1.op], e_1.op.control_inputs) 2510 self.assertItemsEqual([a_2.op], e_2.op.control_inputs) 2511 self.assertItemsEqual([a_3.op], e_3.op.control_inputs) 2512 self.assertItemsEqual([a_4.op], e_4.op.control_inputs) 2513 2514 def testRepeatedDependency(self): 2515 g = ops.Graph() 2516 a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 2517 a_0, a_1 = a.outputs 2518 with g.control_dependencies([a_0]): 2519 b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2520 with g.control_dependencies([a_1]): 2521 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2522 2523 self.assertEqual(b.op.control_inputs, [a]) 2524 self.assertEqual(c.op.control_inputs, [a]) 2525 2526 def testNoControlDependencyWithDataDependency(self): 2527 g = ops.Graph() 2528 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2529 with g.control_dependencies([a]): 2530 b = _apply_op(g, "Identity", [a], [dtypes.float32]) 2531 2532 self.assertEqual(b.op.control_inputs, []) 2533 2534 def testMonitoringAttributeAddedWhenUsingManualControlDep(self): 2535 g = ops.Graph() 2536 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2537 b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2538 with g.control_dependencies([a]): 2539 c = _apply_op(g, "Identity", [b], [dtypes.float32]) 2540 2541 with g.control_dependencies([b]): 2542 d = _apply_op(g, "Identity", [b], [dtypes.float32]) 2543 2544 # Validate that the monitoring attribute is set to track usage of the 2545 # `control_dependencies(...)` API. 2546 self.assertEqual(c.op.control_inputs, [a.op]) 2547 with self.assertRaises(ValueError): 2548 c.op.get_attr("_has_manual_control_dependencies") 2549 self.assertEqual(a.op.get_attr("_has_manual_control_dependencies"), True) 2550 2551 # Validate that the monitoring attribute is set to track usage of the 2552 # `control_dependencies(...)` API even when the manual control deps actually 2553 # happened to be pruned at runtime. 2554 self.assertEqual(d.op.control_inputs, []) 2555 with self.assertRaises(ValueError): 2556 d.op.get_attr("_has_manual_control_dependencies") 2557 self.assertEqual(b.op.get_attr("_has_manual_control_dependencies"), True) 2558 2559 2560class OpScopeTest(test_util.TensorFlowTestCase): 2561 2562 @test_util.run_in_graph_and_eager_modes 2563 def testNames(self): 2564 with ops.name_scope("foo", skip_on_eager=False) as foo: 2565 self.assertEqual("foo/", foo) 2566 with ops.name_scope("foo2", skip_on_eager=False) as foo2: 2567 self.assertEqual("foo/foo2/", foo2) 2568 with ops.name_scope(None, skip_on_eager=False) as empty1: 2569 self.assertEqual("", empty1) 2570 with ops.name_scope("foo3", skip_on_eager=False) as foo3: 2571 self.assertEqual("foo3/", foo3) 2572 with ops.name_scope("", skip_on_eager=False) as empty2: 2573 self.assertEqual("", empty2) 2574 with ops.name_scope("foo/", skip_on_eager=False) as outer_foo: 2575 self.assertEqual("foo/", outer_foo) 2576 with ops.name_scope("", skip_on_eager=False) as empty3: 2577 self.assertEqual("", empty3) 2578 with ops.name_scope("foo4", skip_on_eager=False) as foo4: 2579 self.assertEqual("foo/foo4/", foo4) 2580 with ops.name_scope("foo5//", skip_on_eager=False) as foo5: 2581 self.assertEqual("foo5//", foo5) 2582 with ops.name_scope("foo6", skip_on_eager=False) as foo6: 2583 self.assertEqual("foo5//foo6/", foo6) 2584 with ops.name_scope("/", skip_on_eager=False) as foo7: 2585 self.assertEqual("/", foo7) 2586 with ops.name_scope("//", skip_on_eager=False) as foo8: 2587 self.assertEqual("//", foo8) 2588 with ops.name_scope("a//b/c", skip_on_eager=False) as foo9: 2589 self.assertEqual("foo/a//b/c/", foo9) 2590 with ops.name_scope("a//b/c", skip_on_eager=False) as foo10: 2591 self.assertEqual("a//b/c/", foo10) 2592 2593 @test_util.run_in_graph_and_eager_modes 2594 def testEagerDefaultScopeName(self): 2595 with ops.name_scope(None, "default", skip_on_eager=False) as scope: 2596 self.assertEqual(scope, "default/") 2597 with ops.name_scope(None, "default2", skip_on_eager=False) as scope2: 2598 self.assertEqual(scope2, "default/default2/") 2599 2600 @test_util.run_in_graph_and_eager_modes 2601 def testNameScopeV2IsReEntrant(self): 2602 foo = ops.name_scope_v2("foo") 2603 bar = ops.name_scope_v2("bar") 2604 with foo as scope_name: 2605 self.assertEqual("foo/", scope_name) 2606 with foo as scope_name: 2607 self.assertEqual("foo/foo/", scope_name) 2608 with bar as scope_name: 2609 self.assertEqual("foo/bar/", scope_name) 2610 with foo as scope_name: 2611 self.assertEqual("foo/bar/foo/", scope_name) 2612 with bar as scope_name: 2613 self.assertEqual("bar/", scope_name) 2614 2615 @test_util.run_deprecated_v1 2616 def testNoScopeName(self): 2617 g0 = ops.Graph() 2618 values = [ 2619 g0.create_op("A", [], [dtypes.float32]), 2620 g0.create_op("B", [], [dtypes.float32]) 2621 ] 2622 with self.assertRaises(ValueError): 2623 with ops.name_scope(None, values=values): 2624 pass 2625 with self.assertRaises(ValueError): 2626 with ops.name_scope(None, None, values): 2627 pass 2628 2629 @test_util.run_deprecated_v1 2630 def testEmptyScopeName(self): 2631 g0 = ops.Graph() 2632 a = g0.create_op("A", [], [dtypes.float32]) 2633 b = g0.create_op("B", [], [dtypes.float32]) 2634 with ops.name_scope("", values=[a, b]) as scope: 2635 self.assertEqual("", scope) 2636 self.assertEqual(g0, ops.get_default_graph()) 2637 with ops.name_scope("", "my_default_scope", [a, b]) as scope: 2638 self.assertEqual("", scope) 2639 self.assertEqual(g0, ops.get_default_graph()) 2640 2641 @test_util.run_deprecated_v1 2642 def testDefaultScopeName(self): 2643 g0 = ops.Graph() 2644 a = g0.create_op("A", [], [dtypes.float32]) 2645 b = g0.create_op("B", [], [dtypes.float32]) 2646 scope_name = "my_scope" 2647 default_scope_name = "my_default_scope" 2648 with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: 2649 self.assertEqual("%s/" % scope_name, scope) 2650 self.assertEqual(g0, ops.get_default_graph()) 2651 with ops.name_scope(None, default_scope_name, [a, b]) as scope: 2652 self.assertEqual("%s/" % default_scope_name, scope) 2653 self.assertEqual(g0, ops.get_default_graph()) 2654 with self.assertRaises(TypeError): 2655 with ops.name_scope(scope_name, [a, b]): 2656 pass 2657 2658 def _testGraphElements(self, graph_elements): 2659 scope_name = "my_scope" 2660 with ops.name_scope(scope_name, values=graph_elements) as scope: 2661 self.assertEqual("%s/" % scope_name, scope) 2662 self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) 2663 g1 = ops.Graph() 2664 a = g1.create_op("A", [], [dtypes.float32]) 2665 with self.assertRaises(ValueError): 2666 with ops.name_scope(scope_name, values=graph_elements + [a]): 2667 pass 2668 2669 @test_util.run_in_graph_and_eager_modes 2670 def testGetCurrentNameScope(self): 2671 self.assertEqual(ops.get_current_name_scope(), "") 2672 with ops.name_scope_v2("aaa"): 2673 self.assertEqual(ops.get_current_name_scope(), "aaa") 2674 with ops.name_scope_v2("bbb"): 2675 self.assertEqual(ops.get_current_name_scope(), "aaa/bbb") 2676 self.assertEqual(ops.get_current_name_scope(), "aaa") 2677 self.assertEqual(ops.get_current_name_scope(), "") 2678 2679 @test_util.run_deprecated_v1 2680 def testTensor(self): 2681 g0 = ops.Graph() 2682 a = g0.create_op("A", [], [dtypes.float32]) 2683 b = g0.create_op("B", [], [dtypes.float32]) 2684 self._testGraphElements([a, b]) 2685 2686 @test_util.run_deprecated_v1 2687 def testSparseTensor(self): 2688 g0 = ops.Graph() 2689 a = g0.create_op("A", [], [dtypes.float32]) 2690 b = g0.create_op("B", [], [dtypes.float32]) 2691 sparse = sparse_tensor.SparseTensor( 2692 _apply_op(g0, "Int64Output", [], [dtypes.int64]), 2693 _apply_op(g0, "FloatOutput", [], [dtypes.float32]), 2694 _apply_op(g0, "Int64Output", [], [dtypes.int64])) 2695 self._testGraphElements([a, sparse, b]) 2696 2697 @test_util.run_deprecated_v1 2698 def testVariable(self): 2699 g0 = ops.Graph() 2700 with g0.as_default(): 2701 variable = variables.Variable([1.0]) 2702 a = g0.create_op("A", [], [dtypes.float32]) 2703 b = g0.create_op("B", [], [dtypes.float32]) 2704 self._testGraphElements([a, variable, b]) 2705 2706 2707class InitScopeTest(test_util.TensorFlowTestCase): 2708 2709 def testClearsControlDependencies(self): 2710 g = ops.Graph() 2711 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2712 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2713 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2714 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2715 2716 with g.as_default(): 2717 with g.control_dependencies([a_1]): 2718 with g.control_dependencies([a_2]): 2719 with ops.init_scope(): 2720 with g.control_dependencies([a_3]): 2721 with g.control_dependencies([a_4]): 2722 # deps [a_3, a_4] 2723 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2724 # deps = [a_3] 2725 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2726 # deps back to None 2727 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2728 # deps back to [a_1, a_2] 2729 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2730 # deps back to [a_1] 2731 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2732 with ops.init_scope(): 2733 # deps are None again 2734 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2735 2736 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 2737 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 2738 self.assertItemsEqual([], b_none.op.control_inputs) 2739 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 2740 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 2741 self.assertItemsEqual([], b_none2.op.control_inputs) 2742 2743 def testLiftsOpsFromFunctions(self): 2744 g0 = ops.Graph() 2745 g1 = ops.Graph() 2746 g1._building_function = True # pylint: disable=protected-access 2747 g2 = ops.Graph() 2748 g2._building_function = True # pylint: disable=protected-access 2749 2750 with g0.as_default(): 2751 with g1.as_default(): 2752 with g2.as_default(): 2753 with ops.init_scope(): 2754 _ = constant_op.constant(1.0) 2755 2756 self.assertLen(g2.get_operations(), 0) 2757 self.assertLen(g1.get_operations(), 0) 2758 self.assertLen(g0.get_operations(), 1) 2759 2760 def testPreservesDevices(self): 2761 g0 = ops.Graph() 2762 with g0.as_default(), ops.device("CPU:0"): 2763 g1 = ops.Graph() 2764 g1._building_function = True # pylint: disable=protected-access 2765 with g1.as_default(): 2766 with ops.device("GPU:0"): 2767 with ops.init_scope(): 2768 # init_scope should preserve device set under `g1`. 2769 on_gpu = constant_op.constant(1.0) 2770 self.assertEqual(on_gpu.device, "/device:GPU:0") 2771 still_on_gpu = constant_op.constant(1.0) 2772 self.assertEqual(still_on_gpu.device, "/device:GPU:0") 2773 blank = constant_op.constant(1.0) 2774 self.assertEqual(blank.device, "") 2775 with ops.init_scope(): 2776 now_on_cpu = constant_op.constant(1.0) 2777 self.assertEqual(now_on_cpu.device, "/device:CPU:0") 2778 on_cpu = constant_op.constant(1.0) 2779 self.assertEqual(on_cpu.device, "/device:CPU:0") 2780 2781 def testComposes(self): 2782 g0 = ops.Graph() 2783 g1 = ops.Graph() 2784 g1._building_function = True # pylint: disable=protected-access 2785 g2 = ops.Graph() 2786 g2._building_function = True # pylint: disable=protected-access 2787 g3 = ops.Graph() 2788 g3._building_function = False # pylint: disable=protected-access 2789 2790 with g0.as_default(): 2791 with g1.as_default(): 2792 with ops.init_scope(): 2793 # This op should be lifted into g0. 2794 _ = constant_op.constant(1.0) 2795 self.assertIs(g0, ops.get_default_graph()) 2796 self.assertLen(g2.get_operations(), 0) 2797 self.assertLen(g1.get_operations(), 0) 2798 self.assertLen(g0.get_operations(), 1) 2799 with g2.as_default(): 2800 with ops.init_scope(): 2801 # This op should be lifted into g0. 2802 _ = constant_op.constant(1.0) 2803 self.assertIs(g0, ops.get_default_graph()) 2804 with g3.as_default(): 2805 with ops.init_scope(): 2806 # This op should be lifted into g3, because g3 is not building a 2807 # function. 2808 _ = constant_op.constant(1.0) 2809 self.assertIs(g3, ops.get_default_graph()) 2810 2811 self.assertLen(g3.get_operations(), 1) 2812 self.assertLen(g2.get_operations(), 0) 2813 self.assertLen(g1.get_operations(), 0) 2814 self.assertLen(g0.get_operations(), 2) 2815 2816 def testEscapesToEagerContext(self): 2817 g = ops.Graph() 2818 g._building_function = True # pylint: disable=protected-access 2819 with context.eager_mode(): 2820 with context.graph_mode(): 2821 with g.as_default(): 2822 with ops.init_scope(): 2823 # Because g is building a function, init_scope should 2824 # escape out to the eager context. 2825 self.assertTrue(context.executing_eagerly()) 2826 # g should be reinstated as the default graph, and the 2827 # graph context should be re-entered. 2828 self.assertIs(g, ops.get_default_graph()) 2829 self.assertFalse(context.executing_eagerly()) 2830 2831 def testStaysInEagerWhenOnlyEagerContextActive(self): 2832 with context.eager_mode(): 2833 with ops.init_scope(): 2834 self.assertTrue(context.eager_mode()) 2835 self.assertTrue(context.eager_mode()) 2836 2837 def testEscapesDefunWhenInEagerMode(self): 2838 2839 def function_with_variables(): 2840 with ops.init_scope(): 2841 self.v = resource_variable_ops.ResourceVariable(3) 2842 return self.v.assign_add(1) 2843 2844 with context.eager_mode(): 2845 # Each invocation of function_with_variables recreates a variable. 2846 self.assertEqual(4, int(function_with_variables())) 2847 self.assertEqual(4, int(function_with_variables())) 2848 2849 compiled = eager_function.defun(function_with_variables) 2850 # The init_scope in function_with_variables lifts the variable out 2851 # of the graph function constructed by defun; hence, 2852 # compiled now appears to be stateful. 2853 self.assertEqual(4, int(compiled())) 2854 self.assertEqual(5, int(compiled())) 2855 2856 def testEscapesDefunWhenInGraphMode(self): 2857 def function_with_variables(name): 2858 with ops.init_scope(): 2859 _ = variable_scope.get_variable(name, shape=(1,)) 2860 2861 g = ops.Graph() 2862 with g.as_default(): 2863 with self.cached_session(): 2864 # First ensure that graphs that are not building functions are 2865 # not escaped. 2866 function_with_variables("foo") 2867 with self.assertRaisesRegex(ValueError, 2868 r"Variable foo already exists.*"): 2869 # This will fail because reuse is not set to True. 2870 function_with_variables("foo") 2871 2872 compiled = eager_function.defun(function_with_variables) 2873 compiled("bar") 2874 self.assertEqual( 2875 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2876 2877 # The second call to `compiled` should not create variables: the 2878 # init_scope has lifted the variable creation code out of the defun. 2879 compiled("bar") 2880 self.assertEqual( 2881 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2882 2883 def testEscapesNestedDefun(self): 2884 2885 def inner_function(): 2886 with ops.init_scope(): 2887 self.v = resource_variable_ops.ResourceVariable(1) 2888 return self.v.assign_add(2) 2889 2890 def outer_function(inner=None): 2891 with ops.init_scope(): 2892 self.v0 = resource_variable_ops.ResourceVariable(0) 2893 return self.v0.assign_add(1) + inner() 2894 2895 with context.eager_mode(): 2896 # Each invocation of outer_function recreates variables. 2897 self.assertEqual(4, int(outer_function(inner=inner_function))) 2898 self.assertEqual(4, int(outer_function(inner=inner_function))) 2899 2900 compiled_inner = eager_function.defun(inner_function) 2901 compiled_outer = eager_function.defun(outer_function) 2902 # The init_scope lifts variables out of the graph functions 2903 # constructed by defun; hence, compiled_outer should now appear to be 2904 # stateful. 2905 self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) 2906 self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) 2907 2908 @test_util.run_v1_only("b/120545219") 2909 def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): 2910 with context.graph_mode(): 2911 ops.reset_default_graph() 2912 # This doesn't push anything onto the graph stack, but it does 2913 # set the stack's global graph. 2914 global_graph = ops.get_default_graph() 2915 fn_graph = ops.Graph() 2916 2917 # pylint: disable=protected-access 2918 fn_graph._building_function = True 2919 self.assertLen(ops._default_graph_stack.stack, 0) 2920 with fn_graph.as_default(): 2921 self.assertLen(ops._default_graph_stack.stack, 1) 2922 with ops.init_scope(): 2923 self.assertGreater(len(ops._default_graph_stack.stack), 1) 2924 dummy = constant_op.constant(1.0) 2925 self.assertLen(ops._default_graph_stack.stack, 1) 2926 # Note that the global graph is _not_ on the graph stack. 2927 self.assertLen(ops._default_graph_stack.stack, 0) 2928 # Ensure that `dummy` was added to the global graph. 2929 self.assertEqual(global_graph, dummy.graph) 2930 # pylint: enable=protected-access 2931 2932 def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): 2933 with context.graph_mode(): 2934 # pylint: disable=protected-access 2935 self.assertLen(ops._default_graph_stack.stack, 0) 2936 with ops.init_scope(): 2937 self.assertGreater(len(ops._default_graph_stack.stack), 0) 2938 self.assertLen(ops._default_graph_stack.stack, 0) 2939 # pylint: enable=protected-access 2940 2941 def testPreservesNameScopeInGraphConstruction(self): 2942 with ops.Graph().as_default(): 2943 function_graph = ops.Graph() 2944 with function_graph.as_default(): 2945 with ops.name_scope("inner", skip_on_eager=False), ops.init_scope(): 2946 self.assertEqual(ops.get_name_scope(), "inner") 2947 self.assertEqual(ops.get_name_scope(), "") 2948 2949 def testEnteringGraphFromEagerIsSticky(self): 2950 with context.eager_mode(): 2951 g = ops.Graph() 2952 with g.as_default(): 2953 with ops.init_scope(): 2954 self.assertFalse(context.executing_eagerly()) 2955 self.assertEqual(g, ops.get_default_graph()) 2956 2957 def testMixGraphEager(self): 2958 with context.eager_mode(): 2959 c = constant_op.constant(1.0) 2960 with ops.Graph().as_default(): 2961 with self.assertRaisesRegex(RuntimeError, 2962 "Attempting to capture an EagerTensor"): 2963 math_ops.add(c, c) 2964 c2 = constant_op.constant(2.0) 2965 with self.assertRaises(TypeError): 2966 math_ops.add(c2, c2) 2967 2968 def testPreservesNameScopeInEagerExecution(self): 2969 with context.eager_mode(): 2970 def foo(): 2971 with ops.name_scope("inner", skip_on_eager=False), ops.init_scope(): 2972 if context.executing_eagerly(): 2973 # A trailing slash is always appended when eager execution is 2974 # enabled. 2975 self.assertEqual(context.context().scope_name, "inner/") 2976 else: 2977 self.assertEqual(ops.get_name_scope(), "inner") 2978 2979 foo() 2980 self.assertEqual(ops.get_name_scope(), "") 2981 foo_compiled = eager_function.defun(foo) 2982 foo_compiled() 2983 self.assertEqual(ops.get_name_scope(), "") 2984 2985 def testExecutingEagerlyOutsideFunctions(self): 2986 2987 @def_function.function 2988 def f(): 2989 return ops.executing_eagerly_outside_functions() 2990 2991 with context.graph_mode(): 2992 self.assertFalse(ops.executing_eagerly_outside_functions()) 2993 with session.Session(): 2994 # Need self.evaluate for these as the return type of functions is 2995 # tensors. 2996 self.assertFalse(self.evaluate(f())) 2997 2998 with context.eager_mode(): 2999 self.assertTrue(ops.executing_eagerly_outside_functions()) 3000 self.assertTrue(f()) 3001 3002 with ops.Graph().as_default(): 3003 self.assertFalse(ops.executing_eagerly_outside_functions()) 3004 with session.Session(): 3005 self.assertFalse(self.evaluate(f())) 3006 3007 3008class GraphTest(test_util.TensorFlowTestCase): 3009 3010 def setUp(self): 3011 ops.reset_default_graph() 3012 3013 def _AssertDefault(self, expected): 3014 self.assertIs(expected, ops.get_default_graph()) 3015 3016 def testResetDefaultGraphNesting(self): 3017 g0 = ops.Graph() 3018 with self.assertRaises(AssertionError): 3019 with g0.as_default(): 3020 ops.reset_default_graph() 3021 3022 def testGraphContextManagerCancelsEager(self): 3023 with context.eager_mode(): 3024 with ops.Graph().as_default(): 3025 self.assertFalse(context.executing_eagerly()) 3026 3027 def testGraphContextManager(self): 3028 g0 = ops.Graph() 3029 with g0.as_default() as g1: 3030 self.assertIs(g0, g1) 3031 3032 def testDefaultGraph(self): 3033 orig = ops.get_default_graph() 3034 self.assertFalse(ops.has_default_graph()) 3035 self._AssertDefault(orig) 3036 g0 = ops.Graph() 3037 self.assertFalse(ops.has_default_graph()) 3038 self._AssertDefault(orig) 3039 context_manager_0 = g0.as_default() 3040 self.assertFalse(ops.has_default_graph()) 3041 self._AssertDefault(orig) 3042 with context_manager_0 as g0: 3043 self._AssertDefault(g0) 3044 with ops.Graph().as_default() as g1: 3045 self.assertTrue(ops.has_default_graph()) 3046 self._AssertDefault(g1) 3047 self._AssertDefault(g0) 3048 self._AssertDefault(orig) 3049 self.assertFalse(ops.has_default_graph()) 3050 3051 def testPreventFeeding(self): 3052 g = ops.Graph() 3053 a = constant_op.constant(2.0) 3054 self.assertTrue(g.is_feedable(a)) 3055 g.prevent_feeding(a) 3056 self.assertFalse(g.is_feedable(a)) 3057 3058 @test_util.run_deprecated_v1 3059 def testPreventFetching(self): 3060 g = ops.Graph() 3061 a = constant_op.constant(2.0) 3062 self.assertTrue(g.is_fetchable(a)) 3063 g.prevent_fetching(a.op) 3064 self.assertFalse(g.is_fetchable(a)) 3065 3066 def testAsGraphElementConversions(self): 3067 3068 class ConvertibleObj(object): 3069 3070 def _as_graph_element(self): 3071 return "FloatOutput:0" 3072 3073 class NonConvertibleObj(object): 3074 3075 pass 3076 3077 g = ops.Graph() 3078 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 3079 self.assertEqual(a, g.as_graph_element(ConvertibleObj())) 3080 with self.assertRaises(TypeError): 3081 g.as_graph_element(NonConvertibleObj()) 3082 3083 # Regression test against creating custom __del__ functions in classes 3084 # involved in cyclic references, e.g. Graph and Operation. (Python won't gc 3085 # cycles that require calling a __del__ method, because the __del__ method can 3086 # theoretically increase the object's refcount to "save" it from gc, and any 3087 # already-deleted objects in the cycle would have be to restored.) 3088 def testGarbageCollected(self): 3089 # Create a graph we can delete and a weak reference to monitor if it's gc'd 3090 g = ops.Graph() 3091 g_ref = weakref.ref(g) 3092 # Create some ops 3093 with g.as_default(): 3094 a = constant_op.constant(2.0) 3095 b = constant_op.constant(3.0) 3096 c = math_ops.add(a, b) 3097 # Create a session we can delete 3098 with session.Session(graph=g) as sess: 3099 self.evaluate(c) 3100 # Delete all references and trigger gc 3101 del g 3102 del a 3103 del b 3104 del c 3105 del sess 3106 gc.collect() 3107 self.assertIsNone(g_ref()) 3108 3109 def testRunnableAfterInvalidShape(self): 3110 with ops.Graph().as_default(): 3111 with self.assertRaises(ValueError): 3112 math_ops.add([1, 2], [1, 2, 3]) 3113 a = constant_op.constant(1) 3114 with session.Session() as sess: 3115 self.evaluate(a) 3116 3117 def testRunnableAfterInvalidShapeWithKernelLabelMap(self): 3118 g = ops.Graph() 3119 with g.as_default(): 3120 with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): 3121 with self.assertRaises(ValueError): 3122 test_ops.kernel_label_required(1) 3123 a = constant_op.constant(1) 3124 with session.Session() as sess: 3125 self.evaluate(a) 3126 3127 3128class AttrScopeTest(test_util.TensorFlowTestCase): 3129 3130 def _get_test_attrs(self): 3131 x = control_flow_ops.no_op() 3132 try: 3133 a = compat.as_text(x.get_attr("_A")) 3134 except ValueError: 3135 a = None 3136 try: 3137 b = compat.as_text(x.get_attr("_B")) 3138 except ValueError: 3139 b = None 3140 return (a, b) 3141 3142 @test_util.run_deprecated_v1 3143 def testNoLabel(self): 3144 with self.cached_session(): 3145 self.assertAllEqual((None, None), self._get_test_attrs()) 3146 3147 @test_util.run_deprecated_v1 3148 def testLabelMap(self): 3149 with self.cached_session() as sess: 3150 a1 = self._get_test_attrs() 3151 with sess.graph._attr_scope({ 3152 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) 3153 }): 3154 a2 = self._get_test_attrs() 3155 with sess.graph._attr_scope({ 3156 "_A": None, 3157 "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) 3158 }): 3159 a3 = self._get_test_attrs() 3160 with sess.graph._attr_scope({ 3161 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) 3162 }): 3163 a4 = self._get_test_attrs() 3164 a5 = self._get_test_attrs() 3165 a6 = self._get_test_attrs() 3166 a7 = self._get_test_attrs() 3167 3168 self.assertAllEqual((None, None), a1) 3169 self.assertAllEqual(("foo", None), a2) 3170 self.assertAllEqual((None, "bar"), a3) 3171 self.assertAllEqual(("baz", "bar"), a4) 3172 self.assertAllEqual((None, "bar"), a5) 3173 self.assertAllEqual(("foo", None), a6) 3174 self.assertAllEqual((None, None), a7) 3175 3176 3177class KernelLabelTest(test_util.TensorFlowTestCase): 3178 3179 @test_util.run_deprecated_v1 3180 def testNoLabel(self): 3181 with self.cached_session(): 3182 self.assertAllEqual(b"My label is: default", 3183 test_ops.kernel_label().eval()) 3184 3185 @test_util.run_deprecated_v1 3186 def testLabelMap(self): 3187 with self.cached_session() as sess: 3188 default_1 = test_ops.kernel_label() 3189 # pylint: disable=protected-access 3190 with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): 3191 overload_1_1 = test_ops.kernel_label() 3192 with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): 3193 overload_2 = test_ops.kernel_label() 3194 with sess.graph._kernel_label_map({"KernelLabel": ""}): 3195 default_2 = test_ops.kernel_label() 3196 overload_1_2 = test_ops.kernel_label() 3197 # pylint: enable=protected-access 3198 default_3 = test_ops.kernel_label() 3199 3200 self.assertAllEqual(b"My label is: default", self.evaluate(default_1)) 3201 self.assertAllEqual(b"My label is: default", self.evaluate(default_2)) 3202 self.assertAllEqual(b"My label is: default", self.evaluate(default_3)) 3203 self.assertAllEqual(b"My label is: overload_1", 3204 self.evaluate(overload_1_1)) 3205 self.assertAllEqual(b"My label is: overload_1", 3206 self.evaluate(overload_1_2)) 3207 self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2)) 3208 3209 3210class AsGraphDefTest(test_util.TensorFlowTestCase): 3211 3212 def testGraphDefVersion(self): 3213 """Test that the graphdef version is plumbed through to kernels.""" 3214 with ops.Graph().as_default() as g: 3215 version = g.graph_def_versions.producer 3216 with self.session(graph=g): 3217 v = test_ops.graph_def_version().eval() 3218 self.assertEqual(version, v) 3219 3220 def testAddShapes(self): 3221 with ops.Graph().as_default() as g: 3222 t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], 3223 [dtypes.float32] * 5) 3224 t1.set_shape(None) 3225 t2.set_shape([]) 3226 t3.set_shape([None]) 3227 t4.set_shape([43, 37]) 3228 t5.set_shape([43, None]) 3229 3230 b = constant_op.constant(1.0) # pylint: disable=unused-variable 3231 3232 gd = g.as_graph_def(add_shapes=True) 3233 self.assertProtoEqualsVersion(""" 3234 node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" 3235 attr { 3236 key: "_output_shapes" 3237 value { 3238 list { 3239 shape { unknown_rank: true } 3240 shape { } 3241 shape { dim { size: -1 } } 3242 shape { dim { size: 43 } dim { size: 37 } } 3243 shape { dim { size: 43 } dim { size: -1 } } 3244 } 3245 } 3246 } 3247 } 3248 node { name: "Const" op: "Const" 3249 attr { 3250 key: "_output_shapes" 3251 value { 3252 list { 3253 shape { } 3254 } 3255 } 3256 } 3257 attr { 3258 key: "dtype" 3259 value { type: DT_FLOAT } 3260 } 3261 attr { 3262 key: "value" 3263 value { 3264 tensor { 3265 dtype: DT_FLOAT 3266 tensor_shape { } 3267 float_val: 1.0 } } } } 3268 """, gd) 3269 3270 3271@ops.RegisterStatistics("a", "flops") 3272def _calc_a_forward_flops(unused_graph, unused_node): 3273 return ops.OpStats("flops", 20) 3274 3275 3276class StatisticsTest(test_util.TensorFlowTestCase): 3277 3278 def testRegisteredNode(self): 3279 graph = ops.Graph() 3280 node = ops._NodeDef("a", "an_a") 3281 flops = ops.get_stats_for_node_def(graph, node, "flops") 3282 self.assertEqual(20, flops.value) 3283 missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") 3284 self.assertEqual(None, missing_stat.value) 3285 3286 def testUnregisteredNode(self): 3287 graph = ops.Graph() 3288 node = ops._NodeDef("b", "a_b") 3289 weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") 3290 self.assertEqual(None, weight_params.value) 3291 3292 def testAccumulateStatistics(self): 3293 flops_total = ops.OpStats("flops") 3294 self.assertEqual(None, flops_total.value) 3295 second_flops = ops.OpStats("flops", 3) 3296 flops_total += second_flops 3297 self.assertEqual(3, flops_total.value) 3298 3299 3300class DeviceStackTest(test_util.TensorFlowTestCase): 3301 3302 @test_util.run_deprecated_v1 3303 def testBasicDeviceAssignmentMetadata(self): 3304 3305 def device_func(unused_op): 3306 return "/cpu:*" 3307 3308 const_zero = constant_op.constant([0.0], name="zero") 3309 with ops.device("/cpu"): 3310 const_one = constant_op.constant([1.0], name="one") 3311 with ops.device("/cpu:0"): 3312 const_two = constant_op.constant([2.0], name="two") 3313 with ops.device(device_func): 3314 const_three = constant_op.constant(3.0, name="three") 3315 3316 self.assertEqual(0, len(const_zero.op._device_assignments)) 3317 3318 one_list = const_one.op._device_assignments 3319 self.assertEqual(1, len(one_list)) 3320 self.assertEqual("/cpu", one_list[0].obj) 3321 self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) 3322 3323 two_list = const_two.op._device_assignments 3324 self.assertEqual(2, len(two_list)) 3325 devices = [t.obj for t in two_list] 3326 self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) 3327 3328 three_list = const_three.op._device_assignments 3329 self.assertEqual(1, len(three_list)) 3330 func_description = three_list[0].obj 3331 expected_regex = r"device_func<.*ops_test.py, [0-9]+" 3332 self.assertRegex(func_description, expected_regex) 3333 3334 @test_util.run_deprecated_v1 3335 def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): 3336 3337 with ops.device("/cpu"): 3338 const_one = constant_op.constant([1.0], name="one") 3339 with ops.get_default_graph().device("/cpu"): 3340 const_two = constant_op.constant([2.0], name="two") 3341 3342 one_metadata = const_one.op._device_assignments[0] 3343 two_metadata = const_two.op._device_assignments[0] 3344 3345 # Verify both types of device assignment return the right stack info. 3346 self.assertRegex("ops_test.py", os.path.basename(one_metadata.filename)) 3347 self.assertEqual(one_metadata.filename, two_metadata.filename) 3348 self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) 3349 3350 3351class ColocationGroupTest(test_util.TensorFlowTestCase): 3352 3353 @test_util.run_deprecated_v1 3354 def testBasic(self): 3355 a = constant_op.constant([2.0], name="a") 3356 with ops.colocate_with(a.op): 3357 b = constant_op.constant(3.0) 3358 c = constant_op.constant(4.0) 3359 self.assertEqual([b"loc:@a"], a.op.colocation_groups()) 3360 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3361 with self.assertRaises(ValueError): 3362 c.op.get_attr("_class") 3363 3364 @test_util.run_deprecated_v1 3365 def testBasicColocationMetadata(self): 3366 const_two = constant_op.constant([2.0], name="two") 3367 with ops.colocate_with(const_two.op): 3368 const_three = constant_op.constant(3.0, name="three") 3369 locations_dict = const_three.op._colocation_dict 3370 self.assertIn("two", locations_dict) 3371 metadata = locations_dict["two"] 3372 self.assertIsNone(metadata.obj) 3373 # Check that this test's filename is recorded as the file containing the 3374 # colocation statement. 3375 self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) 3376 3377 @test_util.run_deprecated_v1 3378 def testColocationDeviceInteraction(self): 3379 with ops.device("/cpu:0"): 3380 with ops.device("/device:GPU:0"): 3381 a = constant_op.constant([2.0], name="a") 3382 with ops.colocate_with(a.op): 3383 # 'b' is created in the scope of /cpu:0, but it is 3384 # colocated with 'a', which is on '/device:GPU:0'. colocate_with 3385 # overrides devices because it is a stronger constraint. 3386 b = constant_op.constant(3.0) 3387 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3388 self.assertEqual(a.op.device, b.op.device) 3389 3390 @test_util.run_deprecated_v1 3391 def testColocationCanonicalization(self): 3392 with ops.device("/device:GPU:0"): 3393 _ = constant_op.constant(2.0) 3394 with ops.device(lambda op: "/device:GPU:0"): 3395 b = constant_op.constant(3.0) 3396 with ops.get_default_graph().colocate_with(b): 3397 with ops.device("/device:GPU:0"): 3398 c = constant_op.constant(4.0) 3399 3400 # A's device will be /device:GPU:0 3401 # B's device will be /device:GPU:0 3402 # C's device will be /device:GPU:0 because it 3403 # inherits B's device name, after canonicalizing the names. 3404 self.assertEqual(b.op.device, c.op.device) 3405 3406 @test_util.run_deprecated_v1 3407 def testLocationOverrides(self): 3408 with ops.device("/cpu:0"): 3409 with ops.device("/device:GPU:0"): 3410 a = constant_op.constant([2.0], name="a") 3411 # Note that this colocation is "redundant", since we are 3412 # within the scope of "/device:GPU:0". However, we would like to 3413 # preserve in the GraphDef that these two ops should be 3414 # colocated in a portable way. 3415 with ops.colocate_with(a.op): 3416 b = constant_op.constant(3.0) 3417 c = constant_op.constant(4.0) 3418 d = constant_op.constant(5.0) 3419 3420 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3421 self.assertEqual("/device:GPU:0", a.op.device) 3422 self.assertEqual(a.op.device, b.op.device) 3423 3424 # Test that device function stack is restored. 3425 self.assertEqual("/device:GPU:0", c.op.device) 3426 self.assertEqual("/device:CPU:0", d.op.device) 3427 3428 @test_util.run_deprecated_v1 3429 def testNestedColocateWith(self): 3430 a = constant_op.constant([2.0], name="a") 3431 with ops.colocate_with(a.op): 3432 b = constant_op.constant(3.0) 3433 with ops.colocate_with(b.op): 3434 c = constant_op.constant(4.0) 3435 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3436 self.assertEqual([b"loc:@a"], c.op.colocation_groups()) 3437 3438 @test_util.run_deprecated_v1 3439 def testMultiColocationGroups(self): 3440 a = constant_op.constant([2.0], name="a") 3441 b = constant_op.constant(3.0, name="b") 3442 with ops.colocate_with(a.op): 3443 with ops.colocate_with(b.op): 3444 c = constant_op.constant(4.0) 3445 self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) 3446 3447 @test_util.run_deprecated_v1 3448 def testColocationIgnoreStack(self): 3449 a = constant_op.constant([2.0], name="a") 3450 b = constant_op.constant(3.0, name="b") 3451 with ops.colocate_with(a.op): 3452 with ops.colocate_with(b.op, ignore_existing=True): 3453 c = constant_op.constant(4.0) 3454 self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) 3455 3456 @test_util.run_deprecated_v1 3457 def testColocateWithReset(self): 3458 a = constant_op.constant([2.0], name="a") 3459 with ops.colocate_with(a.op): 3460 b = constant_op.constant(3.0, name="b") 3461 with ops.colocate_with(None, ignore_existing=True): 3462 c = constant_op.constant(4.0, name="c") 3463 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3464 self.assertEqual([b"loc:@c"], c.op.colocation_groups()) 3465 3466 @test_util.run_deprecated_v1 3467 def testColocateWithInitialNoneThenNested(self): 3468 a = constant_op.constant([2.0], name="a") 3469 with ops.colocate_with(a.op): 3470 with ops.colocate_with(None, ignore_existing=True): 3471 b = constant_op.constant(3.0, name="b") 3472 with ops.colocate_with(b.op): 3473 c = constant_op.constant(4.0, name="c") 3474 self.assertEqual([b"loc:@b"], b.op.colocation_groups()) 3475 self.assertEqual([b"loc:@b"], c.op.colocation_groups()) 3476 3477 @test_util.run_deprecated_v1 3478 def testColocateVariables(self): 3479 a = variables.Variable([2.0], name="a") 3480 with ops.colocate_with(a.op): 3481 b = variables.Variable([3.0], name="b") 3482 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3483 3484 @test_util.run_deprecated_v1 3485 def testColocateResourceVariablesInFunction(self): 3486 with ops.device("/device:CPU:0"): 3487 a = resource_variable_ops.ResourceVariable(1.0) 3488 3489 @def_function.function 3490 def f(): 3491 with ops.colocate_with(a): 3492 b = array_ops.ones([], name="output") 3493 self.assertEqual("/device:CPU:0", b.op.device) 3494 f() 3495 3496 def testColocateWithVariableInFunction(self): 3497 v = variables.Variable(1.) 3498 3499 @def_function.function 3500 def f(): 3501 with ops.colocate_with(v): 3502 return array_ops.ones([], name="output") 3503 3504 f() 3505 graph_def = f.get_concrete_function().graph.as_graph_def() 3506 wrap_function.function_from_graph_def(graph_def, [], ["output"]) 3507 3508 3509class DeadlineTest(test_util.TensorFlowTestCase): 3510 3511 def testNoDeadlineSet(self): 3512 with ops.Graph().as_default() as g: 3513 get_deadline = test_ops.get_deadline() 3514 with self.session(graph=g) as sess: 3515 run_options = config_pb2.RunOptions() 3516 with self.assertRaises(errors.InvalidArgumentError): 3517 sess.run(get_deadline, options=run_options) 3518 3519 def testDeadlineSetTimesOut(self): 3520 with ops.Graph().as_default() as g: 3521 sleep_op = test_ops.sleep_op(10) 3522 with self.session(graph=g) as sess: 3523 run_options = config_pb2.RunOptions(timeout_in_ms=3_000) 3524 with self.assertRaises(errors.DeadlineExceededError): 3525 sess.run(sleep_op, options=run_options) 3526 3527 3528class DeprecatedTest(test_util.TensorFlowTestCase): 3529 3530 def testSuccess(self): 3531 with ops.Graph().as_default() as g: 3532 test_util.set_producer_version(g, 7) 3533 old = test_ops.old() 3534 with self.session(graph=g): 3535 old.run() 3536 3537 def _error(self): 3538 return ((r"Op Old is not available in GraphDef version %d\. " 3539 r"It has been removed in version 8\. For reasons\.") % 3540 versions.GRAPH_DEF_VERSION) 3541 3542 def testGraphConstructionFail(self): 3543 with ops.Graph().as_default(): 3544 with self.assertRaisesRegex(NotImplementedError, self._error()): 3545 test_ops.old() 3546 3547 3548class NameScopeTest(test_util.TensorFlowTestCase): 3549 3550 def testStripAndPrependScope(self): 3551 strs = [ 3552 "hidden1/hidden1/weights", # Same prefix. Should strip. 3553 "hidden1///hidden1/weights", # Extra "/". Should strip. 3554 "^hidden1/hidden1/weights", # Same prefix. Should strip. 3555 "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. 3556 "hhidden1/hidden1/weights", # Different prefix. Should keep. 3557 "hidden1" 3558 ] # Not a prefix. Should keep. 3559 expected_striped = [ 3560 "hidden1/weights", "hidden1/weights", "^hidden1/weights", 3561 "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" 3562 ] 3563 expected_prepended = [ 3564 "hidden2/hidden1/weights", "hidden2/hidden1/weights", 3565 "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", 3566 "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" 3567 ] 3568 name_scope_to_strip = "hidden1" 3569 name_scope_to_add = "hidden2" 3570 for es, ep, s in zip(expected_striped, expected_prepended, strs): 3571 striped = ops.strip_name_scope(s, name_scope_to_strip) 3572 self.assertEqual(es, striped) 3573 self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) 3574 3575 def testGetNameScope(self): 3576 with ops.Graph().as_default() as g: 3577 with ops.name_scope("scope1"): 3578 with ops.name_scope("scope2"): 3579 with ops.name_scope("scope3"): 3580 self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) 3581 self.assertEqual("scope1/scope2", g.get_name_scope()) 3582 self.assertEqual("scope1", g.get_name_scope()) 3583 self.assertEqual("", g.get_name_scope()) 3584 3585 def testTwoGraphs(self): 3586 3587 def f(): 3588 g1 = ops.Graph() 3589 g2 = ops.Graph() 3590 with g1.as_default(): 3591 with g2.as_default(): 3592 with ops.name_scope("_"): 3593 pass 3594 3595 self.assertRaisesRegex(ValueError, 3596 "'_' is not a valid (?:root )?scope name", f) 3597 3598 3599class EnableEagerExecutionTest(test_util.TensorFlowTestCase): 3600 3601 @test_util.run_v1_only("b/120545219") 3602 def testBadArgumentsToEnableEagerExecution(self): 3603 with self.assertRaisesRegex(TypeError, "config must be a tf.ConfigProto"): 3604 ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) 3605 with self.assertRaisesRegex(ValueError, "device_policy must be one of"): 3606 c = config_pb2.ConfigProto() 3607 ops.enable_eager_execution(c, c) 3608 with self.assertRaisesRegex(ValueError, "execution_mode must be one of"): 3609 c = config_pb2.ConfigProto() 3610 ops.enable_eager_execution(c, execution_mode=c) 3611 3612 3613class _TupleTensor(composite_tensor.CompositeTensor): 3614 """`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading.""" 3615 3616 def __init__(self, components): 3617 super(_TupleTensor, self).__init__() 3618 self._components = tuple(ops.convert_to_tensor(c) for c in components) 3619 3620 @property 3621 def _type_spec(self): 3622 return _TupleTensorSpec(type_spec.from_value(c) for c in self._components) 3623 3624 def __getitem__(self, key): 3625 return self._components[key] 3626 3627 def __len__(self): 3628 return len(self._components) 3629 3630 def __iter__(self): 3631 return iter(self._components) 3632 3633 3634class _TupleTensorSpec(type_spec.TypeSpec): 3635 3636 def __init__(self, specs): 3637 self._specs = specs 3638 3639 value_type = property(lambda self: _TupleTensor) 3640 _component_specs = property(lambda self: self._specs) 3641 3642 def _to_components(self, value): 3643 return value._components 3644 3645 def _from_components(self, components): 3646 return _TupleTensor(*components) 3647 3648 def _serialize(self): 3649 return (self._specs,) 3650 3651 3652class _MyTuple(object): 3653 """Pretend user-side class for `ConvertToCompositeTensorTest .""" 3654 3655 def __init__(self, components): 3656 super(_MyTuple, self).__init__() 3657 self._components = tuple(components) 3658 3659 def __getitem__(self, key): 3660 return self._components[key] 3661 3662 def __len__(self): 3663 return len(self._components) 3664 3665 def __iter__(self): 3666 return iter(self._components) 3667 3668 3669ops.register_tensor_conversion_function( 3670 _MyTuple, conversion_func=lambda x, *_, **__: _TupleTensor(x)) 3671 3672 3673class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase): 3674 3675 @test_util.disable_tfrt("TODO(kkb): This makes Kokoro tests fail.") 3676 def testCompositeTensorConversion(self): 3677 """Tests that a user can register a CompositeTensor converter.""" 3678 x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]])) 3679 y = ops.convert_to_tensor_or_composite(x) 3680 self.assertFalse(tensor_util.is_tf_type(y)) 3681 self.assertIsInstance(y, _TupleTensor) 3682 self.assertLen(y, len(x)) 3683 for x_, y_ in zip(x, y): 3684 self.assertIsInstance(y_, ops.Tensor) 3685 self.assertTrue(tensor_util.is_tf_type(y_)) 3686 self.assertAllEqual(x_, tensor_util.constant_value(y_)) 3687 3688 3689@test_util.disable_tfrt("Packing EagerTensors is not supported yet.") 3690class PackEagerTensorTest(test_util.TensorFlowTestCase): 3691 3692 def setUp(self): 3693 super(PackEagerTensorTest, self).setUp() 3694 context._reset_context() 3695 cpus = config.list_physical_devices("CPU") 3696 # Set 2 virtual CPUs 3697 config.set_logical_device_configuration(cpus[0], [ 3698 context.LogicalDeviceConfiguration(), 3699 context.LogicalDeviceConfiguration(), 3700 ]) 3701 3702 def testPack(self): 3703 with context.eager_mode(): 3704 with ops.device("CPU:0"): 3705 var0 = resource_variable_ops.ResourceVariable(1.0) 3706 c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 3707 with ops.device("CPU:1"): 3708 var1 = resource_variable_ops.ResourceVariable(2.0) 3709 var2 = resource_variable_ops.ResourceVariable([3.0]) 3710 c1 = constant_op.constant([9.0]) 3711 3712 packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle]) 3713 self.assertTrue(packed_var0.is_packed) 3714 self.assertEqual(packed_var0.dtype, var0.handle.dtype) 3715 self.assertEqual(packed_var0.shape, var0.handle.shape) 3716 self.assertEqual(packed_var0._handle_data, var0.handle._handle_data) 3717 self.assertIn("COMPOSITE:0", packed_var0.device) 3718 self.assertIn("COMPOSITE:0", packed_var0.backing_device) 3719 with self.assertRaises(errors.InvalidArgumentError): 3720 packed_var0.numpy() 3721 3722 # Different dtypes 3723 with self.assertRaises(ValueError): 3724 ops.pack_eager_tensors([var0.handle, c1]) 3725 3726 # Different shapes 3727 with self.assertRaises(ValueError): 3728 ops.pack_eager_tensors([c0, c1]) 3729 3730 # Different handle data 3731 with self.assertRaises(ValueError): 3732 ops.pack_eager_tensors([var0.handle, var2.handle]) 3733 3734 3735class GraphDefInputShapesTest(test_util.TensorFlowTestCase): 3736 3737 def setUpInputShapes(self, pre_add_input_shapes): 3738 3739 test_tensor_shape = [None, 1, 1, 1] 3740 3741 @def_function.function(input_signature=[ 3742 tensor_spec.TensorSpec(shape=test_tensor_shape, dtype=dtypes.float32) 3743 ]) 3744 def f(x): 3745 return array_ops.identity(x, name="output") 3746 3747 x = array_ops.ones([2, 1, 1, 1], dtype=dtypes.float32) 3748 f(x) 3749 3750 tensor_shape_proto = tensor_shape_pb2.TensorShapeProto(dim=[ 3751 tensor_shape_pb2.TensorShapeProto.Dim(size=-1 if d is None else d) 3752 for d in test_tensor_shape 3753 ]) 3754 list_proto = attr_value_pb2.AttrValue.ListValue(shape=[tensor_shape_proto]) 3755 concrete_function = f.get_concrete_function() 3756 if pre_add_input_shapes: 3757 attr_value = attr_value_pb2.AttrValue(list=list_proto) 3758 concrete_function = eager_function.ConcreteFunction( 3759 concrete_function.graph, 3760 attrs={"_input_shapes": attr_value}, 3761 spec=concrete_function._pre_initialized_function_spec) 3762 3763 test_graph = ops.Graph() 3764 with test_graph.as_default(): 3765 concrete_function.add_to_graph(g=test_graph) 3766 graph_def = test_graph.as_graph_def(add_shapes=True) 3767 self.assertLen(graph_def.library.function, 1) 3768 function_def = graph_def.library.function[0] 3769 input_shapes = function_def.attr["_input_shapes"] 3770 return input_shapes 3771 3772 def testGraphDefInputShapes(self): 3773 pre_added_input_shapes = self.setUpInputShapes(pre_add_input_shapes=True) 3774 post_added_input_shapes = self.setUpInputShapes(pre_add_input_shapes=False) 3775 self.assertProtoEquals(pre_added_input_shapes, post_added_input_shapes) 3776 3777 3778class TensorTest(test_util.TensorFlowTestCase): 3779 3780 def testToArrayEagerMode(self): 3781 3782 with context.eager_mode(): 3783 a = np.array(constant_op.constant(32), dtype=np.float32) 3784 b = np.array(constant_op.constant(32, dtype=dtypes.int64)) 3785 3786 self.assertEqual(a.dtype, np.dtype(np.float32)) 3787 self.assertEqual(b.dtype, np.dtype(np.int64)) 3788 3789 def testToArrayFunctionMode(self): 3790 3791 @def_function.function 3792 def f(): 3793 # Raises during trace compilation. 3794 return np.array(constant_op.constant(32), dtype=np.int32) 3795 3796 @def_function.function 3797 def g(): 3798 # Raises during trace compilation. 3799 return np.array(constant_op.constant(32)) 3800 3801 with self.assertRaisesRegex(NotImplementedError, 3802 "Cannot convert a symbolic tf.Tensor"): 3803 f() 3804 3805 with self.assertRaisesRegex(NotImplementedError, 3806 "Cannot convert a symbolic tf.Tensor"): 3807 g() 3808 3809 3810if __name__ == "__main__": 3811 googletest.main() 3812