1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Backend-dependent tests for the Python XLA client.""" 16 17import collections 18import functools 19import itertools 20import re 21import threading 22import unittest 23 24from absl import flags 25from absl import logging 26from absl.testing import absltest 27from absl.testing import parameterized 28import numpy as np 29 30from tensorflow.compiler.xla.python import xla_client 31 32# pylint: disable=g-import-not-at-top 33try: 34 # This import is only used for GPU; the dependency is incompatible with TPU 35 # so it results in an import error. 36 from tensorflow.python.framework import test_util 37except ImportError: 38 test_util = None 39 40# pylint: disable=g-import-not-at-top 41try: 42 from tensorflow.compiler.xla.python import custom_call_for_test 43except ImportError: 44 custom_call_for_test = None 45 46bfloat16 = xla_client.bfloat16 47ops = xla_client.ops 48 49FLAGS = flags.FLAGS 50 51# We choose to ignore pylint's complaints about complex comprehensions, which we 52# use widely for parameterizing tests. 53# pylint: disable=g-complex-comprehension 54 55 56def TestFactory(xla_backend, 57 cloud_tpu=False, 58 tfrt_tpu=False, 59 external_tpu=False): 60 tests = [] 61 62 if not cloud_tpu: 63 int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] 64 # TODO(phawkins): test np.float16, where supported. 65 float_dtypes = [bfloat16, np.float32, np.float64] 66 complex_dtypes = [np.complex64, np.complex128] 67 standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] 68 else: 69 int_dtypes = [np.int32, np.uint32] 70 float_dtypes = [np.float32] 71 complex_dtypes = [np.complex64] 72 standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] 73 dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes 74 75 class ComputationTest(parameterized.TestCase): 76 """Base class for running an XLA Computation through the local client.""" 77 78 def setUp(self): 79 super(ComputationTest, self).setUp() 80 self.backend = xla_backend() 81 82 def _NewComputation(self, name=None): 83 if name is None: 84 name = self.id() 85 return xla_client.XlaBuilder(name) 86 87 def _Execute(self, c, arguments): 88 compiled_c = self.backend.compile(c.build()) 89 return xla_client.execute_with_python_values( 90 compiled_c, arguments, backend=self.backend) 91 92 def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): 93 assert expected is not None 94 results = self._Execute(c, arguments) 95 self.assertLen(results, len(expected)) 96 for result, e in zip(results, expected): 97 # Numpy's comparison methods are a bit too lenient by treating inputs as 98 # "array-like", meaning that scalar 4 will be happily compared equal to 99 # [[4]]. We'd like to be more strict so assert shapes as well. 100 self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) 101 assert_func(result, e) 102 103 def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): 104 self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, 105 expected) 106 107 def _ExecuteAndCompareClose(self, 108 c, 109 arguments=(), 110 expected=None, 111 rtol=1e-4, 112 atol=0): 113 self._ExecuteAndAssertWith( 114 functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), 115 c, arguments, expected) 116 117 def NumpyArrayF32(*args, **kwargs): 118 """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" 119 return np.array(*args, dtype=np.float32, **kwargs) 120 121 def NumpyArrayF64(*args, **kwargs): 122 """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" 123 return np.array(*args, dtype=np.float64, **kwargs) 124 125 def NumpyArrayS32(*args, **kwargs): 126 """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" 127 return np.array(*args, dtype=np.int32, **kwargs) 128 129 def NumpyArrayBool(*args, **kwargs): 130 """Convenience wrapper to create Numpy arrays with a np.bool_ dtype.""" 131 return np.array(*args, dtype=np.bool_, **kwargs) 132 133 class ComputationPrinting(absltest.TestCase): 134 135 def setUp(self): 136 super(ComputationPrinting, self).setUp() 137 self.backend = xla_backend() 138 139 def ExampleComputation(self): 140 builder = xla_client.XlaBuilder("acomputation") 141 p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) 142 p1 = ops.Parameter( 143 builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) 144 x = ops.Mul(p0, p1) 145 ops.Add(x, x) 146 return builder.build() 147 148 @unittest.skipIf(cloud_tpu, "not implemented") 149 def testCompiledHloModuleToHloText(self): 150 computation = self.ExampleComputation() 151 executable = self.backend.compile(computation) 152 hlo_modules = executable.hlo_modules() 153 self.assertLen(hlo_modules, 1) 154 hlo_text = hlo_modules[0].to_string() 155 self.assertTrue(hlo_text.startswith("HloModule acomputation")) 156 self.assertIn("fusion", hlo_text) 157 158 @unittest.skipIf(cloud_tpu, "not implemented") 159 def testCompiledHloModuleAsSerializedProto(self): 160 computation = self.ExampleComputation() 161 executable = self.backend.compile(computation) 162 hlo_modules = executable.hlo_modules() 163 self.assertLen(hlo_modules, 1) 164 hlo_text = hlo_modules[0].to_string() 165 proto = hlo_modules[0].as_serialized_hlo_module_proto() 166 hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module() 167 hlo_text_roundtrip = hlo_module_roundtrip.to_string() 168 self.assertEqual(hlo_text, hlo_text_roundtrip) 169 170 @unittest.skipIf(cloud_tpu, "not implemented") 171 def testStableComputationSerialization(self): 172 # Ideally we would test identical computations produced in different 173 # processes. For now we have this limited smoke test. 174 computation = self.ExampleComputation() 175 ref = computation.as_serialized_hlo_module_proto() 176 for _ in range(10): 177 self.assertEqual(computation.as_serialized_hlo_module_proto(), ref) 178 179 @unittest.skipIf(cloud_tpu, "not implemented") 180 def testFlopEstimate(self): 181 computation = self.ExampleComputation() 182 properties = xla_client._xla.hlo_module_cost_analysis( 183 self.backend, computation.as_hlo_module()) 184 self.assertEqual(properties["flops"], 8.0) 185 186 def testFingerprint(self): 187 computation = self.ExampleComputation() 188 executable = self.backend.compile(computation) 189 fingerprint = executable.fingerprint 190 if self.backend.platform == "tpu" and not cloud_tpu: 191 logging.info("fingerprint: %s", fingerprint) 192 self.assertNotEmpty(fingerprint) 193 else: 194 self.assertIsNone(fingerprint) 195 196 tests.append(ComputationPrinting) 197 198 class ComputationsWithConstantsTest(ComputationTest): 199 """Tests focusing on Constant ops.""" 200 201 @parameterized.named_parameters({ 202 "testcase_name": "_{}".format(dtype.__name__), 203 "dtype": dtype, 204 } for dtype in int_dtypes + float_dtypes) 205 def testConstantScalarSum(self, dtype): 206 if dtype == np.int8 and self.backend.platform == "tpu": 207 self.skipTest("TPU doesn't support int8") 208 c = self._NewComputation() 209 ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) 210 self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) 211 212 @parameterized.named_parameters({ 213 "testcase_name": "_{}".format(dtype.__name__), 214 "dtype": dtype, 215 } for dtype in float_dtypes) 216 def testConstantVectorMul(self, dtype): 217 c = self._NewComputation() 218 ops.Mul( 219 ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), 220 ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) 221 self._ExecuteAndCompareClose( 222 c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) 223 224 @parameterized.named_parameters({ 225 "testcase_name": "_{}".format(dtype.__name__), 226 "dtype": dtype, 227 } for dtype in float_dtypes) 228 def testConstantVectorScalarDiv(self, dtype): 229 c = self._NewComputation() 230 ops.Div( 231 ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), 232 ops.Constant(c, dtype(2.0))) 233 self._ExecuteAndCompareClose( 234 c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) 235 236 @parameterized.named_parameters({ 237 "testcase_name": "_{}".format(dtype.__name__), 238 "dtype": dtype, 239 } for dtype in float_dtypes) 240 def testConstantVectorScalarPow(self, dtype): 241 c = self._NewComputation() 242 ops.Pow( 243 ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), 244 ops.Constant(c, dtype(2.))) 245 self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) 246 247 def testIota(self): 248 c = self._NewComputation() 249 ops.Iota(c, xla_client.PrimitiveType.F32, 10) 250 self._ExecuteAndCompareExact( 251 c, expected=[np.arange(10, dtype=np.float32)]) 252 253 @parameterized.named_parameters({ 254 "testcase_name": "_{}".format(dtype.__name__), 255 "dtype": dtype, 256 } for dtype in int_dtypes) 257 def testBroadcastedIota(self, dtype): 258 c = self._NewComputation() 259 shape = xla_client.Shape.array_shape( 260 xla_client.dtype_to_etype(dtype), (2, 3)) 261 ops.Iota(c, shape, 1) 262 expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) 263 self._ExecuteAndCompareExact(c, expected=[expected]) 264 265 def testBooleanAnd(self): 266 c = self._NewComputation() 267 ops.And( 268 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 269 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 270 self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) 271 272 def testBooleanOr(self): 273 c = self._NewComputation() 274 ops.Or( 275 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 276 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 277 self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) 278 279 def testBooleanXor(self): 280 c = self._NewComputation() 281 ops.Xor( 282 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 283 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 284 self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) 285 286 @parameterized.named_parameters({ 287 "testcase_name": "_{}".format(dtype.__name__), 288 "dtype": dtype, 289 } for dtype in float_dtypes) 290 def testSum2D(self, dtype): 291 c = self._NewComputation() 292 ops.Add( 293 ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), 294 ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) 295 self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) 296 297 def testShiftLeft(self): 298 c = self._NewComputation() 299 ops.ShiftLeft( 300 ops.Constant(c, NumpyArrayS32([3])), 301 ops.Constant(c, NumpyArrayS32([2]))) 302 self._ExecuteAndCompareClose(c, expected=[[12]]) 303 304 def testShiftRightArithmetic(self): 305 c = self._NewComputation() 306 ops.ShiftRightArithmetic( 307 ops.Constant(c, NumpyArrayS32([-2])), 308 ops.Constant(c, NumpyArrayS32([1]))) 309 self._ExecuteAndCompareClose(c, expected=[[-1]]) 310 311 def testShiftRightLogical(self): 312 c = self._NewComputation() 313 ops.ShiftRightLogical( 314 ops.Constant(c, NumpyArrayS32([-1])), 315 ops.Constant(c, NumpyArrayS32([1]))) 316 self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) 317 318 @parameterized.named_parameters({ 319 "testcase_name": "_{}".format(dtype.__name__), 320 "dtype": dtype, 321 } for dtype in float_dtypes) 322 def testSum2DWith1DBroadcastDim0(self, dtype): 323 # sum of a 2D array with a 1D array where the latter is replicated across 324 # dimension 0 to match the former's shape. 325 c = self._NewComputation() 326 ops.Add( 327 ops.Constant(c, 328 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 329 dtype=dtype)), 330 ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), 331 broadcast_dimensions=(0,)) 332 self._ExecuteAndCompareClose( 333 c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) 334 335 @parameterized.named_parameters({ 336 "testcase_name": "_{}".format(dtype.__name__), 337 "dtype": dtype, 338 } for dtype in float_dtypes) 339 def testSum2DWith1DBroadcastDim1(self, dtype): 340 # sum of a 2D array with a 1D array where the latter is replicated across 341 # dimension 1 to match the former's shape. 342 c = self._NewComputation() 343 ops.Add( 344 ops.Constant(c, 345 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 346 dtype=dtype)), 347 ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), 348 broadcast_dimensions=(1,)) 349 self._ExecuteAndCompareClose( 350 c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) 351 352 @parameterized.named_parameters({ 353 "testcase_name": "_{}".format(dtype.__name__), 354 "dtype": dtype, 355 } for dtype in float_dtypes) 356 def testConstantAxpy(self, dtype): 357 c = self._NewComputation() 358 ops.Add( 359 ops.Mul( 360 ops.Constant(c, dtype(2)), 361 ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), 362 ops.Constant(c, np.array([100, -100, 200, -200], dtype))) 363 self._ExecuteAndCompareClose( 364 c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) 365 366 def testCustomCall(self): 367 if self.backend.platform != "cpu": 368 self.skipTest("Test requires cpu platform") 369 c = self._NewComputation() 370 for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 371 xla_client.register_custom_call_target(name, fn, platform="cpu") 372 ops.CustomCallWithLayout( 373 c, 374 b"test_subtract_f32", 375 operands=[ 376 ops.Constant(c, np.float32(1.25)), 377 ops.Constant(c, np.float32(0.5)) 378 ], 379 shape_with_layout=xla_client.Shape.array_shape( 380 np.dtype(np.float32), (), ()), 381 operand_shapes_with_layout=[ 382 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 383 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 384 ], 385 api_version=xla_client.ops.CustomCallApiVersion 386 .API_VERSION_STATUS_RETURNING) 387 self._ExecuteAndCompareClose(c, expected=[0.75]) 388 389 def testCustomCallWithUnifiedApi(self): 390 if self.backend.platform != "cpu": 391 self.skipTest("Test requires cpu platform") 392 c = self._NewComputation() 393 for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 394 xla_client.register_custom_call_target(name, fn, platform="cpu") 395 396 opaque_str = b"foo" 397 ops.CustomCallWithLayout( 398 c, 399 b"test_add_input_and_opaque_len", 400 operands=[ 401 ops.Constant(c, np.float32(1.25)), 402 ops.Constant(c, np.float32(0.5)) 403 ], 404 shape_with_layout=xla_client.Shape.array_shape( 405 np.dtype(np.float32), (), ()), 406 operand_shapes_with_layout=[ 407 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 408 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 409 ], 410 # With opaque length = 3.0 411 opaque=opaque_str, 412 api_version=xla_client.ops.CustomCallApiVersion 413 .API_VERSION_STATUS_RETURNING_UNIFIED) 414 self._ExecuteAndCompareClose(c, expected=[1.25 + len(opaque_str)]) 415 416 tests.append(ComputationsWithConstantsTest) 417 418 class PythonCallbackTest(ComputationTest): 419 420 def testPythonCallback(self): 421 if self.backend.platform not in {"cpu", "gpu"}: 422 self.skipTest("Test requires cpu or gpu platform") 423 c = self._NewComputation() 424 425 f = lambda x, y: (x + y, x - y) 426 427 arg0 = np.array([9, 43, -101, 22], dtype=np.int32) 428 arg1 = np.array([10, 15, -2, 7], dtype=np.int32) 429 shape = xla_client.shape_from_pyval(arg0) 430 shape = shape.with_major_to_minor_layout_if_absent() 431 p0 = ops.Parameter(c, 0, shape) 432 p1 = ops.Parameter(c, 1, shape) 433 out, keepalive = self.backend.emit_python_callback( 434 f, c, [p0, p1], [shape, shape]) 435 self._ExecuteAndCompareExact( 436 c, arguments=[arg0, arg1], expected=[arg0 + arg1, arg0 - arg1]) 437 del out, keepalive 438 439 def testPythonCallbackCanHandleExceptions(self): 440 if self.backend.platform not in {"cpu", "gpu"}: 441 self.skipTest("Test requires cpu or gpu platform") 442 c = self._NewComputation() 443 444 def _Callback(x): 445 raise ValueError("Value error raised!") 446 447 arg0 = np.array([9, 43, -101, 22], dtype=np.int32) 448 shape = xla_client.shape_from_pyval(arg0) 449 shape = shape.with_major_to_minor_layout_if_absent() 450 p0 = ops.Parameter(c, 0, shape) 451 out, keepalive = self.backend.emit_python_callback( 452 _Callback, c, [p0], [shape], has_side_effects=True) 453 with self.assertRaisesRegex(xla_client.XlaRuntimeError, 454 "Value error raised!"): 455 self._Execute(c, [arg0]) 456 del out, keepalive 457 458 def testTokens(self): 459 if self.backend.platform not in {"cpu", "gpu"}: 460 self.skipTest("Test requires cpu or gpu platform") 461 c = self._NewComputation() 462 463 def _Callback(x, y): 464 assert y is None, y 465 return None, x + 1 466 467 arg0 = np.array([9, 43, -101, 22], dtype=np.int32) 468 shape = xla_client.shape_from_pyval(arg0) 469 token_shape = xla_client.Shape.token_shape() 470 p0 = ops.Parameter(c, 0, shape) 471 token = ops.CreateToken(c) 472 out, keepalive = self.backend.emit_python_callback( 473 _Callback, c, [p0, token], [token_shape, shape]) 474 out = ops.GetTupleElement(out, 1) 475 self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 + 1]) 476 del out, keepalive 477 478 def testStriding(self): 479 if self.backend.platform not in {"cpu", "gpu"}: 480 self.skipTest("Test requires cpu or gpu platform") 481 c = self._NewComputation() 482 483 def _Callback(x): 484 assert x.flags.f_contiguous, x.strides 485 # Force the output array to have C layout, which will require a 486 # transpose back to the expected Fortran layout. 487 return np.ascontiguousarray(x * 2), 488 489 arg0 = np.arange(12, dtype=np.int16).reshape(3, 4) 490 shape_f_layout = xla_client.Shape.array_shape( 491 arg0.dtype, arg0.shape, layout=(0, 1)) 492 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 493 out, keepalive = self.backend.emit_python_callback( 494 _Callback, c, [p0], [shape_f_layout], [shape_f_layout]) 495 self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 * 2]) 496 del out, keepalive 497 498 tests.append(PythonCallbackTest) 499 500 class ComputationFromProtoTest(absltest.TestCase): 501 """Test computation execution from HLO proto.""" 502 503 def setUp(self): 504 super(ComputationFromProtoTest, self).setUp() 505 self.backend = xla_backend() 506 507 def testExecuteFromProto(self): 508 # Build the HLO proto 509 b = xla_client.XlaBuilder("computation") 510 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 511 serialized_proto = b.build().as_serialized_hlo_module_proto() 512 513 # Load and execute the proto 514 c = xla_client.XlaComputation(serialized_proto) 515 ans, = xla_client.execute_with_python_values( 516 self.backend.compile(c), (), backend=self.backend) 517 np.testing.assert_equal(ans, np.int32(3)) 518 519 tests.append(ComputationFromProtoTest) 520 521 class ParametersTest(ComputationTest): 522 """Tests focusing on Parameter ops and argument-passing.""" 523 524 @parameterized.named_parameters({ 525 "testcase_name": "_{}".format(dtype.__name__), 526 "dtype": dtype, 527 } for dtype in int_dtypes) 528 def testScalarTimesVector(self, dtype): 529 c = self._NewComputation() 530 arg0 = np.array(3, dtype=dtype) 531 arg1 = np.array([10, 15, -2, 7], dtype=dtype) 532 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 533 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 534 ops.Mul(p0, p1) 535 self._ExecuteAndCompareExact( 536 c, arguments=[arg0, arg1], expected=[arg0 * arg1]) 537 538 # TODO(phawkins): test comparison harness doesn't support bfloat16 539 @parameterized.named_parameters({ 540 "testcase_name": "_{}".format(dtype.__name__), 541 "dtype": dtype, 542 } for dtype in float_dtypes if dtype != bfloat16) 543 def testScalarMinusVectorExplicitNumbering(self, dtype): 544 # Use explicit numbering and pass parameter_num first. Sub is used since 545 # it's not commutative and can help catch parameter reversal within the 546 # computation. 547 c = self._NewComputation() 548 arg0 = np.array(2.0, dtype=dtype) 549 arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) 550 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 551 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 552 ops.Sub(p1, p0) 553 self._ExecuteAndCompareClose( 554 c, arguments=[arg0, arg1], expected=[arg1 - arg0]) 555 556 tests.append(ParametersTest) 557 558 class BufferTest(ComputationTest): 559 """Tests focusing on execution with Buffers.""" 560 561 def testConstantSum(self): 562 c = self._NewComputation() 563 ops.Add( 564 ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) 565 self._ExecuteAndCompareClose(c, expected=[4.25]) 566 567 def testOneParameterSum(self): 568 c = self._NewComputation() 569 ops.Add( 570 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 571 ops.Constant(c, np.float32(3.14))) 572 self._ExecuteAndCompareClose( 573 c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) 574 575 def testTwoParameterSum(self): 576 c = self._NewComputation() 577 ops.Add( 578 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 579 ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) 580 self._ExecuteAndCompareClose( 581 c, 582 arguments=[NumpyArrayF32(1.11), 583 NumpyArrayF32(3.14)], 584 expected=[4.25]) 585 586 @unittest.skipIf(cloud_tpu, "not implemented") 587 def testCannotCallWithDeletedBuffers(self): 588 c = self._NewComputation() 589 ops.Add( 590 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 591 ops.Constant(c, np.float32(3.14))) 592 arg = NumpyArrayF32(1.11) 593 compiled_c = self.backend.compile(c.build()) 594 arg_buffer = self.backend.buffer_from_pyval(arg) 595 arg_buffer.delete() 596 with self.assertRaises(xla_client.XlaRuntimeError): 597 compiled_c.execute([arg_buffer]) 598 599 def testXlaShape(self): 600 pyval = np.array([[1., 2.]], np.float32) 601 local_buffer = self.backend.buffer_from_pyval(pyval) 602 xla_shape = local_buffer.xla_shape() 603 self.assertEqual(xla_shape.dimensions(), (1, 2)) 604 self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) 605 606 def testXlaShapeIndex(self): 607 a = xla_client.ShapeIndex((1, 2)) 608 b = xla_client.ShapeIndex((1, 2)) 609 c = xla_client.ShapeIndex((2, 3)) 610 self.assertEqual(a, b) 611 self.assertNotEqual(b, c) 612 613 def testLayout(self): 614 f32 = xla_client.PrimitiveType.F32 615 a = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() 616 b = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() 617 c = xla_client.Shape.array_shape(f32, (2, 3), (1, 0)).layout() 618 self.assertEqual(a.minor_to_major(), (0, 1)) 619 self.assertEqual(b.minor_to_major(), (0, 1)) 620 self.assertEqual(c.minor_to_major(), (1, 0)) 621 self.assertEqual(a, b) 622 self.assertNotEqual(a, c) 623 self.assertNotEqual(b, c) 624 self.assertEqual(hash(a), hash(b)) 625 self.assertNotEqual(hash(a), hash(c)) 626 self.assertNotEqual(hash(b), hash(c)) 627 628 def testBlockUntilReadyWorks(self): 629 arg = np.array([[1., 2.]], np.float32) 630 arg_buffer = self.backend.buffer_from_pyval(arg) 631 arg_buffer.block_until_ready() 632 # This test merely checks that nothing goes awry when we call 633 # block_until_ready(); it's difficult to test anything else. 634 635 def testBlockUntilReadyRaisesOnDeletedBuffer(self): 636 arg = np.array([[1., 2.]], np.float32) 637 buffer = self.backend.buffer_from_pyval(arg) 638 buffer.delete() 639 with self.assertRaisesRegex( 640 RuntimeError, 641 re.escape( 642 "BlockHostUntilReady() called on deleted or donated buffer")): 643 buffer.block_until_ready() 644 645 def testDeviceArrayBaseSignatures(self): 646 # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray` 647 # and thus needs to correctly implement the following methods. 648 arg = np.array([[1., 2., 3.]], np.float32) 649 buffer = self.backend.buffer_from_pyval(arg) 650 if not isinstance(buffer, xla_client.DeviceArrayBase): 651 raise unittest.SkipTest( 652 "The objectof type {} do not extend DeviceArrayBase".format( 653 type(buffer))) 654 655 self.assertEqual(buffer.__array_priority__, 100) 656 self.assertEqual(buffer.shape, (1, 3)) 657 self.assertEqual(buffer.dtype, np.float32) 658 self.assertEqual(buffer.size, 3) 659 self.assertEqual(buffer.ndim, 2) 660 661 self.assertIs(buffer, buffer.block_until_ready()) 662 self.assertTrue(buffer.is_ready()) 663 buffer.delete() 664 with self.assertRaises(xla_client.XlaRuntimeError): 665 buffer.block_until_ready() 666 with self.assertRaises(xla_client.XlaRuntimeError): 667 buffer.is_ready() 668 669 def testOnDeviceSizeInBytes(self): 670 if not isinstance(self.backend, xla_client.Client): 671 self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") 672 arg0 = np.array([]) 673 arg1 = np.array([[0., 1., 2.]], np.float32) 674 arg2 = np.array([[3., 4., 5.]], bfloat16) 675 arg0_buffer = self.backend.buffer_from_pyval(arg0) 676 arg1_buffer = self.backend.buffer_from_pyval(arg1) 677 arg2_buffer = self.backend.buffer_from_pyval(arg2) 678 self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) 679 # OnDeviceSizeInBytes varies depending on the platform. Confirm there's 680 # a reasonable value. 681 self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) 682 self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) 683 684 def testLiveBuffers(self): 685 if not isinstance(self.backend, xla_client.Client): 686 self.skipTest("TPU Driver doesn't support LiveBuffers().") 687 self.assertEmpty(self.backend.live_buffers()) 688 arg0 = np.array([]) 689 arg1 = np.array([[0., 1., 2.]], np.float32) 690 arg2 = np.array([[3., 4., 5.]], bfloat16) 691 arg0_buffer = self.backend.buffer_from_pyval(arg0) 692 arg1_buffer = self.backend.buffer_from_pyval(arg1) 693 arg2_buffer = self.backend.buffer_from_pyval(arg2) 694 self.assertLen(self.backend.live_buffers(), 3) 695 self.assertIs(self.backend.live_buffers()[0], arg2_buffer) 696 self.assertIs(self.backend.live_buffers()[1], arg1_buffer) 697 self.assertIs(self.backend.live_buffers()[2], arg0_buffer) 698 self.assertEqual(self.backend.devices()[0].live_buffers(), 699 self.backend.live_buffers()) 700 701 arg1_buffer.delete() 702 self.assertLen(self.backend.live_buffers(), 2) 703 self.assertIs(self.backend.live_buffers()[0], arg2_buffer) 704 self.assertIs(self.backend.live_buffers()[1], arg0_buffer) 705 706 arg0_buffer.delete() 707 arg2_buffer.delete() 708 self.assertEmpty(self.backend.live_buffers()) 709 710 def testCopyToHost(self): 711 arg0 = np.array([[1., 2.]], np.float32) 712 arg1 = np.array([[3., 4.]], np.float32) 713 arg0_buffer = self.backend.buffer_from_pyval(arg0) 714 arg1_buffer = self.backend.buffer_from_pyval(arg1) 715 # Prefetch two buffers using copy_to_host_async, and then retrieve their 716 # values using to_py. 717 arg0_buffer.copy_to_host_async() 718 arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. 719 arg1_buffer.copy_to_host_async() 720 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 721 np.testing.assert_equal(arg1, arg1_buffer.to_py()) 722 # copy_to_host_async does nothing after to_py is called. 723 arg0_buffer.copy_to_host_async() 724 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 725 726 def testDevice(self): 727 x = np.arange(8, dtype=np.int32) 728 for device in self.backend.local_devices(): 729 buf = self.backend.buffer_from_pyval(x, device=device) 730 self.assertEqual(buf.device(), device) 731 np.testing.assert_equal(x, buf.to_py()) 732 733 def testStandardTypes(self): 734 for dtype in standard_dtypes: 735 if dtype == bfloat16 or dtype == np.complex128: 736 continue 737 arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype)) 738 arr = arr.to_py() 739 self.assertEqual(dtype, type(arr[0])) 740 741 def testUnsafeBufferPointer(self): 742 if not isinstance(self.backend, xla_client.Client): 743 self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().") 744 arg0 = np.array([]) 745 arg1 = np.array([[0., 1., 2.]], np.float32) 746 arg2 = np.array([[3., 4., 5.]], bfloat16) 747 arg0_buffer = self.backend.buffer_from_pyval(arg0) 748 arg1_buffer = self.backend.buffer_from_pyval(arg1) 749 arg2_buffer = self.backend.buffer_from_pyval(arg2) 750 self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0) 751 self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0) 752 self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0) 753 754 @unittest.skipIf(cloud_tpu, "not implemented") 755 def testClone(self): 756 x = np.array([[3., 4., 5.]], np.float32) 757 y = self.backend.buffer_from_pyval(x) 758 z = y.clone() 759 self.assertNotEqual(id(x), id(y)) 760 np.testing.assert_array_equal(y.to_py(), z.to_py()) 761 self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer()) 762 763 @unittest.skipIf(cloud_tpu, "not implemented") 764 def testJaxAttributesHaveCorrectDefaults(self): 765 x = np.array([[3., 4., 5.]], np.float32) 766 y = self.backend.buffer_from_pyval(x) 767 self.assertIsNone(y.aval) 768 self.assertIsNone(y._device) 769 770 tests.append(BufferTest) 771 772 class SingleOpTest(ComputationTest): 773 """Tests for single ops. 774 775 The goal here is smoke testing - to exercise the most basic functionality of 776 single XLA ops. As minimal as possible number of additional ops are added 777 around the op being tested. 778 """ 779 780 @parameterized.named_parameters({ 781 "testcase_name": "_{}".format(dtype.__name__), 782 "dtype": dtype, 783 } for dtype in float_dtypes) 784 def testConcatenate(self, dtype): 785 c = self._NewComputation() 786 args = ( 787 ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), 788 ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), 789 ) 790 ops.ConcatInDim(c, args, dimension=0) 791 self._ExecuteAndCompareExact( 792 c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) 793 794 # pyformat: disable 795 @parameterized.named_parameters({ 796 "testcase_name": "_{}_{}".format(src_dtype.__name__, 797 dst_dtype.__name__), 798 "src_dtype": src_dtype, 799 "dst_dtype": dst_dtype, 800 } for src_dtype, dst_dtype in itertools.permutations( 801 [np.bool_, np.int32, np.int64, np.float32, np.float64], 2)) 802 # pyformat: enable 803 def testConvertElementType(self, src_dtype, dst_dtype): 804 if ((src_dtype in [np.int64, np.float64] or 805 dst_dtype in [np.int64, np.float64]) and 806 self.backend.platform == "tpu"): 807 self.skipTest("TPU doesn't support float64") 808 c = self._NewComputation() 809 x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) 810 ops.ConvertElementType( 811 ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) 812 813 result = xla_client.execute_with_python_values( 814 self.backend.compile(c.build()), (), backend=self.backend) 815 self.assertLen(result, 1) 816 expected = np.array(x, dtype=dst_dtype) 817 818 self.assertEqual(result[0].shape, expected.shape) 819 self.assertEqual(result[0].dtype, expected.dtype) 820 np.testing.assert_equal(result[0], expected) 821 822 # pyformat: disable 823 @parameterized.named_parameters( 824 { 825 "testcase_name": "_{}_{}".format(src_dtype.__name__, 826 dst_dtype.__name__), 827 "src_dtype": src_dtype, 828 "dst_dtype": dst_dtype, 829 } 830 for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] 831 for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) 832 # pyformat: enable 833 def testBitcastConvertType(self, src_dtype, dst_dtype): 834 if (np.float64 in (src_dtype, dst_dtype) and 835 self.backend.platform == "tpu"): 836 self.skipTest("TPU doesn't support float64") 837 c = self._NewComputation() 838 x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) 839 ops.BitcastConvertType( 840 ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) 841 842 result = xla_client.execute_with_python_values( 843 self.backend.compile(c.build()), (), backend=self.backend) 844 self.assertLen(result, 1) 845 expected = x.view(dst_dtype) 846 847 self.assertEqual(result[0].shape, expected.shape) 848 self.assertEqual(result[0].dtype, expected.dtype) 849 np.testing.assert_equal(result[0], expected) 850 851 # TODO(b/123523486) implement AllToAll on CPU 852 def DISABLED_testAllToAllOneReplica(self): 853 samples = [ 854 NumpyArrayF32([97.0]), 855 NumpyArrayF32([64.0, 117.0]), 856 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 857 ] 858 for lhs in samples[:1]: 859 c = self._NewComputation() 860 ops.AllToAll(ops.Constant(c, lhs), 0, 0) 861 self._ExecuteAndCompareExact(c, expected=[lhs]) 862 863 def testCrossReplicaSumOneReplica(self): 864 samples = [ 865 NumpyArrayF32(42.0), 866 NumpyArrayF32([97.0]), 867 NumpyArrayF32([64.0, 117.0]), 868 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 869 ] 870 for lhs in samples: 871 c = self._NewComputation() 872 ops.CrossReplicaSum(ops.Constant(c, lhs)) 873 self._ExecuteAndCompareExact(c, expected=[lhs]) 874 875 def testReplicaId(self): 876 c = self._NewComputation() 877 _ = ops.ReplicaId(c) 878 self._ExecuteAndCompareExact(c, expected=[0]) 879 880 def testCrossReplicaSumOneReplicaWithSingletonGroup(self): 881 samples = [ 882 NumpyArrayF32(42.0), 883 NumpyArrayF32([97.0]), 884 NumpyArrayF32([64.0, 117.0]), 885 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 886 ] 887 for lhs in samples: 888 c = self._NewComputation() 889 ops.CrossReplicaSum( 890 ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) 891 self._ExecuteAndCompareExact(c, expected=[lhs]) 892 893 # TODO(phawkins): np.dot implementation doesn't support bfloat16 894 @parameterized.named_parameters({ 895 "testcase_name": "_{}".format(dtype.__name__), 896 "dtype": dtype, 897 } for dtype in float_dtypes if dtype != bfloat16) 898 def testDotMatrixVector(self, dtype): 899 c = self._NewComputation() 900 lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) 901 rhs = np.array([[10.0], [20.0]], dtype=dtype) 902 ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) 903 self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) 904 905 # TODO(phawkins): np.dot implementation doesn't support bfloat16 906 @parameterized.named_parameters({ 907 "testcase_name": "_{}".format(dtype.__name__), 908 "dtype": dtype, 909 } for dtype in float_dtypes if dtype != bfloat16) 910 def testDotMatrixMatrix(self, dtype): 911 c = self._NewComputation() 912 lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) 913 rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) 914 ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) 915 self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) 916 917 def testDotGeneral(self): 918 c = self._NewComputation() 919 rng = np.random.RandomState(0) 920 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 921 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 922 dimension_numbers = xla_client.make_dot_dimension_numbers( 923 (([2], [1]), ([0], [0]))) 924 ops.DotGeneral( 925 ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) 926 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 927 928 def testDotGeneralWithDotDimensionNumbersProto(self): 929 c = self._NewComputation() 930 rng = np.random.RandomState(0) 931 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 932 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 933 934 dimension_numbers = xla_client.DotDimensionNumbers() 935 dimension_numbers.lhs_contracting_dimensions.append(2) 936 dimension_numbers.rhs_contracting_dimensions.append(1) 937 dimension_numbers.lhs_batch_dimensions.append(0) 938 dimension_numbers.rhs_batch_dimensions.append(0) 939 940 ops.DotGeneral( 941 ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) 942 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 943 944 def testDotGeneralWithPrecisionConfig(self): 945 c = self._NewComputation() 946 rng = np.random.RandomState(0) 947 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 948 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 949 dimension_numbers = xla_client.make_dot_dimension_numbers( 950 (([2], [1]), ([0], [0]))) 951 config = xla_client.PrecisionConfig() 952 config.operand_precision.append(config.Precision.HIGH) 953 config.operand_precision.append(config.Precision.HIGHEST) 954 ops.DotGeneral( 955 ops.Constant(c, lhs), 956 ops.Constant(c, rhs), 957 dimension_numbers, 958 precision_config=config) 959 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 960 961 def testConvGeneralDilatedF32(self): 962 c = self._NewComputation() 963 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 964 lhs = a(1, 1, 2, 3) 965 rhs = a(1, 1, 1, 2) * 10 966 strides = [1, 1] 967 pads = [(1, 0), (0, 1)] 968 lhs_dilation = (2, 1) 969 rhs_dilation = (1, 1) 970 dimension_numbers = xla_client.make_convolution_dimension_numbers( 971 ("NCHW", "OIHW", "NCHW"), 2) 972 ops.ConvGeneralDilated( 973 ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, 974 lhs_dilation, rhs_dilation, dimension_numbers) 975 result = np.array([[[ 976 [0., 0., 0.], 977 [10., 20., 0.], 978 [0., 0., 0.], 979 [40., 50., 0.], 980 ]]]) 981 self._ExecuteAndCompareClose(c, expected=[result]) 982 983 def testConvGeneralDilatedF32WithPrecisionConfig(self): 984 c = self._NewComputation() 985 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 986 lhs = a(1, 1, 2, 3) 987 rhs = a(1, 1, 1, 2) * 10 988 strides = [1, 1] 989 pads = [(1, 0), (0, 1)] 990 lhs_dilation = (2, 1) 991 rhs_dilation = (1, 1) 992 dimension_numbers = xla_client.make_convolution_dimension_numbers( 993 ("NCHW", "OIHW", "NCHW"), 2) 994 config = xla_client.PrecisionConfig() 995 config.operand_precision.append(config.Precision.HIGHEST) 996 config.operand_precision.append(config.Precision.DEFAULT) 997 ops.ConvGeneralDilated( 998 ops.Constant(c, lhs), 999 ops.Constant(c, rhs), 1000 strides, 1001 pads, 1002 lhs_dilation, 1003 rhs_dilation, 1004 dimension_numbers, 1005 precision_config=config) 1006 result = np.array([[[ 1007 [0., 0., 0.], 1008 [10., 20., 0.], 1009 [0., 0., 0.], 1010 [40., 50., 0.], 1011 ]]]) 1012 self._ExecuteAndCompareClose(c, expected=[result]) 1013 1014 def testConvGeneralDilatedPermutedF32(self): 1015 c = self._NewComputation() 1016 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 1017 lhs = a(1, 1, 2, 3) 1018 rhs = a(1, 1, 1, 2) * 10 1019 strides = [1, 1] 1020 pads = [(1, 0), (0, 1)] 1021 lhs_dilation = (2, 1) 1022 rhs_dilation = (1, 1) 1023 1024 dimension_numbers = xla_client.make_convolution_dimension_numbers( 1025 ("NHWC", "OIHW", "CWNH"), 2) 1026 ops.ConvGeneralDilated( 1027 ops.Constant(c, np.transpose(lhs, 1028 (0, 2, 3, 1))), ops.Constant(c, rhs), 1029 strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) 1030 result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], 1031 [40., 50., 0.]]]]) 1032 self._ExecuteAndCompareClose( 1033 c, expected=[np.transpose(result, (1, 3, 0, 2))]) 1034 1035 def testConvGeneralDilatedGroupedConvolutionF32(self): 1036 c = self._NewComputation() 1037 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 1038 lhs = a(1, 2, 2, 3) 1039 rhs = a(2, 1, 1, 2) * 10 1040 strides = [1, 1] 1041 pads = [(1, 0), (0, 1)] 1042 lhs_dilation = (2, 1) 1043 rhs_dilation = (1, 1) 1044 dimension_numbers = xla_client.make_convolution_dimension_numbers( 1045 ("NCHW", "OIHW", "NCHW"), 2) 1046 feature_group_count = 2 1047 ops.ConvGeneralDilated( 1048 ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, 1049 lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) 1050 result = np.array([[[ 1051 [0., 0., 0.], 1052 [10., 20., 0.], 1053 [0., 0., 0.], 1054 [40., 50., 0.], 1055 ], [ 1056 [0., 0., 0.], 1057 [330., 380., 160.], 1058 [0., 0., 0.], 1059 [480., 530., 220.], 1060 ]]]) 1061 self._ExecuteAndCompareClose(c, expected=[result]) 1062 1063 def testConvGeneralDilatedWindowReversalF32(self): 1064 c = self._NewComputation() 1065 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 1066 lhs = a(1, 1, 2, 3) 1067 rhs = a(1, 1, 1, 2) * 10 1068 strides = [1, 1] 1069 pads = [(1, 0), (0, 1)] 1070 lhs_dilation = (2, 1) 1071 rhs_dilation = (1, 1) 1072 window_reversal = [False, True] 1073 dimension_numbers = xla_client.make_convolution_dimension_numbers( 1074 ("NCHW", "OIHW", "NCHW"), 2) 1075 ops.ConvGeneralDilated( 1076 ops.Constant(c, lhs), 1077 ops.Constant(c, rhs), 1078 strides, 1079 pads, 1080 lhs_dilation, 1081 rhs_dilation, 1082 dimension_numbers, 1083 window_reversal=window_reversal) 1084 result = np.array([[[ 1085 [0., 0., 0.], 1086 [0., 10., 20.], 1087 [0., 0., 0.], 1088 [30., 40., 50.], 1089 ]]]) 1090 self._ExecuteAndCompareClose(c, expected=[result]) 1091 1092 def testBooleanNot(self): 1093 c = self._NewComputation() 1094 arr = NumpyArrayBool([True, False, True]) 1095 ops.Not(ops.Constant(c, arr)) 1096 self._ExecuteAndCompareClose(c, expected=[~arr]) 1097 1098 def testPopulationCount(self): 1099 c = self._NewComputation() 1100 arr = NumpyArrayS32([3, 0, 1]) 1101 ops.PopulationCount(ops.Constant(c, arr)) 1102 self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) 1103 1104 def testCountLeadingZeros(self): 1105 c = self._NewComputation() 1106 arr = NumpyArrayS32([0x7FFF, 0x12345678]) 1107 ops.Clz(ops.Constant(c, arr)) 1108 self._ExecuteAndCompareClose(c, expected=[[17, 3]]) 1109 1110 def testExp(self): 1111 c = self._NewComputation() 1112 arr = NumpyArrayF32([3.3, 12.1]) 1113 ops.Exp(ops.Constant(c, arr)) 1114 self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) 1115 1116 def testExpm1(self): 1117 c = self._NewComputation() 1118 arr = NumpyArrayF32([3.3, 12.1]) 1119 ops.Expm1(ops.Constant(c, arr)) 1120 self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) 1121 1122 def testRound(self): 1123 c = self._NewComputation() 1124 arr = NumpyArrayF32([3.3, 12.1]) 1125 ops.Round(ops.Constant(c, arr)) 1126 self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) 1127 1128 def testLog(self): 1129 c = self._NewComputation() 1130 arr = NumpyArrayF32([3.3, 12.1]) 1131 ops.Log(ops.Constant(c, arr)) 1132 self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) 1133 1134 def testLog1p(self): 1135 c = self._NewComputation() 1136 arr = NumpyArrayF32([3.3, 12.1]) 1137 ops.Log1p(ops.Constant(c, arr)) 1138 self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) 1139 1140 def testNeg(self): 1141 c = self._NewComputation() 1142 arr = NumpyArrayF32([3.3, 12.1]) 1143 ops.Neg(ops.Constant(c, arr)) 1144 self._ExecuteAndCompareClose(c, expected=[-arr]) 1145 1146 def testFloor(self): 1147 c = self._NewComputation() 1148 arr = NumpyArrayF32([3.3, 12.1]) 1149 ops.Floor(ops.Constant(c, arr)) 1150 self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) 1151 1152 def testCeil(self): 1153 c = self._NewComputation() 1154 arr = NumpyArrayF32([3.3, 12.1]) 1155 ops.Ceil(ops.Constant(c, arr)) 1156 self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) 1157 1158 def testAbs(self): 1159 c = self._NewComputation() 1160 arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) 1161 ops.Abs(ops.Constant(c, arr)) 1162 self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) 1163 1164 def testTanhF32(self): 1165 c = self._NewComputation() 1166 arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) 1167 ops.Tanh(ops.Constant(c, arr)) 1168 self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) 1169 1170 def testTanhF64(self): 1171 if self.backend.platform == "tpu": 1172 self.skipTest("TPU doesn't support 64bit tanh") 1173 c = self._NewComputation() 1174 arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) 1175 ops.Tanh(ops.Constant(c, arr)) 1176 self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) 1177 1178 def testTranspose(self): 1179 1180 def _TransposeAndTest(array, permutation): 1181 c = self._NewComputation() 1182 ops.Transpose(ops.Constant(c, array), permutation) 1183 expected = np.transpose(array, permutation) 1184 self._ExecuteAndCompareClose(c, expected=[expected]) 1185 1186 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) 1187 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) 1188 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) 1189 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) 1190 1191 arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) 1192 for permutation in itertools.permutations(range(arr.ndim)): 1193 _TransposeAndTest(arr, permutation) 1194 _TransposeAndTest(np.asfortranarray(arr), permutation) 1195 1196 def testEq(self): 1197 c = self._NewComputation() 1198 ops.Eq( 1199 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), 1200 ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) 1201 self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) 1202 1203 def testNe(self): 1204 c = self._NewComputation() 1205 ops.Ne( 1206 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), 1207 ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) 1208 self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) 1209 1210 ops.Ne( 1211 ops.Constant(c, NumpyArrayF32([-2.0, 0.0, 1212 float("nan"), 1213 float("nan")])), 1214 ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, 1215 float("nan")]))) 1216 self._ExecuteAndAssertWith( 1217 np.testing.assert_allclose, 1218 c, (), 1219 expected=[[True, False, True, True]]) 1220 1221 def testGt(self): 1222 c = self._NewComputation() 1223 ops.Gt( 1224 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1225 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1226 self._ExecuteAndCompareExact( 1227 c, expected=[[False, True, True, False, False]]) 1228 1229 def testGe(self): 1230 c = self._NewComputation() 1231 ops.Ge( 1232 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1233 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1234 self._ExecuteAndCompareExact( 1235 c, expected=[[True, True, True, False, False]]) 1236 1237 def testLt(self): 1238 c = self._NewComputation() 1239 ops.Lt( 1240 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1241 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1242 self._ExecuteAndCompareExact( 1243 c, expected=[[False, False, False, True, True]]) 1244 1245 def testLe(self): 1246 c = self._NewComputation() 1247 ops.Le( 1248 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1249 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1250 self._ExecuteAndCompareExact( 1251 c, expected=[[True, False, False, True, True]]) 1252 1253 def testMax(self): 1254 c = self._NewComputation() 1255 ops.Max( 1256 ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1257 ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1258 self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) 1259 1260 def testMaxExplicitBroadcastDim0(self): 1261 c = self._NewComputation() 1262 ops.Max( 1263 ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1264 ops.Constant(c, NumpyArrayF32([3, 4, 5])), 1265 broadcast_dimensions=(0,)) 1266 self._ExecuteAndCompareExact( 1267 c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) 1268 1269 def testMaxExplicitBroadcastDim1(self): 1270 c = self._NewComputation() 1271 ops.Max( 1272 ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1273 ops.Constant(c, NumpyArrayF32([3, 4, 5])), 1274 broadcast_dimensions=(1,)) 1275 self._ExecuteAndCompareExact( 1276 c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) 1277 1278 def testMin(self): 1279 c = self._NewComputation() 1280 ops.Min( 1281 ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1282 ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1283 self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) 1284 1285 def testPad(self): 1286 c = self._NewComputation() 1287 ops.Pad( 1288 ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1289 ops.Constant(c, NumpyArrayF32(0.0)), 1290 xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) 1291 self._ExecuteAndCompareClose( 1292 c, 1293 expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1294 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) 1295 1296 def testPadWithPaddingConfig(self): 1297 c = self._NewComputation() 1298 padding_config = xla_client.PaddingConfig() 1299 for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: 1300 dimension = xla_client.PaddingConfigDimension() 1301 dimension.edge_padding_low = lo 1302 dimension.edge_padding_high = hi 1303 dimension.interior_padding = interior 1304 padding_config.dimensions.append(dimension) 1305 ops.Pad( 1306 ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1307 ops.Constant(c, NumpyArrayF32(0.0)), padding_config) 1308 self._ExecuteAndCompareClose( 1309 c, 1310 expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1311 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) 1312 1313 def testReshape(self): 1314 c = self._NewComputation() 1315 ops.Reshape( 1316 ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), 1317 dimensions=[0, 1], 1318 new_sizes=[2, 3]) 1319 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) 1320 1321 def testCollapse(self): 1322 c = self._NewComputation() 1323 ops.Collapse( 1324 ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1325 dimensions=[1, 2]) 1326 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) 1327 1328 def testRev(self): 1329 c = self._NewComputation() 1330 ops.Rev( 1331 ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1332 dimensions=[0, 2]) 1333 self._ExecuteAndCompareExact( 1334 c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) 1335 1336 def testReducePrecision(self): 1337 c = self._NewComputation() 1338 ops.ReducePrecision( 1339 ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), 1340 exponent_bits=8, 1341 mantissa_bits=7) 1342 self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) 1343 1344 def testClampF32(self): 1345 c = self._NewComputation() 1346 ops.Clamp( 1347 ops.Constant(c, NumpyArrayF32(-1)), 1348 ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), 1349 ops.Constant(c, NumpyArrayF32(2))) 1350 self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) 1351 1352 def testClampS32(self): 1353 c = self._NewComputation() 1354 ops.Clamp( 1355 ops.Constant(c, NumpyArrayS32(-1)), 1356 ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), 1357 ops.Constant(c, NumpyArrayS32(2))) 1358 self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) 1359 1360 def testSelect(self): 1361 c = self._NewComputation() 1362 ops.Select( 1363 ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), 1364 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), 1365 ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) 1366 self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) 1367 1368 def testSlice(self): 1369 c = self._NewComputation() 1370 ops.Slice( 1371 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1372 [1, 0], [3, 2], [1, 1]) 1373 self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) 1374 1375 def testSliceInDim(self): 1376 c = self._NewComputation() 1377 ops.SliceInDim( 1378 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1379 start_index=1, 1380 limit_index=2, 1381 stride=1, 1382 dimno=1) 1383 self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) 1384 ops.SliceInDim( 1385 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1386 start_index=0, 1387 limit_index=3, 1388 stride=2, 1389 dimno=0) 1390 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) 1391 1392 def testDynamicSlice(self): 1393 c = self._NewComputation() 1394 ops.DynamicSlice( 1395 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1396 [ops.Constant(c, NumpyArrayS32([1, 0]))], [2, 2]) 1397 self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) 1398 1399 def testDynamicUpdateSlice(self): 1400 c = self._NewComputation() 1401 ops.DynamicUpdateSlice( 1402 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1403 ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), 1404 [ops.Constant(c, NumpyArrayS32([1, 1]))]) 1405 self._ExecuteAndCompareExact( 1406 c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) 1407 1408 def testTuple(self): 1409 c = self._NewComputation() 1410 ops.Tuple(c, [ 1411 ops.Constant(c, np.int32(42)), 1412 ops.Constant(c, NumpyArrayF32([1.0, 2.0])), 1413 ops.Constant(c, NumpyArrayBool([True, False, False, True])) 1414 ]) 1415 result = xla_client.execute_with_python_values( 1416 self.backend.compile(c.build()), (), backend=self.backend) 1417 self.assertLen(result, 3) 1418 np.testing.assert_equal(result[0], 42) 1419 np.testing.assert_allclose(result[1], [1.0, 2.0]) 1420 np.testing.assert_equal(result[2], [True, False, False, True]) 1421 1422 def testGetTupleElement(self): 1423 c = self._NewComputation() 1424 ops.GetTupleElement( 1425 ops.Tuple(c, [ 1426 ops.Constant(c, np.int32(42)), 1427 ops.Constant(c, NumpyArrayF32([1.0, 2.0])), 1428 ops.Constant(c, NumpyArrayBool([True, False, False, True])) 1429 ]), 1) 1430 self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) 1431 1432 def testBroadcast(self): 1433 c = self._NewComputation() 1434 ops.Broadcast( 1435 ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) 1436 self._ExecuteAndCompareExact( 1437 c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) 1438 1439 def testBroadcastInDim(self): 1440 c = self._NewComputation() 1441 ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) 1442 self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) 1443 ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) 1444 self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) 1445 1446 def testRngNormal(self): 1447 shape = (2, 3) 1448 c = self._NewComputation() 1449 ops.RngNormal( 1450 ops.Constant(c, NumpyArrayF32(0.)), 1451 ops.Constant(c, NumpyArrayF32(1.)), 1452 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, 1453 shape)) 1454 result = xla_client.execute_with_python_values( 1455 self.backend.compile(c.build()), (), backend=self.backend) 1456 # since the result is random, we just check shape and uniqueness 1457 self.assertLen(result, 1) 1458 self.assertEqual(result[0].shape, shape) 1459 self.assertLen(np.unique(result[0]), np.prod(shape)) 1460 1461 def testRngUniformF32(self): 1462 lo, hi = 2., 4. 1463 shape = (2, 3) 1464 c = self._NewComputation() 1465 ops.RngUniform( 1466 ops.Constant(c, NumpyArrayF32(lo)), 1467 ops.Constant(c, NumpyArrayF32(hi)), 1468 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, 1469 shape)) 1470 result = xla_client.execute_with_python_values( 1471 self.backend.compile(c.build()), (), backend=self.backend) 1472 # since the result is random, we just check shape, uniqueness, and range 1473 self.assertLen(result, 1) 1474 self.assertEqual(result[0].shape, shape) 1475 self.assertLen(np.unique(result[0]), np.prod(shape)) 1476 self.assertTrue(np.all(lo <= result[0])) 1477 self.assertTrue(np.all(result[0] < hi)) 1478 1479 def testRngUniformS32(self): 1480 lo, hi = 2, 4 1481 shape = (2, 3) 1482 c = self._NewComputation() 1483 ops.RngUniform( 1484 ops.Constant(c, NumpyArrayS32(lo)), 1485 ops.Constant(c, NumpyArrayS32(hi)), 1486 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, 1487 shape)) 1488 result = xla_client.execute_with_python_values( 1489 self.backend.compile(c.build()), (), backend=self.backend) 1490 # since the result is random, we just check shape, integrality, and range 1491 self.assertLen(result, 1) 1492 self.assertEqual(result[0].shape, shape) 1493 self.assertEqual(result[0].dtype, np.int32) 1494 self.assertTrue(np.all(lo <= result[0])) 1495 self.assertTrue(np.all(result[0] < hi)) 1496 1497 def testCholesky(self): 1498 l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], 1499 dtype=np.float32) 1500 c = self._NewComputation() 1501 ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) 1502 self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) 1503 1504 def testSort(self): 1505 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1506 c = self._NewComputation() 1507 ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) 1508 self._ExecuteAndCompareClose( 1509 c, 1510 expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) 1511 1512 def testSortKeyVal(self): 1513 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1514 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1515 c = self._NewComputation() 1516 ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) 1517 result = xla_client.execute_with_python_values( 1518 self.backend.compile(c.build()), (), backend=self.backend) 1519 self.assertLen(result, 2) 1520 np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) 1521 np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) 1522 1523 def testSortCustomComparator(self): 1524 b = self._NewComputation("comparator") 1525 p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1526 q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1527 p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1528 q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1529 ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) 1530 comparator = b.build() 1531 1532 keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) 1533 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1534 c = self._NewComputation() 1535 ops.Sort( 1536 c, (ops.Constant(c, keys), ops.Constant(c, values)), 1537 dimension=1, 1538 comparator=comparator) 1539 result = xla_client.execute_with_python_values( 1540 self.backend.compile(c.build()), (), backend=self.backend) 1541 self.assertLen(result, 2) 1542 np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) 1543 np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) 1544 1545 def testQR(self): 1546 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1547 [10, 63, 166, 310]], 1548 dtype=np.float32) 1549 c = self._NewComputation() 1550 ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) 1551 q, r = self._Execute(c, ()) 1552 np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) 1553 1554 def testEigh(self): 1555 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1556 [10, 63, 166, 310]], 1557 dtype=np.float32) 1558 a = (a + a.T) / 2 1559 1560 c = self._NewComputation() 1561 ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) 1562 # TODO(b/129396575): Turn this test back on when it passes without 1563 # fastmath. 1564 # v, w = self._Execute(c, ()) 1565 # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) 1566 1567 def testSVD(self): 1568 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1569 [10, 63, 166, 310]], 1570 dtype=np.float32) 1571 c = self._NewComputation() 1572 ops.Tuple(c, ops.SVD(ops.Constant(c, a))) 1573 u, d, v = self._Execute(c, ()) 1574 self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) 1575 1576 def testTriangularSolve(self): 1577 a_vals = np.array( 1578 [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], 1579 dtype=np.float32) 1580 b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 1581 dtype=np.float32) 1582 1583 c = self._NewComputation() 1584 ops.TriangularSolve( 1585 ops.Constant(c, a_vals), 1586 ops.Constant(c, b_vals), 1587 left_side=False, 1588 lower=True, 1589 transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, 1590 unit_diagonal=False) 1591 self._ExecuteAndCompareClose( 1592 c, 1593 expected=[ 1594 np.array([ 1595 [0.5, 0.08333334, 0.04629629, 0.03367003], 1596 [2.5, -0.25, -0.1388889, -0.1010101], 1597 [4.5, -0.58333331, -0.32407406, -0.23569024], 1598 ], 1599 dtype=np.float32) 1600 ], 1601 rtol=1e-4) 1602 1603 def testApproxTopK(self): 1604 if self.backend.platform != "tpu": 1605 self.skipTest("ApproxTopK is only supported on TPU") 1606 k = 10 1607 qy_size = 256 1608 db_size = 3000 1609 feature = 128 1610 recall_target = 0.95 1611 b = self._NewComputation() 1612 p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1613 q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1614 ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1615 ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1616 ops.Gt(p0, q0) 1617 comparator = b.build() 1618 qy_shape = [qy_size, feature] 1619 db_shape = [feature, db_size] 1620 rng = np.random.RandomState(0) 1621 qy_arg = rng.randn(*qy_shape).astype(np.float32) 1622 db_arg = rng.randn(*db_shape).astype(np.float32) 1623 b = self._NewComputation() 1624 qy = ops.Parameter(b, 0, xla_client.shape_from_pyval(qy_arg)) 1625 db = ops.Parameter(b, 1, xla_client.shape_from_pyval(db_arg)) 1626 scores = ops.Dot(qy, db) 1627 iota = ops.Iota( 1628 b, 1629 xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, 1630 (qy_size, db_size)), 1) 1631 init_val = ops.Constant(b, np.float32(-1)) 1632 init_arg = ops.Constant(b, np.int32(-1)) 1633 ground_truth = ops.TopK(scores, k=k) 1634 approx_topk = ops.ApproxTopK( 1635 b, [scores, iota], [init_val, init_arg], 1636 top_k=k, 1637 reduction_dim=1, 1638 comparator=comparator, 1639 recall_target=recall_target) 1640 ops.Tuple(b, [ 1641 ops.GetTupleElement(ground_truth, 1), 1642 ops.GetTupleElement(approx_topk, 1) 1643 ]) 1644 results = self._Execute(b, [qy_arg, db_arg]) 1645 ground_truth_docids = [set(x) for x in results[0]] 1646 hits = sum( 1647 len( 1648 list(x 1649 for x in approx_topk_per_q 1650 if x in ground_truth_docids[q])) 1651 for q, approx_topk_per_q in enumerate(results[1])) 1652 self.assertGreater(hits / (qy_size * k), recall_target) 1653 1654 def testIsConstant(self): 1655 c = self._NewComputation() 1656 a = ops.Constant(c, np.int32(3)) 1657 b = ops.Constant(c, np.int32(1)) 1658 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1659 const_expr = ops.Sub(b, a) 1660 non_const_expr = ops.Mul(const_expr, x) 1661 self.assertTrue(c.is_constant(const_expr)) 1662 self.assertFalse(c.is_constant(non_const_expr)) 1663 1664 def testGather(self): 1665 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1666 indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) 1667 dnums = xla_client.GatherDimensionNumbers() 1668 dnums.offset_dims.append(1) 1669 dnums.offset_dims.append(2) 1670 dnums.start_index_map.append(0) 1671 dnums.start_index_map.append(1) 1672 dnums.index_vector_dim = 2 1673 c = self._NewComputation() 1674 ops.Gather( 1675 ops.Constant(c, a), 1676 ops.Constant(c, indices), 1677 dnums, 1678 slice_sizes=[1, 1]) 1679 g, = self._Execute(c, ()) 1680 expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) 1681 np.testing.assert_allclose(g, expected, rtol=1e-4) 1682 1683 def testFft(self): 1684 if self.backend.platform == "tpu": 1685 self.skipTest("TPU only supports 1D FFT") 1686 shape = [2, 3, 4, 5] 1687 rng = np.random.RandomState(0) 1688 a = rng.randn(*shape) + 1.0j * rng.randn(*shape) 1689 a = a.astype(np.complex64) 1690 # FFT 1691 c = self._NewComputation() 1692 ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) 1693 self._ExecuteAndCompareClose( 1694 c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) 1695 # IFFT 1696 c = self._NewComputation() 1697 ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) 1698 self._ExecuteAndCompareClose( 1699 c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) 1700 # RFFT 1701 b = rng.randn(*shape).astype(np.float32) 1702 c = self._NewComputation() 1703 ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) 1704 self._ExecuteAndCompareClose( 1705 c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) 1706 # IRFFT 1707 c = self._NewComputation() 1708 ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) 1709 self._ExecuteAndCompareClose( 1710 c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4) 1711 1712 def testNextAfter(self): 1713 c = self._NewComputation() 1714 ops.NextAfter( 1715 ops.Constant(c, np.array([1, 2], dtype=np.float32)), 1716 ops.Constant(c, np.array([2, 1], dtype=np.float32))) 1717 out, = self._Execute(c, ()) 1718 eps = np.finfo(np.float32).eps 1719 np.testing.assert_equal( 1720 np.array([eps + 1, 2 - eps], dtype=np.float32), out) 1721 1722 @parameterized.named_parameters({ 1723 "testcase_name": "_{}".format(dtype.__name__), 1724 "dtype": dtype, 1725 } for dtype in float_dtypes) 1726 def testRegularizedIncompleteBeta(self, dtype): 1727 x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], 1728 dtype=dtype) 1729 a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], 1730 dtype=dtype) 1731 b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], 1732 dtype=dtype) 1733 c = self._NewComputation() 1734 ops.RegularizedIncompleteBeta( 1735 ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) 1736 expected = np.array( 1737 [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) 1738 self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) 1739 1740 tests.append(SingleOpTest) 1741 1742 class EmbeddedComputationsTest(ComputationTest): 1743 """Tests for XLA graphs with embedded computations (such as maps).""" 1744 1745 def _CreateConstantComputation(self, in_dtype, out_dtype): 1746 """Computation (A) -> B that returns a constant 1 for any input.""" 1747 c = self._NewComputation("constant_{}_{}_one".format( 1748 in_dtype.__name__, out_dtype.__name__)) 1749 ops.Parameter( 1750 c, 0, 1751 xla_client.shape_from_pyval(np.array( 1752 0, dtype=in_dtype)).with_major_to_minor_layout_if_absent()) 1753 ops.Constant(c, out_dtype(1)) 1754 return c.build() 1755 1756 def _CreateMulBy2Computation(self, dtype): 1757 """Computation (dtype) -> dtype that multiplies its parameter by 2.""" 1758 c = self._NewComputation("mul_f32_by2") 1759 ops.Mul( 1760 ops.Parameter( 1761 c, 0, 1762 xla_client.shape_from_pyval(np.array( 1763 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), 1764 ops.Constant(c, dtype(2.0))) 1765 return c.build() 1766 1767 def _CreateMulF32ByParamComputation(self): 1768 """Computation (f32) -> f32 that multiplies one parameter by the other.""" 1769 c = self._NewComputation("mul_f32_by_param") 1770 ops.Mul( 1771 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), 1772 ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) 1773 return c.build() 1774 1775 def _CreateBinaryAddComputation(self, dtype): 1776 """Computation (dtype, dtype) -> dtype that adds its two parameters.""" 1777 c = self._NewComputation("add_param0_by_param1") 1778 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1779 shape = shape.with_major_to_minor_layout_if_absent() 1780 ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1781 return c.build() 1782 1783 def _CreateBinaryGeComputation(self, dtype): 1784 """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" 1785 c = self._NewComputation("param0_lt_param1") 1786 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1787 shape = shape.with_major_to_minor_layout_if_absent() 1788 ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1789 return c.build() 1790 1791 def _MakeSample3DArray(self, dtype): 1792 return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 1793 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], 1794 dtype=dtype) 1795 1796 @parameterized.named_parameters({ 1797 "testcase_name": "_{}".format(dtype.__name__), 1798 "dtype": dtype, 1799 } for dtype in float_dtypes) 1800 def testCall(self, dtype): 1801 c = self._NewComputation() 1802 ops.Call( 1803 c, 1804 self._CreateMulBy2Computation(dtype), 1805 operands=(ops.Constant(c, dtype(5.0)),)) 1806 self._ExecuteAndCompareClose(c, expected=[10.0]) 1807 1808 @parameterized.named_parameters({ 1809 "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), 1810 "in_dtype": in_dtype, 1811 "out_dtype": out_dtype, 1812 } for in_dtype, out_dtype in [[np.float32, np.int32]]) 1813 def testMapEachElementToConstant(self, in_dtype, out_dtype): 1814 c = self._NewComputation() 1815 ops.Map(c, 1816 [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], 1817 self._CreateConstantComputation(in_dtype, out_dtype), [0]) 1818 self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) 1819 1820 @parameterized.named_parameters({ 1821 "testcase_name": "_{}".format(dtype.__name__), 1822 "dtype": dtype, 1823 } for dtype in float_dtypes) 1824 def testMapMulBy2(self, dtype): 1825 if dtype == np.float64 and self.backend.platform == "tpu": 1826 self.skipTest("TPU doesn't support float64") 1827 c = self._NewComputation() 1828 ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], 1829 self._CreateMulBy2Computation(dtype), [0]) 1830 self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) 1831 1832 @parameterized.named_parameters({ 1833 "testcase_name": "_{}".format(dtype.__name__), 1834 "dtype": dtype, 1835 } for dtype in float_dtypes) 1836 def testSimpleMapChain(self, dtype): 1837 if dtype == np.float64 and self.backend.platform == "tpu": 1838 self.skipTest("TPU doesn't support float64") 1839 # Chains a map of constant-out with a map of mul-by-2 1840 c = self._NewComputation() 1841 const = ops.Map( 1842 c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], 1843 self._CreateConstantComputation(dtype, dtype), [0]) 1844 ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) 1845 self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) 1846 1847 # TODO(b/154752816): bfloat16 crashes in evaluator. 1848 @parameterized.named_parameters({ 1849 "testcase_name": "_{}".format(dtype.__name__), 1850 "dtype": dtype, 1851 } for dtype in float_dtypes if dtype != bfloat16) 1852 def testDivVectorsWithMap(self, dtype): 1853 1854 def DivComputation(): 1855 c = self._NewComputation("div_param0_by_param1") 1856 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1857 ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1858 return c.build() 1859 1860 c = self._NewComputation() 1861 ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), 1862 ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), 1863 DivComputation(), [0]) 1864 self._ExecuteAndCompareClose( 1865 c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) 1866 1867 @parameterized.named_parameters({ 1868 "testcase_name": "_{}".format(dtype.__name__), 1869 "dtype": dtype, 1870 } for dtype in float_dtypes) 1871 def testSelectAndScatter(self, dtype): 1872 if dtype == np.float64 and self.backend.platform == "tpu": 1873 self.skipTest("TPU doesn't support float64") 1874 c = self._NewComputation() 1875 operand = ops.Constant( 1876 c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) 1877 window_dimensions = (2, 1) 1878 window_strides = (1, 2) 1879 padding = xla_client.window_padding_type_to_pad_values( 1880 xla_client.PaddingType.VALID, 1881 c.get_shape(operand).dimensions(), window_dimensions, window_strides) 1882 ops.SelectAndScatterWithGeneralPadding( 1883 operand, 1884 select=self._CreateBinaryGeComputation(dtype), 1885 window_dimensions=window_dimensions, 1886 window_strides=window_strides, 1887 padding=padding, 1888 source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), 1889 init_value=ops.Constant(c, np.array(1, dtype=dtype)), 1890 scatter=self._CreateBinaryAddComputation(dtype)) 1891 self._ExecuteAndCompareClose( 1892 c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) 1893 1894 @parameterized.named_parameters({ 1895 "testcase_name": "_{}".format(dtype.__name__), 1896 "dtype": dtype, 1897 } for dtype in float_dtypes) 1898 def testReduce1DtoScalar(self, dtype): 1899 c = self._NewComputation() 1900 ops.Reduce( 1901 c, 1902 operands=[ 1903 ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) 1904 ], 1905 init_values=[ops.Constant(c, dtype(0))], 1906 computation=self._CreateBinaryAddComputation(dtype), 1907 dimensions_to_reduce=[0]) 1908 self._ExecuteAndCompareClose(c, expected=[10]) 1909 1910 # TODO(phawkins): test comparison harness doesn't support bfloat16 1911 @parameterized.named_parameters({ 1912 "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), 1913 "dtype": dtype, 1914 "dim": dim, 1915 } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) 1916 def testReduce2DTo1D(self, dtype, dim): 1917 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1918 c = self._NewComputation() 1919 ops.Reduce( 1920 c, 1921 operands=[ops.Constant(c, input_array)], 1922 init_values=[ops.Constant(c, dtype(0))], 1923 computation=self._CreateBinaryAddComputation(dtype), 1924 dimensions_to_reduce=[dim]) 1925 self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) 1926 1927 @parameterized.named_parameters({ 1928 "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), 1929 "dtype": dtype, 1930 "dims": tuple(dims) 1931 } for dtype in float_dtypes for dims in itertools.permutations(range(3))) 1932 def testReduce3DAllPossibleWaysF32(self, dtype, dims): 1933 input_array = self._MakeSample3DArray(dtype) 1934 c = self._NewComputation() 1935 ops.Reduce( 1936 c, 1937 operands=[ops.Constant(c, input_array)], 1938 init_values=[ops.Constant(c, dtype(0))], 1939 computation=self._CreateBinaryAddComputation(dtype), 1940 dimensions_to_reduce=dims) 1941 self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) 1942 1943 @parameterized.named_parameters({ 1944 "testcase_name": "_{}".format(dtype.__name__), 1945 "dtype": dtype, 1946 } for dtype in float_dtypes) 1947 def testReduceWindowValidUnitStrides(self, dtype): 1948 if dtype == np.float64 and self.backend.platform == "tpu": 1949 self.skipTest("TPU doesn't support float64") 1950 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1951 c = self._NewComputation() 1952 window_dimensions = (2, 1) 1953 window_strides = (1, 1) 1954 padding = xla_client.window_padding_type_to_pad_values( 1955 xla_client.PaddingType.VALID, input_array.shape, window_dimensions, 1956 window_strides) 1957 ops.ReduceWindowWithGeneralPadding( 1958 operand=ops.Constant(c, input_array), 1959 init_value=ops.Constant(c, dtype(0)), 1960 computation=self._CreateBinaryAddComputation(dtype), 1961 window_dimensions=window_dimensions, 1962 window_strides=window_strides, 1963 base_dilations=[], 1964 window_dilations=[], 1965 padding=padding) 1966 self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) 1967 1968 @parameterized.named_parameters({ 1969 "testcase_name": "_{}".format(dtype.__name__), 1970 "dtype": dtype, 1971 } for dtype in float_dtypes) 1972 def testReduceWindowSameUnitStrides(self, dtype): 1973 if dtype == np.float64 and self.backend.platform == "tpu": 1974 self.skipTest("TPU doesn't support float64") 1975 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1976 c = self._NewComputation() 1977 window_dimensions = (2, 1) 1978 window_strides = (1, 1) 1979 padding = xla_client.window_padding_type_to_pad_values( 1980 xla_client.PaddingType.SAME, input_array.shape, window_dimensions, 1981 window_strides) 1982 ops.ReduceWindowWithGeneralPadding( 1983 operand=ops.Constant(c, input_array), 1984 init_value=ops.Constant(c, dtype(0)), 1985 computation=self._CreateBinaryAddComputation(dtype), 1986 window_dimensions=window_dimensions, 1987 window_strides=window_strides, 1988 base_dilations=[], 1989 window_dilations=[], 1990 padding=padding) 1991 self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) 1992 1993 @parameterized.named_parameters({ 1994 "testcase_name": "_{}".format(dtype.__name__), 1995 "dtype": dtype, 1996 } for dtype in float_dtypes) 1997 def testReduceWindowValidGeneralStrides(self, dtype): 1998 if dtype == np.float64 and self.backend.platform == "tpu": 1999 self.skipTest("TPU doesn't support float64") 2000 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 2001 c = self._NewComputation() 2002 window_dimensions = (2, 1) 2003 window_strides = (1, 2) 2004 padding = xla_client.window_padding_type_to_pad_values( 2005 xla_client.PaddingType.VALID, input_array.shape, window_dimensions, 2006 window_strides) 2007 ops.ReduceWindowWithGeneralPadding( 2008 operand=ops.Constant(c, input_array), 2009 init_value=ops.Constant(c, dtype(0)), 2010 computation=self._CreateBinaryAddComputation(dtype), 2011 window_dimensions=window_dimensions, 2012 window_strides=window_strides, 2013 base_dilations=[], 2014 window_dilations=[], 2015 padding=padding) 2016 self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) 2017 2018 def testReduceWindowVariadic(self): 2019 c = self._NewComputation("reducer") 2020 shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32)) 2021 shape = shape.with_major_to_minor_layout_if_absent() 2022 ps = [ops.Parameter(c, i, shape) for i in range(4)] 2023 which = ops.Ge(ps[0], ps[2]) 2024 ops.Tuple( 2025 c, [ops.Select(which, ps[0], ps[2]), 2026 ops.Select(which, ps[1], ps[3])]) 2027 reducer = c.build() 2028 2029 key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32) 2030 val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32) 2031 c = self._NewComputation() 2032 window_dimensions = (2, 1) 2033 window_strides = (1, 1) 2034 padding = xla_client.window_padding_type_to_pad_values( 2035 xla_client.PaddingType.VALID, key_array.shape, window_dimensions, 2036 window_strides) 2037 ops.ReduceWindowWithGeneralPadding( 2038 operands=[ops.Constant(c, key_array), 2039 ops.Constant(c, val_array)], 2040 init_values=[ 2041 ops.Constant(c, np.int32(0)), 2042 ops.Constant(c, np.int32(0)) 2043 ], 2044 computation=reducer, 2045 window_dimensions=window_dimensions, 2046 window_strides=window_strides, 2047 base_dilations=[], 2048 window_dilations=[], 2049 padding=padding) 2050 self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]]) 2051 2052 @parameterized.named_parameters({ 2053 "testcase_name": "_{}".format(dtype.__name__), 2054 "dtype": dtype, 2055 } for dtype in float_dtypes) 2056 def testWhile(self, dtype): 2057 2058 def LessThan10Cond(): 2059 c = self._NewComputation("test_lt_10") 2060 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 2061 ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) 2062 return c.build() 2063 2064 cond = LessThan10Cond() 2065 body = self._CreateMulBy2Computation(dtype) 2066 c = self._NewComputation() 2067 init = ops.Constant(c, dtype(1.)) 2068 ops.While(cond, body, init) 2069 self._ExecuteAndCompareClose(c, expected=[16.]) 2070 2071 def testConditionalTrue(self): 2072 c = self._NewComputation() 2073 pred = ops.Constant(c, np.bool_(True)) 2074 true_operand = ops.Constant(c, np.float32(3.)) 2075 true_computation = self._CreateMulBy2Computation(np.float32) 2076 false_operand = ops.Constant(c, np.float32(2.)) 2077 false_computation = self._CreateConstantComputation( 2078 np.float32, np.float32) 2079 ops.Conditional(pred, true_operand, true_computation, false_operand, 2080 false_computation) 2081 self._ExecuteAndCompareClose(c, expected=[6.]) 2082 2083 def testConditionalFalse(self): 2084 c = self._NewComputation() 2085 pred = ops.Constant(c, np.bool_(False)) 2086 true_operand = ops.Constant(c, np.float32(3.)) 2087 true_computation = self._CreateMulBy2Computation(np.float32) 2088 false_operand = ops.Constant(c, np.float32(2.)) 2089 false_computation = self._CreateConstantComputation( 2090 np.float32, np.float32) 2091 ops.Conditional(pred, true_operand, true_computation, false_operand, 2092 false_computation) 2093 self._ExecuteAndCompareClose(c, expected=[1.]) 2094 2095 @unittest.skipIf(cloud_tpu, "not implemented") 2096 def testInfeedS32Values(self): 2097 to_infeed = NumpyArrayS32([1, 2, 3, 4]) 2098 c = self._NewComputation() 2099 ops.GetTupleElement( 2100 ops.InfeedWithToken( 2101 ops.CreateToken(c), 2102 xla_client.shape_from_pyval( 2103 to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) 2104 compiled_c = self.backend.compile(c.build()) 2105 device = self.backend.local_devices()[0] 2106 for item in to_infeed: 2107 device.transfer_to_infeed(item) 2108 2109 for item in to_infeed: 2110 result, = xla_client.execute_with_python_values( 2111 compiled_c, (), backend=self.backend) 2112 self.assertEqual(result, item) 2113 2114 @unittest.skipIf(cloud_tpu, "not implemented") 2115 def testInfeedTuple(self): 2116 to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) 2117 c = self._NewComputation() 2118 ops.GetTupleElement( 2119 ops.InfeedWithToken( 2120 ops.CreateToken(c), 2121 xla_client.shape_from_pyval( 2122 to_infeed).with_major_to_minor_layout_if_absent()), 0) 2123 compiled_c = self.backend.compile(c.build()) 2124 device = self.backend.local_devices()[0] 2125 device.transfer_to_infeed(to_infeed) 2126 2127 result = xla_client.execute_with_python_values( 2128 compiled_c, (), backend=self.backend) 2129 self.assertLen(result, 2) 2130 np.testing.assert_equal(result[0], to_infeed[0]) 2131 np.testing.assert_equal(result[1], to_infeed[1]) 2132 2133 @unittest.skipIf(cloud_tpu, "not implemented") 2134 def testInfeedThenOutfeedS32(self): 2135 to_round_trip = NumpyArrayS32([1, 2, 3, 4]) 2136 c = self._NewComputation() 2137 x_and_token = ops.InfeedWithToken( 2138 ops.CreateToken(c), 2139 xla_client.shape_from_pyval( 2140 to_round_trip[0]).with_major_to_minor_layout_if_absent()) 2141 x = ops.GetTupleElement(x_and_token, 0) 2142 token = ops.GetTupleElement(x_and_token, 1) 2143 outfeed_shape = xla_client.shape_from_pyval( 2144 to_round_trip[0]).with_major_to_minor_layout_if_absent() 2145 ops.OutfeedWithToken(x, token, outfeed_shape) 2146 2147 compiled_c = self.backend.compile(c.build()) 2148 device = self.backend.local_devices()[0] 2149 2150 for want in to_round_trip: 2151 execution = threading.Thread(target=lambda: compiled_c.execute([])) 2152 execution.start() 2153 device.transfer_to_infeed(want) 2154 got = device.transfer_from_outfeed(outfeed_shape) 2155 execution.join() 2156 self.assertEqual(want, got) 2157 2158 def testScatter(self): 2159 a = np.arange(9).astype(np.int32).reshape((3, 3)) 2160 scatter_indices = np.array([0, 2], dtype=np.int32) 2161 updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) 2162 2163 dnums = xla_client.ScatterDimensionNumbers() 2164 dnums.update_window_dims.append(1) 2165 dnums.inserted_window_dims.append(0) 2166 dnums.scatter_dims_to_operand_dims.append(0) 2167 dnums.index_vector_dim = 1 2168 2169 c = self._NewComputation() 2170 ops.Scatter( 2171 ops.Constant(c, a), ops.Constant(c, scatter_indices), 2172 ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), 2173 dnums) 2174 expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], 2175 dtype=np.int32) 2176 self._ExecuteAndCompareClose(c, expected=[expected]) 2177 2178 class DeviceTest(ComputationTest): 2179 2180 def testPlatform(self): 2181 for device in self.backend.local_devices(): 2182 self.assertEqual(device.platform, self.backend.platform) 2183 2184 tests.append(DeviceTest) 2185 2186 class ErrorTest(ComputationTest): 2187 2188 def setUp(self): 2189 super(ErrorTest, self).setUp() 2190 self.f32_scalar_2 = NumpyArrayF32(2.0) 2191 self.s32_scalar_2 = NumpyArrayS32(2) 2192 2193 def testCompileWithWrongElementTypeInLayout(self): 2194 c = self._NewComputation() 2195 c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) 2196 ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) 2197 c.clear_op_metadata() 2198 2199 options = xla_client.CompileOptions() 2200 options.argument_layouts = [ 2201 xla_client.Shape.array_shape(np.dtype(np.float32), []) 2202 ] 2203 2204 def TestFun(): 2205 return self.backend.compile(c.build(), compile_options=options) 2206 2207 self.assertRaisesRegex( 2208 RuntimeError, r".*Invalid argument shape.*" 2209 r"expected s32\[\], got f32\[\].*", TestFun) 2210 2211 def testInvokeWithWrongElementType(self): 2212 c = self._NewComputation() 2213 c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) 2214 ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) 2215 c.clear_op_metadata() 2216 2217 def TestFun(): 2218 return xla_client.execute_with_python_values( 2219 self.backend.compile(c.build()), [self.f32_scalar_2], self.backend) 2220 2221 self.assertRaisesRegex( 2222 RuntimeError, r"Invalid argument: Argument does not match.*" 2223 r"want s32\[\], got f32\[\].*", TestFun) 2224 2225 tests.append(EmbeddedComputationsTest) 2226 2227 class ComputationRootTest(ComputationTest): 2228 """Tests related to setting the root of the computation.""" 2229 2230 def testComputationRootDifferentFromLastOp(self): 2231 c = self._NewComputation() 2232 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) 2233 result = ops.Add(x, ops.Constant(c, np.float32(3.14))) 2234 ops.Add(result, ops.Constant(c, np.float32(1.618))) 2235 2236 arg = NumpyArrayF32(1.0) 2237 compiled_c = self.backend.compile(c.build(result)) 2238 ans, = xla_client.execute_with_python_values( 2239 compiled_c, [arg], backend=self.backend) 2240 np.testing.assert_allclose(ans, 4.14) 2241 2242 tests.append(ComputationRootTest) 2243 2244 class SetShardingTest(ComputationTest): 2245 """Tests related to set OpSharding.""" 2246 2247 def testSetSharding(self): 2248 c = self._NewComputation() 2249 sharding = xla_client.OpSharding() 2250 sharding.type = xla_client.OpSharding.Type.REPLICATED 2251 sharding.tile_assignment_dimensions = [1] 2252 sharding.tile_assignment_devices = [0] 2253 c.set_sharding(sharding) 2254 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) 2255 c.clear_sharding() 2256 2257 result = ops.Add(x, ops.Constant(c, np.float32(3.14))) 2258 ops.Add(result, ops.Constant(c, np.float32(1.618))) 2259 arg = NumpyArrayF32(1.0) 2260 compiled_c = self.backend.compile(c.build(result)) 2261 ans, = xla_client.execute_with_python_values( 2262 compiled_c, [arg], backend=self.backend) 2263 np.testing.assert_allclose(ans, 4.14) 2264 2265 tests.append(SetShardingTest) 2266 2267 testcase_shapes = [ 2268 (), 2269 (1,), 2270 (2, 3), 2271 (2, 0), 2272 (0, 7), 2273 (4, 1, 2), 2274 (2, 1, 3), 2275 (2, 4, 1), 2276 (3, 1), 2277 (1, 3), 2278 ] 2279 2280 def FormatShapeAndDtype(shape, dtype): 2281 return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) 2282 2283 class DLPackTest(parameterized.TestCase): 2284 2285 def setUp(self): 2286 super(DLPackTest, self).setUp() 2287 self.backend = xla_backend() 2288 if self.backend.platform not in ("cpu", "gpu"): 2289 self.skipTest("DLPack requires CPU or GPU") 2290 self.cpu_backend = ( 2291 self.backend 2292 if self.backend.platform == "cpu" else xla_client.make_cpu_client()) 2293 self.gpu_backend = ( 2294 self.backend if self.backend.platform == "gpu" else None) 2295 2296 def tearDown(self): 2297 super().tearDown() 2298 del self.backend 2299 del self.cpu_backend 2300 del self.gpu_backend 2301 2302 # pylint: disable=g-complex-comprehension 2303 # pyformat: disable 2304 @parameterized.named_parameters({ 2305 "testcase_name": "{}_own={}_gpu={}".format( 2306 FormatShapeAndDtype(shape, dtype), take_ownership, gpu), 2307 "dtype": dtype, 2308 "shape": shape, 2309 "take_ownership": take_ownership, 2310 "gpu": gpu 2311 } for dtype in dlpack_dtypes for shape in testcase_shapes 2312 for take_ownership in [False, True] 2313 for gpu in [False, True]) 2314 # pyformat: enable 2315 def testRoundTrip(self, dtype, shape, take_ownership, gpu): 2316 if gpu and self.gpu_backend is None: 2317 raise unittest.SkipTest("Test not running with GPU support") 2318 backend = self.gpu_backend if gpu else self.cpu_backend 2319 if dtype == np.bool_: 2320 x = np.random.randint(0, 2, size=shape).astype(np.bool_) 2321 else: 2322 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 2323 buffer = backend.buffer_from_pyval(x) 2324 dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( 2325 buffer, take_ownership=take_ownership) 2326 del buffer # Free "buffer" to make sure dlt retains ownership. 2327 self.assertEqual(type(dlt).__name__, "PyCapsule") 2328 y = xla_client._xla.dlpack_managed_tensor_to_buffer( 2329 dlt, self.cpu_backend, self.gpu_backend) 2330 np.testing.assert_array_equal( 2331 x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py()) 2332 2333 def testTensorsCanBeConsumedOnceOnly(self): 2334 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2335 buffer = self.backend.buffer_from_pyval(x) 2336 dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( 2337 buffer, take_ownership=True) 2338 2339 def ConsumeDLPackTensor(): 2340 _ = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend) 2341 2342 ConsumeDLPackTensor() 2343 self.assertRaisesRegex( 2344 RuntimeError, ".*a DLPack tensor may be consumed at most once.*", 2345 ConsumeDLPackTensor) 2346 2347 def testTensorsCanBeOwnedOnceOnly(self): 2348 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2349 buffer = self.backend.buffer_from_pyval(x) 2350 _ = xla_client._xla.buffer_to_dlpack_managed_tensor( 2351 buffer, take_ownership=True) 2352 self.assertTrue(buffer.is_deleted()) 2353 with self.assertRaisesRegex( 2354 RuntimeError, 2355 "Cannot convert deleted/invalid buffer to DLPack tensor.*"): 2356 _ = xla_client._xla.buffer_to_dlpack_managed_tensor( 2357 buffer, take_ownership=True) 2358 2359 def testNonOwnedDlpackCanBeViewedTwice(self): 2360 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2361 buffer = self.backend.buffer_from_pyval(x) 2362 d1 = xla_client._xla.buffer_to_dlpack_managed_tensor( 2363 buffer, take_ownership=False) 2364 d2 = xla_client._xla.buffer_to_dlpack_managed_tensor( 2365 buffer, take_ownership=False) 2366 2367 y = xla_client._xla.dlpack_managed_tensor_to_buffer(d1, self.backend) 2368 z = xla_client._xla.dlpack_managed_tensor_to_buffer(d2, self.backend) 2369 del d1, d2 2370 np.testing.assert_array_equal(x, buffer.to_py()) 2371 np.testing.assert_array_equal(x, y.to_py()) 2372 np.testing.assert_array_equal(x, z.to_py()) 2373 2374 tests.append(DLPackTest) 2375 2376 class BufferProtocolTest(parameterized.TestCase): 2377 2378 def setUp(self): 2379 super(BufferProtocolTest, self).setUp() 2380 self.backend = xla_backend() 2381 if self.backend.platform != "cpu": 2382 self.skipTest("Test requires CPU") 2383 2384 # pylint: disable=g-complex-comprehension 2385 @parameterized.named_parameters({ 2386 "testcase_name": FormatShapeAndDtype(shape, dtype), 2387 "dtype": dtype, 2388 "shape": shape 2389 } for dtype in standard_dtypes if dtype != bfloat16 2390 for shape in testcase_shapes) 2391 def testRoundTrip(self, dtype, shape): 2392 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 2393 x_ptr = x.__array_interface__["data"][0] 2394 buffer = self.backend.buffer_from_pyval( 2395 x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) 2396 y = np.array(buffer, copy=False) 2397 y_ptr = y.__array_interface__["data"][0] 2398 np.testing.assert_array_equal(x, y) 2399 # If the input was sufficiently aligned, the input and output should 2400 # alias. 2401 self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) 2402 self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) 2403 2404 during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL 2405 buffer2 = self.backend.buffer_from_pyval( 2406 x, host_buffer_semantics=during_call) 2407 z = np.array(buffer2, copy=False) 2408 self.assertNotEqual(x.__array_interface__["data"][0], 2409 z.__array_interface__["data"][0]) 2410 2411 def testDeleteWithActiveView(self): 2412 x = np.random.randn(20, 10) 2413 buffer = self.backend.buffer_from_pyval(x) 2414 buffer_ptr = buffer.unsafe_buffer_pointer() 2415 y = np.array(buffer, copy=False) 2416 buffer.delete() 2417 # It is still legal to access `y`; the array view must keep it alive. 2418 np.testing.assert_array_equal(x, y) 2419 self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) 2420 2421 tests.append(BufferProtocolTest) 2422 2423 class TracebackTest(absltest.TestCase): 2424 2425 def setUp(self): 2426 super(TracebackTest, self).setUp() 2427 self.backend = xla_backend() 2428 2429 def testNoTracebacksIfDisabled(self): 2430 with xla_client.tracebacks(enabled=False): 2431 self.assertEqual(None, xla_client.Traceback.get_traceback()) 2432 buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) 2433 self.assertEqual(None, buffer.traceback) 2434 2435 b = xla_client.XlaBuilder("computation") 2436 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 2437 e = self.backend.compile(b.build()) 2438 self.assertEqual(None, e.traceback) 2439 2440 def assertIsTracebackContaining(self, tb, function): 2441 self.assertIsInstance(tb, xla_client.Traceback) 2442 self.assertIn(function, str(tb)) 2443 self.assertTrue(any(f.function_name == function for f in tb.frames)) 2444 2445 def testTracebacks(self): 2446 with xla_client.tracebacks(enabled=True): 2447 tb = xla_client.Traceback.get_traceback() 2448 self.assertIsTracebackContaining(tb, "testTracebacks") 2449 2450 # Tracebacks are not implemented on the TPU driver extension's variant 2451 # of buffers and executables. 2452 if not isinstance(self.backend, xla_client.Client): 2453 return 2454 2455 buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) 2456 self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") 2457 2458 b = xla_client.XlaBuilder("computation") 2459 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 2460 e = self.backend.compile(b.build()) 2461 self.assertIsTracebackContaining(e.traceback, "testTracebacks") 2462 2463 def testNestedFunction(self): 2464 2465 def AFunction(): 2466 2467 def AnotherFunction(): 2468 return xla_client.Traceback.get_traceback() 2469 2470 return AnotherFunction() 2471 2472 with xla_client.tracebacks(enabled=True): 2473 tb = AFunction() 2474 self.assertIsInstance(tb, xla_client.Traceback) 2475 frames = tb.frames 2476 i = next( 2477 i for (i, f) in enumerate(frames) if f.function_name == "AFunction") 2478 self.assertEqual(frames[i - 1].function_name, "AnotherFunction") 2479 self.assertEqual(frames[i + 1].function_name, "testNestedFunction") 2480 2481 tests.append(TracebackTest) 2482 2483 class ClientTest(ComputationTest): 2484 2485 def setUp(self): 2486 super(ClientTest, self).setUp() 2487 self.backend = xla_backend() 2488 2489 def testPlatformVersion(self): 2490 version = self.backend.platform_version 2491 logging.info("platform_version:\n%s", version) 2492 if self.backend.platform == "cpu": 2493 self.assertEqual(version, "<unknown>") 2494 elif self.backend.platform == "gpu": 2495 # Following is false if not built with --config=cuda 2496 if test_util.is_gpu_available(cuda_only=True): 2497 self.assertTrue( 2498 re.match(r"^cuda \d{4,}$", version), 2499 msg=f"Expected CUDA version string; got {repr(version)}") 2500 else: 2501 self.assertEqual(version, "<unknown>") 2502 elif self.backend.platform == "tpu" and not cloud_tpu: 2503 self.assertIn("tpu", version.lower()) 2504 self.assertIn("cl/", version) 2505 2506 @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented") 2507 def testExecutableSerialization(self): 2508 if self.backend.platform != "tpu": 2509 self.skipTest("Test requires tpu platform") 2510 2511 c = self._NewComputation() 2512 ops.Add( 2513 ops.Constant(c, NumpyArrayS32([1, 2])), 2514 ops.Constant(c, NumpyArrayS32([3, 4]))) 2515 2516 options = xla_client.CompileOptions() 2517 executable = self.backend.compile(c.build(), options) 2518 self.assertLen(executable.hlo_modules(), 1) 2519 2520 serialized = self.backend.serialize_executable(executable) 2521 deserialized = self.backend.deserialize_executable( 2522 serialized, 2523 executable.hlo_modules()[0], options) 2524 2525 expected, = xla_client.execute_with_python_values(executable, (), 2526 self.backend) 2527 actual, = xla_client.execute_with_python_values(deserialized, (), 2528 self.backend) 2529 self.assertTrue(np.all(actual == expected)) 2530 2531 tests.append(ClientTest) 2532 2533 # TODO(b/182461453): Add TFRT and cloud TPU implementation of 2534 # ReadDynamicShapes 2535 class DynamicReshapeTest(ComputationTest): 2536 """Tests related to DynamicReshape.""" 2537 2538 def _CompareToPyAndBufferProtocol(self, builder, args, expected_results, 2539 test_fn): 2540 compiled = self.backend.compile(builder.build()) 2541 output_buffers = compiled.execute([ 2542 self.backend.buffer_from_pyval( 2543 arg, device=compiled.local_devices()[0]) for arg in args 2544 ]) 2545 self.assertLen(output_buffers, len(expected_results)) 2546 for buf, expected in zip(output_buffers, expected_results): 2547 to_py_result = buf.to_py() 2548 self.assertEqual(expected.shape, to_py_result.shape) 2549 test_fn(expected, to_py_result) 2550 if self.backend.platform == "cpu" and buf.dtype != bfloat16: 2551 mview = memoryview(buf) 2552 self.assertEqual(expected.shape, mview.shape) 2553 test_fn(expected, np.asarray(mview)) 2554 else: 2555 # Buffer protocol expected to fail on non-cpu platforms and bfloat16 2556 # Note that np.asarray(buf) doesn't throw an exception. To test if the 2557 # error was thrown properly we must use memoryview(buf). 2558 with self.assertRaises(BufferError): 2559 memoryview(buf) 2560 2561 # 1D reshape of full size, half size, and size of 0. 2562 @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented") 2563 @parameterized.parameters((5), (3), (0)) 2564 def testReshape1D(self, reshape_size): 2565 full_size = 5 2566 c = self._NewComputation() 2567 arg = np.array(reshape_size, dtype=np.int32) 2568 expected = np.array(range(reshape_size), dtype=np.int32) 2569 p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) 2570 ops.DynamicReshape( 2571 ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size], 2572 [True]) 2573 self._CompareToPyAndBufferProtocol(c, [arg], [expected], 2574 np.testing.assert_equal) 2575 2576 # 2D reshape with an slice on the minor dimension. We test different types 2577 # where the strides may differ between the host and devices. The reshaped 2578 # physical memory layout is not consecutive, and we test if the program can 2579 # return the correct logical view of the data. 2580 @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented") 2581 @parameterized.named_parameters({ 2582 "testcase_name": "_{}".format(dtype.__name__), 2583 "dtype": dtype, 2584 } for dtype in int_dtypes + float_dtypes) 2585 def testReshape2D(self, dtype): 2586 arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) 2587 arg1 = np.array(2, dtype=np.int32) 2588 expected = np.array([[1, 2], [4, 5]], dtype=np.int32) 2589 c = self._NewComputation() 2590 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 2591 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 2592 ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True]) 2593 self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected], 2594 np.testing.assert_equal) 2595 2596 @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented") 2597 @parameterized.named_parameters({ 2598 "testcase_name": "_{}".format(dtype.__name__), 2599 "dtype": dtype, 2600 } for dtype in int_dtypes + float_dtypes) 2601 def testDynamicShapeArgs(self, dtype): 2602 full_size = 10 2603 dynamic_shape_size = 4 2604 # subcomputation 1 2605 binary_add_builder = self._NewComputation() 2606 scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype)) 2607 ops.Add( 2608 ops.Parameter(binary_add_builder, 0, scalar_shape), 2609 ops.Parameter(binary_add_builder, 1, scalar_shape)) 2610 # subcomputation 2 2611 reshape_reduce_builder = self._NewComputation() 2612 dshape = xla_client.Shape.array_shape( 2613 np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True]) 2614 reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape) 2615 ops.Reduce( 2616 reshape_reduce_builder, 2617 operands=[reshape_reduce_p], 2618 init_values=[ops.Constant(reshape_reduce_builder, dtype(0))], 2619 computation=binary_add_builder.build(), 2620 dimensions_to_reduce=[0]) 2621 # main computation: sum(range(full_size)[:dynamic_shape_size]) 2622 c = self._NewComputation() 2623 arg = np.array(dynamic_shape_size, dtype=np.int32) 2624 p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) 2625 reshaped = ops.DynamicReshape( 2626 ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p], 2627 [full_size], [True]) 2628 ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,)) 2629 self._ExecuteAndCompareClose(c, [arg], [dtype(6)]) 2630 2631 tests.append(DynamicReshapeTest) 2632 2633 class DeviceAssignmentTest(ComputationTest): 2634 2635 def testSerialize(self): 2636 shape = (3, 4) 2637 device_assignment = xla_client.DeviceAssignment.create( 2638 np.arange(np.prod(shape)).reshape(*shape)) 2639 self.assertEqual(device_assignment.replica_count(), shape[0]) 2640 self.assertEqual(device_assignment.computation_count(), shape[1]) 2641 serialized = device_assignment.serialize() 2642 self.assertIsInstance(serialized, bytes) 2643 self.assertNotEmpty(serialized) 2644 2645 tests.append(DeviceAssignmentTest) 2646 2647 class TokenTest(ComputationTest): 2648 """Tests related to PyToken.""" 2649 2650 def testExecuteWithToken(self): 2651 c = self._NewComputation() 2652 ops.Mul( 2653 ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), 2654 ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) 2655 compiled_c = self.backend.compile(c.build()) 2656 results, token = compiled_c.execute_with_token([]) 2657 token.block_until_ready() 2658 self.assertLen(results, 1) 2659 np.testing.assert_allclose( 2660 results[0].to_py(), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3) 2661 2662 def testExecuteShardedOnLocalDevicesWithTokens(self): 2663 c = self._NewComputation() 2664 ops.Mul( 2665 ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), 2666 ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) 2667 num_replicas = 1 2668 options = xla_client.CompileOptions() 2669 options.num_replicas = num_replicas 2670 compiled_c = self.backend.compile(c.build(), compile_options=options) 2671 results, sharded_token = compiled_c.execute_sharded_on_local_devices_with_tokens( 2672 []) 2673 sharded_token.block_until_ready() 2674 self.assertLen(results, 1) 2675 self.assertLen(results[0], 1) 2676 np.testing.assert_allclose( 2677 results[0][0].to_py(), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3) 2678 2679 tests.append(TokenTest) 2680 2681 class HostCallbackTest(ComputationTest): 2682 """Tests related to HostCallback.""" 2683 2684 @unittest.skipIf(not tfrt_tpu, "not implemented") 2685 def testHostCallback(self): 2686 2687 c = self._NewComputation() 2688 token = ops.CreateToken(c) 2689 2690 frontend_attributes = xla_client._xla.FrontendAttributes() 2691 frontend_attributes["_xla_host_transfer_rendezvous"] = "undef" 2692 frontend_attributes["_xla_host_transfer_original_type"] = "u32" 2693 frontend_attributes["_xla_host_transfer_is_lower_bits"] = "false" 2694 frontend_attributes["_xla_host_transfer_handler_name"] = "undef" 2695 c.set_frontend_attributes(frontend_attributes) 2696 2697 send_channel_handle = self.backend.create_channel_handle() 2698 send_channel_handle.type = ( 2699 xla_client._xla.ChannelHandle_ChannelType.DEVICE_TO_HOST) 2700 send_channel_handle.handle = 1 2701 ops.SendToHost( 2702 ops.Constant(c, np.float32(1.25)), 2703 token, 2704 shape_with_layout=xla_client.Shape.scalar_shape(np.dtype(np.float32)), 2705 handle=send_channel_handle) 2706 2707 recv_channel_handle = self.backend.create_channel_handle() 2708 recv_channel_handle.type = ( 2709 xla_client._xla.ChannelHandle_ChannelType.HOST_TO_DEVICE) 2710 recv_channel_handle.handle = 2 2711 data = ops.RecvFromHost( 2712 token, 2713 shape=xla_client.Shape.scalar_shape(np.dtype(np.float32)), 2714 handle=recv_channel_handle) 2715 ops.GetTupleElement(data, 0) 2716 2717 def Identity(x): 2718 return (x,) 2719 2720 host_callback = self.backend.make_python_callback_from_host_send_and_recv( 2721 Identity, 2722 operand_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.float32))], 2723 result_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.float32))], 2724 send_channel_ids=[1], 2725 recv_channel_ids=[2]) 2726 2727 compiled_c = self.backend.compile( 2728 c.build(), host_callbacks=[host_callback]) 2729 c.clear_frontend_attributes() 2730 2731 results = compiled_c.execute([]) 2732 self.assertLen(results, 1) 2733 2734 np.testing.assert_equal(results[0].to_py(), np.float32(1.25)) 2735 2736 tests.append(HostCallbackTest) 2737 2738 class HostCallbackMultiReplicaTest(ComputationTest): 2739 """Tests related to HostCallback for multi-replica execution.""" 2740 2741 @unittest.skipIf(not tfrt_tpu, "not implemented") 2742 def testHostCallbackMultiReplica(self): 2743 2744 c = self._NewComputation() 2745 token = ops.CreateToken(c) 2746 2747 frontend_attributes = xla_client._xla.FrontendAttributes() 2748 frontend_attributes["_xla_host_transfer_rendezvous"] = "undef" 2749 frontend_attributes["_xla_host_transfer_original_type"] = "u32" 2750 frontend_attributes["_xla_host_transfer_is_lower_bits"] = "false" 2751 frontend_attributes["_xla_host_transfer_handler_name"] = "undef" 2752 c.set_frontend_attributes(frontend_attributes) 2753 2754 send_channel_handle = self.backend.create_channel_handle() 2755 send_channel_handle.type = ( 2756 xla_client._xla.ChannelHandle_ChannelType.DEVICE_TO_HOST) 2757 send_channel_handle.handle = 1 2758 ops.SendToHost( 2759 ops.ReplicaId(c), 2760 token, 2761 shape_with_layout=xla_client.Shape.scalar_shape(np.dtype(np.uint32)), 2762 handle=send_channel_handle) 2763 2764 recv_channel_handle = self.backend.create_channel_handle() 2765 recv_channel_handle.type = ( 2766 xla_client._xla.ChannelHandle_ChannelType.HOST_TO_DEVICE) 2767 recv_channel_handle.handle = 2 2768 data = ops.RecvFromHost( 2769 token, 2770 shape=xla_client.Shape.scalar_shape(np.dtype(np.uint32)), 2771 handle=recv_channel_handle) 2772 ops.GetTupleElement(data, 0) 2773 2774 def Identity(x): 2775 return (x,) 2776 2777 host_callback = self.backend.make_python_callback_from_host_send_and_recv( 2778 Identity, 2779 operand_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.uint32))], 2780 result_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.uint32))], 2781 send_channel_ids=[1], 2782 recv_channel_ids=[2]) 2783 2784 num_replicas = 2 2785 options = xla_client.CompileOptions() 2786 options.num_replicas = num_replicas 2787 compiled_c = self.backend.compile( 2788 c.build(), compile_options=options, host_callbacks=[host_callback]) 2789 c.clear_frontend_attributes() 2790 2791 results = compiled_c.execute_sharded_on_local_devices([]) 2792 self.assertLen(results, 1) 2793 self.assertLen(results[0], num_replicas) 2794 2795 for i in range(num_replicas): 2796 np.testing.assert_equal(results[0][i].to_py(), np.uint32(i)) 2797 2798 tests.append(HostCallbackMultiReplicaTest) 2799 2800 class ExecutePortableTest(ComputationTest): 2801 2802 def testExecutePortable(self): 2803 devices_by_kind = collections.defaultdict(list) 2804 for device in self.backend.devices(): 2805 devices_by_kind[device.device_kind].append(device) 2806 multi_devices = [d for d in devices_by_kind.values() if len(d) > 1] 2807 if not multi_devices: 2808 raise unittest.SkipTest("Test needs multiple identical devices") 2809 devices = multi_devices[0] 2810 2811 c = self._NewComputation() 2812 args = [ 2813 np.array(3, dtype=np.int32), 2814 np.array([10, 15, -2, 7], dtype=np.int32) 2815 ] 2816 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(args[0])) 2817 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(args[1])) 2818 ops.Mul(p0, p1) 2819 options = xla_client.CompileOptions() 2820 options.compile_portable_executable = True 2821 compiled_c = self.backend.compile(c.build(), compile_options=options) 2822 for device in devices: 2823 out, = compiled_c.execute( 2824 [self.backend.buffer_from_pyval(a, device=device) for a in args], 2825 device=device) 2826 np.testing.assert_array_equal(out.to_py(), args[0] * args[1]) 2827 2828 tests.append(ExecutePortableTest) 2829 2830 return tests 2831 2832 2833def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): 2834 # Avoid creating a new backend per test (this causes GPU OOM, and is probably 2835 # inefficient). 2836 backend_fn = functools.lru_cache(maxsize=None)(backend_fn) 2837 for klass in TestFactory(backend_fn, **kw): 2838 test = type(test_prefix + klass.__name__, (klass,), {}) 2839 # Clean up the qualified names of the tests to not include the test factory. 2840 test.__qualname__ = test.__name__ 2841 globals_dict[test.__name__] = test 2842 2843 2844backends = { 2845 "cpu": xla_client.make_cpu_client, 2846 "gpu": xla_client.make_gpu_client, 2847} 2848 2849if __name__ == "__main__": 2850 flags.DEFINE_string("backend", "cpu", "Target platform.") 2851 # pylint: disable=unnecessary-lambda 2852 InstantiateTests(globals(), lambda: backends[FLAGS.backend]()) 2853 # pylint: enable=unnecessary-lambda 2854 absltest.main() 2855