xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/session_debug_multi_gpu_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for debugger functionalities under multiple (i.e., >1) GPUs."""
16import os
17import tempfile
18
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.python.client import device_lib
21from tensorflow.python.client import session
22from tensorflow.python.debug.lib import debug_data
23from tensorflow.python.debug.lib import debug_utils
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.lib.io import file_io
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import googletest
31
32
33class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
34
35  def setUp(self):
36    self._dump_root = tempfile.mkdtemp()
37
38  def tearDown(self):
39    ops.reset_default_graph()
40
41    # Tear down temporary dump directory.
42    if os.path.isdir(self._dump_root):
43      file_io.delete_recursively(self._dump_root)
44
45  def testMultiGPUSessionRun(self):
46    local_devices = device_lib.list_local_devices()
47    gpu_device_names = []
48    for device in local_devices:
49      if device.device_type == "GPU":
50        gpu_device_names.append(device.name)
51    gpu_device_names = sorted(gpu_device_names)
52
53    if len(gpu_device_names) < 2:
54      self.skipTest(
55          "This test requires at least 2 GPUs, but only %d is available." %
56          len(gpu_device_names))
57
58    with session.Session() as sess:
59      v = variables.Variable([10.0, 15.0], dtype=dtypes.float32, name="v")
60      with ops.device(gpu_device_names[0]):
61        u0 = math_ops.add(v, v, name="u0")
62      with ops.device(gpu_device_names[1]):
63        u1 = math_ops.multiply(v, v, name="u1")
64      w = math_ops.subtract(u1, u0, name="w")
65
66      self.evaluate(v.initializer)
67
68      run_options = config_pb2.RunOptions(output_partition_graphs=True)
69      debug_utils.watch_graph(run_options, sess.graph,
70                              debug_urls="file://" + self._dump_root)
71      run_metadata = config_pb2.RunMetadata()
72      self.assertAllClose(
73          [80.0, 195.0],
74          sess.run(w, options=run_options, run_metadata=run_metadata))
75
76      debug_dump_dir = debug_data.DebugDumpDir(
77          self._dump_root, partition_graphs=run_metadata.partition_graphs)
78      self.assertEqual(3, len(debug_dump_dir.devices()))
79      self.assertAllClose(
80          [10.0, 15.0], debug_dump_dir.get_tensors("v", 0, "DebugIdentity")[0])
81      self.assertAllClose(
82          [20.0, 30.0], debug_dump_dir.get_tensors("u0", 0, "DebugIdentity")[0])
83      self.assertAllClose(
84          [100.0, 225.0],
85          debug_dump_dir.get_tensors("u1", 0, "DebugIdentity")[0])
86
87
88if __name__ == "__main__":
89  googletest.main()
90