xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/auto_shard_rewriter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_AUTO_SHARD_REWRITER_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_AUTO_SHARD_REWRITER_H_
17 
18 #include <cstdint>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/data/service/common.pb.h"
24 #include "tensorflow/core/framework/dataset_options.pb.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
27 #include "tensorflow/core/platform/status.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // Rewrites the dataset graph by applying an auto-shard policy.
35 class AutoShardRewriter {
36  public:
37   // Creates an `AutoShardRewriter` according to `task_def`. Returns an error if
38   // the sharding policy is not a valid auto-shard policy.
39   static StatusOr<AutoShardRewriter> Create(const TaskDef& task_def);
40 
41   // Applies auto-sharding to `graph_def`. If auto-shard policy is OFF, returns
42   // the same graph as `graph_def`. Otherwise, returns the re-written graph.
43   StatusOr<GraphDef> ApplyAutoShardRewrite(const GraphDef& graph_def);
44 
45  private:
46   AutoShardRewriter(AutoShardPolicy auto_shard_policy, int64_t num_workers,
47                     int64_t worker_index);
48 
49   // Creates a rewrite config based on the auto-shard policy.
50   tensorflow::RewriterConfig::CustomGraphOptimizer GetRewriteConfig() const;
51 
52   const AutoShardPolicy auto_shard_policy_;
53   const int64_t num_workers_;
54   const int64_t worker_index_;
55 };
56 
57 // Maps a worker to its index, given a list of workers. For example, suppose
58 // `worker_addresses` contains
59 //   /worker/task/0:worker, /worker/task/1:worker, /worker/task/2:worker,
60 // then
61 //   /worker/task/0:worker maps to index 0,
62 //   /worker/task/1:worker maps to index 1,
63 //   /worker/task/2:worker maps to index 2.
64 // This is useful for deterministically sharding a dataset among a fixed set of
65 // tf.data service workers.
66 class WorkerIndexResolver {
67  public:
68   // Constructs a `WorkerIndexResolver` to generate worker indexes according to
69   // the specified worker addresses. The worker addresses can be "host" or
70   // "host:port", where "port" is a number, named port, or "%port%" to be
71   // replaced with the actual port.
72   template <class T>
WorkerIndexResolver(const T & worker_addresses)73   explicit WorkerIndexResolver(const T& worker_addresses)
74       : worker_addresses_(worker_addresses.cbegin(), worker_addresses.cend()) {}
75 
76   // Validates `worker_address`. Returns an error if the `worker_addresses` list
77   // is non-empty and `worker_address` is not specified in the worker addresses
78   // list (with optional port replacement).
79   Status ValidateWorker(absl::string_view worker_address) const;
80 
81   // Processes a worker at address `worker_address`. Its index can be retrieved
82   // by calling `GetWorkerIndex`.
83   void AddWorker(absl::string_view worker_address);
84 
85   // Returns the worker index for the worker at `worker_address`. Returns a
86   // NotFound error if the worker is not registered.
87   StatusOr<int64_t> GetWorkerIndex(absl::string_view worker_address) const;
88 
89  private:
90   std::vector<std::string> worker_addresses_;
91 };
92 
93 }  // namespace data
94 }  // namespace tensorflow
95 
96 #endif  // TENSORFLOW_CORE_DATA_SERVICE_AUTO_SHARD_REWRITER_H_
97