1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""XLA tests for pfor.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from tensorflow.compiler.tf2xla.python import xla as xla_ops 19from tensorflow.python.compiler.xla import jit 20from tensorflow.python.compiler.xla import xla 21from tensorflow.python.eager import context 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import control_flow_v2_toggles 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import random_ops 30from tensorflow.python.ops import resource_variable_ops 31from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 32from tensorflow.python.ops.parallel_for.test_util import PForTestCase 33from tensorflow.python.platform import test 34 35 36@test_util.run_all_in_graph_and_eager_modes 37class PForTest(PForTestCase): 38 39 def __init__(self, method_name="runTest"): 40 super(PForTest, self).__init__(method_name) 41 context.context().enable_xla_devices() 42 43 def test_xla_einsum(self): 44 num_loop = 10 45 x_series = random_ops.random_uniform([num_loop, 9, 9]) 46 y_series = random_ops.random_uniform([num_loop, 9, 1]) 47 48 def loop_fn(i): 49 x = array_ops.gather(x_series, 0) # invariant. 50 y = array_ops.gather(y_series, 0) # invariant. 51 x_i = array_ops.gather(x_series, i) 52 y_i = array_ops.gather(y_series, i) 53 z1 = xla_ops.einsum(x_i, y, "ab,bc->ac") 54 z2 = xla_ops.einsum(x, y_i, "ab,bc->ac") 55 z3 = xla_ops.einsum(x, y, "ab,bc->ac") 56 z4 = xla_ops.einsum(x_i, y_i, "ab,bc->ac") 57 z5 = xla_ops.einsum(y_i, x_i, "cd,ce->de") # Includes transpose. 58 outputs = [z1, z2, z3, z4, z5] 59 return outputs 60 61 self._test_loop_fn(loop_fn, num_loop) 62 63 def test_xla(self): 64 65 def compute(x): 66 return math_ops.reduce_mean(x, axis=0, keepdims=True) 67 68 def vectorized_compute(x): 69 return pfor_control_flow_ops.vectorized_map(compute, x) 70 71 result = xla.compile( 72 vectorized_compute, inputs=[array_ops.ones((10, 5, 3))]) 73 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 74 75 def test_function_jit_compile(self): 76 77 def compute(x): 78 return math_ops.reduce_mean(x, axis=0, keepdims=True) 79 80 @def_function.function(jit_compile=True) 81 def vectorized_compute(x): 82 return pfor_control_flow_ops.vectorized_map(compute, x) 83 84 result = vectorized_compute(array_ops.ones((10, 5, 3))) 85 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 86 87 def test_xla_while_loop(self): 88 89 def compute(x): 90 return math_ops.reduce_mean(x, axis=0, keepdims=True) 91 92 def vectorized_compute(x, i): 93 inp = array_ops.gather(x, i) 94 output = pfor_control_flow_ops.vectorized_map(compute, inp) 95 output.set_shape([5, 1]) 96 return output 97 98 def while_compute(x): 99 return control_flow_ops.while_loop_v2( 100 lambda i, _: i < 10, 101 lambda i, y: (i + 1, y + vectorized_compute(x, i)), 102 (0, array_ops.zeros([5, 1])))[1] 103 104 result = xla.compile(while_compute, inputs=[array_ops.ones((10, 5, 3))]) 105 expected = array_ops.ones([5, 1]) * 10 106 self.run_and_assert_equal(expected, result) 107 108 def test_reduce_mean(self): 109 x = random_ops.random_uniform([8, 3]) 110 111 @def_function.function(jit_compile=True) 112 def f(): 113 114 def loop_fn(i, pfor_config): 115 x_i = array_ops.gather(x, i) 116 return x_i - pfor_config.reduce_mean(x_i) 117 118 return pfor_control_flow_ops.pfor(loop_fn, 8) 119 120 output = f() 121 ans = x - math_ops.reduce_mean(x, axis=0) 122 output_val, ans_val = self.evaluate([output, ans]) 123 self.assertAllClose(ans_val, output_val) 124 125 126def _make_unstacked(cond, body, pfor_config): 127 128 def _cond(*args): 129 return math_ops.reduce_any(pfor_config.reduce_concat(args[0])) 130 131 def _body(*args): 132 not_done = args[0] 133 args = args[1:] 134 not_done = math_ops.logical_and(not_done, cond(*args)) 135 outputs = body(*args) 136 return (not_done,) + tuple( 137 array_ops.where_v2(not_done, x, y) for x, y in zip(outputs, args)) 138 139 return _cond, _body 140 141 142@test_util.run_all_in_graph_and_eager_modes 143class WhileV2Test(PForTestCase): 144 145 def setUp(self): 146 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 147 control_flow_v2_toggles.enable_control_flow_v2() 148 super(WhileV2Test, self).setUp() 149 150 def tearDown(self): 151 if not self._enabled: 152 control_flow_v2_toggles.disable_control_flow_v2() 153 super(WhileV2Test, self).tearDown() 154 155 def _test_loop_fn(self, loop_fn, iters, force_xla=False): 156 157 def f(): 158 return pfor_control_flow_ops.pfor(loop_fn, iters) 159 160 @def_function.function 161 def jit_f(): 162 with jit.experimental_jit_scope(): 163 return f() 164 165 out = f() 166 jit_out = jit_f() 167 self.run_and_assert_equal(out, jit_out) 168 # TODO(agarwal): The following may complain about uncompilable nodes. Hence 169 # these are currently not enabled for all tests. 170 if force_xla: 171 out_exp_compile_f = def_function.function(jit_compile=True)(f)() 172 self.run_and_assert_equal(out, out_exp_compile_f) 173 out_xla_compile_f = xla.compile(f, inputs=[]) 174 self.run_and_assert_equal(out, out_xla_compile_f) 175 176 def test_stateless_while(self): 177 x = random_ops.random_uniform([3, 5]) 178 lengths = constant_op.constant([4, 0, 2]) 179 180 def loop_fn(i): 181 x_i = array_ops.gather(x, i) 182 lengths_i = array_ops.gather(lengths, i) 183 184 return control_flow_ops.while_loop( 185 lambda j, _: j < lengths_i, 186 lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), 187 [0, 0.]) 188 189 self._test_loop_fn(loop_fn, 3) 190 191 def test_while_with_variable(self): 192 if not context.executing_eagerly(): 193 self.skipTest("Flaky with tf.Session") 194 195 v = resource_variable_ops.ResourceVariable(5.) 196 197 def loop_fn(_): 198 _, output = control_flow_ops.while_loop( 199 lambda j, x: j < 4, 200 lambda j, x: (j + 1, x + v), 201 [0, 0.]) 202 return output 203 204 self._test_loop_fn(loop_fn, 3) 205 206 def test_while_unstacked_condition(self): 207 208 def loop_fn(i): 209 return control_flow_ops.while_loop( 210 lambda j, x: j < 4, 211 lambda j, x: (j + 1, x + i), [0, 0]) 212 213 self._test_loop_fn(loop_fn, 3, force_xla=True) 214 215 def test_while_force_unstacked_condition(self): 216 # The while_loop in this setup is similar to the one in test_stateless_while 217 # whose condition is loop variant. However here we wrap the cond and body of 218 # the loop in a way that makes the while_loop condition pfor loop invariant. 219 # This allows xla compilation to work since the vectorized code no longer 220 # needs to perform dynamic partitioning of the inputs. 221 x = random_ops.random_uniform([3, 5]) 222 lengths = constant_op.constant([4, 0, 2]) 223 224 def loop_fn(i, pfor_config): 225 x_i = array_ops.gather(x, i) 226 lengths_i = array_ops.gather(lengths, i) 227 228 def _cond(j, _): 229 return j < lengths_i 230 231 def _body(j, t): 232 return (j + 1, t + array_ops.gather(x_i, j)) 233 234 cond, body = _make_unstacked(_cond, _body, pfor_config) 235 return control_flow_ops.while_loop( 236 cond, 237 body, 238 [True, 0, 0.]) 239 240 self._test_loop_fn(loop_fn, 3, force_xla=True) 241 242 243if __name__ == "__main__": 244 test.main() 245