xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_outside_compilation_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPU outside compilation."""
16
17import os
18import tempfile
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2
24from tensorboard.plugins.image import summary_v2 as image_summary_v2
25from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2
26from tensorflow.core.util import event_pb2
27from tensorflow.python.distribute import tpu_strategy as tpu_lib
28from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
29from tensorflow.python.eager import def_function
30from tensorflow.python.eager import remote
31from tensorflow.python.eager import test
32from tensorflow.python.framework import config
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_spec
37from tensorflow.python.lib.io import tf_record
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import gradients_impl
41from tensorflow.python.ops import image_ops
42from tensorflow.python.ops import logging_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import random_ops
45from tensorflow.python.ops import string_ops
46from tensorflow.python.ops import summary_ops_v2 as summary
47from tensorflow.python.ops import tensor_array_ops
48from tensorflow.python.platform import flags
49from tensorflow.python.platform import gfile
50from tensorflow.python.tpu import functional as tpu_functional
51from tensorflow.python.tpu import tpu
52from tensorflow.python.tpu import tpu_strategy_util
53from tensorflow.python.tpu.ops import tpu_ops
54
55FLAGS = flags.FLAGS
56flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
57flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
58flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
59
60
61def get_tpu_cluster_resolver():
62  resolver = tpu_cluster_resolver.TPUClusterResolver(
63      tpu=FLAGS.tpu,
64      zone=FLAGS.zone,
65      project=FLAGS.project,
66  )
67  return resolver
68
69
70def get_tpu_strategy():
71  resolver = get_tpu_cluster_resolver()
72  remote.connect_to_cluster(resolver)
73  tpu_strategy_util.initialize_tpu_system(resolver)
74  return tpu_lib.TPUStrategyV2(resolver)
75
76
77def computation_with_string_ops(x):
78  output = string_ops.string_format("1{}", x)
79  return string_ops.string_to_number(output)
80
81
82def _events_from_logdir(test_case, logdir):
83  """Reads summary events from log directory."""
84  test_case.assertTrue(gfile.Exists(logdir))
85  files = gfile.ListDirectory(logdir)
86  test_case.assertLen(files, 1)
87  records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
88  result = []
89  for r in records:
90    event = event_pb2.Event()
91    event.ParseFromString(r)
92    result.append(event)
93  return result
94
95
96def _rewrite_func_wrapper(tf_func):
97
98  def tpu_fn(*args, **kwargs):
99    # tpu.rewrite only accepts list of tensors as input. We need to flatten
100    # keyword arguments to meet this requirement.
101    concrete = tf_func.get_concrete_function(*(list(args) +
102                                               list(kwargs.values())))
103    return tpu.rewrite(concrete.__call__, list(args) + list(kwargs.values()))
104
105  return def_function.function(tpu_fn)
106
107
108def _tpu_partitioned_call_wrapper(tf_func):
109  """Wrap a tensorflow Function with TPUPartitionedCall."""
110
111  def inner_func(*args, **kwargs):
112    concrete = tf_func.get_concrete_function(*args, **kwargs)
113    # TPUPartitionedCall only accepts list of tensors as input args.
114    # Flatten keyword arguments and do some basic ordering:
115    # Positional args + Flattened keyword args + Captured args.
116    op_args = list(args) + list(kwargs.values()) + concrete.captured_inputs
117    return tpu_functional.TPUPartitionedCall(
118        args=op_args,
119        device_ordinal=tpu_ops.tpu_ordinal_selector(),
120        Tout=[o.type for o in concrete.function_def.signature.output_arg],
121        f=concrete)
122
123  return def_function.function(inner_func)
124
125
126class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
127
128  def setUp(self):
129    super(TpuOutsideCompilationTest, self).setUp()
130    config.set_soft_device_placement(False)
131
132  def testHostNoInput(self):
133    strategy = get_tpu_strategy()
134
135    def outside_fn():
136      logging_ops.print_v2("Outside compiled")
137
138    @def_function.function
139    def train_step():
140
141      def tpu_fn(x):
142        x2 = x + 5.0
143        tpu.outside_compilation(outside_fn)
144        return x2 + 5.0
145
146      return strategy.run(tpu_fn, args=(25.0,))
147
148    self.assertAllEqual(
149        strategy.experimental_local_results(train_step()),
150        constant_op.constant(35., shape=(strategy.num_replicas_in_sync)))
151
152  def testHostInputOnly(self):
153    strategy = get_tpu_strategy()
154
155    def outside_fn(x):
156      logging_ops.print_v2("Outside compiled", x)
157
158    @def_function.function
159    def train_step():
160
161      def tpu_fn(x):
162        x2 = x + 5.0
163        tpu.outside_compilation(outside_fn, x2)
164        return x2 + 5.0
165
166      return strategy.run(tpu_fn, args=(25.0,))
167
168    self.assertAllEqual(
169        strategy.experimental_local_results(train_step()),
170        constant_op.constant(35., shape=(strategy.num_replicas_in_sync)))
171
172  def testHostInputOutput(self):
173    strategy = get_tpu_strategy()
174
175    def outside_fn(x):
176      logging_ops.print_v2("Outside compiled", x)
177      return x + 6.0
178
179    @def_function.function
180    def train_step():
181
182      def tpu_fn(x):
183        x2 = x + 5.0
184        output = tpu.outside_compilation(outside_fn, x2)
185        return output
186
187      return strategy.run(tpu_fn, args=(25.0,))
188
189    self.assertAllEqual(
190        strategy.experimental_local_results(train_step()),
191        constant_op.constant(36., shape=(strategy.num_replicas_in_sync)))
192
193  def testHostMultipleInputs(self):
194    strategy = get_tpu_strategy()
195    val0 = np.arange(6).reshape((2, 3)).astype(np.float32)
196    val1 = np.arange(6).reshape((3, 2)).astype(np.float32)
197
198    def outside_fn(arg0, arg1):
199      tmp = array_ops.reshape(arg1, array_ops.shape(arg0))
200      ret0 = arg0 + tmp
201      ret1 = math_ops.matmul(arg0, arg1)
202      ret2 = array_ops.concat([arg0, tmp], 0)
203      return ret0, ret1, ret2
204
205    @def_function.function
206    def train_step():
207
208      def tpu_fn(x, y):
209        a = x + 7.0
210        b = y * 2.0
211        c, d, e = tpu.outside_compilation(outside_fn, a, b)
212        return (math_ops.reduce_max(c) + math_ops.reduce_min(d) +
213                math_ops.reduce_sum(e))
214
215      return strategy.run(tpu_fn, args=(val0, val1))
216
217    self.assertAllEqual(
218        strategy.experimental_local_results(train_step()),
219        constant_op.constant(213., shape=(strategy.num_replicas_in_sync)))
220
221  def testMultipleClusters(self):
222    strategy = get_tpu_strategy()
223
224    def outside_fn1(x):
225      logging_ops.print_v2("Outside compiled", x)
226      return x + 6.0
227
228    def outside_fn2(x):
229      logging_ops.print_v2("Outside compiled", x)
230      return x - 18.0
231
232    @def_function.function
233    def train_step():
234
235      def tpu_fn(x):
236        x2 = x + 5.0
237        output1 = tpu.outside_compilation(outside_fn1, x2)
238        x3 = output1 + 3.0
239        output2 = tpu.outside_compilation(outside_fn2, x3)
240        return output2
241
242      return strategy.run(tpu_fn, args=(25.0,))
243
244    self.assertAllEqual(
245        strategy.experimental_local_results(train_step()),
246        constant_op.constant(21., shape=(strategy.num_replicas_in_sync)))
247
248  @parameterized.parameters((True), (False))
249  def testOutsideCompilationControlFlowIf(self, take_true_branch):
250    strategy = get_tpu_strategy()
251
252    def outside_fn(x):
253      logging_ops.print_v2("Outside compiled", x)
254      return x + 6.0
255
256    input_value = 51.0 if take_true_branch else 25.0
257
258    @def_function.function
259    def train_step():
260
261      def tpu_fn(x):
262        x2 = x + 5.0
263        if x < 50.0:
264          return tpu.outside_compilation(outside_fn, x2)
265        else:
266          return x2
267
268      return strategy.run(tpu_fn, args=(input_value,))
269
270    output_value = 36.0
271    if take_true_branch:
272      output_value = 56.0
273    self.assertAllEqual(
274        strategy.experimental_local_results(train_step()),
275        constant_op.constant(
276            output_value, shape=(strategy.num_replicas_in_sync)))
277
278  def testOutsideCompilationControlFlowWhile(self):
279    strategy = get_tpu_strategy()
280
281    def outside_fn(x):
282      logging_ops.print_v2("Outside compiled", x)
283      return x + 6.0
284
285    @def_function.function
286    def train_step():
287
288      def tpu_fn(x):
289        x2 = x + 5.0
290        while x2 < 50.0:
291          x2 = tpu.outside_compilation(outside_fn, x2)
292        return x2 + 4.0
293
294      return strategy.run(tpu_fn, args=(25.0,))
295
296    self.assertAllEqual(
297        strategy.experimental_local_results(train_step()),
298        constant_op.constant(58., shape=(strategy.num_replicas_in_sync)))
299
300  def testOutsideCompilationHostControlFlow(self):
301    """Tests that control flow on host for outside_compilation works."""
302    strategy = get_tpu_strategy()
303
304    def outside_fn(x):
305      n = 0
306      while n < 4:
307        x = x + 6.0
308        n = n + 1
309      return x
310
311    @def_function.function
312    def train_step():
313
314      def tpu_fn(x):
315        x2 = x + 5.0
316        x2 = tpu.outside_compilation(outside_fn, x2)
317        return x2 + 4.0
318
319      return strategy.run(tpu_fn, args=(25.0,))
320
321    self.assertAllEqual(
322        strategy.experimental_local_results(train_step()),
323        constant_op.constant(58., shape=(strategy.num_replicas_in_sync)))
324
325  def testSummary(self):
326    strategy = get_tpu_strategy()
327
328    def host_computation(x):
329      scalar_summary_v2.scalar("x", x, step=0)
330      return x * 2.0
331
332    @def_function.function
333    def step():
334
335      def computation(x):
336        x = x + 1.0
337        y = tpu.outside_compilation(host_computation, x)
338        y = tpu.outside_compilation(host_computation, x)
339        return y + 1.0
340
341      return strategy.run(computation, args=(2.0,))
342
343    summary_writer = summary.create_file_writer(
344        os.path.join(os.getenv("TEST_TMPDIR", "/tmp")), flush_millis=10000)
345    with summary_writer.as_default(), summary.always_record_summaries():
346      self.assertAllEqual(
347          strategy.experimental_local_results(step()),
348          constant_op.constant(7., shape=(strategy.num_replicas_in_sync)))
349
350  @parameterized.parameters((True), (False))
351  def testSummaryInCond(self, take_true_branch):
352    strategy = get_tpu_strategy()
353
354    def host_computation(x):
355      scalar_summary_v2.scalar("x", x, step=0)
356      return x * 2.0
357
358    @def_function.function
359    def step(take_true_branch):
360
361      def computation(x):
362        x = x + 1.0
363        if x < 5.0:
364          y = tpu.outside_compilation(host_computation, x)
365          y = tpu.outside_compilation(host_computation, x)
366          x = y
367        return x + 1.0
368
369      if take_true_branch:
370        return strategy.run(computation, args=(2.0,))
371      else:
372        return strategy.run(computation, args=(10.0,))
373
374    summary_writer = summary.create_file_writer(
375        os.path.join(os.getenv("TEST_TMPDIR", "/tmp")), flush_millis=10000)
376
377    output_value = 12.
378    if take_true_branch:
379      output_value = 7.
380    with summary_writer.as_default(), summary.always_record_summaries():
381      self.assertAllEqual(
382          strategy.experimental_local_results(step(take_true_branch)),
383          constant_op.constant(
384              output_value, shape=(strategy.num_replicas_in_sync)))
385
386  def testSummaryInWhile(self):
387    strategy = get_tpu_strategy()
388
389    def host_computation(x):
390      scalar_summary_v2.scalar("x", x, step=0)
391      return x * 2.0
392
393    @def_function.function
394    def step():
395
396      def computation(x):
397        n = 0
398        while n < 3:
399          x = x + 1.0
400          y = tpu.outside_compilation(host_computation, x)
401          y = tpu.outside_compilation(host_computation, x)
402          x = y
403          n = n + 1
404        return y + 1.0
405
406      return strategy.run(computation, args=(2.0,))
407
408    summary_writer = summary.create_file_writer(
409        os.path.join(os.getenv("TEST_TMPDIR", "/tmp")), flush_millis=10000)
410    with summary_writer.as_default(), summary.always_record_summaries():
411      self.assertAllEqual(
412          strategy.experimental_local_results(step()),
413          constant_op.constant(31., shape=(strategy.num_replicas_in_sync)))
414
415  def testOutsideCompilationAtHeadAndTail(self):
416    """Tests that outside_compilation at head/tail of TPU computation works."""
417    strategy = get_tpu_strategy()
418
419    def host_computation(x):
420      return x * 2.0
421
422    @def_function.function
423    def train_step():
424
425      def computation(x):
426        w = tpu.outside_compilation(host_computation, x)
427        y = w + 1.0
428        z = tpu.outside_compilation(host_computation, y)
429        return z + 5.0
430
431      return strategy.run(computation, args=(2.0,))
432    self.assertAllEqual(
433        strategy.experimental_local_results(train_step()),
434        constant_op.constant(15., shape=(strategy.num_replicas_in_sync)))
435
436  def testGradientAcrossOutsideCompilation(self):
437    """Tests compiled gradients can contain host computations."""
438    strategy = get_tpu_strategy()
439
440    def host_computation(a):
441      b = a * a
442      c = b * b
443      return c
444
445    @def_function.function
446    def train_step():
447      def computation(x, y):
448        a = x + 7.0
449        b = tpu.outside_compilation(host_computation, a)
450        c = b * y
451        d = gradients_impl.gradients(
452            [c], [x], colocate_gradients_with_ops=True)[0]
453        return d
454
455      return strategy.run(computation, args=(2.0, 3.0))
456    self.assertAllEqual(
457        strategy.experimental_local_results(train_step()),
458        constant_op.constant(8748., shape=(strategy.num_replicas_in_sync)))
459
460  def testGradientOfGradientAcrossOutsideCompilation(self):
461    """Tests compiled gradients of gradients can contain host computations."""
462    strategy = get_tpu_strategy()
463
464    def host_computation(a):
465      b = a * a
466      c = b * b
467      return c
468
469    @def_function.function
470    def train_step():
471      def computation(x, y):
472        a = x + 7.0
473        b = tpu.outside_compilation(host_computation, a)
474        c = b * y
475        d = gradients_impl.gradients(
476            [c], [x], colocate_gradients_with_ops=True)[0]
477        e = gradients_impl.gradients(
478            [d], [x], colocate_gradients_with_ops=True)[0]
479        return e
480
481      return strategy.run(computation, args=(2.0, 3.0))
482    self.assertAllEqual(
483        strategy.experimental_local_results(train_step()),
484        constant_op.constant(2916., shape=(strategy.num_replicas_in_sync)))
485
486  def testColocateGradientWithOutsideCompiledOp(self):
487    strategy = get_tpu_strategy()
488
489    @def_function.function
490    def train_step():
491
492      @def_function.function
493      def tpu_fn(x):
494        x1 = tpu.outside_compilation(math_ops.sqrt, x)
495        grad = gradients_impl.gradients([x1], [x],
496                                        colocate_gradients_with_ops=True)[0]
497        sqrt = [
498            op for op in ops.get_default_graph().get_operations()
499            if op.type == "Sqrt"
500        ][0]
501        sqrt_grad = [
502            op for op in ops.get_default_graph().get_operations()
503            if op.type == "SqrtGrad"
504        ][0]
505        assert sqrt.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) == b"0"
506        assert (sqrt_grad.get_attr(
507            tpu._OUTSIDE_COMPILATION_ATTR) == b"0.gradients/uid")
508        return grad
509
510      return strategy.run(tpu_fn, args=(25.0,))
511
512    self.assertAllEqual(
513        strategy.experimental_local_results(train_step()),
514        constant_op.constant(.1, shape=(strategy.num_replicas_in_sync)))
515
516
517class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
518                                            parameterized.TestCase):
519
520  def setUp(self):
521    super(OutsideCompilationOnUnsupportedOpTest, self).setUp()
522    config.set_soft_device_placement(True)
523
524  def testStringOpWithManualOutsideCompilation(self):
525    strategy = get_tpu_strategy()
526
527    @def_function.function
528    def train_step(x):
529
530      def computation(x):
531        return tpu.outside_compilation(computation_with_string_ops, x)
532
533      return strategy.run(computation, args=(x,))
534
535    self.assertAllEqual(
536        strategy.experimental_local_results(train_step(0)),
537        constant_op.constant(10, shape=(strategy.num_replicas_in_sync)))
538
539  def testStringOpWithAutoOutsideCompilation(self):
540    strategy = get_tpu_strategy()
541
542    @def_function.function
543    def train_step(x):
544
545      def computation(x):
546        return computation_with_string_ops(x)
547
548      return strategy.run(computation, args=(x,))
549
550    self.assertAllEqual(
551        strategy.experimental_local_results(train_step(0)),
552        constant_op.constant(10, shape=(strategy.num_replicas_in_sync)))
553
554  # Regression test for b/180509859.
555  def testImageSummary(self):
556    strategy = get_tpu_strategy()
557
558    def run():
559
560      @def_function.function
561      def sample_sequence():
562        bsz = 3
563        max_length = 32 * 32
564
565        def f():
566
567          def body(step, tokens):
568            next_token = random_ops.random_uniform([bsz])
569            tokens = tokens.write(step, next_token)
570            return (step + 1, tokens)
571
572          def cond(step, tokens):
573            del tokens
574            return math_ops.less(step, max_length)
575
576          tokens_var = tensor_array_ops.TensorArray(
577              dtype=dtypes.float32,
578              size=max_length,
579              dynamic_size=False,
580              clear_after_read=False,
581              element_shape=(bsz,),
582              name="tokens_accumulator",
583          )
584
585          step = constant_op.constant(0)
586          step, tokens_var = control_flow_ops.while_loop(
587              cond, body, [step, tokens_var])
588
589          image_flat = array_ops.transpose(tokens_var.stack(), [1, 0])
590          image = array_ops.tile(
591              array_ops.reshape(image_flat, [bsz, 32, 32, 1]), [1, 1, 1, 3])
592          image_summary_v2.image("image_sample", image,
593                                 constant_op.constant(5, dtype=dtypes.int64))
594
595        return strategy.run(f)
596
597      sample_sequence()
598
599    logdir = tempfile.mkdtemp()
600    summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
601    with summary_writer.as_default(), summary.always_record_summaries():
602      run()
603    events = _events_from_logdir(self, logdir)
604    decoded_image = image_ops.decode_png(
605        events[1].summary.value[0].tensor.string_val[2]).numpy()
606    # Ensure that non-zero values were written to the image summary.
607    self.assertNotAllEqual(
608        array_ops.zeros((3072,), dtype=dtypes.float32),
609        list(decoded_image.flat))
610
611  def testSummaryWithAutoOutsideCompilation(self):
612    strategy = get_tpu_strategy()
613
614    def host_computation(x):
615      scalar_summary_v2.scalar("x", x, step=0)
616      return x * 2.0
617
618    @def_function.function
619    def step():
620
621      def computation(x):
622        x = x + 1.0
623        y = host_computation(x)
624        return y + 1.0
625
626      return strategy.run(computation, args=(2.0,))
627
628    logdir = tempfile.mkdtemp()
629    summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
630    with summary_writer.as_default(), summary.always_record_summaries():
631      self.assertAllEqual(
632          strategy.experimental_local_results(step()),
633          constant_op.constant(7., shape=(strategy.num_replicas_in_sync)))
634    events = _events_from_logdir(self, logdir)
635    # There will be 2 entries: 1 summary file header entry, and 1 entry
636    # written by host.
637    self.assertLen(events, 2)
638    self.assertEqual(events[1].summary.value[0].tag, "x")
639
640  def testNestedFunctionScalarSummary(self):
641    strategy = get_tpu_strategy()
642
643    def host_computation(x):
644      scalar_summary_v2.scalar("x", x, step=0)
645      return x * 2.0
646
647    @def_function.function
648    def step():
649
650      @def_function.function
651      def computation(x):
652        x = x + 1.0
653        y = host_computation(x)
654        return y + 1.0
655
656      return strategy.run(computation, args=(2.0,))
657
658    logdir = tempfile.mkdtemp()
659    summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
660    with summary_writer.as_default(), summary.always_record_summaries():
661      self.assertAllEqual(
662          strategy.experimental_local_results(step()),
663          constant_op.constant(7., shape=(strategy.num_replicas_in_sync)))
664    events = _events_from_logdir(self, logdir)
665    # There will be 2 entries: 1 summary file header entry, and 1 entry
666    # written by host.
667    self.assertLen(events, 2)
668    self.assertEqual(events[1].summary.value[0].tag, "x")
669
670  def testHistogramSummaryWithAutoOutsideCompilation(self):
671    strategy = get_tpu_strategy()
672
673    def host_computation(x):
674      histogram_summary_v2.histogram("x", x, step=0)
675      return x * 2.0
676
677    @def_function.function
678    def step():
679
680      def computation(x):
681        x = x + 1.0
682        y = host_computation(x)
683        return y + 1.0
684
685      return strategy.run(computation, args=(2.0,))
686
687    logdir = tempfile.mkdtemp()
688    summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
689    with summary_writer.as_default(), summary.always_record_summaries():
690      self.assertAllEqual(
691          strategy.experimental_local_results(step()),
692          constant_op.constant(7., shape=(strategy.num_replicas_in_sync)))
693    events = _events_from_logdir(self, logdir)
694    # There will be 2 entries: 1 summary file header entry, and 1 entry
695    # written by host.
696    self.assertLen(events, 2)
697    self.assertEqual(events[1].summary.value[0].tag, "x")
698
699  @parameterized.parameters((True), (False))
700  def testSummaryControlFlowIfWithAutoOutsideCompilation(
701      self, take_true_branch):
702    strategy = get_tpu_strategy()
703
704    @def_function.function
705    def step():
706
707      def computation(x):
708        x = x + 1.0
709        if x < 5:
710          scalar_summary_v2.scalar("x", x, step=0)
711          x = x * 2.0
712        return x + 1.0
713
714      if take_true_branch:
715        return strategy.run(computation, args=(2.0,))
716      else:
717        return strategy.run(computation, args=(10.0,))
718
719    logdir = tempfile.mkdtemp()
720    summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
721    output_value = 12.
722    if take_true_branch:
723      output_value = 7.
724    with summary_writer.as_default(), summary.always_record_summaries():
725      self.assertAllEqual(
726          strategy.experimental_local_results(step()),
727          constant_op.constant(
728              output_value, shape=(strategy.num_replicas_in_sync)))
729    if take_true_branch:
730      events = _events_from_logdir(self, logdir)
731      # There will be 2 entries: 1 summary file header entry, and 1 entry
732      # written by host.
733      #
734      self.assertLen(events, 2)
735      self.assertEqual(events[1].summary.value[0].tag, "cond/x")
736
737  def testAutoOutsideCompilationWithFunctionalNodes(self):
738    strategy = get_tpu_strategy()
739
740    @def_function.function
741    def train_step(a, b):
742
743      def fn(a, b):
744        fn1 = lambda: computation_with_string_ops(a * 100)
745        fn2 = lambda: computation_with_string_ops(a)
746        pred = math_ops.greater_equal(a, b)
747        result = array_ops.identity(
748            control_flow_ops.cond(pred, fn1, fn2),
749            name="uncompilable_control_flow")
750        return result
751
752      return strategy.run(fn, args=(a, b))
753
754    self.assertAllEqual(
755        strategy.experimental_local_results(train_step(0.0, -1.0)),
756        constant_op.constant(10, shape=(strategy.num_replicas_in_sync)))
757
758  def testRandomOpsWithAutoOutsideCompilation(self):
759    strategy = get_tpu_strategy()
760
761    @def_function.function
762    def train_step():
763
764      def computation():
765        return random_ops.random_normal(shape=[1, 2, 3])
766
767      return strategy.run(computation, args=())
768
769    self.assertAllEqual(
770        strategy.experimental_local_results(train_step())[0].shape, [1, 2, 3])
771
772  def testOutsideCompilationWithTPUPartitionedCallOp(self):
773    """Tests that control flow with TPUPartitionedCall including outside_compilation works."""
774    get_tpu_strategy()
775
776    def host_computation(x):
777      return x + 1
778
779    @def_function.function()
780    def train_step(x):
781      x2 = x + 5.0
782      logging_ops.print_v2(x2)
783      x2 = tpu.outside_compilation(host_computation, x2)
784      return x2 + 4.0
785
786    tpu_fn = _rewrite_func_wrapper(train_step)
787    partitioned_tpu_fn = _tpu_partitioned_call_wrapper(tpu_fn)
788
789    concrete = partitioned_tpu_fn.get_concrete_function(
790        x=tensor_spec.TensorSpec(
791            shape=(1), dtype=dtypes.float32, name="input_tensor"))
792
793    self.assertIsInstance(
794        concrete(array_ops.ones((1), dtype=dtypes.float32))[0], ops.Tensor)
795
796
797if __name__ == "__main__":
798  test.main()
799