xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/cli_test_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Testing utilities for tfdbg command-line interface."""
16import re
17
18import numpy as np
19
20
21def assert_lines_equal_ignoring_whitespace(test, expected_lines, actual_lines):
22  """Assert equality in lines, ignoring all whitespace.
23
24  Args:
25    test: An instance of unittest.TestCase or its subtypes (e.g.,
26      TensorFlowTestCase).
27    expected_lines: Expected lines as an iterable of strings.
28    actual_lines: Actual lines as an iterable of strings.
29  """
30  test.assertEqual(
31      len(expected_lines), len(actual_lines),
32      "Mismatch in the number of lines: %d vs %d" % (
33          len(expected_lines), len(actual_lines)))
34  for expected_line, actual_line in zip(expected_lines, actual_lines):
35    test.assertEqual("".join(expected_line.split()),
36                     "".join(actual_line.split()))
37
38
39# Regular expression for separators between values in a string representation
40# of an ndarray, exclusing whitespace.
41_ARRAY_VALUE_SEPARATOR_REGEX = re.compile(r"(array|\(|\[|\]|\)|\||,)")
42
43
44def assert_array_lines_close(test, expected_array, array_lines):
45  """Assert that the array value represented by lines is close to expected.
46
47  Note that the shape of the array represented by the `array_lines` is ignored.
48
49  Args:
50    test: An instance of TensorFlowTestCase.
51    expected_array: Expected value of the array.
52    array_lines: A list of strings representing the array.
53      E.g., "array([[ 1.0, 2.0 ], [ 3.0, 4.0 ]])"
54      Assumes that values are separated by commas, parentheses, brackets, "|"
55      characters and whitespace.
56  """
57  elements = []
58  for line in array_lines:
59    line = re.sub(_ARRAY_VALUE_SEPARATOR_REGEX, " ", line)
60    elements.extend(float(s) for s in line.split())
61  test.assertAllClose(np.array(expected_array).flatten(), elements)
62