xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/distribute/distributed_file_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities that help manage directory path in distributed settings.
16
17In multi-worker training, the need to write a file to distributed file
18location often requires only one copy done by one worker despite many workers
19that are involved in training. The option to only perform saving by chief is
20not feasible for a couple of reasons: 1) Chief and workers may each contain
21a client that runs the same piece of code and it's preferred not to make
22any distinction between the code run by chief and other workers, and 2)
23saving of model or model's related information may require SyncOnRead
24variables to be read, which needs the cooperation of all workers to perform
25all-reduce.
26
27This set of utility is used so that only one copy is written to the needed
28directory, by supplying a temporary write directory path for workers that don't
29need to save, and removing the temporary directory once file writing is done.
30
31Example usage:
32```
33# Before using a directory to write file to.
34self.log_write_dir = write_dirpath(self.log_dir, get_distribution_strategy())
35# Now `self.log_write_dir` can be safely used to write file to.
36
37...
38
39# After the file is written to the directory.
40remove_temp_dirpath(self.log_dir, get_distribution_strategy())
41
42```
43
44Experimental. API is subject to change.
45"""
46
47import os
48
49from tensorflow.python.distribute import distribution_strategy_context
50from tensorflow.python.lib.io import file_io
51
52
53def _get_base_dirpath(strategy):
54  task_id = strategy.extended._task_id  # pylint: disable=protected-access
55  return 'workertemp_' + str(task_id)
56
57
58def _is_temp_dir(dirpath, strategy):
59  return dirpath.endswith(_get_base_dirpath(strategy))
60
61
62def _get_temp_dir(dirpath, strategy):
63  if _is_temp_dir(dirpath, strategy):
64    temp_dir = dirpath
65  else:
66    temp_dir = os.path.join(dirpath, _get_base_dirpath(strategy))
67  file_io.recursive_create_dir_v2(temp_dir)
68  return temp_dir
69
70
71def write_dirpath(dirpath, strategy):
72  """Returns the writing dir that should be used to save file distributedly.
73
74  `dirpath` would be created if it doesn't exist.
75
76  Args:
77    dirpath: Original dirpath that would be used without distribution.
78    strategy: The tf.distribute strategy object currently used.
79
80  Returns:
81    The writing dir path that should be used to save with distribution.
82  """
83  if strategy is None:
84    # Infer strategy from `distribution_strategy_context` if not given.
85    strategy = distribution_strategy_context.get_strategy()
86  if strategy is None:
87    # If strategy is still not available, this is not in distributed training.
88    # Fallback to original dirpath.
89    return dirpath
90  if not strategy.extended._in_multi_worker_mode():  # pylint: disable=protected-access
91    return dirpath
92  if strategy.extended.should_checkpoint:
93    return dirpath
94  # If this worker is not chief and hence should not save file, save it to a
95  # temporary directory to be removed later.
96  return _get_temp_dir(dirpath, strategy)
97
98
99def remove_temp_dirpath(dirpath, strategy):
100  """Removes the temp path after writing is finished.
101
102  Args:
103    dirpath: Original dirpath that would be used without distribution.
104    strategy: The tf.distribute strategy object currently used.
105  """
106  if strategy is None:
107    # Infer strategy from `distribution_strategy_context` if not given.
108    strategy = distribution_strategy_context.get_strategy()
109  if strategy is None:
110    # If strategy is still not available, this is not in distributed training.
111    # Fallback to no-op.
112    return
113  # TODO(anjalisridhar): Consider removing the check for multi worker mode since
114  # it is redundant when used with the should_checkpoint property.
115  if (strategy.extended._in_multi_worker_mode() and  # pylint: disable=protected-access
116      not strategy.extended.should_checkpoint):
117    # If this worker is not chief and hence should not save file, remove
118    # the temporary directory.
119    file_io.delete_recursively(_get_temp_dir(dirpath, strategy))
120
121
122def write_filepath(filepath, strategy):
123  """Returns the writing file path to be used to save file distributedly.
124
125  Directory to contain `filepath` would be created if it doesn't exist.
126
127  Args:
128    filepath: Original filepath that would be used without distribution.
129    strategy: The tf.distribute strategy object currently used.
130
131  Returns:
132    The writing filepath that should be used to save file with distribution.
133  """
134  dirpath = os.path.dirname(filepath)
135  base = os.path.basename(filepath)
136  return os.path.join(write_dirpath(dirpath, strategy), base)
137
138
139def remove_temp_dir_with_filepath(filepath, strategy):
140  """Removes the temp path for file after writing is finished.
141
142  Args:
143    filepath: Original filepath that would be used without distribution.
144    strategy: The tf.distribute strategy object currently used.
145  """
146  remove_temp_dirpath(os.path.dirname(filepath), strategy)
147