xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/xla_client_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Backend-dependent tests for the Python XLA client."""
16
17import collections
18import functools
19import itertools
20import re
21import threading
22import unittest
23
24from absl import flags
25from absl import logging
26from absl.testing import absltest
27from absl.testing import parameterized
28import numpy as np
29
30from tensorflow.compiler.xla.python import xla_client
31
32# pylint: disable=g-import-not-at-top
33try:
34  # This import is only used for GPU; the dependency is incompatible with TPU
35  # so it results in an import error.
36  from tensorflow.python.framework import test_util
37except ImportError:
38  test_util = None
39
40# pylint: disable=g-import-not-at-top
41try:
42  from tensorflow.compiler.xla.python import custom_call_for_test
43except ImportError:
44  custom_call_for_test = None
45
46bfloat16 = xla_client.bfloat16
47ops = xla_client.ops
48
49FLAGS = flags.FLAGS
50
51# We choose to ignore pylint's complaints about complex comprehensions, which we
52# use widely for parameterizing tests.
53# pylint: disable=g-complex-comprehension
54
55
56def TestFactory(xla_backend,
57                cloud_tpu=False,
58                tfrt_tpu=False,
59                external_tpu=False):
60  tests = []
61
62  if not cloud_tpu:
63    int_dtypes = [np.int32, np.int64, np.uint32, np.uint64]
64    # TODO(phawkins): test np.float16, where supported.
65    float_dtypes = [bfloat16, np.float32, np.float64]
66    complex_dtypes = [np.complex64, np.complex128]
67    standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
68  else:
69    int_dtypes = [np.int32, np.uint32]
70    float_dtypes = [np.float32]
71    complex_dtypes = [np.complex64]
72    standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
73  dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes
74
75  class ComputationTest(parameterized.TestCase):
76    """Base class for running an XLA Computation through the local client."""
77
78    def setUp(self):
79      super(ComputationTest, self).setUp()
80      self.backend = xla_backend()
81
82    def _NewComputation(self, name=None):
83      if name is None:
84        name = self.id()
85      return xla_client.XlaBuilder(name)
86
87    def _Execute(self, c, arguments):
88      compiled_c = self.backend.compile(c.build())
89      return xla_client.execute_with_python_values(
90          compiled_c, arguments, backend=self.backend)
91
92    def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
93      assert expected is not None
94      results = self._Execute(c, arguments)
95      self.assertLen(results, len(expected))
96      for result, e in zip(results, expected):
97        # Numpy's comparison methods are a bit too lenient by treating inputs as
98        # "array-like", meaning that scalar 4 will be happily compared equal to
99        # [[4]]. We'd like to be more strict so assert shapes as well.
100        self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape)
101        assert_func(result, e)
102
103    def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
104      self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments,
105                                 expected)
106
107    def _ExecuteAndCompareClose(self,
108                                c,
109                                arguments=(),
110                                expected=None,
111                                rtol=1e-4,
112                                atol=0):
113      self._ExecuteAndAssertWith(
114          functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
115          c, arguments, expected)
116
117  def NumpyArrayF32(*args, **kwargs):
118    """Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
119    return np.array(*args, dtype=np.float32, **kwargs)
120
121  def NumpyArrayF64(*args, **kwargs):
122    """Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
123    return np.array(*args, dtype=np.float64, **kwargs)
124
125  def NumpyArrayS32(*args, **kwargs):
126    """Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
127    return np.array(*args, dtype=np.int32, **kwargs)
128
129  def NumpyArrayBool(*args, **kwargs):
130    """Convenience wrapper to create Numpy arrays with a np.bool_ dtype."""
131    return np.array(*args, dtype=np.bool_, **kwargs)
132
133  class ComputationPrinting(absltest.TestCase):
134
135    def setUp(self):
136      super(ComputationPrinting, self).setUp()
137      self.backend = xla_backend()
138
139    def ExampleComputation(self):
140      builder = xla_client.XlaBuilder("acomputation")
141      p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0)))
142      p1 = ops.Parameter(
143          builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
144      x = ops.Mul(p0, p1)
145      ops.Add(x, x)
146      return builder.build()
147
148    @unittest.skipIf(cloud_tpu, "not implemented")
149    def testCompiledHloModuleToHloText(self):
150      computation = self.ExampleComputation()
151      executable = self.backend.compile(computation)
152      hlo_modules = executable.hlo_modules()
153      self.assertLen(hlo_modules, 1)
154      hlo_text = hlo_modules[0].to_string()
155      self.assertTrue(hlo_text.startswith("HloModule acomputation"))
156      self.assertIn("fusion", hlo_text)
157
158    @unittest.skipIf(cloud_tpu, "not implemented")
159    def testCompiledHloModuleAsSerializedProto(self):
160      computation = self.ExampleComputation()
161      executable = self.backend.compile(computation)
162      hlo_modules = executable.hlo_modules()
163      self.assertLen(hlo_modules, 1)
164      hlo_text = hlo_modules[0].to_string()
165      proto = hlo_modules[0].as_serialized_hlo_module_proto()
166      hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module()
167      hlo_text_roundtrip = hlo_module_roundtrip.to_string()
168      self.assertEqual(hlo_text, hlo_text_roundtrip)
169
170    @unittest.skipIf(cloud_tpu, "not implemented")
171    def testStableComputationSerialization(self):
172      # Ideally we would test identical computations produced in different
173      # processes. For now we have this limited smoke test.
174      computation = self.ExampleComputation()
175      ref = computation.as_serialized_hlo_module_proto()
176      for _ in range(10):
177        self.assertEqual(computation.as_serialized_hlo_module_proto(), ref)
178
179    @unittest.skipIf(cloud_tpu, "not implemented")
180    def testFlopEstimate(self):
181      computation = self.ExampleComputation()
182      properties = xla_client._xla.hlo_module_cost_analysis(
183          self.backend, computation.as_hlo_module())
184      self.assertEqual(properties["flops"], 8.0)
185
186    def testFingerprint(self):
187      computation = self.ExampleComputation()
188      executable = self.backend.compile(computation)
189      fingerprint = executable.fingerprint
190      if self.backend.platform == "tpu" and not cloud_tpu:
191        logging.info("fingerprint: %s", fingerprint)
192        self.assertNotEmpty(fingerprint)
193      else:
194        self.assertIsNone(fingerprint)
195
196  tests.append(ComputationPrinting)
197
198  class ComputationsWithConstantsTest(ComputationTest):
199    """Tests focusing on Constant ops."""
200
201    @parameterized.named_parameters({
202        "testcase_name": "_{}".format(dtype.__name__),
203        "dtype": dtype,
204    } for dtype in int_dtypes + float_dtypes)
205    def testConstantScalarSum(self, dtype):
206      if dtype == np.int8 and self.backend.platform == "tpu":
207        self.skipTest("TPU doesn't support int8")
208      c = self._NewComputation()
209      ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14)))
210      self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)])
211
212    @parameterized.named_parameters({
213        "testcase_name": "_{}".format(dtype.__name__),
214        "dtype": dtype,
215    } for dtype in float_dtypes)
216    def testConstantVectorMul(self, dtype):
217      c = self._NewComputation()
218      ops.Mul(
219          ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)),
220          ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype)))
221      self._ExecuteAndCompareClose(
222          c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3)
223
224    @parameterized.named_parameters({
225        "testcase_name": "_{}".format(dtype.__name__),
226        "dtype": dtype,
227    } for dtype in float_dtypes)
228    def testConstantVectorScalarDiv(self, dtype):
229      c = self._NewComputation()
230      ops.Div(
231          ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)),
232          ops.Constant(c, dtype(2.0)))
233      self._ExecuteAndCompareClose(
234          c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3)
235
236    @parameterized.named_parameters({
237        "testcase_name": "_{}".format(dtype.__name__),
238        "dtype": dtype,
239    } for dtype in float_dtypes)
240    def testConstantVectorScalarPow(self, dtype):
241      c = self._NewComputation()
242      ops.Pow(
243          ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)),
244          ops.Constant(c, dtype(2.)))
245      self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]])
246
247    def testIota(self):
248      c = self._NewComputation()
249      ops.Iota(c, xla_client.PrimitiveType.F32, 10)
250      self._ExecuteAndCompareExact(
251          c, expected=[np.arange(10, dtype=np.float32)])
252
253    @parameterized.named_parameters({
254        "testcase_name": "_{}".format(dtype.__name__),
255        "dtype": dtype,
256    } for dtype in int_dtypes)
257    def testBroadcastedIota(self, dtype):
258      c = self._NewComputation()
259      shape = xla_client.Shape.array_shape(
260          xla_client.dtype_to_etype(dtype), (2, 3))
261      ops.Iota(c, shape, 1)
262      expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype)
263      self._ExecuteAndCompareExact(c, expected=[expected])
264
265    def testBooleanAnd(self):
266      c = self._NewComputation()
267      ops.And(
268          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
269          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
270      self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]])
271
272    def testBooleanOr(self):
273      c = self._NewComputation()
274      ops.Or(
275          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
276          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
277      self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]])
278
279    def testBooleanXor(self):
280      c = self._NewComputation()
281      ops.Xor(
282          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
283          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
284      self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]])
285
286    @parameterized.named_parameters({
287        "testcase_name": "_{}".format(dtype.__name__),
288        "dtype": dtype,
289    } for dtype in float_dtypes)
290    def testSum2D(self, dtype):
291      c = self._NewComputation()
292      ops.Add(
293          ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)),
294          ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype)))
295      self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]])
296
297    def testShiftLeft(self):
298      c = self._NewComputation()
299      ops.ShiftLeft(
300          ops.Constant(c, NumpyArrayS32([3])),
301          ops.Constant(c, NumpyArrayS32([2])))
302      self._ExecuteAndCompareClose(c, expected=[[12]])
303
304    def testShiftRightArithmetic(self):
305      c = self._NewComputation()
306      ops.ShiftRightArithmetic(
307          ops.Constant(c, NumpyArrayS32([-2])),
308          ops.Constant(c, NumpyArrayS32([1])))
309      self._ExecuteAndCompareClose(c, expected=[[-1]])
310
311    def testShiftRightLogical(self):
312      c = self._NewComputation()
313      ops.ShiftRightLogical(
314          ops.Constant(c, NumpyArrayS32([-1])),
315          ops.Constant(c, NumpyArrayS32([1])))
316      self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]])
317
318    @parameterized.named_parameters({
319        "testcase_name": "_{}".format(dtype.__name__),
320        "dtype": dtype,
321    } for dtype in float_dtypes)
322    def testSum2DWith1DBroadcastDim0(self, dtype):
323      # sum of a 2D array with a 1D array where the latter is replicated across
324      # dimension 0 to match the former's shape.
325      c = self._NewComputation()
326      ops.Add(
327          ops.Constant(c,
328                       np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
329                                dtype=dtype)),
330          ops.Constant(c, np.array([10, 20, 30], dtype=dtype)),
331          broadcast_dimensions=(0,))
332      self._ExecuteAndCompareClose(
333          c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]])
334
335    @parameterized.named_parameters({
336        "testcase_name": "_{}".format(dtype.__name__),
337        "dtype": dtype,
338    } for dtype in float_dtypes)
339    def testSum2DWith1DBroadcastDim1(self, dtype):
340      # sum of a 2D array with a 1D array where the latter is replicated across
341      # dimension 1 to match the former's shape.
342      c = self._NewComputation()
343      ops.Add(
344          ops.Constant(c,
345                       np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
346                                dtype=dtype)),
347          ops.Constant(c, np.array([10, 20, 30], dtype=dtype)),
348          broadcast_dimensions=(1,))
349      self._ExecuteAndCompareClose(
350          c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]])
351
352    @parameterized.named_parameters({
353        "testcase_name": "_{}".format(dtype.__name__),
354        "dtype": dtype,
355    } for dtype in float_dtypes)
356    def testConstantAxpy(self, dtype):
357      c = self._NewComputation()
358      ops.Add(
359          ops.Mul(
360              ops.Constant(c, dtype(2)),
361              ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))),
362          ops.Constant(c, np.array([100, -100, 200, -200], dtype)))
363      self._ExecuteAndCompareClose(
364          c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3)
365
366    def testCustomCall(self):
367      if self.backend.platform != "cpu":
368        self.skipTest("Test requires cpu platform")
369      c = self._NewComputation()
370      for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
371        xla_client.register_custom_call_target(name, fn, platform="cpu")
372      ops.CustomCallWithLayout(
373          c,
374          b"test_subtract_f32",
375          operands=[
376              ops.Constant(c, np.float32(1.25)),
377              ops.Constant(c, np.float32(0.5))
378          ],
379          shape_with_layout=xla_client.Shape.array_shape(
380              np.dtype(np.float32), (), ()),
381          operand_shapes_with_layout=[
382              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
383              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
384          ],
385          api_version=xla_client.ops.CustomCallApiVersion
386          .API_VERSION_STATUS_RETURNING)
387      self._ExecuteAndCompareClose(c, expected=[0.75])
388
389    def testCustomCallWithUnifiedApi(self):
390      if self.backend.platform != "cpu":
391        self.skipTest("Test requires cpu platform")
392      c = self._NewComputation()
393      for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
394        xla_client.register_custom_call_target(name, fn, platform="cpu")
395
396      opaque_str = b"foo"
397      ops.CustomCallWithLayout(
398          c,
399          b"test_add_input_and_opaque_len",
400          operands=[
401              ops.Constant(c, np.float32(1.25)),
402              ops.Constant(c, np.float32(0.5))
403          ],
404          shape_with_layout=xla_client.Shape.array_shape(
405              np.dtype(np.float32), (), ()),
406          operand_shapes_with_layout=[
407              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
408              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
409          ],
410          # With opaque length = 3.0
411          opaque=opaque_str,
412          api_version=xla_client.ops.CustomCallApiVersion
413          .API_VERSION_STATUS_RETURNING_UNIFIED)
414      self._ExecuteAndCompareClose(c, expected=[1.25 + len(opaque_str)])
415
416  tests.append(ComputationsWithConstantsTest)
417
418  class PythonCallbackTest(ComputationTest):
419
420    def testPythonCallback(self):
421      if self.backend.platform not in {"cpu", "gpu"}:
422        self.skipTest("Test requires cpu or gpu platform")
423      c = self._NewComputation()
424
425      f = lambda x, y: (x + y, x - y)
426
427      arg0 = np.array([9, 43, -101, 22], dtype=np.int32)
428      arg1 = np.array([10, 15, -2, 7], dtype=np.int32)
429      shape = xla_client.shape_from_pyval(arg0)
430      shape = shape.with_major_to_minor_layout_if_absent()
431      p0 = ops.Parameter(c, 0, shape)
432      p1 = ops.Parameter(c, 1, shape)
433      out, keepalive = self.backend.emit_python_callback(
434          f, c, [p0, p1], [shape, shape])
435      self._ExecuteAndCompareExact(
436          c, arguments=[arg0, arg1], expected=[arg0 + arg1, arg0 - arg1])
437      del out, keepalive
438
439    def testPythonCallbackCanHandleExceptions(self):
440      if self.backend.platform not in {"cpu", "gpu"}:
441        self.skipTest("Test requires cpu or gpu platform")
442      c = self._NewComputation()
443
444      def _Callback(x):
445        raise ValueError("Value error raised!")
446
447      arg0 = np.array([9, 43, -101, 22], dtype=np.int32)
448      shape = xla_client.shape_from_pyval(arg0)
449      shape = shape.with_major_to_minor_layout_if_absent()
450      p0 = ops.Parameter(c, 0, shape)
451      out, keepalive = self.backend.emit_python_callback(
452          _Callback, c, [p0], [shape], has_side_effects=True)
453      with self.assertRaisesRegex(xla_client.XlaRuntimeError,
454                                  "Value error raised!"):
455        self._Execute(c, [arg0])
456      del out, keepalive
457
458    def testTokens(self):
459      if self.backend.platform not in {"cpu", "gpu"}:
460        self.skipTest("Test requires cpu or gpu platform")
461      c = self._NewComputation()
462
463      def _Callback(x, y):
464        assert y is None, y
465        return None, x + 1
466
467      arg0 = np.array([9, 43, -101, 22], dtype=np.int32)
468      shape = xla_client.shape_from_pyval(arg0)
469      token_shape = xla_client.Shape.token_shape()
470      p0 = ops.Parameter(c, 0, shape)
471      token = ops.CreateToken(c)
472      out, keepalive = self.backend.emit_python_callback(
473          _Callback, c, [p0, token], [token_shape, shape])
474      out = ops.GetTupleElement(out, 1)
475      self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 + 1])
476      del out, keepalive
477
478    def testStriding(self):
479      if self.backend.platform not in {"cpu", "gpu"}:
480        self.skipTest("Test requires cpu or gpu platform")
481      c = self._NewComputation()
482
483      def _Callback(x):
484        assert x.flags.f_contiguous, x.strides
485        # Force the output array to have C layout, which will require a
486        # transpose back to the expected Fortran layout.
487        return np.ascontiguousarray(x * 2),
488
489      arg0 = np.arange(12, dtype=np.int16).reshape(3, 4)
490      shape_f_layout = xla_client.Shape.array_shape(
491          arg0.dtype, arg0.shape, layout=(0, 1))
492      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
493      out, keepalive = self.backend.emit_python_callback(
494          _Callback, c, [p0], [shape_f_layout], [shape_f_layout])
495      self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 * 2])
496      del out, keepalive
497
498  tests.append(PythonCallbackTest)
499
500  class ComputationFromProtoTest(absltest.TestCase):
501    """Test computation execution from HLO proto."""
502
503    def setUp(self):
504      super(ComputationFromProtoTest, self).setUp()
505      self.backend = xla_backend()
506
507    def testExecuteFromProto(self):
508      # Build the HLO proto
509      b = xla_client.XlaBuilder("computation")
510      ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
511      serialized_proto = b.build().as_serialized_hlo_module_proto()
512
513      # Load and execute the proto
514      c = xla_client.XlaComputation(serialized_proto)
515      ans, = xla_client.execute_with_python_values(
516          self.backend.compile(c), (), backend=self.backend)
517      np.testing.assert_equal(ans, np.int32(3))
518
519  tests.append(ComputationFromProtoTest)
520
521  class ParametersTest(ComputationTest):
522    """Tests focusing on Parameter ops and argument-passing."""
523
524    @parameterized.named_parameters({
525        "testcase_name": "_{}".format(dtype.__name__),
526        "dtype": dtype,
527    } for dtype in int_dtypes)
528    def testScalarTimesVector(self, dtype):
529      c = self._NewComputation()
530      arg0 = np.array(3, dtype=dtype)
531      arg1 = np.array([10, 15, -2, 7], dtype=dtype)
532      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
533      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
534      ops.Mul(p0, p1)
535      self._ExecuteAndCompareExact(
536          c, arguments=[arg0, arg1], expected=[arg0 * arg1])
537
538    # TODO(phawkins): test comparison harness doesn't support bfloat16
539    @parameterized.named_parameters({
540        "testcase_name": "_{}".format(dtype.__name__),
541        "dtype": dtype,
542    } for dtype in float_dtypes if dtype != bfloat16)
543    def testScalarMinusVectorExplicitNumbering(self, dtype):
544      # Use explicit numbering and pass parameter_num first. Sub is used since
545      # it's not commutative and can help catch parameter reversal within the
546      # computation.
547      c = self._NewComputation()
548      arg0 = np.array(2.0, dtype=dtype)
549      arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype)
550      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
551      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
552      ops.Sub(p1, p0)
553      self._ExecuteAndCompareClose(
554          c, arguments=[arg0, arg1], expected=[arg1 - arg0])
555
556  tests.append(ParametersTest)
557
558  class BufferTest(ComputationTest):
559    """Tests focusing on execution with Buffers."""
560
561    def testConstantSum(self):
562      c = self._NewComputation()
563      ops.Add(
564          ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14)))
565      self._ExecuteAndCompareClose(c, expected=[4.25])
566
567    def testOneParameterSum(self):
568      c = self._NewComputation()
569      ops.Add(
570          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
571          ops.Constant(c, np.float32(3.14)))
572      self._ExecuteAndCompareClose(
573          c, arguments=[NumpyArrayF32(1.11)], expected=[4.25])
574
575    def testTwoParameterSum(self):
576      c = self._NewComputation()
577      ops.Add(
578          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
579          ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.))))
580      self._ExecuteAndCompareClose(
581          c,
582          arguments=[NumpyArrayF32(1.11),
583                     NumpyArrayF32(3.14)],
584          expected=[4.25])
585
586    @unittest.skipIf(cloud_tpu, "not implemented")
587    def testCannotCallWithDeletedBuffers(self):
588      c = self._NewComputation()
589      ops.Add(
590          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
591          ops.Constant(c, np.float32(3.14)))
592      arg = NumpyArrayF32(1.11)
593      compiled_c = self.backend.compile(c.build())
594      arg_buffer = self.backend.buffer_from_pyval(arg)
595      arg_buffer.delete()
596      with self.assertRaises(xla_client.XlaRuntimeError):
597        compiled_c.execute([arg_buffer])
598
599    def testXlaShape(self):
600      pyval = np.array([[1., 2.]], np.float32)
601      local_buffer = self.backend.buffer_from_pyval(pyval)
602      xla_shape = local_buffer.xla_shape()
603      self.assertEqual(xla_shape.dimensions(), (1, 2))
604      self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
605
606    def testXlaShapeIndex(self):
607      a = xla_client.ShapeIndex((1, 2))
608      b = xla_client.ShapeIndex((1, 2))
609      c = xla_client.ShapeIndex((2, 3))
610      self.assertEqual(a, b)
611      self.assertNotEqual(b, c)
612
613    def testLayout(self):
614      f32 = xla_client.PrimitiveType.F32
615      a = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout()
616      b = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout()
617      c = xla_client.Shape.array_shape(f32, (2, 3), (1, 0)).layout()
618      self.assertEqual(a.minor_to_major(), (0, 1))
619      self.assertEqual(b.minor_to_major(), (0, 1))
620      self.assertEqual(c.minor_to_major(), (1, 0))
621      self.assertEqual(a, b)
622      self.assertNotEqual(a, c)
623      self.assertNotEqual(b, c)
624      self.assertEqual(hash(a), hash(b))
625      self.assertNotEqual(hash(a), hash(c))
626      self.assertNotEqual(hash(b), hash(c))
627
628    def testBlockUntilReadyWorks(self):
629      arg = np.array([[1., 2.]], np.float32)
630      arg_buffer = self.backend.buffer_from_pyval(arg)
631      arg_buffer.block_until_ready()
632      # This test merely checks that nothing goes awry when we call
633      # block_until_ready(); it's difficult to test anything else.
634
635    def testBlockUntilReadyRaisesOnDeletedBuffer(self):
636      arg = np.array([[1., 2.]], np.float32)
637      buffer = self.backend.buffer_from_pyval(arg)
638      buffer.delete()
639      with self.assertRaisesRegex(
640          RuntimeError,
641          re.escape(
642              "BlockHostUntilReady() called on deleted or donated buffer")):
643        buffer.block_until_ready()
644
645    def testDeviceArrayBaseSignatures(self):
646      # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray`
647      # and thus needs to correctly implement the following methods.
648      arg = np.array([[1., 2., 3.]], np.float32)
649      buffer = self.backend.buffer_from_pyval(arg)
650      if not isinstance(buffer, xla_client.DeviceArrayBase):
651        raise unittest.SkipTest(
652            "The objectof type {} do not extend DeviceArrayBase".format(
653                type(buffer)))
654
655      self.assertEqual(buffer.__array_priority__, 100)
656      self.assertEqual(buffer.shape, (1, 3))
657      self.assertEqual(buffer.dtype, np.float32)
658      self.assertEqual(buffer.size, 3)
659      self.assertEqual(buffer.ndim, 2)
660
661      self.assertIs(buffer, buffer.block_until_ready())
662      self.assertTrue(buffer.is_ready())
663      buffer.delete()
664      with self.assertRaises(xla_client.XlaRuntimeError):
665        buffer.block_until_ready()
666      with self.assertRaises(xla_client.XlaRuntimeError):
667        buffer.is_ready()
668
669    def testOnDeviceSizeInBytes(self):
670      if not isinstance(self.backend, xla_client.Client):
671        self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.")
672      arg0 = np.array([])
673      arg1 = np.array([[0., 1., 2.]], np.float32)
674      arg2 = np.array([[3., 4., 5.]], bfloat16)
675      arg0_buffer = self.backend.buffer_from_pyval(arg0)
676      arg1_buffer = self.backend.buffer_from_pyval(arg1)
677      arg2_buffer = self.backend.buffer_from_pyval(arg2)
678      self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0)
679      # OnDeviceSizeInBytes varies depending on the platform. Confirm there's
680      # a reasonable value.
681      self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0)
682      self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0)
683
684    def testLiveBuffers(self):
685      if not isinstance(self.backend, xla_client.Client):
686        self.skipTest("TPU Driver doesn't support LiveBuffers().")
687      self.assertEmpty(self.backend.live_buffers())
688      arg0 = np.array([])
689      arg1 = np.array([[0., 1., 2.]], np.float32)
690      arg2 = np.array([[3., 4., 5.]], bfloat16)
691      arg0_buffer = self.backend.buffer_from_pyval(arg0)
692      arg1_buffer = self.backend.buffer_from_pyval(arg1)
693      arg2_buffer = self.backend.buffer_from_pyval(arg2)
694      self.assertLen(self.backend.live_buffers(), 3)
695      self.assertIs(self.backend.live_buffers()[0], arg2_buffer)
696      self.assertIs(self.backend.live_buffers()[1], arg1_buffer)
697      self.assertIs(self.backend.live_buffers()[2], arg0_buffer)
698      self.assertEqual(self.backend.devices()[0].live_buffers(),
699                       self.backend.live_buffers())
700
701      arg1_buffer.delete()
702      self.assertLen(self.backend.live_buffers(), 2)
703      self.assertIs(self.backend.live_buffers()[0], arg2_buffer)
704      self.assertIs(self.backend.live_buffers()[1], arg0_buffer)
705
706      arg0_buffer.delete()
707      arg2_buffer.delete()
708      self.assertEmpty(self.backend.live_buffers())
709
710    def testCopyToHost(self):
711      arg0 = np.array([[1., 2.]], np.float32)
712      arg1 = np.array([[3., 4.]], np.float32)
713      arg0_buffer = self.backend.buffer_from_pyval(arg0)
714      arg1_buffer = self.backend.buffer_from_pyval(arg1)
715      # Prefetch two buffers using copy_to_host_async, and then retrieve their
716      # values using to_py.
717      arg0_buffer.copy_to_host_async()
718      arg0_buffer.copy_to_host_async()  # Duplicate calls don't do anything.
719      arg1_buffer.copy_to_host_async()
720      np.testing.assert_equal(arg0, arg0_buffer.to_py())
721      np.testing.assert_equal(arg1, arg1_buffer.to_py())
722      # copy_to_host_async does nothing after to_py is called.
723      arg0_buffer.copy_to_host_async()
724      np.testing.assert_equal(arg0, arg0_buffer.to_py())
725
726    def testDevice(self):
727      x = np.arange(8, dtype=np.int32)
728      for device in self.backend.local_devices():
729        buf = self.backend.buffer_from_pyval(x, device=device)
730        self.assertEqual(buf.device(), device)
731        np.testing.assert_equal(x, buf.to_py())
732
733    def testStandardTypes(self):
734      for dtype in standard_dtypes:
735        if dtype == bfloat16 or dtype == np.complex128:
736          continue
737        arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype))
738        arr = arr.to_py()
739        self.assertEqual(dtype, type(arr[0]))
740
741    def testUnsafeBufferPointer(self):
742      if not isinstance(self.backend, xla_client.Client):
743        self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().")
744      arg0 = np.array([])
745      arg1 = np.array([[0., 1., 2.]], np.float32)
746      arg2 = np.array([[3., 4., 5.]], bfloat16)
747      arg0_buffer = self.backend.buffer_from_pyval(arg0)
748      arg1_buffer = self.backend.buffer_from_pyval(arg1)
749      arg2_buffer = self.backend.buffer_from_pyval(arg2)
750      self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0)
751      self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0)
752      self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0)
753
754    @unittest.skipIf(cloud_tpu, "not implemented")
755    def testClone(self):
756      x = np.array([[3., 4., 5.]], np.float32)
757      y = self.backend.buffer_from_pyval(x)
758      z = y.clone()
759      self.assertNotEqual(id(x), id(y))
760      np.testing.assert_array_equal(y.to_py(), z.to_py())
761      self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer())
762
763    @unittest.skipIf(cloud_tpu, "not implemented")
764    def testJaxAttributesHaveCorrectDefaults(self):
765      x = np.array([[3., 4., 5.]], np.float32)
766      y = self.backend.buffer_from_pyval(x)
767      self.assertIsNone(y.aval)
768      self.assertIsNone(y._device)
769
770  tests.append(BufferTest)
771
772  class SingleOpTest(ComputationTest):
773    """Tests for single ops.
774
775    The goal here is smoke testing - to exercise the most basic functionality of
776    single XLA ops. As minimal as possible number of additional ops are added
777    around the op being tested.
778    """
779
780    @parameterized.named_parameters({
781        "testcase_name": "_{}".format(dtype.__name__),
782        "dtype": dtype,
783    } for dtype in float_dtypes)
784    def testConcatenate(self, dtype):
785      c = self._NewComputation()
786      args = (
787          ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)),
788          ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)),
789      )
790      ops.ConcatInDim(c, args, dimension=0)
791      self._ExecuteAndCompareExact(
792          c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)])
793
794    # pyformat: disable
795    @parameterized.named_parameters({
796        "testcase_name": "_{}_{}".format(src_dtype.__name__,
797                                         dst_dtype.__name__),
798        "src_dtype": src_dtype,
799        "dst_dtype": dst_dtype,
800    } for src_dtype, dst_dtype in itertools.permutations(
801        [np.bool_, np.int32, np.int64, np.float32, np.float64], 2))
802    # pyformat: enable
803    def testConvertElementType(self, src_dtype, dst_dtype):
804      if ((src_dtype in [np.int64, np.float64] or
805           dst_dtype in [np.int64, np.float64]) and
806          self.backend.platform == "tpu"):
807        self.skipTest("TPU doesn't support float64")
808      c = self._NewComputation()
809      x = np.array([0, 1, 0, 0, 1], dtype=src_dtype)
810      ops.ConvertElementType(
811          ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
812
813      result = xla_client.execute_with_python_values(
814          self.backend.compile(c.build()), (), backend=self.backend)
815      self.assertLen(result, 1)
816      expected = np.array(x, dtype=dst_dtype)
817
818      self.assertEqual(result[0].shape, expected.shape)
819      self.assertEqual(result[0].dtype, expected.dtype)
820      np.testing.assert_equal(result[0], expected)
821
822    # pyformat: disable
823    @parameterized.named_parameters(
824        {
825            "testcase_name": "_{}_{}".format(src_dtype.__name__,
826                                             dst_dtype.__name__),
827            "src_dtype": src_dtype,
828            "dst_dtype": dst_dtype,
829        }
830        for dtypes in [[np.int32, np.float32], [np.int64, np.float64]]
831        for src_dtype, dst_dtype in itertools.permutations(dtypes, 2))
832    # pyformat: enable
833    def testBitcastConvertType(self, src_dtype, dst_dtype):
834      if (np.float64 in (src_dtype, dst_dtype) and
835          self.backend.platform == "tpu"):
836        self.skipTest("TPU doesn't support float64")
837      c = self._NewComputation()
838      x = np.array([0, 1, 0, 0, 1], dtype=src_dtype)
839      ops.BitcastConvertType(
840          ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
841
842      result = xla_client.execute_with_python_values(
843          self.backend.compile(c.build()), (), backend=self.backend)
844      self.assertLen(result, 1)
845      expected = x.view(dst_dtype)
846
847      self.assertEqual(result[0].shape, expected.shape)
848      self.assertEqual(result[0].dtype, expected.dtype)
849      np.testing.assert_equal(result[0], expected)
850
851    # TODO(b/123523486) implement AllToAll on CPU
852    def DISABLED_testAllToAllOneReplica(self):
853      samples = [
854          NumpyArrayF32([97.0]),
855          NumpyArrayF32([64.0, 117.0]),
856          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
857      ]
858      for lhs in samples[:1]:
859        c = self._NewComputation()
860        ops.AllToAll(ops.Constant(c, lhs), 0, 0)
861        self._ExecuteAndCompareExact(c, expected=[lhs])
862
863    def testCrossReplicaSumOneReplica(self):
864      samples = [
865          NumpyArrayF32(42.0),
866          NumpyArrayF32([97.0]),
867          NumpyArrayF32([64.0, 117.0]),
868          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
869      ]
870      for lhs in samples:
871        c = self._NewComputation()
872        ops.CrossReplicaSum(ops.Constant(c, lhs))
873        self._ExecuteAndCompareExact(c, expected=[lhs])
874
875    def testReplicaId(self):
876      c = self._NewComputation()
877      _ = ops.ReplicaId(c)
878      self._ExecuteAndCompareExact(c, expected=[0])
879
880    def testCrossReplicaSumOneReplicaWithSingletonGroup(self):
881      samples = [
882          NumpyArrayF32(42.0),
883          NumpyArrayF32([97.0]),
884          NumpyArrayF32([64.0, 117.0]),
885          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
886      ]
887      for lhs in samples:
888        c = self._NewComputation()
889        ops.CrossReplicaSum(
890            ops.Constant(c, lhs), xla_client.make_replica_groups([[0]]))
891        self._ExecuteAndCompareExact(c, expected=[lhs])
892
893    # TODO(phawkins): np.dot implementation doesn't support bfloat16
894    @parameterized.named_parameters({
895        "testcase_name": "_{}".format(dtype.__name__),
896        "dtype": dtype,
897    } for dtype in float_dtypes if dtype != bfloat16)
898    def testDotMatrixVector(self, dtype):
899      c = self._NewComputation()
900      lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype)
901      rhs = np.array([[10.0], [20.0]], dtype=dtype)
902      ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs))
903      self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)])
904
905    # TODO(phawkins): np.dot implementation doesn't support bfloat16
906    @parameterized.named_parameters({
907        "testcase_name": "_{}".format(dtype.__name__),
908        "dtype": dtype,
909    } for dtype in float_dtypes if dtype != bfloat16)
910    def testDotMatrixMatrix(self, dtype):
911      c = self._NewComputation()
912      lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype)
913      rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype)
914      ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs))
915      self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)])
916
917    def testDotGeneral(self):
918      c = self._NewComputation()
919      rng = np.random.RandomState(0)
920      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
921      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
922      dimension_numbers = xla_client.make_dot_dimension_numbers(
923          (([2], [1]), ([0], [0])))
924      ops.DotGeneral(
925          ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers)
926      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
927
928    def testDotGeneralWithDotDimensionNumbersProto(self):
929      c = self._NewComputation()
930      rng = np.random.RandomState(0)
931      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
932      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
933
934      dimension_numbers = xla_client.DotDimensionNumbers()
935      dimension_numbers.lhs_contracting_dimensions.append(2)
936      dimension_numbers.rhs_contracting_dimensions.append(1)
937      dimension_numbers.lhs_batch_dimensions.append(0)
938      dimension_numbers.rhs_batch_dimensions.append(0)
939
940      ops.DotGeneral(
941          ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers)
942      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
943
944    def testDotGeneralWithPrecisionConfig(self):
945      c = self._NewComputation()
946      rng = np.random.RandomState(0)
947      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
948      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
949      dimension_numbers = xla_client.make_dot_dimension_numbers(
950          (([2], [1]), ([0], [0])))
951      config = xla_client.PrecisionConfig()
952      config.operand_precision.append(config.Precision.HIGH)
953      config.operand_precision.append(config.Precision.HIGHEST)
954      ops.DotGeneral(
955          ops.Constant(c, lhs),
956          ops.Constant(c, rhs),
957          dimension_numbers,
958          precision_config=config)
959      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
960
961    def testConvGeneralDilatedF32(self):
962      c = self._NewComputation()
963      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
964      lhs = a(1, 1, 2, 3)
965      rhs = a(1, 1, 1, 2) * 10
966      strides = [1, 1]
967      pads = [(1, 0), (0, 1)]
968      lhs_dilation = (2, 1)
969      rhs_dilation = (1, 1)
970      dimension_numbers = xla_client.make_convolution_dimension_numbers(
971          ("NCHW", "OIHW", "NCHW"), 2)
972      ops.ConvGeneralDilated(
973          ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads,
974          lhs_dilation, rhs_dilation, dimension_numbers)
975      result = np.array([[[
976          [0., 0., 0.],
977          [10., 20., 0.],
978          [0., 0., 0.],
979          [40., 50., 0.],
980      ]]])
981      self._ExecuteAndCompareClose(c, expected=[result])
982
983    def testConvGeneralDilatedF32WithPrecisionConfig(self):
984      c = self._NewComputation()
985      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
986      lhs = a(1, 1, 2, 3)
987      rhs = a(1, 1, 1, 2) * 10
988      strides = [1, 1]
989      pads = [(1, 0), (0, 1)]
990      lhs_dilation = (2, 1)
991      rhs_dilation = (1, 1)
992      dimension_numbers = xla_client.make_convolution_dimension_numbers(
993          ("NCHW", "OIHW", "NCHW"), 2)
994      config = xla_client.PrecisionConfig()
995      config.operand_precision.append(config.Precision.HIGHEST)
996      config.operand_precision.append(config.Precision.DEFAULT)
997      ops.ConvGeneralDilated(
998          ops.Constant(c, lhs),
999          ops.Constant(c, rhs),
1000          strides,
1001          pads,
1002          lhs_dilation,
1003          rhs_dilation,
1004          dimension_numbers,
1005          precision_config=config)
1006      result = np.array([[[
1007          [0., 0., 0.],
1008          [10., 20., 0.],
1009          [0., 0., 0.],
1010          [40., 50., 0.],
1011      ]]])
1012      self._ExecuteAndCompareClose(c, expected=[result])
1013
1014    def testConvGeneralDilatedPermutedF32(self):
1015      c = self._NewComputation()
1016      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
1017      lhs = a(1, 1, 2, 3)
1018      rhs = a(1, 1, 1, 2) * 10
1019      strides = [1, 1]
1020      pads = [(1, 0), (0, 1)]
1021      lhs_dilation = (2, 1)
1022      rhs_dilation = (1, 1)
1023
1024      dimension_numbers = xla_client.make_convolution_dimension_numbers(
1025          ("NHWC", "OIHW", "CWNH"), 2)
1026      ops.ConvGeneralDilated(
1027          ops.Constant(c, np.transpose(lhs,
1028                                       (0, 2, 3, 1))), ops.Constant(c, rhs),
1029          strides, pads, lhs_dilation, rhs_dilation, dimension_numbers)
1030      result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.],
1031                           [40., 50., 0.]]]])
1032      self._ExecuteAndCompareClose(
1033          c, expected=[np.transpose(result, (1, 3, 0, 2))])
1034
1035    def testConvGeneralDilatedGroupedConvolutionF32(self):
1036      c = self._NewComputation()
1037      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
1038      lhs = a(1, 2, 2, 3)
1039      rhs = a(2, 1, 1, 2) * 10
1040      strides = [1, 1]
1041      pads = [(1, 0), (0, 1)]
1042      lhs_dilation = (2, 1)
1043      rhs_dilation = (1, 1)
1044      dimension_numbers = xla_client.make_convolution_dimension_numbers(
1045          ("NCHW", "OIHW", "NCHW"), 2)
1046      feature_group_count = 2
1047      ops.ConvGeneralDilated(
1048          ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads,
1049          lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count)
1050      result = np.array([[[
1051          [0., 0., 0.],
1052          [10., 20., 0.],
1053          [0., 0., 0.],
1054          [40., 50., 0.],
1055      ], [
1056          [0., 0., 0.],
1057          [330., 380., 160.],
1058          [0., 0., 0.],
1059          [480., 530., 220.],
1060      ]]])
1061      self._ExecuteAndCompareClose(c, expected=[result])
1062
1063    def testConvGeneralDilatedWindowReversalF32(self):
1064      c = self._NewComputation()
1065      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
1066      lhs = a(1, 1, 2, 3)
1067      rhs = a(1, 1, 1, 2) * 10
1068      strides = [1, 1]
1069      pads = [(1, 0), (0, 1)]
1070      lhs_dilation = (2, 1)
1071      rhs_dilation = (1, 1)
1072      window_reversal = [False, True]
1073      dimension_numbers = xla_client.make_convolution_dimension_numbers(
1074          ("NCHW", "OIHW", "NCHW"), 2)
1075      ops.ConvGeneralDilated(
1076          ops.Constant(c, lhs),
1077          ops.Constant(c, rhs),
1078          strides,
1079          pads,
1080          lhs_dilation,
1081          rhs_dilation,
1082          dimension_numbers,
1083          window_reversal=window_reversal)
1084      result = np.array([[[
1085          [0., 0., 0.],
1086          [0., 10., 20.],
1087          [0., 0., 0.],
1088          [30., 40., 50.],
1089      ]]])
1090      self._ExecuteAndCompareClose(c, expected=[result])
1091
1092    def testBooleanNot(self):
1093      c = self._NewComputation()
1094      arr = NumpyArrayBool([True, False, True])
1095      ops.Not(ops.Constant(c, arr))
1096      self._ExecuteAndCompareClose(c, expected=[~arr])
1097
1098    def testPopulationCount(self):
1099      c = self._NewComputation()
1100      arr = NumpyArrayS32([3, 0, 1])
1101      ops.PopulationCount(ops.Constant(c, arr))
1102      self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])])
1103
1104    def testCountLeadingZeros(self):
1105      c = self._NewComputation()
1106      arr = NumpyArrayS32([0x7FFF, 0x12345678])
1107      ops.Clz(ops.Constant(c, arr))
1108      self._ExecuteAndCompareClose(c, expected=[[17, 3]])
1109
1110    def testExp(self):
1111      c = self._NewComputation()
1112      arr = NumpyArrayF32([3.3, 12.1])
1113      ops.Exp(ops.Constant(c, arr))
1114      self._ExecuteAndCompareClose(c, expected=[np.exp(arr)])
1115
1116    def testExpm1(self):
1117      c = self._NewComputation()
1118      arr = NumpyArrayF32([3.3, 12.1])
1119      ops.Expm1(ops.Constant(c, arr))
1120      self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)])
1121
1122    def testRound(self):
1123      c = self._NewComputation()
1124      arr = NumpyArrayF32([3.3, 12.1])
1125      ops.Round(ops.Constant(c, arr))
1126      self._ExecuteAndCompareClose(c, expected=[np.round(arr)])
1127
1128    def testLog(self):
1129      c = self._NewComputation()
1130      arr = NumpyArrayF32([3.3, 12.1])
1131      ops.Log(ops.Constant(c, arr))
1132      self._ExecuteAndCompareClose(c, expected=[np.log(arr)])
1133
1134    def testLog1p(self):
1135      c = self._NewComputation()
1136      arr = NumpyArrayF32([3.3, 12.1])
1137      ops.Log1p(ops.Constant(c, arr))
1138      self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)])
1139
1140    def testNeg(self):
1141      c = self._NewComputation()
1142      arr = NumpyArrayF32([3.3, 12.1])
1143      ops.Neg(ops.Constant(c, arr))
1144      self._ExecuteAndCompareClose(c, expected=[-arr])
1145
1146    def testFloor(self):
1147      c = self._NewComputation()
1148      arr = NumpyArrayF32([3.3, 12.1])
1149      ops.Floor(ops.Constant(c, arr))
1150      self._ExecuteAndCompareClose(c, expected=[np.floor(arr)])
1151
1152    def testCeil(self):
1153      c = self._NewComputation()
1154      arr = NumpyArrayF32([3.3, 12.1])
1155      ops.Ceil(ops.Constant(c, arr))
1156      self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)])
1157
1158    def testAbs(self):
1159      c = self._NewComputation()
1160      arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
1161      ops.Abs(ops.Constant(c, arr))
1162      self._ExecuteAndCompareClose(c, expected=[np.abs(arr)])
1163
1164    def testTanhF32(self):
1165      c = self._NewComputation()
1166      arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001])
1167      ops.Tanh(ops.Constant(c, arr))
1168      self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)])
1169
1170    def testTanhF64(self):
1171      if self.backend.platform == "tpu":
1172        self.skipTest("TPU doesn't support 64bit tanh")
1173      c = self._NewComputation()
1174      arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001])
1175      ops.Tanh(ops.Constant(c, arr))
1176      self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12)
1177
1178    def testTranspose(self):
1179
1180      def _TransposeAndTest(array, permutation):
1181        c = self._NewComputation()
1182        ops.Transpose(ops.Constant(c, array), permutation)
1183        expected = np.transpose(array, permutation)
1184        self._ExecuteAndCompareClose(c, expected=[expected])
1185
1186      _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
1187      _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
1188      _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
1189      _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
1190
1191      arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
1192      for permutation in itertools.permutations(range(arr.ndim)):
1193        _TransposeAndTest(arr, permutation)
1194        _TransposeAndTest(np.asfortranarray(arr), permutation)
1195
1196    def testEq(self):
1197      c = self._NewComputation()
1198      ops.Eq(
1199          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])),
1200          ops.Constant(c, NumpyArrayS32([4, 2, 3, 1])))
1201      self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]])
1202
1203    def testNe(self):
1204      c = self._NewComputation()
1205      ops.Ne(
1206          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])),
1207          ops.Constant(c, NumpyArrayS32([4, 2, 3, 1])))
1208      self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]])
1209
1210      ops.Ne(
1211          ops.Constant(c, NumpyArrayF32([-2.0, 0.0,
1212                                         float("nan"),
1213                                         float("nan")])),
1214          ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0,
1215                                         float("nan")])))
1216      self._ExecuteAndAssertWith(
1217          np.testing.assert_allclose,
1218          c, (),
1219          expected=[[True, False, True, True]])
1220
1221    def testGt(self):
1222      c = self._NewComputation()
1223      ops.Gt(
1224          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1225          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1226      self._ExecuteAndCompareExact(
1227          c, expected=[[False, True, True, False, False]])
1228
1229    def testGe(self):
1230      c = self._NewComputation()
1231      ops.Ge(
1232          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1233          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1234      self._ExecuteAndCompareExact(
1235          c, expected=[[True, True, True, False, False]])
1236
1237    def testLt(self):
1238      c = self._NewComputation()
1239      ops.Lt(
1240          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1241          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1242      self._ExecuteAndCompareExact(
1243          c, expected=[[False, False, False, True, True]])
1244
1245    def testLe(self):
1246      c = self._NewComputation()
1247      ops.Le(
1248          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1249          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1250      self._ExecuteAndCompareExact(
1251          c, expected=[[True, False, False, True, True]])
1252
1253    def testMax(self):
1254      c = self._NewComputation()
1255      ops.Max(
1256          ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1257          ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1258      self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]])
1259
1260    def testMaxExplicitBroadcastDim0(self):
1261      c = self._NewComputation()
1262      ops.Max(
1263          ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1264          ops.Constant(c, NumpyArrayF32([3, 4, 5])),
1265          broadcast_dimensions=(0,))
1266      self._ExecuteAndCompareExact(
1267          c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]])
1268
1269    def testMaxExplicitBroadcastDim1(self):
1270      c = self._NewComputation()
1271      ops.Max(
1272          ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1273          ops.Constant(c, NumpyArrayF32([3, 4, 5])),
1274          broadcast_dimensions=(1,))
1275      self._ExecuteAndCompareExact(
1276          c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]])
1277
1278    def testMin(self):
1279      c = self._NewComputation()
1280      ops.Min(
1281          ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1282          ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1283      self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]])
1284
1285    def testPad(self):
1286      c = self._NewComputation()
1287      ops.Pad(
1288          ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1289          ops.Constant(c, NumpyArrayF32(0.0)),
1290          xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)]))
1291      self._ExecuteAndCompareClose(
1292          c,
1293          expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1294                     [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]])
1295
1296    def testPadWithPaddingConfig(self):
1297      c = self._NewComputation()
1298      padding_config = xla_client.PaddingConfig()
1299      for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]:
1300        dimension = xla_client.PaddingConfigDimension()
1301        dimension.edge_padding_low = lo
1302        dimension.edge_padding_high = hi
1303        dimension.interior_padding = interior
1304        padding_config.dimensions.append(dimension)
1305      ops.Pad(
1306          ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1307          ops.Constant(c, NumpyArrayF32(0.0)), padding_config)
1308      self._ExecuteAndCompareClose(
1309          c,
1310          expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1311                     [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]])
1312
1313    def testReshape(self):
1314      c = self._NewComputation()
1315      ops.Reshape(
1316          ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
1317          dimensions=[0, 1],
1318          new_sizes=[2, 3])
1319      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]])
1320
1321    def testCollapse(self):
1322      c = self._NewComputation()
1323      ops.Collapse(
1324          ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1325          dimensions=[1, 2])
1326      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]])
1327
1328    def testRev(self):
1329      c = self._NewComputation()
1330      ops.Rev(
1331          ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1332          dimensions=[0, 2])
1333      self._ExecuteAndCompareExact(
1334          c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]])
1335
1336    def testReducePrecision(self):
1337      c = self._NewComputation()
1338      ops.ReducePrecision(
1339          ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])),
1340          exponent_bits=8,
1341          mantissa_bits=7)
1342      self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]])
1343
1344    def testClampF32(self):
1345      c = self._NewComputation()
1346      ops.Clamp(
1347          ops.Constant(c, NumpyArrayF32(-1)),
1348          ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
1349          ops.Constant(c, NumpyArrayF32(2)))
1350      self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]])
1351
1352    def testClampS32(self):
1353      c = self._NewComputation()
1354      ops.Clamp(
1355          ops.Constant(c, NumpyArrayS32(-1)),
1356          ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
1357          ops.Constant(c, NumpyArrayS32(2)))
1358      self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]])
1359
1360    def testSelect(self):
1361      c = self._NewComputation()
1362      ops.Select(
1363          ops.Constant(c, NumpyArrayBool([True, False, False, True, False])),
1364          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])),
1365          ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5])))
1366      self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]])
1367
1368    def testSlice(self):
1369      c = self._NewComputation()
1370      ops.Slice(
1371          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1372          [1, 0], [3, 2], [1, 1])
1373      self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]])
1374
1375    def testSliceInDim(self):
1376      c = self._NewComputation()
1377      ops.SliceInDim(
1378          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1379          start_index=1,
1380          limit_index=2,
1381          stride=1,
1382          dimno=1)
1383      self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]])
1384      ops.SliceInDim(
1385          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1386          start_index=0,
1387          limit_index=3,
1388          stride=2,
1389          dimno=0)
1390      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]])
1391
1392    def testDynamicSlice(self):
1393      c = self._NewComputation()
1394      ops.DynamicSlice(
1395          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1396          [ops.Constant(c, NumpyArrayS32([1, 0]))], [2, 2])
1397      self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]])
1398
1399    def testDynamicUpdateSlice(self):
1400      c = self._NewComputation()
1401      ops.DynamicUpdateSlice(
1402          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1403          ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])),
1404          [ops.Constant(c, NumpyArrayS32([1, 1]))])
1405      self._ExecuteAndCompareExact(
1406          c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]])
1407
1408    def testTuple(self):
1409      c = self._NewComputation()
1410      ops.Tuple(c, [
1411          ops.Constant(c, np.int32(42)),
1412          ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
1413          ops.Constant(c, NumpyArrayBool([True, False, False, True]))
1414      ])
1415      result = xla_client.execute_with_python_values(
1416          self.backend.compile(c.build()), (), backend=self.backend)
1417      self.assertLen(result, 3)
1418      np.testing.assert_equal(result[0], 42)
1419      np.testing.assert_allclose(result[1], [1.0, 2.0])
1420      np.testing.assert_equal(result[2], [True, False, False, True])
1421
1422    def testGetTupleElement(self):
1423      c = self._NewComputation()
1424      ops.GetTupleElement(
1425          ops.Tuple(c, [
1426              ops.Constant(c, np.int32(42)),
1427              ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
1428              ops.Constant(c, NumpyArrayBool([True, False, False, True]))
1429          ]), 1)
1430      self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]])
1431
1432    def testBroadcast(self):
1433      c = self._NewComputation()
1434      ops.Broadcast(
1435          ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
1436      self._ExecuteAndCompareExact(
1437          c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]])
1438
1439    def testBroadcastInDim(self):
1440      c = self._NewComputation()
1441      ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0])
1442      self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]])
1443      ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1])
1444      self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]])
1445
1446    def testRngNormal(self):
1447      shape = (2, 3)
1448      c = self._NewComputation()
1449      ops.RngNormal(
1450          ops.Constant(c, NumpyArrayF32(0.)),
1451          ops.Constant(c, NumpyArrayF32(1.)),
1452          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
1453                                             shape))
1454      result = xla_client.execute_with_python_values(
1455          self.backend.compile(c.build()), (), backend=self.backend)
1456      # since the result is random, we just check shape and uniqueness
1457      self.assertLen(result, 1)
1458      self.assertEqual(result[0].shape, shape)
1459      self.assertLen(np.unique(result[0]), np.prod(shape))
1460
1461    def testRngUniformF32(self):
1462      lo, hi = 2., 4.
1463      shape = (2, 3)
1464      c = self._NewComputation()
1465      ops.RngUniform(
1466          ops.Constant(c, NumpyArrayF32(lo)),
1467          ops.Constant(c, NumpyArrayF32(hi)),
1468          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
1469                                             shape))
1470      result = xla_client.execute_with_python_values(
1471          self.backend.compile(c.build()), (), backend=self.backend)
1472      # since the result is random, we just check shape, uniqueness, and range
1473      self.assertLen(result, 1)
1474      self.assertEqual(result[0].shape, shape)
1475      self.assertLen(np.unique(result[0]), np.prod(shape))
1476      self.assertTrue(np.all(lo <= result[0]))
1477      self.assertTrue(np.all(result[0] < hi))
1478
1479    def testRngUniformS32(self):
1480      lo, hi = 2, 4
1481      shape = (2, 3)
1482      c = self._NewComputation()
1483      ops.RngUniform(
1484          ops.Constant(c, NumpyArrayS32(lo)),
1485          ops.Constant(c, NumpyArrayS32(hi)),
1486          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
1487                                             shape))
1488      result = xla_client.execute_with_python_values(
1489          self.backend.compile(c.build()), (), backend=self.backend)
1490      # since the result is random, we just check shape, integrality, and range
1491      self.assertLen(result, 1)
1492      self.assertEqual(result[0].shape, shape)
1493      self.assertEqual(result[0].dtype, np.int32)
1494      self.assertTrue(np.all(lo <= result[0]))
1495      self.assertTrue(np.all(result[0] < hi))
1496
1497    def testCholesky(self):
1498      l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
1499                   dtype=np.float32)
1500      c = self._NewComputation()
1501      ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T))))
1502      self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4)
1503
1504    def testSort(self):
1505      keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1506      c = self._NewComputation()
1507      ops.Sort(c, [ops.Constant(c, keys)], is_stable=True)
1508      self._ExecuteAndCompareClose(
1509          c,
1510          expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)])
1511
1512    def testSortKeyVal(self):
1513      keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1514      values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1515      c = self._NewComputation()
1516      ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
1517      result = xla_client.execute_with_python_values(
1518          self.backend.compile(c.build()), (), backend=self.backend)
1519      self.assertLen(result, 2)
1520      np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
1521      np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
1522
1523    def testSortCustomComparator(self):
1524      b = self._NewComputation("comparator")
1525      p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1526      q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1527      p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1528      q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1529      ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1)))
1530      comparator = b.build()
1531
1532      keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32)
1533      values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1534      c = self._NewComputation()
1535      ops.Sort(
1536          c, (ops.Constant(c, keys), ops.Constant(c, values)),
1537          dimension=1,
1538          comparator=comparator)
1539      result = xla_client.execute_with_python_values(
1540          self.backend.compile(c.build()), (), backend=self.backend)
1541      self.assertLen(result, 2)
1542      np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
1543      np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
1544
1545    def testQR(self):
1546      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1547                    [10, 63, 166, 310]],
1548                   dtype=np.float32)
1549      c = self._NewComputation()
1550      ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True))
1551      q, r = self._Execute(c, ())
1552      np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
1553
1554    def testEigh(self):
1555      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1556                    [10, 63, 166, 310]],
1557                   dtype=np.float32)
1558      a = (a + a.T) / 2
1559
1560      c = self._NewComputation()
1561      ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True))
1562      # TODO(b/129396575): Turn this test back on when it passes without
1563      # fastmath.
1564      # v, w = self._Execute(c, ())
1565      # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3)
1566
1567    def testSVD(self):
1568      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1569                    [10, 63, 166, 310]],
1570                   dtype=np.float32)
1571      c = self._NewComputation()
1572      ops.Tuple(c, ops.SVD(ops.Constant(c, a)))
1573      u, d, v = self._Execute(c, ())
1574      self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3)
1575
1576    def testTriangularSolve(self):
1577      a_vals = np.array(
1578          [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
1579          dtype=np.float32)
1580      b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
1581                        dtype=np.float32)
1582
1583      c = self._NewComputation()
1584      ops.TriangularSolve(
1585          ops.Constant(c, a_vals),
1586          ops.Constant(c, b_vals),
1587          left_side=False,
1588          lower=True,
1589          transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE,
1590          unit_diagonal=False)
1591      self._ExecuteAndCompareClose(
1592          c,
1593          expected=[
1594              np.array([
1595                  [0.5, 0.08333334, 0.04629629, 0.03367003],
1596                  [2.5, -0.25, -0.1388889, -0.1010101],
1597                  [4.5, -0.58333331, -0.32407406, -0.23569024],
1598              ],
1599                       dtype=np.float32)
1600          ],
1601          rtol=1e-4)
1602
1603    def testApproxTopK(self):
1604      if self.backend.platform != "tpu":
1605        self.skipTest("ApproxTopK is only supported on TPU")
1606      k = 10
1607      qy_size = 256
1608      db_size = 3000
1609      feature = 128
1610      recall_target = 0.95
1611      b = self._NewComputation()
1612      p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1613      q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1614      ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1615      ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1616      ops.Gt(p0, q0)
1617      comparator = b.build()
1618      qy_shape = [qy_size, feature]
1619      db_shape = [feature, db_size]
1620      rng = np.random.RandomState(0)
1621      qy_arg = rng.randn(*qy_shape).astype(np.float32)
1622      db_arg = rng.randn(*db_shape).astype(np.float32)
1623      b = self._NewComputation()
1624      qy = ops.Parameter(b, 0, xla_client.shape_from_pyval(qy_arg))
1625      db = ops.Parameter(b, 1, xla_client.shape_from_pyval(db_arg))
1626      scores = ops.Dot(qy, db)
1627      iota = ops.Iota(
1628          b,
1629          xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
1630                                       (qy_size, db_size)), 1)
1631      init_val = ops.Constant(b, np.float32(-1))
1632      init_arg = ops.Constant(b, np.int32(-1))
1633      ground_truth = ops.TopK(scores, k=k)
1634      approx_topk = ops.ApproxTopK(
1635          b, [scores, iota], [init_val, init_arg],
1636          top_k=k,
1637          reduction_dim=1,
1638          comparator=comparator,
1639          recall_target=recall_target)
1640      ops.Tuple(b, [
1641          ops.GetTupleElement(ground_truth, 1),
1642          ops.GetTupleElement(approx_topk, 1)
1643      ])
1644      results = self._Execute(b, [qy_arg, db_arg])
1645      ground_truth_docids = [set(x) for x in results[0]]
1646      hits = sum(
1647          len(
1648              list(x
1649                   for x in approx_topk_per_q
1650                   if x in ground_truth_docids[q]))
1651          for q, approx_topk_per_q in enumerate(results[1]))
1652      self.assertGreater(hits / (qy_size * k), recall_target)
1653
1654    def testIsConstant(self):
1655      c = self._NewComputation()
1656      a = ops.Constant(c, np.int32(3))
1657      b = ops.Constant(c, np.int32(1))
1658      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1659      const_expr = ops.Sub(b, a)
1660      non_const_expr = ops.Mul(const_expr, x)
1661      self.assertTrue(c.is_constant(const_expr))
1662      self.assertFalse(c.is_constant(non_const_expr))
1663
1664    def testGather(self):
1665      a = np.arange(9).astype(np.int32).reshape((3, 3))
1666      indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32)
1667      dnums = xla_client.GatherDimensionNumbers()
1668      dnums.offset_dims.append(1)
1669      dnums.offset_dims.append(2)
1670      dnums.start_index_map.append(0)
1671      dnums.start_index_map.append(1)
1672      dnums.index_vector_dim = 2
1673      c = self._NewComputation()
1674      ops.Gather(
1675          ops.Constant(c, a),
1676          ops.Constant(c, indices),
1677          dnums,
1678          slice_sizes=[1, 1])
1679      g, = self._Execute(c, ())
1680      expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
1681      np.testing.assert_allclose(g, expected, rtol=1e-4)
1682
1683    def testFft(self):
1684      if self.backend.platform == "tpu":
1685        self.skipTest("TPU only supports 1D FFT")
1686      shape = [2, 3, 4, 5]
1687      rng = np.random.RandomState(0)
1688      a = rng.randn(*shape) + 1.0j * rng.randn(*shape)
1689      a = a.astype(np.complex64)
1690      # FFT
1691      c = self._NewComputation()
1692      ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:])
1693      self._ExecuteAndCompareClose(
1694          c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4)
1695      # IFFT
1696      c = self._NewComputation()
1697      ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:])
1698      self._ExecuteAndCompareClose(
1699          c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4)
1700      # RFFT
1701      b = rng.randn(*shape).astype(np.float32)
1702      c = self._NewComputation()
1703      ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:])
1704      self._ExecuteAndCompareClose(
1705          c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4)
1706      # IRFFT
1707      c = self._NewComputation()
1708      ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8])
1709      self._ExecuteAndCompareClose(
1710          c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4)
1711
1712    def testNextAfter(self):
1713      c = self._NewComputation()
1714      ops.NextAfter(
1715          ops.Constant(c, np.array([1, 2], dtype=np.float32)),
1716          ops.Constant(c, np.array([2, 1], dtype=np.float32)))
1717      out, = self._Execute(c, ())
1718      eps = np.finfo(np.float32).eps
1719      np.testing.assert_equal(
1720          np.array([eps + 1, 2 - eps], dtype=np.float32), out)
1721
1722    @parameterized.named_parameters({
1723        "testcase_name": "_{}".format(dtype.__name__),
1724        "dtype": dtype,
1725    } for dtype in float_dtypes)
1726    def testRegularizedIncompleteBeta(self, dtype):
1727      x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538],
1728                   dtype=dtype)
1729      a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606],
1730                   dtype=dtype)
1731      b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677],
1732                   dtype=dtype)
1733      c = self._NewComputation()
1734      ops.RegularizedIncompleteBeta(
1735          ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x))
1736      expected = np.array(
1737          [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155])
1738      self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2)
1739
1740  tests.append(SingleOpTest)
1741
1742  class EmbeddedComputationsTest(ComputationTest):
1743    """Tests for XLA graphs with embedded computations (such as maps)."""
1744
1745    def _CreateConstantComputation(self, in_dtype, out_dtype):
1746      """Computation (A) -> B that returns a constant 1 for any input."""
1747      c = self._NewComputation("constant_{}_{}_one".format(
1748          in_dtype.__name__, out_dtype.__name__))
1749      ops.Parameter(
1750          c, 0,
1751          xla_client.shape_from_pyval(np.array(
1752              0, dtype=in_dtype)).with_major_to_minor_layout_if_absent())
1753      ops.Constant(c, out_dtype(1))
1754      return c.build()
1755
1756    def _CreateMulBy2Computation(self, dtype):
1757      """Computation (dtype) -> dtype that multiplies its parameter by 2."""
1758      c = self._NewComputation("mul_f32_by2")
1759      ops.Mul(
1760          ops.Parameter(
1761              c, 0,
1762              xla_client.shape_from_pyval(np.array(
1763                  0, dtype=dtype)).with_major_to_minor_layout_if_absent()),
1764          ops.Constant(c, dtype(2.0)))
1765      return c.build()
1766
1767    def _CreateMulF32ByParamComputation(self):
1768      """Computation (f32) -> f32 that multiplies one parameter by the other."""
1769      c = self._NewComputation("mul_f32_by_param")
1770      ops.Mul(
1771          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))),
1772          ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))))
1773      return c.build()
1774
1775    def _CreateBinaryAddComputation(self, dtype):
1776      """Computation (dtype, dtype) -> dtype that adds its two parameters."""
1777      c = self._NewComputation("add_param0_by_param1")
1778      shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1779      shape = shape.with_major_to_minor_layout_if_absent()
1780      ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1781      return c.build()
1782
1783    def _CreateBinaryGeComputation(self, dtype):
1784      """Computation (dtype, dtype) -> bool that tests param0 >= param1."""
1785      c = self._NewComputation("param0_lt_param1")
1786      shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1787      shape = shape.with_major_to_minor_layout_if_absent()
1788      ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1789      return c.build()
1790
1791    def _MakeSample3DArray(self, dtype):
1792      return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1793                       [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
1794                      dtype=dtype)
1795
1796    @parameterized.named_parameters({
1797        "testcase_name": "_{}".format(dtype.__name__),
1798        "dtype": dtype,
1799    } for dtype in float_dtypes)
1800    def testCall(self, dtype):
1801      c = self._NewComputation()
1802      ops.Call(
1803          c,
1804          self._CreateMulBy2Computation(dtype),
1805          operands=(ops.Constant(c, dtype(5.0)),))
1806      self._ExecuteAndCompareClose(c, expected=[10.0])
1807
1808    @parameterized.named_parameters({
1809        "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__),
1810        "in_dtype": in_dtype,
1811        "out_dtype": out_dtype,
1812    } for in_dtype, out_dtype in [[np.float32, np.int32]])
1813    def testMapEachElementToConstant(self, in_dtype, out_dtype):
1814      c = self._NewComputation()
1815      ops.Map(c,
1816              [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))],
1817              self._CreateConstantComputation(in_dtype, out_dtype), [0])
1818      self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]])
1819
1820    @parameterized.named_parameters({
1821        "testcase_name": "_{}".format(dtype.__name__),
1822        "dtype": dtype,
1823    } for dtype in float_dtypes)
1824    def testMapMulBy2(self, dtype):
1825      if dtype == np.float64 and self.backend.platform == "tpu":
1826        self.skipTest("TPU doesn't support float64")
1827      c = self._NewComputation()
1828      ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))],
1829              self._CreateMulBy2Computation(dtype), [0])
1830      self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]])
1831
1832    @parameterized.named_parameters({
1833        "testcase_name": "_{}".format(dtype.__name__),
1834        "dtype": dtype,
1835    } for dtype in float_dtypes)
1836    def testSimpleMapChain(self, dtype):
1837      if dtype == np.float64 and self.backend.platform == "tpu":
1838        self.skipTest("TPU doesn't support float64")
1839      # Chains a map of constant-out with a map of mul-by-2
1840      c = self._NewComputation()
1841      const = ops.Map(
1842          c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))],
1843          self._CreateConstantComputation(dtype, dtype), [0])
1844      ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0])
1845      self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]])
1846
1847    # TODO(b/154752816): bfloat16 crashes in evaluator.
1848    @parameterized.named_parameters({
1849        "testcase_name": "_{}".format(dtype.__name__),
1850        "dtype": dtype,
1851    } for dtype in float_dtypes if dtype != bfloat16)
1852    def testDivVectorsWithMap(self, dtype):
1853
1854      def DivComputation():
1855        c = self._NewComputation("div_param0_by_param1")
1856        shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1857        ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1858        return c.build()
1859
1860      c = self._NewComputation()
1861      ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)),
1862                  ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))),
1863              DivComputation(), [0])
1864      self._ExecuteAndCompareClose(
1865          c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3)
1866
1867    @parameterized.named_parameters({
1868        "testcase_name": "_{}".format(dtype.__name__),
1869        "dtype": dtype,
1870    } for dtype in float_dtypes)
1871    def testSelectAndScatter(self, dtype):
1872      if dtype == np.float64 and self.backend.platform == "tpu":
1873        self.skipTest("TPU doesn't support float64")
1874      c = self._NewComputation()
1875      operand = ops.Constant(
1876          c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype))
1877      window_dimensions = (2, 1)
1878      window_strides = (1, 2)
1879      padding = xla_client.window_padding_type_to_pad_values(
1880          xla_client.PaddingType.VALID,
1881          c.get_shape(operand).dimensions(), window_dimensions, window_strides)
1882      ops.SelectAndScatterWithGeneralPadding(
1883          operand,
1884          select=self._CreateBinaryGeComputation(dtype),
1885          window_dimensions=window_dimensions,
1886          window_strides=window_strides,
1887          padding=padding,
1888          source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)),
1889          init_value=ops.Constant(c, np.array(1, dtype=dtype)),
1890          scatter=self._CreateBinaryAddComputation(dtype))
1891      self._ExecuteAndCompareClose(
1892          c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3)
1893
1894    @parameterized.named_parameters({
1895        "testcase_name": "_{}".format(dtype.__name__),
1896        "dtype": dtype,
1897    } for dtype in float_dtypes)
1898    def testReduce1DtoScalar(self, dtype):
1899      c = self._NewComputation()
1900      ops.Reduce(
1901          c,
1902          operands=[
1903              ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))
1904          ],
1905          init_values=[ops.Constant(c, dtype(0))],
1906          computation=self._CreateBinaryAddComputation(dtype),
1907          dimensions_to_reduce=[0])
1908      self._ExecuteAndCompareClose(c, expected=[10])
1909
1910    # TODO(phawkins): test comparison harness doesn't support bfloat16
1911    @parameterized.named_parameters({
1912        "testcase_name": "_{}_dim{}".format(dtype.__name__, dim),
1913        "dtype": dtype,
1914        "dim": dim,
1915    } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2))
1916    def testReduce2DTo1D(self, dtype, dim):
1917      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1918      c = self._NewComputation()
1919      ops.Reduce(
1920          c,
1921          operands=[ops.Constant(c, input_array)],
1922          init_values=[ops.Constant(c, dtype(0))],
1923          computation=self._CreateBinaryAddComputation(dtype),
1924          dimensions_to_reduce=[dim])
1925      self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)])
1926
1927    @parameterized.named_parameters({
1928        "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims),
1929        "dtype": dtype,
1930        "dims": tuple(dims)
1931    } for dtype in float_dtypes for dims in itertools.permutations(range(3)))
1932    def testReduce3DAllPossibleWaysF32(self, dtype, dims):
1933      input_array = self._MakeSample3DArray(dtype)
1934      c = self._NewComputation()
1935      ops.Reduce(
1936          c,
1937          operands=[ops.Constant(c, input_array)],
1938          init_values=[ops.Constant(c, dtype(0))],
1939          computation=self._CreateBinaryAddComputation(dtype),
1940          dimensions_to_reduce=dims)
1941      self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)])
1942
1943    @parameterized.named_parameters({
1944        "testcase_name": "_{}".format(dtype.__name__),
1945        "dtype": dtype,
1946    } for dtype in float_dtypes)
1947    def testReduceWindowValidUnitStrides(self, dtype):
1948      if dtype == np.float64 and self.backend.platform == "tpu":
1949        self.skipTest("TPU doesn't support float64")
1950      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1951      c = self._NewComputation()
1952      window_dimensions = (2, 1)
1953      window_strides = (1, 1)
1954      padding = xla_client.window_padding_type_to_pad_values(
1955          xla_client.PaddingType.VALID, input_array.shape, window_dimensions,
1956          window_strides)
1957      ops.ReduceWindowWithGeneralPadding(
1958          operand=ops.Constant(c, input_array),
1959          init_value=ops.Constant(c, dtype(0)),
1960          computation=self._CreateBinaryAddComputation(dtype),
1961          window_dimensions=window_dimensions,
1962          window_strides=window_strides,
1963          base_dilations=[],
1964          window_dilations=[],
1965          padding=padding)
1966      self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]])
1967
1968    @parameterized.named_parameters({
1969        "testcase_name": "_{}".format(dtype.__name__),
1970        "dtype": dtype,
1971    } for dtype in float_dtypes)
1972    def testReduceWindowSameUnitStrides(self, dtype):
1973      if dtype == np.float64 and self.backend.platform == "tpu":
1974        self.skipTest("TPU doesn't support float64")
1975      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1976      c = self._NewComputation()
1977      window_dimensions = (2, 1)
1978      window_strides = (1, 1)
1979      padding = xla_client.window_padding_type_to_pad_values(
1980          xla_client.PaddingType.SAME, input_array.shape, window_dimensions,
1981          window_strides)
1982      ops.ReduceWindowWithGeneralPadding(
1983          operand=ops.Constant(c, input_array),
1984          init_value=ops.Constant(c, dtype(0)),
1985          computation=self._CreateBinaryAddComputation(dtype),
1986          window_dimensions=window_dimensions,
1987          window_strides=window_strides,
1988          base_dilations=[],
1989          window_dilations=[],
1990          padding=padding)
1991      self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]])
1992
1993    @parameterized.named_parameters({
1994        "testcase_name": "_{}".format(dtype.__name__),
1995        "dtype": dtype,
1996    } for dtype in float_dtypes)
1997    def testReduceWindowValidGeneralStrides(self, dtype):
1998      if dtype == np.float64 and self.backend.platform == "tpu":
1999        self.skipTest("TPU doesn't support float64")
2000      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
2001      c = self._NewComputation()
2002      window_dimensions = (2, 1)
2003      window_strides = (1, 2)
2004      padding = xla_client.window_padding_type_to_pad_values(
2005          xla_client.PaddingType.VALID, input_array.shape, window_dimensions,
2006          window_strides)
2007      ops.ReduceWindowWithGeneralPadding(
2008          operand=ops.Constant(c, input_array),
2009          init_value=ops.Constant(c, dtype(0)),
2010          computation=self._CreateBinaryAddComputation(dtype),
2011          window_dimensions=window_dimensions,
2012          window_strides=window_strides,
2013          base_dilations=[],
2014          window_dilations=[],
2015          padding=padding)
2016      self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]])
2017
2018    def testReduceWindowVariadic(self):
2019      c = self._NewComputation("reducer")
2020      shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32))
2021      shape = shape.with_major_to_minor_layout_if_absent()
2022      ps = [ops.Parameter(c, i, shape) for i in range(4)]
2023      which = ops.Ge(ps[0], ps[2])
2024      ops.Tuple(
2025          c, [ops.Select(which, ps[0], ps[2]),
2026              ops.Select(which, ps[1], ps[3])])
2027      reducer = c.build()
2028
2029      key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32)
2030      val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32)
2031      c = self._NewComputation()
2032      window_dimensions = (2, 1)
2033      window_strides = (1, 1)
2034      padding = xla_client.window_padding_type_to_pad_values(
2035          xla_client.PaddingType.VALID, key_array.shape, window_dimensions,
2036          window_strides)
2037      ops.ReduceWindowWithGeneralPadding(
2038          operands=[ops.Constant(c, key_array),
2039                    ops.Constant(c, val_array)],
2040          init_values=[
2041              ops.Constant(c, np.int32(0)),
2042              ops.Constant(c, np.int32(0))
2043          ],
2044          computation=reducer,
2045          window_dimensions=window_dimensions,
2046          window_strides=window_strides,
2047          base_dilations=[],
2048          window_dilations=[],
2049          padding=padding)
2050      self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]])
2051
2052    @parameterized.named_parameters({
2053        "testcase_name": "_{}".format(dtype.__name__),
2054        "dtype": dtype,
2055    } for dtype in float_dtypes)
2056    def testWhile(self, dtype):
2057
2058      def LessThan10Cond():
2059        c = self._NewComputation("test_lt_10")
2060        shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
2061        ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.)))
2062        return c.build()
2063
2064      cond = LessThan10Cond()
2065      body = self._CreateMulBy2Computation(dtype)
2066      c = self._NewComputation()
2067      init = ops.Constant(c, dtype(1.))
2068      ops.While(cond, body, init)
2069      self._ExecuteAndCompareClose(c, expected=[16.])
2070
2071    def testConditionalTrue(self):
2072      c = self._NewComputation()
2073      pred = ops.Constant(c, np.bool_(True))
2074      true_operand = ops.Constant(c, np.float32(3.))
2075      true_computation = self._CreateMulBy2Computation(np.float32)
2076      false_operand = ops.Constant(c, np.float32(2.))
2077      false_computation = self._CreateConstantComputation(
2078          np.float32, np.float32)
2079      ops.Conditional(pred, true_operand, true_computation, false_operand,
2080                      false_computation)
2081      self._ExecuteAndCompareClose(c, expected=[6.])
2082
2083    def testConditionalFalse(self):
2084      c = self._NewComputation()
2085      pred = ops.Constant(c, np.bool_(False))
2086      true_operand = ops.Constant(c, np.float32(3.))
2087      true_computation = self._CreateMulBy2Computation(np.float32)
2088      false_operand = ops.Constant(c, np.float32(2.))
2089      false_computation = self._CreateConstantComputation(
2090          np.float32, np.float32)
2091      ops.Conditional(pred, true_operand, true_computation, false_operand,
2092                      false_computation)
2093      self._ExecuteAndCompareClose(c, expected=[1.])
2094
2095    @unittest.skipIf(cloud_tpu, "not implemented")
2096    def testInfeedS32Values(self):
2097      to_infeed = NumpyArrayS32([1, 2, 3, 4])
2098      c = self._NewComputation()
2099      ops.GetTupleElement(
2100          ops.InfeedWithToken(
2101              ops.CreateToken(c),
2102              xla_client.shape_from_pyval(
2103                  to_infeed[0]).with_major_to_minor_layout_if_absent()), 0)
2104      compiled_c = self.backend.compile(c.build())
2105      device = self.backend.local_devices()[0]
2106      for item in to_infeed:
2107        device.transfer_to_infeed(item)
2108
2109      for item in to_infeed:
2110        result, = xla_client.execute_with_python_values(
2111            compiled_c, (), backend=self.backend)
2112        self.assertEqual(result, item)
2113
2114    @unittest.skipIf(cloud_tpu, "not implemented")
2115    def testInfeedTuple(self):
2116      to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]]))
2117      c = self._NewComputation()
2118      ops.GetTupleElement(
2119          ops.InfeedWithToken(
2120              ops.CreateToken(c),
2121              xla_client.shape_from_pyval(
2122                  to_infeed).with_major_to_minor_layout_if_absent()), 0)
2123      compiled_c = self.backend.compile(c.build())
2124      device = self.backend.local_devices()[0]
2125      device.transfer_to_infeed(to_infeed)
2126
2127      result = xla_client.execute_with_python_values(
2128          compiled_c, (), backend=self.backend)
2129      self.assertLen(result, 2)
2130      np.testing.assert_equal(result[0], to_infeed[0])
2131      np.testing.assert_equal(result[1], to_infeed[1])
2132
2133    @unittest.skipIf(cloud_tpu, "not implemented")
2134    def testInfeedThenOutfeedS32(self):
2135      to_round_trip = NumpyArrayS32([1, 2, 3, 4])
2136      c = self._NewComputation()
2137      x_and_token = ops.InfeedWithToken(
2138          ops.CreateToken(c),
2139          xla_client.shape_from_pyval(
2140              to_round_trip[0]).with_major_to_minor_layout_if_absent())
2141      x = ops.GetTupleElement(x_and_token, 0)
2142      token = ops.GetTupleElement(x_and_token, 1)
2143      outfeed_shape = xla_client.shape_from_pyval(
2144          to_round_trip[0]).with_major_to_minor_layout_if_absent()
2145      ops.OutfeedWithToken(x, token, outfeed_shape)
2146
2147      compiled_c = self.backend.compile(c.build())
2148      device = self.backend.local_devices()[0]
2149
2150      for want in to_round_trip:
2151        execution = threading.Thread(target=lambda: compiled_c.execute([]))
2152        execution.start()
2153        device.transfer_to_infeed(want)
2154        got = device.transfer_from_outfeed(outfeed_shape)
2155        execution.join()
2156        self.assertEqual(want, got)
2157
2158    def testScatter(self):
2159      a = np.arange(9).astype(np.int32).reshape((3, 3))
2160      scatter_indices = np.array([0, 2], dtype=np.int32)
2161      updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32)
2162
2163      dnums = xla_client.ScatterDimensionNumbers()
2164      dnums.update_window_dims.append(1)
2165      dnums.inserted_window_dims.append(0)
2166      dnums.scatter_dims_to_operand_dims.append(0)
2167      dnums.index_vector_dim = 1
2168
2169      c = self._NewComputation()
2170      ops.Scatter(
2171          ops.Constant(c, a), ops.Constant(c, scatter_indices),
2172          ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32),
2173          dnums)
2174      expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]],
2175                          dtype=np.int32)
2176      self._ExecuteAndCompareClose(c, expected=[expected])
2177
2178  class DeviceTest(ComputationTest):
2179
2180    def testPlatform(self):
2181      for device in self.backend.local_devices():
2182        self.assertEqual(device.platform, self.backend.platform)
2183
2184  tests.append(DeviceTest)
2185
2186  class ErrorTest(ComputationTest):
2187
2188    def setUp(self):
2189      super(ErrorTest, self).setUp()
2190      self.f32_scalar_2 = NumpyArrayF32(2.0)
2191      self.s32_scalar_2 = NumpyArrayS32(2)
2192
2193    def testCompileWithWrongElementTypeInLayout(self):
2194      c = self._NewComputation()
2195      c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
2196      ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
2197      c.clear_op_metadata()
2198
2199      options = xla_client.CompileOptions()
2200      options.argument_layouts = [
2201          xla_client.Shape.array_shape(np.dtype(np.float32), [])
2202      ]
2203
2204      def TestFun():
2205        return self.backend.compile(c.build(), compile_options=options)
2206
2207      self.assertRaisesRegex(
2208          RuntimeError, r".*Invalid argument shape.*"
2209          r"expected s32\[\], got f32\[\].*", TestFun)
2210
2211    def testInvokeWithWrongElementType(self):
2212      c = self._NewComputation()
2213      c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
2214      ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
2215      c.clear_op_metadata()
2216
2217      def TestFun():
2218        return xla_client.execute_with_python_values(
2219            self.backend.compile(c.build()), [self.f32_scalar_2], self.backend)
2220
2221      self.assertRaisesRegex(
2222          RuntimeError, r"Invalid argument: Argument does not match.*"
2223          r"want s32\[\], got f32\[\].*", TestFun)
2224
2225  tests.append(EmbeddedComputationsTest)
2226
2227  class ComputationRootTest(ComputationTest):
2228    """Tests related to setting the root of the computation."""
2229
2230    def testComputationRootDifferentFromLastOp(self):
2231      c = self._NewComputation()
2232      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
2233      result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
2234      ops.Add(result, ops.Constant(c, np.float32(1.618)))
2235
2236      arg = NumpyArrayF32(1.0)
2237      compiled_c = self.backend.compile(c.build(result))
2238      ans, = xla_client.execute_with_python_values(
2239          compiled_c, [arg], backend=self.backend)
2240      np.testing.assert_allclose(ans, 4.14)
2241
2242  tests.append(ComputationRootTest)
2243
2244  class SetShardingTest(ComputationTest):
2245    """Tests related to set OpSharding."""
2246
2247    def testSetSharding(self):
2248      c = self._NewComputation()
2249      sharding = xla_client.OpSharding()
2250      sharding.type = xla_client.OpSharding.Type.REPLICATED
2251      sharding.tile_assignment_dimensions = [1]
2252      sharding.tile_assignment_devices = [0]
2253      c.set_sharding(sharding)
2254      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
2255      c.clear_sharding()
2256
2257      result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
2258      ops.Add(result, ops.Constant(c, np.float32(1.618)))
2259      arg = NumpyArrayF32(1.0)
2260      compiled_c = self.backend.compile(c.build(result))
2261      ans, = xla_client.execute_with_python_values(
2262          compiled_c, [arg], backend=self.backend)
2263      np.testing.assert_allclose(ans, 4.14)
2264
2265  tests.append(SetShardingTest)
2266
2267  testcase_shapes = [
2268      (),
2269      (1,),
2270      (2, 3),
2271      (2, 0),
2272      (0, 7),
2273      (4, 1, 2),
2274      (2, 1, 3),
2275      (2, 4, 1),
2276      (3, 1),
2277      (1, 3),
2278  ]
2279
2280  def FormatShapeAndDtype(shape, dtype):
2281    return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape)))
2282
2283  class DLPackTest(parameterized.TestCase):
2284
2285    def setUp(self):
2286      super(DLPackTest, self).setUp()
2287      self.backend = xla_backend()
2288      if self.backend.platform not in ("cpu", "gpu"):
2289        self.skipTest("DLPack requires CPU or GPU")
2290      self.cpu_backend = (
2291          self.backend
2292          if self.backend.platform == "cpu" else xla_client.make_cpu_client())
2293      self.gpu_backend = (
2294          self.backend if self.backend.platform == "gpu" else None)
2295
2296    def tearDown(self):
2297      super().tearDown()
2298      del self.backend
2299      del self.cpu_backend
2300      del self.gpu_backend
2301
2302    # pylint: disable=g-complex-comprehension
2303    # pyformat: disable
2304    @parameterized.named_parameters({
2305        "testcase_name": "{}_own={}_gpu={}".format(
2306            FormatShapeAndDtype(shape, dtype), take_ownership, gpu),
2307        "dtype": dtype,
2308        "shape": shape,
2309        "take_ownership": take_ownership,
2310        "gpu": gpu
2311    } for dtype in dlpack_dtypes for shape in testcase_shapes
2312                                    for take_ownership in [False, True]
2313                                    for gpu in [False, True])
2314    # pyformat: enable
2315    def testRoundTrip(self, dtype, shape, take_ownership, gpu):
2316      if gpu and self.gpu_backend is None:
2317        raise unittest.SkipTest("Test not running with GPU support")
2318      backend = self.gpu_backend if gpu else self.cpu_backend
2319      if dtype == np.bool_:
2320        x = np.random.randint(0, 2, size=shape).astype(np.bool_)
2321      else:
2322        x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
2323      buffer = backend.buffer_from_pyval(x)
2324      dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
2325          buffer, take_ownership=take_ownership)
2326      del buffer  # Free "buffer" to make sure dlt retains ownership.
2327      self.assertEqual(type(dlt).__name__, "PyCapsule")
2328      y = xla_client._xla.dlpack_managed_tensor_to_buffer(
2329          dlt, self.cpu_backend, self.gpu_backend)
2330      np.testing.assert_array_equal(
2331          x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py())
2332
2333    def testTensorsCanBeConsumedOnceOnly(self):
2334      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2335      buffer = self.backend.buffer_from_pyval(x)
2336      dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
2337          buffer, take_ownership=True)
2338
2339      def ConsumeDLPackTensor():
2340        _ = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
2341
2342      ConsumeDLPackTensor()
2343      self.assertRaisesRegex(
2344          RuntimeError, ".*a DLPack tensor may be consumed at most once.*",
2345          ConsumeDLPackTensor)
2346
2347    def testTensorsCanBeOwnedOnceOnly(self):
2348      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2349      buffer = self.backend.buffer_from_pyval(x)
2350      _ = xla_client._xla.buffer_to_dlpack_managed_tensor(
2351          buffer, take_ownership=True)
2352      self.assertTrue(buffer.is_deleted())
2353      with self.assertRaisesRegex(
2354          RuntimeError,
2355          "Cannot convert deleted/invalid buffer to DLPack tensor.*"):
2356        _ = xla_client._xla.buffer_to_dlpack_managed_tensor(
2357            buffer, take_ownership=True)
2358
2359    def testNonOwnedDlpackCanBeViewedTwice(self):
2360      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2361      buffer = self.backend.buffer_from_pyval(x)
2362      d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(
2363          buffer, take_ownership=False)
2364      d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(
2365          buffer, take_ownership=False)
2366
2367      y = xla_client._xla.dlpack_managed_tensor_to_buffer(d1, self.backend)
2368      z = xla_client._xla.dlpack_managed_tensor_to_buffer(d2, self.backend)
2369      del d1, d2
2370      np.testing.assert_array_equal(x, buffer.to_py())
2371      np.testing.assert_array_equal(x, y.to_py())
2372      np.testing.assert_array_equal(x, z.to_py())
2373
2374  tests.append(DLPackTest)
2375
2376  class BufferProtocolTest(parameterized.TestCase):
2377
2378    def setUp(self):
2379      super(BufferProtocolTest, self).setUp()
2380      self.backend = xla_backend()
2381      if self.backend.platform != "cpu":
2382        self.skipTest("Test requires CPU")
2383
2384    # pylint: disable=g-complex-comprehension
2385    @parameterized.named_parameters({
2386        "testcase_name": FormatShapeAndDtype(shape, dtype),
2387        "dtype": dtype,
2388        "shape": shape
2389    } for dtype in standard_dtypes if dtype != bfloat16
2390                                    for shape in testcase_shapes)
2391    def testRoundTrip(self, dtype, shape):
2392      x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
2393      x_ptr = x.__array_interface__["data"][0]
2394      buffer = self.backend.buffer_from_pyval(
2395          x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
2396      y = np.array(buffer, copy=False)
2397      y_ptr = y.__array_interface__["data"][0]
2398      np.testing.assert_array_equal(x, y)
2399      # If the input was sufficiently aligned, the input and output should
2400      # alias.
2401      self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
2402      self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
2403
2404      during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
2405      buffer2 = self.backend.buffer_from_pyval(
2406          x, host_buffer_semantics=during_call)
2407      z = np.array(buffer2, copy=False)
2408      self.assertNotEqual(x.__array_interface__["data"][0],
2409                          z.__array_interface__["data"][0])
2410
2411    def testDeleteWithActiveView(self):
2412      x = np.random.randn(20, 10)
2413      buffer = self.backend.buffer_from_pyval(x)
2414      buffer_ptr = buffer.unsafe_buffer_pointer()
2415      y = np.array(buffer, copy=False)
2416      buffer.delete()
2417      # It is still legal to access `y`; the array view must keep it alive.
2418      np.testing.assert_array_equal(x, y)
2419      self.assertEqual(y.__array_interface__["data"][0], buffer_ptr)
2420
2421  tests.append(BufferProtocolTest)
2422
2423  class TracebackTest(absltest.TestCase):
2424
2425    def setUp(self):
2426      super(TracebackTest, self).setUp()
2427      self.backend = xla_backend()
2428
2429    def testNoTracebacksIfDisabled(self):
2430      with xla_client.tracebacks(enabled=False):
2431        self.assertEqual(None, xla_client.Traceback.get_traceback())
2432        buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
2433        self.assertEqual(None, buffer.traceback)
2434
2435        b = xla_client.XlaBuilder("computation")
2436        ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
2437        e = self.backend.compile(b.build())
2438        self.assertEqual(None, e.traceback)
2439
2440    def assertIsTracebackContaining(self, tb, function):
2441      self.assertIsInstance(tb, xla_client.Traceback)
2442      self.assertIn(function, str(tb))
2443      self.assertTrue(any(f.function_name == function for f in tb.frames))
2444
2445    def testTracebacks(self):
2446      with xla_client.tracebacks(enabled=True):
2447        tb = xla_client.Traceback.get_traceback()
2448        self.assertIsTracebackContaining(tb, "testTracebacks")
2449
2450        # Tracebacks are not implemented on the TPU driver extension's variant
2451        # of buffers and executables.
2452        if not isinstance(self.backend, xla_client.Client):
2453          return
2454
2455        buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
2456        self.assertIsTracebackContaining(buffer.traceback, "testTracebacks")
2457
2458        b = xla_client.XlaBuilder("computation")
2459        ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
2460        e = self.backend.compile(b.build())
2461        self.assertIsTracebackContaining(e.traceback, "testTracebacks")
2462
2463    def testNestedFunction(self):
2464
2465      def AFunction():
2466
2467        def AnotherFunction():
2468          return xla_client.Traceback.get_traceback()
2469
2470        return AnotherFunction()
2471
2472      with xla_client.tracebacks(enabled=True):
2473        tb = AFunction()
2474        self.assertIsInstance(tb, xla_client.Traceback)
2475        frames = tb.frames
2476        i = next(
2477            i for (i, f) in enumerate(frames) if f.function_name == "AFunction")
2478        self.assertEqual(frames[i - 1].function_name, "AnotherFunction")
2479        self.assertEqual(frames[i + 1].function_name, "testNestedFunction")
2480
2481  tests.append(TracebackTest)
2482
2483  class ClientTest(ComputationTest):
2484
2485    def setUp(self):
2486      super(ClientTest, self).setUp()
2487      self.backend = xla_backend()
2488
2489    def testPlatformVersion(self):
2490      version = self.backend.platform_version
2491      logging.info("platform_version:\n%s", version)
2492      if self.backend.platform == "cpu":
2493        self.assertEqual(version, "<unknown>")
2494      elif self.backend.platform == "gpu":
2495        # Following is false if not built with --config=cuda
2496        if test_util.is_gpu_available(cuda_only=True):
2497          self.assertTrue(
2498              re.match(r"^cuda \d{4,}$", version),
2499              msg=f"Expected CUDA version string; got {repr(version)}")
2500        else:
2501          self.assertEqual(version, "<unknown>")
2502      elif self.backend.platform == "tpu" and not cloud_tpu:
2503        self.assertIn("tpu", version.lower())
2504        self.assertIn("cl/", version)
2505
2506    @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented")
2507    def testExecutableSerialization(self):
2508      if self.backend.platform != "tpu":
2509        self.skipTest("Test requires tpu platform")
2510
2511      c = self._NewComputation()
2512      ops.Add(
2513          ops.Constant(c, NumpyArrayS32([1, 2])),
2514          ops.Constant(c, NumpyArrayS32([3, 4])))
2515
2516      options = xla_client.CompileOptions()
2517      executable = self.backend.compile(c.build(), options)
2518      self.assertLen(executable.hlo_modules(), 1)
2519
2520      serialized = self.backend.serialize_executable(executable)
2521      deserialized = self.backend.deserialize_executable(
2522          serialized,
2523          executable.hlo_modules()[0], options)
2524
2525      expected, = xla_client.execute_with_python_values(executable, (),
2526                                                        self.backend)
2527      actual, = xla_client.execute_with_python_values(deserialized, (),
2528                                                      self.backend)
2529      self.assertTrue(np.all(actual == expected))
2530
2531  tests.append(ClientTest)
2532
2533  # TODO(b/182461453): Add TFRT and cloud TPU implementation of
2534  # ReadDynamicShapes
2535  class DynamicReshapeTest(ComputationTest):
2536    """Tests related to DynamicReshape."""
2537
2538    def _CompareToPyAndBufferProtocol(self, builder, args, expected_results,
2539                                      test_fn):
2540      compiled = self.backend.compile(builder.build())
2541      output_buffers = compiled.execute([
2542          self.backend.buffer_from_pyval(
2543              arg, device=compiled.local_devices()[0]) for arg in args
2544      ])
2545      self.assertLen(output_buffers, len(expected_results))
2546      for buf, expected in zip(output_buffers, expected_results):
2547        to_py_result = buf.to_py()
2548        self.assertEqual(expected.shape, to_py_result.shape)
2549        test_fn(expected, to_py_result)
2550        if self.backend.platform == "cpu" and buf.dtype != bfloat16:
2551          mview = memoryview(buf)
2552          self.assertEqual(expected.shape, mview.shape)
2553          test_fn(expected, np.asarray(mview))
2554        else:
2555          # Buffer protocol expected to fail on non-cpu platforms and bfloat16
2556          # Note that np.asarray(buf) doesn't throw an exception. To test if the
2557          # error was thrown properly we must use memoryview(buf).
2558          with self.assertRaises(BufferError):
2559            memoryview(buf)
2560
2561    # 1D reshape of full size, half size, and size of 0.
2562    @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented")
2563    @parameterized.parameters((5), (3), (0))
2564    def testReshape1D(self, reshape_size):
2565      full_size = 5
2566      c = self._NewComputation()
2567      arg = np.array(reshape_size, dtype=np.int32)
2568      expected = np.array(range(reshape_size), dtype=np.int32)
2569      p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg))
2570      ops.DynamicReshape(
2571          ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size],
2572          [True])
2573      self._CompareToPyAndBufferProtocol(c, [arg], [expected],
2574                                         np.testing.assert_equal)
2575
2576    # 2D reshape with an slice on the minor dimension.  We test different types
2577    # where the strides may differ between the host and devices. The reshaped
2578    # physical memory layout is not consecutive, and we test if the program can
2579    # return the correct logical view of the data.
2580    @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented")
2581    @parameterized.named_parameters({
2582        "testcase_name": "_{}".format(dtype.__name__),
2583        "dtype": dtype,
2584    } for dtype in int_dtypes + float_dtypes)
2585    def testReshape2D(self, dtype):
2586      arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
2587      arg1 = np.array(2, dtype=np.int32)
2588      expected = np.array([[1, 2], [4, 5]], dtype=np.int32)
2589      c = self._NewComputation()
2590      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
2591      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
2592      ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True])
2593      self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected],
2594                                         np.testing.assert_equal)
2595
2596    @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented")
2597    @parameterized.named_parameters({
2598        "testcase_name": "_{}".format(dtype.__name__),
2599        "dtype": dtype,
2600    } for dtype in int_dtypes + float_dtypes)
2601    def testDynamicShapeArgs(self, dtype):
2602      full_size = 10
2603      dynamic_shape_size = 4
2604      # subcomputation 1
2605      binary_add_builder = self._NewComputation()
2606      scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype))
2607      ops.Add(
2608          ops.Parameter(binary_add_builder, 0, scalar_shape),
2609          ops.Parameter(binary_add_builder, 1, scalar_shape))
2610      # subcomputation 2
2611      reshape_reduce_builder = self._NewComputation()
2612      dshape = xla_client.Shape.array_shape(
2613          np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True])
2614      reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape)
2615      ops.Reduce(
2616          reshape_reduce_builder,
2617          operands=[reshape_reduce_p],
2618          init_values=[ops.Constant(reshape_reduce_builder, dtype(0))],
2619          computation=binary_add_builder.build(),
2620          dimensions_to_reduce=[0])
2621      # main computation: sum(range(full_size)[:dynamic_shape_size])
2622      c = self._NewComputation()
2623      arg = np.array(dynamic_shape_size, dtype=np.int32)
2624      p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg))
2625      reshaped = ops.DynamicReshape(
2626          ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p],
2627          [full_size], [True])
2628      ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,))
2629      self._ExecuteAndCompareClose(c, [arg], [dtype(6)])
2630
2631  tests.append(DynamicReshapeTest)
2632
2633  class DeviceAssignmentTest(ComputationTest):
2634
2635    def testSerialize(self):
2636      shape = (3, 4)
2637      device_assignment = xla_client.DeviceAssignment.create(
2638          np.arange(np.prod(shape)).reshape(*shape))
2639      self.assertEqual(device_assignment.replica_count(), shape[0])
2640      self.assertEqual(device_assignment.computation_count(), shape[1])
2641      serialized = device_assignment.serialize()
2642      self.assertIsInstance(serialized, bytes)
2643      self.assertNotEmpty(serialized)
2644
2645  tests.append(DeviceAssignmentTest)
2646
2647  class TokenTest(ComputationTest):
2648    """Tests related to PyToken."""
2649
2650    def testExecuteWithToken(self):
2651      c = self._NewComputation()
2652      ops.Mul(
2653          ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)),
2654          ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32)))
2655      compiled_c = self.backend.compile(c.build())
2656      results, token = compiled_c.execute_with_token([])
2657      token.block_until_ready()
2658      self.assertLen(results, 1)
2659      np.testing.assert_allclose(
2660          results[0].to_py(), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3)
2661
2662    def testExecuteShardedOnLocalDevicesWithTokens(self):
2663      c = self._NewComputation()
2664      ops.Mul(
2665          ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)),
2666          ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32)))
2667      num_replicas = 1
2668      options = xla_client.CompileOptions()
2669      options.num_replicas = num_replicas
2670      compiled_c = self.backend.compile(c.build(), compile_options=options)
2671      results, sharded_token = compiled_c.execute_sharded_on_local_devices_with_tokens(
2672          [])
2673      sharded_token.block_until_ready()
2674      self.assertLen(results, 1)
2675      self.assertLen(results[0], 1)
2676      np.testing.assert_allclose(
2677          results[0][0].to_py(), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3)
2678
2679  tests.append(TokenTest)
2680
2681  class HostCallbackTest(ComputationTest):
2682    """Tests related to HostCallback."""
2683
2684    @unittest.skipIf(not tfrt_tpu, "not implemented")
2685    def testHostCallback(self):
2686
2687      c = self._NewComputation()
2688      token = ops.CreateToken(c)
2689
2690      frontend_attributes = xla_client._xla.FrontendAttributes()
2691      frontend_attributes["_xla_host_transfer_rendezvous"] = "undef"
2692      frontend_attributes["_xla_host_transfer_original_type"] = "u32"
2693      frontend_attributes["_xla_host_transfer_is_lower_bits"] = "false"
2694      frontend_attributes["_xla_host_transfer_handler_name"] = "undef"
2695      c.set_frontend_attributes(frontend_attributes)
2696
2697      send_channel_handle = self.backend.create_channel_handle()
2698      send_channel_handle.type = (
2699          xla_client._xla.ChannelHandle_ChannelType.DEVICE_TO_HOST)
2700      send_channel_handle.handle = 1
2701      ops.SendToHost(
2702          ops.Constant(c, np.float32(1.25)),
2703          token,
2704          shape_with_layout=xla_client.Shape.scalar_shape(np.dtype(np.float32)),
2705          handle=send_channel_handle)
2706
2707      recv_channel_handle = self.backend.create_channel_handle()
2708      recv_channel_handle.type = (
2709          xla_client._xla.ChannelHandle_ChannelType.HOST_TO_DEVICE)
2710      recv_channel_handle.handle = 2
2711      data = ops.RecvFromHost(
2712          token,
2713          shape=xla_client.Shape.scalar_shape(np.dtype(np.float32)),
2714          handle=recv_channel_handle)
2715      ops.GetTupleElement(data, 0)
2716
2717      def Identity(x):
2718        return (x,)
2719
2720      host_callback = self.backend.make_python_callback_from_host_send_and_recv(
2721          Identity,
2722          operand_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.float32))],
2723          result_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.float32))],
2724          send_channel_ids=[1],
2725          recv_channel_ids=[2])
2726
2727      compiled_c = self.backend.compile(
2728          c.build(), host_callbacks=[host_callback])
2729      c.clear_frontend_attributes()
2730
2731      results = compiled_c.execute([])
2732      self.assertLen(results, 1)
2733
2734      np.testing.assert_equal(results[0].to_py(), np.float32(1.25))
2735
2736  tests.append(HostCallbackTest)
2737
2738  class HostCallbackMultiReplicaTest(ComputationTest):
2739    """Tests related to HostCallback for multi-replica execution."""
2740
2741    @unittest.skipIf(not tfrt_tpu, "not implemented")
2742    def testHostCallbackMultiReplica(self):
2743
2744      c = self._NewComputation()
2745      token = ops.CreateToken(c)
2746
2747      frontend_attributes = xla_client._xla.FrontendAttributes()
2748      frontend_attributes["_xla_host_transfer_rendezvous"] = "undef"
2749      frontend_attributes["_xla_host_transfer_original_type"] = "u32"
2750      frontend_attributes["_xla_host_transfer_is_lower_bits"] = "false"
2751      frontend_attributes["_xla_host_transfer_handler_name"] = "undef"
2752      c.set_frontend_attributes(frontend_attributes)
2753
2754      send_channel_handle = self.backend.create_channel_handle()
2755      send_channel_handle.type = (
2756          xla_client._xla.ChannelHandle_ChannelType.DEVICE_TO_HOST)
2757      send_channel_handle.handle = 1
2758      ops.SendToHost(
2759          ops.ReplicaId(c),
2760          token,
2761          shape_with_layout=xla_client.Shape.scalar_shape(np.dtype(np.uint32)),
2762          handle=send_channel_handle)
2763
2764      recv_channel_handle = self.backend.create_channel_handle()
2765      recv_channel_handle.type = (
2766          xla_client._xla.ChannelHandle_ChannelType.HOST_TO_DEVICE)
2767      recv_channel_handle.handle = 2
2768      data = ops.RecvFromHost(
2769          token,
2770          shape=xla_client.Shape.scalar_shape(np.dtype(np.uint32)),
2771          handle=recv_channel_handle)
2772      ops.GetTupleElement(data, 0)
2773
2774      def Identity(x):
2775        return (x,)
2776
2777      host_callback = self.backend.make_python_callback_from_host_send_and_recv(
2778          Identity,
2779          operand_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.uint32))],
2780          result_shapes=[xla_client.Shape.scalar_shape(np.dtype(np.uint32))],
2781          send_channel_ids=[1],
2782          recv_channel_ids=[2])
2783
2784      num_replicas = 2
2785      options = xla_client.CompileOptions()
2786      options.num_replicas = num_replicas
2787      compiled_c = self.backend.compile(
2788          c.build(), compile_options=options, host_callbacks=[host_callback])
2789      c.clear_frontend_attributes()
2790
2791      results = compiled_c.execute_sharded_on_local_devices([])
2792      self.assertLen(results, 1)
2793      self.assertLen(results[0], num_replicas)
2794
2795      for i in range(num_replicas):
2796        np.testing.assert_equal(results[0][i].to_py(), np.uint32(i))
2797
2798  tests.append(HostCallbackMultiReplicaTest)
2799
2800  class ExecutePortableTest(ComputationTest):
2801
2802    def testExecutePortable(self):
2803      devices_by_kind = collections.defaultdict(list)
2804      for device in self.backend.devices():
2805        devices_by_kind[device.device_kind].append(device)
2806      multi_devices = [d for d in devices_by_kind.values() if len(d) > 1]
2807      if not multi_devices:
2808        raise unittest.SkipTest("Test needs multiple identical devices")
2809      devices = multi_devices[0]
2810
2811      c = self._NewComputation()
2812      args = [
2813          np.array(3, dtype=np.int32),
2814          np.array([10, 15, -2, 7], dtype=np.int32)
2815      ]
2816      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(args[0]))
2817      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(args[1]))
2818      ops.Mul(p0, p1)
2819      options = xla_client.CompileOptions()
2820      options.compile_portable_executable = True
2821      compiled_c = self.backend.compile(c.build(), compile_options=options)
2822      for device in devices:
2823        out, = compiled_c.execute(
2824            [self.backend.buffer_from_pyval(a, device=device) for a in args],
2825            device=device)
2826        np.testing.assert_array_equal(out.to_py(), args[0] * args[1])
2827
2828  tests.append(ExecutePortableTest)
2829
2830  return tests
2831
2832
2833def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw):
2834  # Avoid creating a new backend per test (this causes GPU OOM, and is probably
2835  # inefficient).
2836  backend_fn = functools.lru_cache(maxsize=None)(backend_fn)
2837  for klass in TestFactory(backend_fn, **kw):
2838    test = type(test_prefix + klass.__name__, (klass,), {})
2839    # Clean up the qualified names of the tests to not include the test factory.
2840    test.__qualname__ = test.__name__
2841    globals_dict[test.__name__] = test
2842
2843
2844backends = {
2845    "cpu": xla_client.make_cpu_client,
2846    "gpu": xla_client.make_gpu_client,
2847}
2848
2849if __name__ == "__main__":
2850  flags.DEFINE_string("backend", "cpu", "Target platform.")
2851  # pylint: disable=unnecessary-lambda
2852  InstantiateTests(globals(), lambda: backends[FLAGS.backend]())
2853  # pylint: enable=unnecessary-lambda
2854  absltest.main()
2855