1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""CSR sparse matrix tests.""" 16 17import itertools 18 19import numpy as np 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gradient_checker 25from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_grad # pylint: disable=unused-import 26from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops 27from tensorflow.python.platform import test 28from tensorflow.python.platform import tf_logging 29 30 31def _add_test(test, op_name, testcase_name, fn): # pylint: disable=redefined-outer-name 32 if fn is None: 33 return 34 test_name = "_".join(["test", op_name, testcase_name]) 35 if hasattr(test, test_name): 36 raise RuntimeError("Test %s defined more than once" % test_name) 37 setattr(test, test_name, fn) 38 39 40class CSRSparseMatrixDenseMatMulGradTest(test.TestCase): 41 42 @classmethod 43 def setUpClass(cls): 44 super(CSRSparseMatrixDenseMatMulGradTest, cls).setUpClass() 45 cls._gpu_available = test_util.is_gpu_available() 46 47 # TODO(penporn): Make these tests runnable on eager mode. 48 # (tf.gradients and gradient_checker only run in graph mode.) 49 @test_util.run_deprecated_v1 50 def _testLargeBatchSparseMatrixMatMulGrad( 51 self, 52 datatype, 53 transpose_a, 54 transpose_b, 55 adjoint_a, 56 adjoint_b, 57 transpose_output, 58 conjugate_output, 59 batched_inputs, 60 ): 61 if batched_inputs: 62 a_shape = (3, 5, 11) 63 b_shape = (3, 11, 13) 64 transpose = lambda x: np.transpose(x, (0, 2, 1)) 65 else: 66 a_shape = (5, 11) 67 b_shape = (11, 13) 68 transpose = np.transpose 69 70 sparsify = lambda m: m * (m > 0) 71 a_mats_val = sparsify( 72 np.random.randn(*a_shape) + 73 1.j * np.random.randn(*a_shape)).astype(datatype) 74 if transpose_a or adjoint_a: 75 a_mats_val = transpose(a_mats_val) 76 if adjoint_a: 77 a_mats_val = np.conj(a_mats_val) 78 b_mats_val = (np.random.randn(*b_shape) + 79 1.j * np.random.randn(*b_shape)).astype(datatype) 80 if transpose_b or adjoint_b: 81 b_mats_val = transpose(b_mats_val) 82 if adjoint_b: 83 b_mats_val = np.conj(b_mats_val) 84 with self.test_session(): 85 a_mats = ops.convert_to_tensor(a_mats_val, dtype=datatype) 86 b_mats = ops.convert_to_tensor(b_mats_val, dtype=datatype) 87 locs = array_ops.where(abs(a_mats_val) > 0) 88 a_sm = sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(a_mats, locs) 89 c_mats = sparse_csr_matrix_ops.sparse_matrix_mat_mul( 90 a_sm, 91 b_mats, 92 transpose_a=transpose_a, 93 transpose_b=transpose_b, 94 adjoint_a=adjoint_a, 95 adjoint_b=adjoint_b, 96 transpose_output=transpose_output, 97 conjugate_output=conjugate_output) 98 for [ten, val, nn] in [[a_mats, a_mats_val, "a"], 99 [b_mats, b_mats_val, "b"]]: 100 tf_logging.info("Testing gradients for %s" % nn) 101 theoretical, numerical = gradient_checker.compute_gradient( 102 ten, 103 ten.get_shape().as_list(), 104 c_mats, 105 c_mats.get_shape().as_list(), 106 x_init_value=val, 107 delta=1e-3) 108 self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=1e-3) 109 110 111# These tests are refactored from sparse_csr_matrix_grad_test to keep its size 112# "medium". 113dtypes_to_test = [np.float32, np.complex64] 114for dtype in dtypes_to_test: 115 for (t_a, t_b, adj_a, adj_b, t_out, 116 conj_out, batched) in itertools.product(*(([False, True],) * 7)): 117 118 def create_mat_mul_test_fn(dtype_, t_a_, t_b_, adj_a_, adj_b_, t_out_, 119 conj_out_, batched_): 120 # Skip invalid cases. 121 if (t_a_ and adj_a_) or (t_b_ and adj_b_): 122 return 123 # Skip cases where we conjugate real matrices. 124 if dtype_ == np.float32 and (adj_a_ or adj_b_ or conj_out_): 125 return 126 127 def test_fn(self): 128 self._testLargeBatchSparseMatrixMatMulGrad(dtype_, t_a_, t_b_, adj_a_, 129 adj_b_, t_out_, conj_out_, 130 batched_) 131 132 return test_fn 133 134 name = ( 135 "_testLargeBatchSparseMatrixMatMulGrad_dtype_%s_t_a_%s_t_b_%s_adj_a_%s_" 136 "adj_b_%s_t_out_%s_conj_out_%s_batched_%s" % 137 (dtype.__name__, t_a, t_b, adj_a, adj_b, t_out, conj_out, batched)) 138 139 _add_test( 140 CSRSparseMatrixDenseMatMulGradTest, "CSRSparseMatrixGradTest", name, 141 create_mat_mul_test_fn(dtype, t_a, t_b, adj_a, adj_b, t_out, conj_out, 142 batched)) 143 144if __name__ == "__main__": 145 test.main() 146