xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/auto_shard_rewriter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
16 
17 #include <cstdlib>
18 #include <iterator>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/strings/substitute.h"
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/data/rewrite_utils.h"
31 #include "tensorflow/core/data/service/common.h"
32 #include "tensorflow/core/data/service/common.pb.h"
33 #include "tensorflow/core/data/service/url.h"
34 #include "tensorflow/core/framework/dataset_options.pb.h"
35 #include "tensorflow/core/framework/graph.pb.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/grappler_item_builder.h"
41 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
42 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
43 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
44 #include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
45 #include "tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h"
46 #include "tensorflow/core/platform/errors.h"
47 #include "tensorflow/core/platform/status.h"
48 #include "tensorflow/core/platform/statusor.h"
49 #include "tensorflow/core/protobuf/data_service.pb.h"
50 #include "tensorflow/core/protobuf/meta_graph.pb.h"
51 
52 namespace tensorflow {
53 namespace data {
54 namespace {
55 
56 using ::tensorflow::data::experimental::AutoShardDatasetOp;
57 
58 // A dynamic port has form %port% or %port_foo% that is to be replaced with the
59 // actual port.
HasDynamicPort(absl::string_view address)60 bool HasDynamicPort(absl::string_view address) {
61   URL url(address);
62   return url.has_port() && absl::StartsWith(url.port(), "%port") &&
63          absl::EndsWith(url.port(), "%");
64 }
65 
66 // Returns true if `config_address` has no port or a dynamic port (e.g.: %port%)
67 // and `worker_address` has an actual port (number of named port).
68 //
69 // For example, it returns true for the following cases:
70 //
71 //  config_address                    worker_address
72 //  ----------------------------------------------------------
73 //  /worker/task/0                    /worker/task/0:worker
74 //  /worker/task/0:%port%             /worker/task/0:10000
75 //  /worker/task/0:%port_worker%      /worker/task/0:worker
76 //  /worker/task/0:%port_worker%      /worker/task/0:10000
77 //  localhost                         localhost:10000
78 //  localhost:%port%                  localhost:10000
ShouldReplaceDynamicPort(absl::string_view config_address,absl::string_view worker_address)79 bool ShouldReplaceDynamicPort(absl::string_view config_address,
80                               absl::string_view worker_address) {
81   URL config_url(config_address), worker_url(worker_address);
82   return (!config_url.has_port() || HasDynamicPort(config_address)) &&
83          worker_url.has_port() && config_url.host() == worker_url.host();
84 }
85 }  // namespace
86 
Create(const TaskDef & task_def)87 StatusOr<AutoShardRewriter> AutoShardRewriter::Create(const TaskDef& task_def) {
88   TF_ASSIGN_OR_RETURN(
89       AutoShardPolicy auto_shard_policy,
90       ToAutoShardPolicy(task_def.processing_mode_def().sharding_policy()));
91   return AutoShardRewriter(auto_shard_policy, task_def.num_workers(),
92                            task_def.worker_index());
93 }
94 
ApplyAutoShardRewrite(const GraphDef & graph_def)95 StatusOr<GraphDef> AutoShardRewriter::ApplyAutoShardRewrite(
96     const GraphDef& graph_def) {
97   if (auto_shard_policy_ == AutoShardPolicy::OFF) {
98     return graph_def;
99   }
100 
101   VLOG(2) << "Applying auto-shard policy "
102           << AutoShardPolicy_Name(auto_shard_policy_)
103           << ". Number of workers: " << num_workers_
104           << "; worker index: " << worker_index_ << ".";
105   grappler::AutoShard autoshard;
106   tensorflow::RewriterConfig::CustomGraphOptimizer config = GetRewriteConfig();
107   TF_RETURN_IF_ERROR(autoshard.Init(&config));
108 
109   GraphDef input_graph = graph_def;
110   TF_ASSIGN_OR_RETURN(std::string dataset_node, GetDatasetNode(input_graph));
111   std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
112       GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false);
113 
114   GraphDef rewritten_graph;
115   std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
116   tensorflow::grappler::VirtualCluster cluster(device_map);
117   grappler::AutoShard::OptimizationStats stats;
118   TF_RETURN_IF_ERROR(autoshard.OptimizeAndCollectStats(
119       &cluster, *grappler_item, &rewritten_graph, &stats));
120   return rewritten_graph;
121 }
122 
AutoShardRewriter(AutoShardPolicy auto_shard_policy,int64_t num_workers,int64_t worker_index)123 AutoShardRewriter::AutoShardRewriter(AutoShardPolicy auto_shard_policy,
124                                      int64_t num_workers, int64_t worker_index)
125     : auto_shard_policy_(auto_shard_policy),
126       num_workers_(num_workers),
127       worker_index_(worker_index) {}
128 
129 tensorflow::RewriterConfig::CustomGraphOptimizer
GetRewriteConfig() const130 AutoShardRewriter::GetRewriteConfig() const {
131   tensorflow::RewriterConfig::CustomGraphOptimizer config;
132   config.set_name("tf-data-service-auto-shard");
133   (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumWorkers].set_i(
134       num_workers_);
135   (*config.mutable_parameter_map())[AutoShardDatasetOp::kIndex].set_i(
136       worker_index_);
137   (*config.mutable_parameter_map())[AutoShardDatasetOp::kAutoShardPolicy].set_i(
138       auto_shard_policy_);
139   // This parameter is used internally by tf.distribute to rebatch the dataset.
140   // It is not used outside the context of `experimental_distribute_dataset`.
141   (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumReplicas].set_i(1);
142   return config;
143 }
144 
ValidateWorker(absl::string_view worker_address) const145 Status WorkerIndexResolver::ValidateWorker(
146     absl::string_view worker_address) const {
147   if (worker_addresses_.empty()) {
148     return OkStatus();
149   }
150 
151   for (absl::string_view config_address : worker_addresses_) {
152     if (config_address == worker_address ||
153         ShouldReplaceDynamicPort(config_address, worker_address)) {
154       return OkStatus();
155     }
156   }
157 
158   return errors::FailedPrecondition(absl::Substitute(
159       "Failed to assign an index for worker $0. Configured workers list: [$1]. "
160       "The worker's address is not configured, or other workers are already "
161       "running at the configured host. If your worker has restarted, make sure "
162       "it runs at the same address and port.",
163       worker_address, absl::StrJoin(worker_addresses_, ", ")));
164 }
165 
AddWorker(absl::string_view worker_address)166 void WorkerIndexResolver::AddWorker(absl::string_view worker_address) {
167   for (std::string& config_address : worker_addresses_) {
168     if (config_address == worker_address) {
169       return;
170     }
171     if (ShouldReplaceDynamicPort(config_address, worker_address)) {
172       config_address = std::string(worker_address);
173       return;
174     }
175   }
176 }
177 
GetWorkerIndex(absl::string_view worker_address) const178 StatusOr<int64_t> WorkerIndexResolver::GetWorkerIndex(
179     absl::string_view worker_address) const {
180   const auto it = absl::c_find(worker_addresses_, worker_address);
181   if (it == worker_addresses_.cend()) {
182     return errors::NotFound(absl::Substitute(
183         "Failed to shard dataset in tf.data service: Worker $0 is not in the "
184         "workers list. Got workers list $1.",
185         worker_address, absl::StrJoin(worker_addresses_, ",")));
186   }
187   return std::distance(worker_addresses_.cbegin(), it);
188 }
189 
190 }  // namespace data
191 }  // namespace tensorflow
192