1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CSR sparse matrix tests."""
16
17import itertools
18
19import numpy as np
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gradient_checker
25from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_grad  # pylint: disable=unused-import
26from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
27from tensorflow.python.platform import test
28from tensorflow.python.platform import tf_logging
29
30
31def _add_test(test, op_name, testcase_name, fn):  # pylint: disable=redefined-outer-name
32  if fn is None:
33    return
34  test_name = "_".join(["test", op_name, testcase_name])
35  if hasattr(test, test_name):
36    raise RuntimeError("Test %s defined more than once" % test_name)
37  setattr(test, test_name, fn)
38
39
40class CSRSparseMatrixDenseMatMulGradTest(test.TestCase):
41
42  @classmethod
43  def setUpClass(cls):
44    super(CSRSparseMatrixDenseMatMulGradTest, cls).setUpClass()
45    cls._gpu_available = test_util.is_gpu_available()
46
47  # TODO(penporn): Make these tests runnable on eager mode.
48  # (tf.gradients and gradient_checker only run in graph mode.)
49  @test_util.run_deprecated_v1
50  def _testLargeBatchSparseMatrixMatMulGrad(
51      self,
52      datatype,
53      transpose_a,
54      transpose_b,
55      adjoint_a,
56      adjoint_b,
57      transpose_output,
58      conjugate_output,
59      batched_inputs,
60  ):
61    if batched_inputs:
62      a_shape = (3, 5, 11)
63      b_shape = (3, 11, 13)
64      transpose = lambda x: np.transpose(x, (0, 2, 1))
65    else:
66      a_shape = (5, 11)
67      b_shape = (11, 13)
68      transpose = np.transpose
69
70    sparsify = lambda m: m * (m > 0)
71    a_mats_val = sparsify(
72        np.random.randn(*a_shape) +
73        1.j * np.random.randn(*a_shape)).astype(datatype)
74    if transpose_a or adjoint_a:
75      a_mats_val = transpose(a_mats_val)
76    if adjoint_a:
77      a_mats_val = np.conj(a_mats_val)
78    b_mats_val = (np.random.randn(*b_shape) +
79                  1.j * np.random.randn(*b_shape)).astype(datatype)
80    if transpose_b or adjoint_b:
81      b_mats_val = transpose(b_mats_val)
82    if adjoint_b:
83      b_mats_val = np.conj(b_mats_val)
84    with self.test_session():
85      a_mats = ops.convert_to_tensor(a_mats_val, dtype=datatype)
86      b_mats = ops.convert_to_tensor(b_mats_val, dtype=datatype)
87      locs = array_ops.where(abs(a_mats_val) > 0)
88      a_sm = sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(a_mats, locs)
89      c_mats = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
90          a_sm,
91          b_mats,
92          transpose_a=transpose_a,
93          transpose_b=transpose_b,
94          adjoint_a=adjoint_a,
95          adjoint_b=adjoint_b,
96          transpose_output=transpose_output,
97          conjugate_output=conjugate_output)
98      for [ten, val, nn] in [[a_mats, a_mats_val, "a"],
99                             [b_mats, b_mats_val, "b"]]:
100        tf_logging.info("Testing gradients for %s" % nn)
101        theoretical, numerical = gradient_checker.compute_gradient(
102            ten,
103            ten.get_shape().as_list(),
104            c_mats,
105            c_mats.get_shape().as_list(),
106            x_init_value=val,
107            delta=1e-3)
108        self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=1e-3)
109
110
111# These tests are refactored from sparse_csr_matrix_grad_test to keep its size
112# "medium".
113dtypes_to_test = [np.float32, np.complex64]
114for dtype in dtypes_to_test:
115  for (t_a, t_b, adj_a, adj_b, t_out,
116       conj_out, batched) in itertools.product(*(([False, True],) * 7)):
117
118    def create_mat_mul_test_fn(dtype_, t_a_, t_b_, adj_a_, adj_b_, t_out_,
119                               conj_out_, batched_):
120      # Skip invalid cases.
121      if (t_a_ and adj_a_) or (t_b_ and adj_b_):
122        return
123      # Skip cases where we conjugate real matrices.
124      if dtype_ == np.float32 and (adj_a_ or adj_b_ or conj_out_):
125        return
126
127      def test_fn(self):
128        self._testLargeBatchSparseMatrixMatMulGrad(dtype_, t_a_, t_b_, adj_a_,
129                                                   adj_b_, t_out_, conj_out_,
130                                                   batched_)
131
132      return test_fn
133
134    name = (
135        "_testLargeBatchSparseMatrixMatMulGrad_dtype_%s_t_a_%s_t_b_%s_adj_a_%s_"
136        "adj_b_%s_t_out_%s_conj_out_%s_batched_%s" %
137        (dtype.__name__, t_a, t_b, adj_a, adj_b, t_out, conj_out, batched))
138
139    _add_test(
140        CSRSparseMatrixDenseMatMulGradTest, "CSRSparseMatrixGradTest", name,
141        create_mat_mul_test_fn(dtype, t_a, t_b, adj_a, adj_b, t_out, conj_out,
142                               batched))
143
144if __name__ == "__main__":
145  test.main()
146