xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/nary_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for operators with > 3 or arbitrary numbers of arguments."""
16
17import unittest
18
19import numpy as np
20
21from tensorflow.compiler.tests import xla_test
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.platform import googletest
27
28
29class NAryOpsTest(xla_test.XLATestCase):
30
31  def _testNAry(self, op, args, expected, equality_fn=None):
32    with self.session() as session:
33      with self.test_scope():
34        placeholders = [
35            array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
36            for arg in args
37        ]
38        feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
39        output = op(placeholders)
40      result = session.run(output, feeds)
41      if not equality_fn:
42        equality_fn = self.assertAllClose
43      equality_fn(result, expected, rtol=1e-3)
44
45  def _nAryListCheck(self, results, expected, **kwargs):
46    self.assertEqual(len(results), len(expected))
47    for (r, e) in zip(results, expected):
48      self.assertAllClose(r, e, **kwargs)
49
50  def _testNAryLists(self, op, args, expected):
51    self._testNAry(op, args, expected, equality_fn=self._nAryListCheck)
52
53  def testFloat(self):
54    self._testNAry(math_ops.add_n,
55                   [np.array([[1, 2, 3]], dtype=np.float32)],
56                   expected=np.array([[1, 2, 3]], dtype=np.float32))
57
58    self._testNAry(math_ops.add_n,
59                   [np.array([1, 2], dtype=np.float32),
60                    np.array([10, 20], dtype=np.float32)],
61                   expected=np.array([11, 22], dtype=np.float32))
62    self._testNAry(math_ops.add_n,
63                   [np.array([-4], dtype=np.float32),
64                    np.array([10], dtype=np.float32),
65                    np.array([42], dtype=np.float32)],
66                   expected=np.array([48], dtype=np.float32))
67
68  def testComplex(self):
69    for dtype in self.complex_types:
70      self._testNAry(
71          math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)],
72          expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype))
73
74      self._testNAry(
75          math_ops.add_n, [
76              np.array([1 + 2j, 2 - 3j], dtype=dtype),
77              np.array([10j, 20], dtype=dtype)
78          ],
79          expected=np.array([1 + 12j, 22 - 3j], dtype=dtype))
80      self._testNAry(
81          math_ops.add_n, [
82              np.array([-4, 5j], dtype=dtype),
83              np.array([2 + 10j, -2], dtype=dtype),
84              np.array([42j, 3 + 3j], dtype=dtype)
85          ],
86          expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype))
87
88  @unittest.skip("IdentityN is temporarily CompilationOnly as workaround")
89  def testIdentityN(self):
90    self._testNAryLists(array_ops.identity_n,
91                        [np.array([[1, 2, 3]], dtype=np.float32)],
92                        expected=[np.array([[1, 2, 3]], dtype=np.float32)])
93    self._testNAryLists(array_ops.identity_n,
94                        [np.array([[1, 2], [3, 4]], dtype=np.float32),
95                         np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)],
96                        expected=[
97                            np.array([[1, 2], [3, 4]], dtype=np.float32),
98                            np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)])
99    self._testNAryLists(array_ops.identity_n,
100                        [np.array([[1], [2], [3], [4]], dtype=np.int32),
101                         np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)],
102                        expected=[
103                            np.array([[1], [2], [3], [4]], dtype=np.int32),
104                            np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)])
105
106  def testConcat(self):
107    self._testNAry(
108        lambda x: array_ops.concat(x, 0), [
109            np.array(
110                [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
111                    [[7, 8, 9], [10, 11, 12]], dtype=np.float32)
112        ],
113        expected=np.array(
114            [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32))
115
116    self._testNAry(
117        lambda x: array_ops.concat(x, 1), [
118            np.array(
119                [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
120                    [[7, 8, 9], [10, 11, 12]], dtype=np.float32)
121        ],
122        expected=np.array(
123            [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
124
125  def testOneHot(self):
126    with self.session() as session, self.test_scope():
127      indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32))
128      op = array_ops.one_hot(indices,
129                             np.int32(4),
130                             on_value=np.float32(7), off_value=np.float32(3))
131      output = session.run(op)
132      expected = np.array([[[3, 3, 7, 3], [3, 3, 3, 7]],
133                           [[7, 3, 3, 3], [3, 7, 3, 3]]],
134                          dtype=np.float32)
135      self.assertAllEqual(output, expected)
136
137      op = array_ops.one_hot(indices,
138                             np.int32(4),
139                             on_value=np.int32(2), off_value=np.int32(1),
140                             axis=1)
141      output = session.run(op)
142      expected = np.array([[[1, 1], [1, 1], [2, 1], [1, 2]],
143                           [[2, 1], [1, 2], [1, 1], [1, 1]]],
144                          dtype=np.int32)
145      self.assertAllEqual(output, expected)
146
147  def testSplitV(self):
148    with self.session() as session:
149      with self.test_scope():
150        output = session.run(
151            array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
152                                     dtype=np.float32),
153                            [2, 2], 1))
154        expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32),
155                    np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
156        self.assertAllEqual(output, expected)
157
158  def testSplitVNegativeSizes(self):
159    with self.session() as session:
160      with self.test_scope():
161        with self.assertRaisesRegexp(
162            (ValueError, errors.InvalidArgumentError),
163            "Split size at index 1 must be >= .*. Got: -2"):
164          _ = session.run(
165              array_ops.split(np.array([1, 2, 3], dtype=np.float32), [-1, -2],
166                              axis=0))
167
168  def testStridedSlice(self):
169    self._testNAry(lambda x: array_ops.strided_slice(*x),
170                   [np.array([[], [], []], dtype=np.float32),
171                    np.array([1, 0], dtype=np.int32),
172                    np.array([3, 0], dtype=np.int32),
173                    np.array([1, 1], dtype=np.int32)],
174                   expected=np.array([[], []], dtype=np.float32))
175
176    if np.int64 in self.int_types:
177      self._testNAry(
178          lambda x: array_ops.strided_slice(*x), [
179              np.array([[], [], []], dtype=np.float32), np.array(
180                  [1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64),
181              np.array([1, 1], dtype=np.int64)
182          ],
183          expected=np.array([[], []], dtype=np.float32))
184
185    self._testNAry(lambda x: array_ops.strided_slice(*x),
186                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
187                             dtype=np.float32),
188                    np.array([1, 1], dtype=np.int32),
189                    np.array([3, 3], dtype=np.int32),
190                    np.array([1, 1], dtype=np.int32)],
191                   expected=np.array([[5, 6], [8, 9]], dtype=np.float32))
192
193    self._testNAry(lambda x: array_ops.strided_slice(*x),
194                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
195                             dtype=np.float32),
196                    np.array([0, 2], dtype=np.int32),
197                    np.array([2, 0], dtype=np.int32),
198                    np.array([1, -1], dtype=np.int32)],
199                   expected=np.array([[3, 2], [6, 5]], dtype=np.float32))
200
201    self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1],
202                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
203                             dtype=np.float32)],
204                   expected=np.array([[[3, 2, 1]], [[6, 5, 4]]],
205                                     dtype=np.float32))
206
207    self._testNAry(lambda x: x[0][1, :, array_ops.newaxis],
208                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
209                             dtype=np.float32)],
210                   expected=np.array([[4], [5], [6]], dtype=np.float32))
211
212  def testStridedSliceGrad(self):
213    # Tests cases where input shape is empty.
214    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
215                   [np.array([], dtype=np.int32),
216                    np.array([], dtype=np.int32),
217                    np.array([], dtype=np.int32),
218                    np.array([], dtype=np.int32),
219                    np.float32(0.5)],
220                   expected=np.array(np.float32(0.5), dtype=np.float32))
221
222    # Tests case where input shape is non-empty, but gradients are empty.
223    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
224                   [np.array([3], dtype=np.int32),
225                    np.array([0], dtype=np.int32),
226                    np.array([0], dtype=np.int32),
227                    np.array([1], dtype=np.int32),
228                    np.array([], dtype=np.float32)],
229                   expected=np.array([0, 0, 0], dtype=np.float32))
230
231    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
232                   [np.array([3, 0], dtype=np.int32),
233                    np.array([1, 0], dtype=np.int32),
234                    np.array([3, 0], dtype=np.int32),
235                    np.array([1, 1], dtype=np.int32),
236                    np.array([[], []], dtype=np.float32)],
237                   expected=np.array([[], [], []], dtype=np.float32))
238
239    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
240                   [np.array([3, 3], dtype=np.int32),
241                    np.array([1, 1], dtype=np.int32),
242                    np.array([3, 3], dtype=np.int32),
243                    np.array([1, 1], dtype=np.int32),
244                    np.array([[5, 6], [8, 9]], dtype=np.float32)],
245                   expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]],
246                                     dtype=np.float32))
247
248    def ssg_test(x):
249      return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4,
250                                          new_axis_mask=0x1)
251
252    self._testNAry(ssg_test,
253                   [np.array([3, 1, 3], dtype=np.int32),
254                    np.array([0, 0, 0, 2], dtype=np.int32),
255                    np.array([0, 3, 1, -4], dtype=np.int32),
256                    np.array([1, 2, 1, -3], dtype=np.int32),
257                    np.array([[[1], [2]]], dtype=np.float32)],
258                   expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]],
259                                     dtype=np.float32))
260
261    ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15)
262    self._testNAry(ssg_test2,
263                   [np.array([4, 4], dtype=np.int32),
264                    np.array([0, 0, 0, 1, 0], dtype=np.int32),
265                    np.array([0, 3, 0, 4, 0], dtype=np.int32),
266                    np.array([1, 2, 1, 2, 1], dtype=np.int32),
267                    np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)],
268                   expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4],
269                                      [0, 0, 0, 0]], dtype=np.float32))
270
271    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
272                   [np.array([3, 3], dtype=np.int32),
273                    np.array([0, 2], dtype=np.int32),
274                    np.array([2, 0], dtype=np.int32),
275                    np.array([1, -1], dtype=np.int32),
276                    np.array([[1, 2], [3, 4]], dtype=np.float32)],
277                   expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]],
278                                     dtype=np.float32))
279
280    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
281                   [np.array([3, 3], dtype=np.int32),
282                    np.array([2, 2], dtype=np.int32),
283                    np.array([0, 1], dtype=np.int32),
284                    np.array([-1, -2], dtype=np.int32),
285                    np.array([[1], [2]], dtype=np.float32)],
286                   expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]],
287                                     dtype=np.float32))
288
289if __name__ == "__main__":
290  googletest.main()
291