1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/client/validate_utils.h"
16
17 #include <memory>
18
19 #include "tensorflow/core/data/service/client/common.h"
20 #include "tensorflow/core/data/service/common.pb.h"
21 #include "tensorflow/core/data/service/worker_impl.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/platform/status_matchers.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/data_service.pb.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow/core/protobuf/service_config.pb.h"
30
31 namespace tensorflow {
32 namespace data {
33 namespace {
34
35 using ::tensorflow::testing::StatusIs;
36 using ::testing::HasSubstr;
37
GetDefaultParams()38 DataServiceParams GetDefaultParams() {
39 DataServiceParams params;
40 params.dataset_id = "dataset_id";
41 params.processing_mode.set_sharding_policy(ProcessingModeDef::OFF);
42 params.address = "localhost";
43 params.protocol = "grpc";
44 params.data_transfer_protocol = "grpc";
45 params.metadata.set_cardinality(kUnknownCardinality);
46 return params;
47 }
48
GetLocalWorker()49 std::shared_ptr<DataServiceWorkerImpl> GetLocalWorker() {
50 experimental::WorkerConfig config;
51 config.set_protocol("grpc");
52 config.set_dispatcher_address("localhost");
53 config.set_worker_address("localhost");
54 return std::make_shared<DataServiceWorkerImpl>(config);
55 }
56
TEST(ValidateUtilsTest,DefaultParams)57 TEST(ValidateUtilsTest, DefaultParams) {
58 TF_EXPECT_OK(ValidateDataServiceParams(GetDefaultParams()));
59 }
60
TEST(ValidateUtilsTest,LocalWorkerSuccess)61 TEST(ValidateUtilsTest, LocalWorkerSuccess) {
62 DataServiceParams params = GetDefaultParams();
63 LocalWorkers::Add("localhost", GetLocalWorker());
64 params.target_workers = TARGET_WORKERS_LOCAL;
65 TF_EXPECT_OK(ValidateDataServiceParams(params));
66 LocalWorkers::Remove("localhost");
67 }
68
TEST(ValidateUtilsTest,NoLocalWorker)69 TEST(ValidateUtilsTest, NoLocalWorker) {
70 DataServiceParams params = GetDefaultParams();
71 params.target_workers = TARGET_WORKERS_LOCAL;
72 EXPECT_THAT(
73 ValidateDataServiceParams(params),
74 StatusIs(
75 error::INVALID_ARGUMENT,
76 HasSubstr(
77 "Local reads require local tf.data workers, but no local worker "
78 "is found.")));
79 }
80
TEST(ValidateUtilsTest,NoLocalWorkerStaticSharding)81 TEST(ValidateUtilsTest, NoLocalWorkerStaticSharding) {
82 DataServiceParams params = GetDefaultParams();
83 params.processing_mode.set_sharding_policy(ProcessingModeDef::FILE_OR_DATA);
84 params.target_workers = TARGET_WORKERS_LOCAL;
85 EXPECT_THAT(
86 ValidateDataServiceParams(params),
87 StatusIs(
88 error::INVALID_ARGUMENT,
89 HasSubstr(
90 "Static sharding policy <FILE_OR_DATA> requires local tf.data "
91 "workers, but no local worker is found.")));
92 }
93
TEST(ValidateUtilsTest,LocalReadDisallowsCoordinatedRead)94 TEST(ValidateUtilsTest, LocalReadDisallowsCoordinatedRead) {
95 DataServiceParams params = GetDefaultParams();
96 LocalWorkers::Add("localhost", GetLocalWorker());
97 params.num_consumers = 1;
98 params.consumer_index = 0;
99 params.target_workers = TARGET_WORKERS_LOCAL;
100 EXPECT_THAT(
101 ValidateDataServiceParams(params),
102 StatusIs(error::INVALID_ARGUMENT,
103 HasSubstr("Coordinated reads require non-local workers, but "
104 "`target_workers` is \"LOCAL\".")));
105 LocalWorkers::Remove("localhost");
106 }
107
TEST(ValidateUtilsTest,CrossTrainerCacheSuccess)108 TEST(ValidateUtilsTest, CrossTrainerCacheSuccess) {
109 DataServiceParams params = GetDefaultParams();
110 params.job_name = "job_name";
111 params.repetition = 1;
112 params.metadata.set_cardinality(kInfiniteCardinality);
113 params.cross_trainer_cache_options.emplace();
114 params.cross_trainer_cache_options->set_trainer_id("trainer ID");
115 TF_EXPECT_OK(ValidateDataServiceParams(params));
116 }
117
TEST(ValidateUtilsTest,CrossTrainerCacheRequiresJobName)118 TEST(ValidateUtilsTest, CrossTrainerCacheRequiresJobName) {
119 DataServiceParams params = GetDefaultParams();
120 params.repetition = 1;
121 params.metadata.set_cardinality(kInfiniteCardinality);
122 params.cross_trainer_cache_options.emplace();
123 params.cross_trainer_cache_options->set_trainer_id("trainer ID");
124 EXPECT_THAT(
125 ValidateDataServiceParams(params),
126 StatusIs(
127 error::INVALID_ARGUMENT,
128 "Cross-trainer caching requires named jobs. Got empty `job_name`."));
129 }
130
TEST(ValidateUtilsTest,CrossTrainerCacheRequiresInfiniteDataset)131 TEST(ValidateUtilsTest, CrossTrainerCacheRequiresInfiniteDataset) {
132 DataServiceParams params = GetDefaultParams();
133 params.job_name = "job_name";
134 params.repetition = 1;
135 params.metadata.set_cardinality(10);
136 params.cross_trainer_cache_options.emplace();
137 params.cross_trainer_cache_options->set_trainer_id("trainer ID");
138 EXPECT_THAT(ValidateDataServiceParams(params),
139 StatusIs(error::INVALID_ARGUMENT,
140 HasSubstr("Cross-trainer caching requires the input "
141 "dataset to be infinite.")));
142 }
143
TEST(ValidateUtilsTest,CrossTrainerCacheDisallowsRepetition)144 TEST(ValidateUtilsTest, CrossTrainerCacheDisallowsRepetition) {
145 DataServiceParams params = GetDefaultParams();
146 params.job_name = "job_name";
147 params.repetition = 5;
148 params.metadata.set_cardinality(kInfiniteCardinality);
149 params.cross_trainer_cache_options.emplace();
150 params.cross_trainer_cache_options->set_trainer_id("trainer ID");
151 EXPECT_THAT(
152 ValidateDataServiceParams(params),
153 StatusIs(
154 error::INVALID_ARGUMENT,
155 HasSubstr(
156 "Cross-trainer caching requires infinite datasets and disallows "
157 "multiple repetitions of the same dataset.")));
158 }
159
TEST(ValidateUtilsTest,CrossTrainerCacheDisallowsCoordinatedRead)160 TEST(ValidateUtilsTest, CrossTrainerCacheDisallowsCoordinatedRead) {
161 DataServiceParams params = GetDefaultParams();
162 params.job_name = "job_name";
163 params.repetition = 1;
164 params.num_consumers = 1;
165 params.consumer_index = 0;
166 params.metadata.set_cardinality(kInfiniteCardinality);
167 params.cross_trainer_cache_options.emplace();
168 params.cross_trainer_cache_options->set_trainer_id("trainer ID");
169 EXPECT_THAT(
170 ValidateDataServiceParams(params),
171 StatusIs(
172 error::INVALID_ARGUMENT,
173 HasSubstr(
174 "Cross-trainer caching does not support coordinated reads.")));
175 }
176 } // namespace
177 } // namespace data
178 } // namespace tensorflow
179