xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/common.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/common.h"
16 
17 #include <string>
18 
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/data/service/common.pb.h"
21 #include "tensorflow/core/framework/dataset_options.pb.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/platform/statusor.h"
25 #include "tensorflow/core/protobuf/data_service.pb.h"
26 
27 namespace tensorflow {
28 namespace data {
29 
30 namespace {
31 constexpr const char kAuto[] = "AUTO";
32 constexpr const char kAny[] = "ANY";
33 constexpr const char kLocal[] = "LOCAL";
34 
35 constexpr const char kColocated[] = "COLOCATED";
36 constexpr const char kRemote[] = "REMOTE";
37 constexpr const char kHybrid[] = "HYBRID";
38 }  // namespace
39 
IsNoShard(const ProcessingModeDef & processing_mode)40 bool IsNoShard(const ProcessingModeDef& processing_mode) {
41   return processing_mode.sharding_policy() == ProcessingModeDef::OFF;
42 }
43 
IsDynamicShard(const ProcessingModeDef & processing_mode)44 bool IsDynamicShard(const ProcessingModeDef& processing_mode) {
45   return processing_mode.sharding_policy() == ProcessingModeDef::DYNAMIC;
46 }
47 
IsStaticShard(const ProcessingModeDef & processing_mode)48 bool IsStaticShard(const ProcessingModeDef& processing_mode) {
49   return processing_mode.sharding_policy() == ProcessingModeDef::FILE ||
50          processing_mode.sharding_policy() == ProcessingModeDef::DATA ||
51          processing_mode.sharding_policy() == ProcessingModeDef::FILE_OR_DATA ||
52          processing_mode.sharding_policy() == ProcessingModeDef::HINT;
53 }
54 
ValidateProcessingMode(const ProcessingModeDef & processing_mode)55 Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) {
56   if (!IsNoShard(processing_mode) && !IsDynamicShard(processing_mode) &&
57       !IsStaticShard(processing_mode)) {
58     return errors::Internal(
59         "ProcessingMode ", processing_mode.ShortDebugString(),
60         " does not "
61         "specify a valid sharding policy. Please add the policy to either "
62         "`IsDynamicShard` or `IsStaticShard` (i.e., auto-shard).");
63   }
64   return OkStatus();
65 }
66 
ToAutoShardPolicy(const ProcessingModeDef::ShardingPolicy sharding_policy)67 StatusOr<AutoShardPolicy> ToAutoShardPolicy(
68     const ProcessingModeDef::ShardingPolicy sharding_policy) {
69   switch (sharding_policy) {
70     case ProcessingModeDef::FILE:
71       return AutoShardPolicy::FILE;
72     case ProcessingModeDef::DATA:
73       return AutoShardPolicy::DATA;
74     case ProcessingModeDef::FILE_OR_DATA:
75       return AutoShardPolicy::AUTO;
76     case ProcessingModeDef::HINT:
77       return AutoShardPolicy::HINT;
78     case ProcessingModeDef::DYNAMIC:
79     case ProcessingModeDef::OFF:
80       return AutoShardPolicy::OFF;
81     default:
82       return errors::Internal(
83           "tf.data service sharding policy ",
84           ProcessingModeDef::ShardingPolicy_Name(sharding_policy),
85           " is not convertible to a valid auto-shard policy. If you're "
86           "defining a new sharding policy, please update the policy mapping.");
87   }
88 }
89 
ParseTargetWorkers(absl::string_view s)90 StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s) {
91   std::string str_upper = absl::AsciiStrToUpper(s);
92   if (str_upper.empty() || str_upper == kAuto) {
93     return TARGET_WORKERS_AUTO;
94   }
95   if (str_upper == kAny) {
96     return TARGET_WORKERS_ANY;
97   }
98   if (str_upper == kLocal) {
99     return TARGET_WORKERS_LOCAL;
100   }
101   return errors::InvalidArgument("Unrecognized target workers: ", s);
102 }
103 
TargetWorkersToString(TargetWorkers target_workers)104 std::string TargetWorkersToString(TargetWorkers target_workers) {
105   switch (target_workers) {
106     case TARGET_WORKERS_AUTO:
107       return kAuto;
108     case TARGET_WORKERS_ANY:
109       return kAny;
110     case TARGET_WORKERS_LOCAL:
111       return kLocal;
112     default:
113       DCHECK(false);
114       return "UNKNOWN";
115   }
116 }
117 
ParseDeploymentMode(absl::string_view s)118 StatusOr<DeploymentMode> ParseDeploymentMode(absl::string_view s) {
119   std::string str_upper = absl::AsciiStrToUpper(s);
120   if (str_upper == kColocated) {
121     return DEPLOYMENT_MODE_COLOCATED;
122   }
123   if (str_upper == kRemote) {
124     return DEPLOYMENT_MODE_REMOTE;
125   }
126   if (str_upper == kHybrid) {
127     return DEPLOYMENT_MODE_HYBRID;
128   }
129   return errors::InvalidArgument("Invalid tf.data service deployment mode: ", s,
130                                  ". Supported modes are "
131                                  "COLOCATED, REMOTE, and HYBRID.");
132 }
133 
IsPreemptedError(const Status & status)134 bool IsPreemptedError(const Status& status) {
135   return errors::IsAborted(status) || errors::IsCancelled(status) ||
136          errors::IsUnavailable(status);
137 }
138 }  // namespace data
139 }  // namespace tensorflow
140