1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
16
17 #include <cstdlib>
18 #include <iterator>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/strings/substitute.h"
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/data/rewrite_utils.h"
31 #include "tensorflow/core/data/service/common.h"
32 #include "tensorflow/core/data/service/common.pb.h"
33 #include "tensorflow/core/data/service/url.h"
34 #include "tensorflow/core/framework/dataset_options.pb.h"
35 #include "tensorflow/core/framework/graph.pb.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/grappler_item_builder.h"
41 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
42 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
43 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
44 #include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
45 #include "tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h"
46 #include "tensorflow/core/platform/errors.h"
47 #include "tensorflow/core/platform/status.h"
48 #include "tensorflow/core/platform/statusor.h"
49 #include "tensorflow/core/protobuf/data_service.pb.h"
50 #include "tensorflow/core/protobuf/meta_graph.pb.h"
51
52 namespace tensorflow {
53 namespace data {
54 namespace {
55
56 using ::tensorflow::data::experimental::AutoShardDatasetOp;
57
58 // A dynamic port has form %port% or %port_foo% that is to be replaced with the
59 // actual port.
HasDynamicPort(absl::string_view address)60 bool HasDynamicPort(absl::string_view address) {
61 URL url(address);
62 return url.has_port() && absl::StartsWith(url.port(), "%port") &&
63 absl::EndsWith(url.port(), "%");
64 }
65
66 // Returns true if `config_address` has no port or a dynamic port (e.g.: %port%)
67 // and `worker_address` has an actual port (number of named port).
68 //
69 // For example, it returns true for the following cases:
70 //
71 // config_address worker_address
72 // ----------------------------------------------------------
73 // /worker/task/0 /worker/task/0:worker
74 // /worker/task/0:%port% /worker/task/0:10000
75 // /worker/task/0:%port_worker% /worker/task/0:worker
76 // /worker/task/0:%port_worker% /worker/task/0:10000
77 // localhost localhost:10000
78 // localhost:%port% localhost:10000
ShouldReplaceDynamicPort(absl::string_view config_address,absl::string_view worker_address)79 bool ShouldReplaceDynamicPort(absl::string_view config_address,
80 absl::string_view worker_address) {
81 URL config_url(config_address), worker_url(worker_address);
82 return (!config_url.has_port() || HasDynamicPort(config_address)) &&
83 worker_url.has_port() && config_url.host() == worker_url.host();
84 }
85 } // namespace
86
Create(const TaskDef & task_def)87 StatusOr<AutoShardRewriter> AutoShardRewriter::Create(const TaskDef& task_def) {
88 TF_ASSIGN_OR_RETURN(
89 AutoShardPolicy auto_shard_policy,
90 ToAutoShardPolicy(task_def.processing_mode_def().sharding_policy()));
91 return AutoShardRewriter(auto_shard_policy, task_def.num_workers(),
92 task_def.worker_index());
93 }
94
ApplyAutoShardRewrite(const GraphDef & graph_def)95 StatusOr<GraphDef> AutoShardRewriter::ApplyAutoShardRewrite(
96 const GraphDef& graph_def) {
97 if (auto_shard_policy_ == AutoShardPolicy::OFF) {
98 return graph_def;
99 }
100
101 VLOG(2) << "Applying auto-shard policy "
102 << AutoShardPolicy_Name(auto_shard_policy_)
103 << ". Number of workers: " << num_workers_
104 << "; worker index: " << worker_index_ << ".";
105 grappler::AutoShard autoshard;
106 tensorflow::RewriterConfig::CustomGraphOptimizer config = GetRewriteConfig();
107 TF_RETURN_IF_ERROR(autoshard.Init(&config));
108
109 GraphDef input_graph = graph_def;
110 TF_ASSIGN_OR_RETURN(std::string dataset_node, GetDatasetNode(input_graph));
111 std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
112 GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false);
113
114 GraphDef rewritten_graph;
115 std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
116 tensorflow::grappler::VirtualCluster cluster(device_map);
117 grappler::AutoShard::OptimizationStats stats;
118 TF_RETURN_IF_ERROR(autoshard.OptimizeAndCollectStats(
119 &cluster, *grappler_item, &rewritten_graph, &stats));
120 return rewritten_graph;
121 }
122
AutoShardRewriter(AutoShardPolicy auto_shard_policy,int64_t num_workers,int64_t worker_index)123 AutoShardRewriter::AutoShardRewriter(AutoShardPolicy auto_shard_policy,
124 int64_t num_workers, int64_t worker_index)
125 : auto_shard_policy_(auto_shard_policy),
126 num_workers_(num_workers),
127 worker_index_(worker_index) {}
128
129 tensorflow::RewriterConfig::CustomGraphOptimizer
GetRewriteConfig() const130 AutoShardRewriter::GetRewriteConfig() const {
131 tensorflow::RewriterConfig::CustomGraphOptimizer config;
132 config.set_name("tf-data-service-auto-shard");
133 (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumWorkers].set_i(
134 num_workers_);
135 (*config.mutable_parameter_map())[AutoShardDatasetOp::kIndex].set_i(
136 worker_index_);
137 (*config.mutable_parameter_map())[AutoShardDatasetOp::kAutoShardPolicy].set_i(
138 auto_shard_policy_);
139 // This parameter is used internally by tf.distribute to rebatch the dataset.
140 // It is not used outside the context of `experimental_distribute_dataset`.
141 (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumReplicas].set_i(1);
142 return config;
143 }
144
ValidateWorker(absl::string_view worker_address) const145 Status WorkerIndexResolver::ValidateWorker(
146 absl::string_view worker_address) const {
147 if (worker_addresses_.empty()) {
148 return OkStatus();
149 }
150
151 for (absl::string_view config_address : worker_addresses_) {
152 if (config_address == worker_address ||
153 ShouldReplaceDynamicPort(config_address, worker_address)) {
154 return OkStatus();
155 }
156 }
157
158 return errors::FailedPrecondition(absl::Substitute(
159 "Failed to assign an index for worker $0. Configured workers list: [$1]. "
160 "The worker's address is not configured, or other workers are already "
161 "running at the configured host. If your worker has restarted, make sure "
162 "it runs at the same address and port.",
163 worker_address, absl::StrJoin(worker_addresses_, ", ")));
164 }
165
AddWorker(absl::string_view worker_address)166 void WorkerIndexResolver::AddWorker(absl::string_view worker_address) {
167 for (std::string& config_address : worker_addresses_) {
168 if (config_address == worker_address) {
169 return;
170 }
171 if (ShouldReplaceDynamicPort(config_address, worker_address)) {
172 config_address = std::string(worker_address);
173 return;
174 }
175 }
176 }
177
GetWorkerIndex(absl::string_view worker_address) const178 StatusOr<int64_t> WorkerIndexResolver::GetWorkerIndex(
179 absl::string_view worker_address) const {
180 const auto it = absl::c_find(worker_addresses_, worker_address);
181 if (it == worker_addresses_.cend()) {
182 return errors::NotFound(absl::Substitute(
183 "Failed to shard dataset in tf.data service: Worker $0 is not in the "
184 "workers list. Got workers list $1.",
185 worker_address, absl::StrJoin(worker_addresses_, ",")));
186 }
187 return std::distance(worker_addresses_.cbegin(), it);
188 }
189
190 } // namespace data
191 } // namespace tensorflow
192