1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for operators with > 3 or arbitrary numbers of arguments.""" 16 17import unittest 18 19import numpy as np 20 21from tensorflow.compiler.tests import xla_test 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import errors 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.platform import googletest 27 28 29class NAryOpsTest(xla_test.XLATestCase): 30 31 def _testNAry(self, op, args, expected, equality_fn=None): 32 with self.session() as session: 33 with self.test_scope(): 34 placeholders = [ 35 array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) 36 for arg in args 37 ] 38 feeds = {placeholders[i]: args[i] for i in range(0, len(args))} 39 output = op(placeholders) 40 result = session.run(output, feeds) 41 if not equality_fn: 42 equality_fn = self.assertAllClose 43 equality_fn(result, expected, rtol=1e-3) 44 45 def _nAryListCheck(self, results, expected, **kwargs): 46 self.assertEqual(len(results), len(expected)) 47 for (r, e) in zip(results, expected): 48 self.assertAllClose(r, e, **kwargs) 49 50 def _testNAryLists(self, op, args, expected): 51 self._testNAry(op, args, expected, equality_fn=self._nAryListCheck) 52 53 def testFloat(self): 54 self._testNAry(math_ops.add_n, 55 [np.array([[1, 2, 3]], dtype=np.float32)], 56 expected=np.array([[1, 2, 3]], dtype=np.float32)) 57 58 self._testNAry(math_ops.add_n, 59 [np.array([1, 2], dtype=np.float32), 60 np.array([10, 20], dtype=np.float32)], 61 expected=np.array([11, 22], dtype=np.float32)) 62 self._testNAry(math_ops.add_n, 63 [np.array([-4], dtype=np.float32), 64 np.array([10], dtype=np.float32), 65 np.array([42], dtype=np.float32)], 66 expected=np.array([48], dtype=np.float32)) 67 68 def testComplex(self): 69 for dtype in self.complex_types: 70 self._testNAry( 71 math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)], 72 expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)) 73 74 self._testNAry( 75 math_ops.add_n, [ 76 np.array([1 + 2j, 2 - 3j], dtype=dtype), 77 np.array([10j, 20], dtype=dtype) 78 ], 79 expected=np.array([1 + 12j, 22 - 3j], dtype=dtype)) 80 self._testNAry( 81 math_ops.add_n, [ 82 np.array([-4, 5j], dtype=dtype), 83 np.array([2 + 10j, -2], dtype=dtype), 84 np.array([42j, 3 + 3j], dtype=dtype) 85 ], 86 expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype)) 87 88 @unittest.skip("IdentityN is temporarily CompilationOnly as workaround") 89 def testIdentityN(self): 90 self._testNAryLists(array_ops.identity_n, 91 [np.array([[1, 2, 3]], dtype=np.float32)], 92 expected=[np.array([[1, 2, 3]], dtype=np.float32)]) 93 self._testNAryLists(array_ops.identity_n, 94 [np.array([[1, 2], [3, 4]], dtype=np.float32), 95 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], 96 expected=[ 97 np.array([[1, 2], [3, 4]], dtype=np.float32), 98 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) 99 self._testNAryLists(array_ops.identity_n, 100 [np.array([[1], [2], [3], [4]], dtype=np.int32), 101 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], 102 expected=[ 103 np.array([[1], [2], [3], [4]], dtype=np.int32), 104 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) 105 106 def testConcat(self): 107 self._testNAry( 108 lambda x: array_ops.concat(x, 0), [ 109 np.array( 110 [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array( 111 [[7, 8, 9], [10, 11, 12]], dtype=np.float32) 112 ], 113 expected=np.array( 114 [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32)) 115 116 self._testNAry( 117 lambda x: array_ops.concat(x, 1), [ 118 np.array( 119 [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array( 120 [[7, 8, 9], [10, 11, 12]], dtype=np.float32) 121 ], 122 expected=np.array( 123 [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) 124 125 def testOneHot(self): 126 with self.session() as session, self.test_scope(): 127 indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) 128 op = array_ops.one_hot(indices, 129 np.int32(4), 130 on_value=np.float32(7), off_value=np.float32(3)) 131 output = session.run(op) 132 expected = np.array([[[3, 3, 7, 3], [3, 3, 3, 7]], 133 [[7, 3, 3, 3], [3, 7, 3, 3]]], 134 dtype=np.float32) 135 self.assertAllEqual(output, expected) 136 137 op = array_ops.one_hot(indices, 138 np.int32(4), 139 on_value=np.int32(2), off_value=np.int32(1), 140 axis=1) 141 output = session.run(op) 142 expected = np.array([[[1, 1], [1, 1], [2, 1], [1, 2]], 143 [[2, 1], [1, 2], [1, 1], [1, 1]]], 144 dtype=np.int32) 145 self.assertAllEqual(output, expected) 146 147 def testSplitV(self): 148 with self.session() as session: 149 with self.test_scope(): 150 output = session.run( 151 array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], 152 dtype=np.float32), 153 [2, 2], 1)) 154 expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32), 155 np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)] 156 self.assertAllEqual(output, expected) 157 158 def testSplitVNegativeSizes(self): 159 with self.session() as session: 160 with self.test_scope(): 161 with self.assertRaisesRegexp( 162 (ValueError, errors.InvalidArgumentError), 163 "Split size at index 1 must be >= .*. Got: -2"): 164 _ = session.run( 165 array_ops.split(np.array([1, 2, 3], dtype=np.float32), [-1, -2], 166 axis=0)) 167 168 def testStridedSlice(self): 169 self._testNAry(lambda x: array_ops.strided_slice(*x), 170 [np.array([[], [], []], dtype=np.float32), 171 np.array([1, 0], dtype=np.int32), 172 np.array([3, 0], dtype=np.int32), 173 np.array([1, 1], dtype=np.int32)], 174 expected=np.array([[], []], dtype=np.float32)) 175 176 if np.int64 in self.int_types: 177 self._testNAry( 178 lambda x: array_ops.strided_slice(*x), [ 179 np.array([[], [], []], dtype=np.float32), np.array( 180 [1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64), 181 np.array([1, 1], dtype=np.int64) 182 ], 183 expected=np.array([[], []], dtype=np.float32)) 184 185 self._testNAry(lambda x: array_ops.strided_slice(*x), 186 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 187 dtype=np.float32), 188 np.array([1, 1], dtype=np.int32), 189 np.array([3, 3], dtype=np.int32), 190 np.array([1, 1], dtype=np.int32)], 191 expected=np.array([[5, 6], [8, 9]], dtype=np.float32)) 192 193 self._testNAry(lambda x: array_ops.strided_slice(*x), 194 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 195 dtype=np.float32), 196 np.array([0, 2], dtype=np.int32), 197 np.array([2, 0], dtype=np.int32), 198 np.array([1, -1], dtype=np.int32)], 199 expected=np.array([[3, 2], [6, 5]], dtype=np.float32)) 200 201 self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1], 202 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 203 dtype=np.float32)], 204 expected=np.array([[[3, 2, 1]], [[6, 5, 4]]], 205 dtype=np.float32)) 206 207 self._testNAry(lambda x: x[0][1, :, array_ops.newaxis], 208 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 209 dtype=np.float32)], 210 expected=np.array([[4], [5], [6]], dtype=np.float32)) 211 212 def testStridedSliceGrad(self): 213 # Tests cases where input shape is empty. 214 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 215 [np.array([], dtype=np.int32), 216 np.array([], dtype=np.int32), 217 np.array([], dtype=np.int32), 218 np.array([], dtype=np.int32), 219 np.float32(0.5)], 220 expected=np.array(np.float32(0.5), dtype=np.float32)) 221 222 # Tests case where input shape is non-empty, but gradients are empty. 223 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 224 [np.array([3], dtype=np.int32), 225 np.array([0], dtype=np.int32), 226 np.array([0], dtype=np.int32), 227 np.array([1], dtype=np.int32), 228 np.array([], dtype=np.float32)], 229 expected=np.array([0, 0, 0], dtype=np.float32)) 230 231 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 232 [np.array([3, 0], dtype=np.int32), 233 np.array([1, 0], dtype=np.int32), 234 np.array([3, 0], dtype=np.int32), 235 np.array([1, 1], dtype=np.int32), 236 np.array([[], []], dtype=np.float32)], 237 expected=np.array([[], [], []], dtype=np.float32)) 238 239 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 240 [np.array([3, 3], dtype=np.int32), 241 np.array([1, 1], dtype=np.int32), 242 np.array([3, 3], dtype=np.int32), 243 np.array([1, 1], dtype=np.int32), 244 np.array([[5, 6], [8, 9]], dtype=np.float32)], 245 expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]], 246 dtype=np.float32)) 247 248 def ssg_test(x): 249 return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4, 250 new_axis_mask=0x1) 251 252 self._testNAry(ssg_test, 253 [np.array([3, 1, 3], dtype=np.int32), 254 np.array([0, 0, 0, 2], dtype=np.int32), 255 np.array([0, 3, 1, -4], dtype=np.int32), 256 np.array([1, 2, 1, -3], dtype=np.int32), 257 np.array([[[1], [2]]], dtype=np.float32)], 258 expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]], 259 dtype=np.float32)) 260 261 ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15) 262 self._testNAry(ssg_test2, 263 [np.array([4, 4], dtype=np.int32), 264 np.array([0, 0, 0, 1, 0], dtype=np.int32), 265 np.array([0, 3, 0, 4, 0], dtype=np.int32), 266 np.array([1, 2, 1, 2, 1], dtype=np.int32), 267 np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)], 268 expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4], 269 [0, 0, 0, 0]], dtype=np.float32)) 270 271 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 272 [np.array([3, 3], dtype=np.int32), 273 np.array([0, 2], dtype=np.int32), 274 np.array([2, 0], dtype=np.int32), 275 np.array([1, -1], dtype=np.int32), 276 np.array([[1, 2], [3, 4]], dtype=np.float32)], 277 expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]], 278 dtype=np.float32)) 279 280 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 281 [np.array([3, 3], dtype=np.int32), 282 np.array([2, 2], dtype=np.int32), 283 np.array([0, 1], dtype=np.int32), 284 np.array([-1, -2], dtype=np.int32), 285 np.array([[1], [2]], dtype=np.float32)], 286 expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]], 287 dtype=np.float32)) 288 289if __name__ == "__main__": 290 googletest.main() 291