1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for TPU outside compilation.""" 16 17import os 18import tempfile 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2 24from tensorboard.plugins.image import summary_v2 as image_summary_v2 25from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2 26from tensorflow.core.util import event_pb2 27from tensorflow.python.distribute import tpu_strategy as tpu_lib 28from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 29from tensorflow.python.eager import def_function 30from tensorflow.python.eager import remote 31from tensorflow.python.eager import test 32from tensorflow.python.framework import config 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_spec 37from tensorflow.python.lib.io import tf_record 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import gradients_impl 41from tensorflow.python.ops import image_ops 42from tensorflow.python.ops import logging_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import random_ops 45from tensorflow.python.ops import string_ops 46from tensorflow.python.ops import summary_ops_v2 as summary 47from tensorflow.python.ops import tensor_array_ops 48from tensorflow.python.platform import flags 49from tensorflow.python.platform import gfile 50from tensorflow.python.tpu import functional as tpu_functional 51from tensorflow.python.tpu import tpu 52from tensorflow.python.tpu import tpu_strategy_util 53from tensorflow.python.tpu.ops import tpu_ops 54 55FLAGS = flags.FLAGS 56flags.DEFINE_string("tpu", "", "Name of TPU to connect to.") 57flags.DEFINE_string("project", None, "Name of GCP project with TPU.") 58flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.") 59 60 61def get_tpu_cluster_resolver(): 62 resolver = tpu_cluster_resolver.TPUClusterResolver( 63 tpu=FLAGS.tpu, 64 zone=FLAGS.zone, 65 project=FLAGS.project, 66 ) 67 return resolver 68 69 70def get_tpu_strategy(): 71 resolver = get_tpu_cluster_resolver() 72 remote.connect_to_cluster(resolver) 73 tpu_strategy_util.initialize_tpu_system(resolver) 74 return tpu_lib.TPUStrategyV2(resolver) 75 76 77def computation_with_string_ops(x): 78 output = string_ops.string_format("1{}", x) 79 return string_ops.string_to_number(output) 80 81 82def _events_from_logdir(test_case, logdir): 83 """Reads summary events from log directory.""" 84 test_case.assertTrue(gfile.Exists(logdir)) 85 files = gfile.ListDirectory(logdir) 86 test_case.assertLen(files, 1) 87 records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) 88 result = [] 89 for r in records: 90 event = event_pb2.Event() 91 event.ParseFromString(r) 92 result.append(event) 93 return result 94 95 96def _rewrite_func_wrapper(tf_func): 97 98 def tpu_fn(*args, **kwargs): 99 # tpu.rewrite only accepts list of tensors as input. We need to flatten 100 # keyword arguments to meet this requirement. 101 concrete = tf_func.get_concrete_function(*(list(args) + 102 list(kwargs.values()))) 103 return tpu.rewrite(concrete.__call__, list(args) + list(kwargs.values())) 104 105 return def_function.function(tpu_fn) 106 107 108def _tpu_partitioned_call_wrapper(tf_func): 109 """Wrap a tensorflow Function with TPUPartitionedCall.""" 110 111 def inner_func(*args, **kwargs): 112 concrete = tf_func.get_concrete_function(*args, **kwargs) 113 # TPUPartitionedCall only accepts list of tensors as input args. 114 # Flatten keyword arguments and do some basic ordering: 115 # Positional args + Flattened keyword args + Captured args. 116 op_args = list(args) + list(kwargs.values()) + concrete.captured_inputs 117 return tpu_functional.TPUPartitionedCall( 118 args=op_args, 119 device_ordinal=tpu_ops.tpu_ordinal_selector(), 120 Tout=[o.type for o in concrete.function_def.signature.output_arg], 121 f=concrete) 122 123 return def_function.function(inner_func) 124 125 126class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): 127 128 def setUp(self): 129 super(TpuOutsideCompilationTest, self).setUp() 130 config.set_soft_device_placement(False) 131 132 def testHostNoInput(self): 133 strategy = get_tpu_strategy() 134 135 def outside_fn(): 136 logging_ops.print_v2("Outside compiled") 137 138 @def_function.function 139 def train_step(): 140 141 def tpu_fn(x): 142 x2 = x + 5.0 143 tpu.outside_compilation(outside_fn) 144 return x2 + 5.0 145 146 return strategy.run(tpu_fn, args=(25.0,)) 147 148 self.assertAllEqual( 149 strategy.experimental_local_results(train_step()), 150 constant_op.constant(35., shape=(strategy.num_replicas_in_sync))) 151 152 def testHostInputOnly(self): 153 strategy = get_tpu_strategy() 154 155 def outside_fn(x): 156 logging_ops.print_v2("Outside compiled", x) 157 158 @def_function.function 159 def train_step(): 160 161 def tpu_fn(x): 162 x2 = x + 5.0 163 tpu.outside_compilation(outside_fn, x2) 164 return x2 + 5.0 165 166 return strategy.run(tpu_fn, args=(25.0,)) 167 168 self.assertAllEqual( 169 strategy.experimental_local_results(train_step()), 170 constant_op.constant(35., shape=(strategy.num_replicas_in_sync))) 171 172 def testHostInputOutput(self): 173 strategy = get_tpu_strategy() 174 175 def outside_fn(x): 176 logging_ops.print_v2("Outside compiled", x) 177 return x + 6.0 178 179 @def_function.function 180 def train_step(): 181 182 def tpu_fn(x): 183 x2 = x + 5.0 184 output = tpu.outside_compilation(outside_fn, x2) 185 return output 186 187 return strategy.run(tpu_fn, args=(25.0,)) 188 189 self.assertAllEqual( 190 strategy.experimental_local_results(train_step()), 191 constant_op.constant(36., shape=(strategy.num_replicas_in_sync))) 192 193 def testHostMultipleInputs(self): 194 strategy = get_tpu_strategy() 195 val0 = np.arange(6).reshape((2, 3)).astype(np.float32) 196 val1 = np.arange(6).reshape((3, 2)).astype(np.float32) 197 198 def outside_fn(arg0, arg1): 199 tmp = array_ops.reshape(arg1, array_ops.shape(arg0)) 200 ret0 = arg0 + tmp 201 ret1 = math_ops.matmul(arg0, arg1) 202 ret2 = array_ops.concat([arg0, tmp], 0) 203 return ret0, ret1, ret2 204 205 @def_function.function 206 def train_step(): 207 208 def tpu_fn(x, y): 209 a = x + 7.0 210 b = y * 2.0 211 c, d, e = tpu.outside_compilation(outside_fn, a, b) 212 return (math_ops.reduce_max(c) + math_ops.reduce_min(d) + 213 math_ops.reduce_sum(e)) 214 215 return strategy.run(tpu_fn, args=(val0, val1)) 216 217 self.assertAllEqual( 218 strategy.experimental_local_results(train_step()), 219 constant_op.constant(213., shape=(strategy.num_replicas_in_sync))) 220 221 def testMultipleClusters(self): 222 strategy = get_tpu_strategy() 223 224 def outside_fn1(x): 225 logging_ops.print_v2("Outside compiled", x) 226 return x + 6.0 227 228 def outside_fn2(x): 229 logging_ops.print_v2("Outside compiled", x) 230 return x - 18.0 231 232 @def_function.function 233 def train_step(): 234 235 def tpu_fn(x): 236 x2 = x + 5.0 237 output1 = tpu.outside_compilation(outside_fn1, x2) 238 x3 = output1 + 3.0 239 output2 = tpu.outside_compilation(outside_fn2, x3) 240 return output2 241 242 return strategy.run(tpu_fn, args=(25.0,)) 243 244 self.assertAllEqual( 245 strategy.experimental_local_results(train_step()), 246 constant_op.constant(21., shape=(strategy.num_replicas_in_sync))) 247 248 @parameterized.parameters((True), (False)) 249 def testOutsideCompilationControlFlowIf(self, take_true_branch): 250 strategy = get_tpu_strategy() 251 252 def outside_fn(x): 253 logging_ops.print_v2("Outside compiled", x) 254 return x + 6.0 255 256 input_value = 51.0 if take_true_branch else 25.0 257 258 @def_function.function 259 def train_step(): 260 261 def tpu_fn(x): 262 x2 = x + 5.0 263 if x < 50.0: 264 return tpu.outside_compilation(outside_fn, x2) 265 else: 266 return x2 267 268 return strategy.run(tpu_fn, args=(input_value,)) 269 270 output_value = 36.0 271 if take_true_branch: 272 output_value = 56.0 273 self.assertAllEqual( 274 strategy.experimental_local_results(train_step()), 275 constant_op.constant( 276 output_value, shape=(strategy.num_replicas_in_sync))) 277 278 def testOutsideCompilationControlFlowWhile(self): 279 strategy = get_tpu_strategy() 280 281 def outside_fn(x): 282 logging_ops.print_v2("Outside compiled", x) 283 return x + 6.0 284 285 @def_function.function 286 def train_step(): 287 288 def tpu_fn(x): 289 x2 = x + 5.0 290 while x2 < 50.0: 291 x2 = tpu.outside_compilation(outside_fn, x2) 292 return x2 + 4.0 293 294 return strategy.run(tpu_fn, args=(25.0,)) 295 296 self.assertAllEqual( 297 strategy.experimental_local_results(train_step()), 298 constant_op.constant(58., shape=(strategy.num_replicas_in_sync))) 299 300 def testOutsideCompilationHostControlFlow(self): 301 """Tests that control flow on host for outside_compilation works.""" 302 strategy = get_tpu_strategy() 303 304 def outside_fn(x): 305 n = 0 306 while n < 4: 307 x = x + 6.0 308 n = n + 1 309 return x 310 311 @def_function.function 312 def train_step(): 313 314 def tpu_fn(x): 315 x2 = x + 5.0 316 x2 = tpu.outside_compilation(outside_fn, x2) 317 return x2 + 4.0 318 319 return strategy.run(tpu_fn, args=(25.0,)) 320 321 self.assertAllEqual( 322 strategy.experimental_local_results(train_step()), 323 constant_op.constant(58., shape=(strategy.num_replicas_in_sync))) 324 325 def testSummary(self): 326 strategy = get_tpu_strategy() 327 328 def host_computation(x): 329 scalar_summary_v2.scalar("x", x, step=0) 330 return x * 2.0 331 332 @def_function.function 333 def step(): 334 335 def computation(x): 336 x = x + 1.0 337 y = tpu.outside_compilation(host_computation, x) 338 y = tpu.outside_compilation(host_computation, x) 339 return y + 1.0 340 341 return strategy.run(computation, args=(2.0,)) 342 343 summary_writer = summary.create_file_writer( 344 os.path.join(os.getenv("TEST_TMPDIR", "/tmp")), flush_millis=10000) 345 with summary_writer.as_default(), summary.always_record_summaries(): 346 self.assertAllEqual( 347 strategy.experimental_local_results(step()), 348 constant_op.constant(7., shape=(strategy.num_replicas_in_sync))) 349 350 @parameterized.parameters((True), (False)) 351 def testSummaryInCond(self, take_true_branch): 352 strategy = get_tpu_strategy() 353 354 def host_computation(x): 355 scalar_summary_v2.scalar("x", x, step=0) 356 return x * 2.0 357 358 @def_function.function 359 def step(take_true_branch): 360 361 def computation(x): 362 x = x + 1.0 363 if x < 5.0: 364 y = tpu.outside_compilation(host_computation, x) 365 y = tpu.outside_compilation(host_computation, x) 366 x = y 367 return x + 1.0 368 369 if take_true_branch: 370 return strategy.run(computation, args=(2.0,)) 371 else: 372 return strategy.run(computation, args=(10.0,)) 373 374 summary_writer = summary.create_file_writer( 375 os.path.join(os.getenv("TEST_TMPDIR", "/tmp")), flush_millis=10000) 376 377 output_value = 12. 378 if take_true_branch: 379 output_value = 7. 380 with summary_writer.as_default(), summary.always_record_summaries(): 381 self.assertAllEqual( 382 strategy.experimental_local_results(step(take_true_branch)), 383 constant_op.constant( 384 output_value, shape=(strategy.num_replicas_in_sync))) 385 386 def testSummaryInWhile(self): 387 strategy = get_tpu_strategy() 388 389 def host_computation(x): 390 scalar_summary_v2.scalar("x", x, step=0) 391 return x * 2.0 392 393 @def_function.function 394 def step(): 395 396 def computation(x): 397 n = 0 398 while n < 3: 399 x = x + 1.0 400 y = tpu.outside_compilation(host_computation, x) 401 y = tpu.outside_compilation(host_computation, x) 402 x = y 403 n = n + 1 404 return y + 1.0 405 406 return strategy.run(computation, args=(2.0,)) 407 408 summary_writer = summary.create_file_writer( 409 os.path.join(os.getenv("TEST_TMPDIR", "/tmp")), flush_millis=10000) 410 with summary_writer.as_default(), summary.always_record_summaries(): 411 self.assertAllEqual( 412 strategy.experimental_local_results(step()), 413 constant_op.constant(31., shape=(strategy.num_replicas_in_sync))) 414 415 def testOutsideCompilationAtHeadAndTail(self): 416 """Tests that outside_compilation at head/tail of TPU computation works.""" 417 strategy = get_tpu_strategy() 418 419 def host_computation(x): 420 return x * 2.0 421 422 @def_function.function 423 def train_step(): 424 425 def computation(x): 426 w = tpu.outside_compilation(host_computation, x) 427 y = w + 1.0 428 z = tpu.outside_compilation(host_computation, y) 429 return z + 5.0 430 431 return strategy.run(computation, args=(2.0,)) 432 self.assertAllEqual( 433 strategy.experimental_local_results(train_step()), 434 constant_op.constant(15., shape=(strategy.num_replicas_in_sync))) 435 436 def testGradientAcrossOutsideCompilation(self): 437 """Tests compiled gradients can contain host computations.""" 438 strategy = get_tpu_strategy() 439 440 def host_computation(a): 441 b = a * a 442 c = b * b 443 return c 444 445 @def_function.function 446 def train_step(): 447 def computation(x, y): 448 a = x + 7.0 449 b = tpu.outside_compilation(host_computation, a) 450 c = b * y 451 d = gradients_impl.gradients( 452 [c], [x], colocate_gradients_with_ops=True)[0] 453 return d 454 455 return strategy.run(computation, args=(2.0, 3.0)) 456 self.assertAllEqual( 457 strategy.experimental_local_results(train_step()), 458 constant_op.constant(8748., shape=(strategy.num_replicas_in_sync))) 459 460 def testGradientOfGradientAcrossOutsideCompilation(self): 461 """Tests compiled gradients of gradients can contain host computations.""" 462 strategy = get_tpu_strategy() 463 464 def host_computation(a): 465 b = a * a 466 c = b * b 467 return c 468 469 @def_function.function 470 def train_step(): 471 def computation(x, y): 472 a = x + 7.0 473 b = tpu.outside_compilation(host_computation, a) 474 c = b * y 475 d = gradients_impl.gradients( 476 [c], [x], colocate_gradients_with_ops=True)[0] 477 e = gradients_impl.gradients( 478 [d], [x], colocate_gradients_with_ops=True)[0] 479 return e 480 481 return strategy.run(computation, args=(2.0, 3.0)) 482 self.assertAllEqual( 483 strategy.experimental_local_results(train_step()), 484 constant_op.constant(2916., shape=(strategy.num_replicas_in_sync))) 485 486 def testColocateGradientWithOutsideCompiledOp(self): 487 strategy = get_tpu_strategy() 488 489 @def_function.function 490 def train_step(): 491 492 @def_function.function 493 def tpu_fn(x): 494 x1 = tpu.outside_compilation(math_ops.sqrt, x) 495 grad = gradients_impl.gradients([x1], [x], 496 colocate_gradients_with_ops=True)[0] 497 sqrt = [ 498 op for op in ops.get_default_graph().get_operations() 499 if op.type == "Sqrt" 500 ][0] 501 sqrt_grad = [ 502 op for op in ops.get_default_graph().get_operations() 503 if op.type == "SqrtGrad" 504 ][0] 505 assert sqrt.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) == b"0" 506 assert (sqrt_grad.get_attr( 507 tpu._OUTSIDE_COMPILATION_ATTR) == b"0.gradients/uid") 508 return grad 509 510 return strategy.run(tpu_fn, args=(25.0,)) 511 512 self.assertAllEqual( 513 strategy.experimental_local_results(train_step()), 514 constant_op.constant(.1, shape=(strategy.num_replicas_in_sync))) 515 516 517class OutsideCompilationOnUnsupportedOpTest(test.TestCase, 518 parameterized.TestCase): 519 520 def setUp(self): 521 super(OutsideCompilationOnUnsupportedOpTest, self).setUp() 522 config.set_soft_device_placement(True) 523 524 def testStringOpWithManualOutsideCompilation(self): 525 strategy = get_tpu_strategy() 526 527 @def_function.function 528 def train_step(x): 529 530 def computation(x): 531 return tpu.outside_compilation(computation_with_string_ops, x) 532 533 return strategy.run(computation, args=(x,)) 534 535 self.assertAllEqual( 536 strategy.experimental_local_results(train_step(0)), 537 constant_op.constant(10, shape=(strategy.num_replicas_in_sync))) 538 539 def testStringOpWithAutoOutsideCompilation(self): 540 strategy = get_tpu_strategy() 541 542 @def_function.function 543 def train_step(x): 544 545 def computation(x): 546 return computation_with_string_ops(x) 547 548 return strategy.run(computation, args=(x,)) 549 550 self.assertAllEqual( 551 strategy.experimental_local_results(train_step(0)), 552 constant_op.constant(10, shape=(strategy.num_replicas_in_sync))) 553 554 # Regression test for b/180509859. 555 def testImageSummary(self): 556 strategy = get_tpu_strategy() 557 558 def run(): 559 560 @def_function.function 561 def sample_sequence(): 562 bsz = 3 563 max_length = 32 * 32 564 565 def f(): 566 567 def body(step, tokens): 568 next_token = random_ops.random_uniform([bsz]) 569 tokens = tokens.write(step, next_token) 570 return (step + 1, tokens) 571 572 def cond(step, tokens): 573 del tokens 574 return math_ops.less(step, max_length) 575 576 tokens_var = tensor_array_ops.TensorArray( 577 dtype=dtypes.float32, 578 size=max_length, 579 dynamic_size=False, 580 clear_after_read=False, 581 element_shape=(bsz,), 582 name="tokens_accumulator", 583 ) 584 585 step = constant_op.constant(0) 586 step, tokens_var = control_flow_ops.while_loop( 587 cond, body, [step, tokens_var]) 588 589 image_flat = array_ops.transpose(tokens_var.stack(), [1, 0]) 590 image = array_ops.tile( 591 array_ops.reshape(image_flat, [bsz, 32, 32, 1]), [1, 1, 1, 3]) 592 image_summary_v2.image("image_sample", image, 593 constant_op.constant(5, dtype=dtypes.int64)) 594 595 return strategy.run(f) 596 597 sample_sequence() 598 599 logdir = tempfile.mkdtemp() 600 summary_writer = summary.create_file_writer(logdir, flush_millis=10000) 601 with summary_writer.as_default(), summary.always_record_summaries(): 602 run() 603 events = _events_from_logdir(self, logdir) 604 decoded_image = image_ops.decode_png( 605 events[1].summary.value[0].tensor.string_val[2]).numpy() 606 # Ensure that non-zero values were written to the image summary. 607 self.assertNotAllEqual( 608 array_ops.zeros((3072,), dtype=dtypes.float32), 609 list(decoded_image.flat)) 610 611 def testSummaryWithAutoOutsideCompilation(self): 612 strategy = get_tpu_strategy() 613 614 def host_computation(x): 615 scalar_summary_v2.scalar("x", x, step=0) 616 return x * 2.0 617 618 @def_function.function 619 def step(): 620 621 def computation(x): 622 x = x + 1.0 623 y = host_computation(x) 624 return y + 1.0 625 626 return strategy.run(computation, args=(2.0,)) 627 628 logdir = tempfile.mkdtemp() 629 summary_writer = summary.create_file_writer(logdir, flush_millis=10000) 630 with summary_writer.as_default(), summary.always_record_summaries(): 631 self.assertAllEqual( 632 strategy.experimental_local_results(step()), 633 constant_op.constant(7., shape=(strategy.num_replicas_in_sync))) 634 events = _events_from_logdir(self, logdir) 635 # There will be 2 entries: 1 summary file header entry, and 1 entry 636 # written by host. 637 self.assertLen(events, 2) 638 self.assertEqual(events[1].summary.value[0].tag, "x") 639 640 def testNestedFunctionScalarSummary(self): 641 strategy = get_tpu_strategy() 642 643 def host_computation(x): 644 scalar_summary_v2.scalar("x", x, step=0) 645 return x * 2.0 646 647 @def_function.function 648 def step(): 649 650 @def_function.function 651 def computation(x): 652 x = x + 1.0 653 y = host_computation(x) 654 return y + 1.0 655 656 return strategy.run(computation, args=(2.0,)) 657 658 logdir = tempfile.mkdtemp() 659 summary_writer = summary.create_file_writer(logdir, flush_millis=10000) 660 with summary_writer.as_default(), summary.always_record_summaries(): 661 self.assertAllEqual( 662 strategy.experimental_local_results(step()), 663 constant_op.constant(7., shape=(strategy.num_replicas_in_sync))) 664 events = _events_from_logdir(self, logdir) 665 # There will be 2 entries: 1 summary file header entry, and 1 entry 666 # written by host. 667 self.assertLen(events, 2) 668 self.assertEqual(events[1].summary.value[0].tag, "x") 669 670 def testHistogramSummaryWithAutoOutsideCompilation(self): 671 strategy = get_tpu_strategy() 672 673 def host_computation(x): 674 histogram_summary_v2.histogram("x", x, step=0) 675 return x * 2.0 676 677 @def_function.function 678 def step(): 679 680 def computation(x): 681 x = x + 1.0 682 y = host_computation(x) 683 return y + 1.0 684 685 return strategy.run(computation, args=(2.0,)) 686 687 logdir = tempfile.mkdtemp() 688 summary_writer = summary.create_file_writer(logdir, flush_millis=10000) 689 with summary_writer.as_default(), summary.always_record_summaries(): 690 self.assertAllEqual( 691 strategy.experimental_local_results(step()), 692 constant_op.constant(7., shape=(strategy.num_replicas_in_sync))) 693 events = _events_from_logdir(self, logdir) 694 # There will be 2 entries: 1 summary file header entry, and 1 entry 695 # written by host. 696 self.assertLen(events, 2) 697 self.assertEqual(events[1].summary.value[0].tag, "x") 698 699 @parameterized.parameters((True), (False)) 700 def testSummaryControlFlowIfWithAutoOutsideCompilation( 701 self, take_true_branch): 702 strategy = get_tpu_strategy() 703 704 @def_function.function 705 def step(): 706 707 def computation(x): 708 x = x + 1.0 709 if x < 5: 710 scalar_summary_v2.scalar("x", x, step=0) 711 x = x * 2.0 712 return x + 1.0 713 714 if take_true_branch: 715 return strategy.run(computation, args=(2.0,)) 716 else: 717 return strategy.run(computation, args=(10.0,)) 718 719 logdir = tempfile.mkdtemp() 720 summary_writer = summary.create_file_writer(logdir, flush_millis=10000) 721 output_value = 12. 722 if take_true_branch: 723 output_value = 7. 724 with summary_writer.as_default(), summary.always_record_summaries(): 725 self.assertAllEqual( 726 strategy.experimental_local_results(step()), 727 constant_op.constant( 728 output_value, shape=(strategy.num_replicas_in_sync))) 729 if take_true_branch: 730 events = _events_from_logdir(self, logdir) 731 # There will be 2 entries: 1 summary file header entry, and 1 entry 732 # written by host. 733 # 734 self.assertLen(events, 2) 735 self.assertEqual(events[1].summary.value[0].tag, "cond/x") 736 737 def testAutoOutsideCompilationWithFunctionalNodes(self): 738 strategy = get_tpu_strategy() 739 740 @def_function.function 741 def train_step(a, b): 742 743 def fn(a, b): 744 fn1 = lambda: computation_with_string_ops(a * 100) 745 fn2 = lambda: computation_with_string_ops(a) 746 pred = math_ops.greater_equal(a, b) 747 result = array_ops.identity( 748 control_flow_ops.cond(pred, fn1, fn2), 749 name="uncompilable_control_flow") 750 return result 751 752 return strategy.run(fn, args=(a, b)) 753 754 self.assertAllEqual( 755 strategy.experimental_local_results(train_step(0.0, -1.0)), 756 constant_op.constant(10, shape=(strategy.num_replicas_in_sync))) 757 758 def testRandomOpsWithAutoOutsideCompilation(self): 759 strategy = get_tpu_strategy() 760 761 @def_function.function 762 def train_step(): 763 764 def computation(): 765 return random_ops.random_normal(shape=[1, 2, 3]) 766 767 return strategy.run(computation, args=()) 768 769 self.assertAllEqual( 770 strategy.experimental_local_results(train_step())[0].shape, [1, 2, 3]) 771 772 def testOutsideCompilationWithTPUPartitionedCallOp(self): 773 """Tests that control flow with TPUPartitionedCall including outside_compilation works.""" 774 get_tpu_strategy() 775 776 def host_computation(x): 777 return x + 1 778 779 @def_function.function() 780 def train_step(x): 781 x2 = x + 5.0 782 logging_ops.print_v2(x2) 783 x2 = tpu.outside_compilation(host_computation, x2) 784 return x2 + 4.0 785 786 tpu_fn = _rewrite_func_wrapper(train_step) 787 partitioned_tpu_fn = _tpu_partitioned_call_wrapper(tpu_fn) 788 789 concrete = partitioned_tpu_fn.get_concrete_function( 790 x=tensor_spec.TensorSpec( 791 shape=(1), dtype=dtypes.float32, name="input_tensor")) 792 793 self.assertIsInstance( 794 concrete(array_ops.ones((1), dtype=dtypes.float32))[0], ops.Tensor) 795 796 797if __name__ == "__main__": 798 test.main() 799