1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/common.h"
16
17 #include <string>
18
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/data/service/common.pb.h"
21 #include "tensorflow/core/framework/dataset_options.pb.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/platform/statusor.h"
25 #include "tensorflow/core/protobuf/data_service.pb.h"
26
27 namespace tensorflow {
28 namespace data {
29
30 namespace {
31 constexpr const char kAuto[] = "AUTO";
32 constexpr const char kAny[] = "ANY";
33 constexpr const char kLocal[] = "LOCAL";
34
35 constexpr const char kColocated[] = "COLOCATED";
36 constexpr const char kRemote[] = "REMOTE";
37 constexpr const char kHybrid[] = "HYBRID";
38 } // namespace
39
IsNoShard(const ProcessingModeDef & processing_mode)40 bool IsNoShard(const ProcessingModeDef& processing_mode) {
41 return processing_mode.sharding_policy() == ProcessingModeDef::OFF;
42 }
43
IsDynamicShard(const ProcessingModeDef & processing_mode)44 bool IsDynamicShard(const ProcessingModeDef& processing_mode) {
45 return processing_mode.sharding_policy() == ProcessingModeDef::DYNAMIC;
46 }
47
IsStaticShard(const ProcessingModeDef & processing_mode)48 bool IsStaticShard(const ProcessingModeDef& processing_mode) {
49 return processing_mode.sharding_policy() == ProcessingModeDef::FILE ||
50 processing_mode.sharding_policy() == ProcessingModeDef::DATA ||
51 processing_mode.sharding_policy() == ProcessingModeDef::FILE_OR_DATA ||
52 processing_mode.sharding_policy() == ProcessingModeDef::HINT;
53 }
54
ValidateProcessingMode(const ProcessingModeDef & processing_mode)55 Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) {
56 if (!IsNoShard(processing_mode) && !IsDynamicShard(processing_mode) &&
57 !IsStaticShard(processing_mode)) {
58 return errors::Internal(
59 "ProcessingMode ", processing_mode.ShortDebugString(),
60 " does not "
61 "specify a valid sharding policy. Please add the policy to either "
62 "`IsDynamicShard` or `IsStaticShard` (i.e., auto-shard).");
63 }
64 return OkStatus();
65 }
66
ToAutoShardPolicy(const ProcessingModeDef::ShardingPolicy sharding_policy)67 StatusOr<AutoShardPolicy> ToAutoShardPolicy(
68 const ProcessingModeDef::ShardingPolicy sharding_policy) {
69 switch (sharding_policy) {
70 case ProcessingModeDef::FILE:
71 return AutoShardPolicy::FILE;
72 case ProcessingModeDef::DATA:
73 return AutoShardPolicy::DATA;
74 case ProcessingModeDef::FILE_OR_DATA:
75 return AutoShardPolicy::AUTO;
76 case ProcessingModeDef::HINT:
77 return AutoShardPolicy::HINT;
78 case ProcessingModeDef::DYNAMIC:
79 case ProcessingModeDef::OFF:
80 return AutoShardPolicy::OFF;
81 default:
82 return errors::Internal(
83 "tf.data service sharding policy ",
84 ProcessingModeDef::ShardingPolicy_Name(sharding_policy),
85 " is not convertible to a valid auto-shard policy. If you're "
86 "defining a new sharding policy, please update the policy mapping.");
87 }
88 }
89
ParseTargetWorkers(absl::string_view s)90 StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s) {
91 std::string str_upper = absl::AsciiStrToUpper(s);
92 if (str_upper.empty() || str_upper == kAuto) {
93 return TARGET_WORKERS_AUTO;
94 }
95 if (str_upper == kAny) {
96 return TARGET_WORKERS_ANY;
97 }
98 if (str_upper == kLocal) {
99 return TARGET_WORKERS_LOCAL;
100 }
101 return errors::InvalidArgument("Unrecognized target workers: ", s);
102 }
103
TargetWorkersToString(TargetWorkers target_workers)104 std::string TargetWorkersToString(TargetWorkers target_workers) {
105 switch (target_workers) {
106 case TARGET_WORKERS_AUTO:
107 return kAuto;
108 case TARGET_WORKERS_ANY:
109 return kAny;
110 case TARGET_WORKERS_LOCAL:
111 return kLocal;
112 default:
113 DCHECK(false);
114 return "UNKNOWN";
115 }
116 }
117
ParseDeploymentMode(absl::string_view s)118 StatusOr<DeploymentMode> ParseDeploymentMode(absl::string_view s) {
119 std::string str_upper = absl::AsciiStrToUpper(s);
120 if (str_upper == kColocated) {
121 return DEPLOYMENT_MODE_COLOCATED;
122 }
123 if (str_upper == kRemote) {
124 return DEPLOYMENT_MODE_REMOTE;
125 }
126 if (str_upper == kHybrid) {
127 return DEPLOYMENT_MODE_HYBRID;
128 }
129 return errors::InvalidArgument("Invalid tf.data service deployment mode: ", s,
130 ". Supported modes are "
131 "COLOCATED, REMOTE, and HYBRID.");
132 }
133
IsPreemptedError(const Status & status)134 bool IsPreemptedError(const Status& status) {
135 return errors::IsAborted(status) || errors::IsCancelled(status) ||
136 errors::IsUnavailable(status);
137 }
138 } // namespace data
139 } // namespace tensorflow
140