xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/client/validate_utils_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/client/validate_utils.h"
16 
17 #include <memory>
18 
19 #include "tensorflow/core/data/service/client/common.h"
20 #include "tensorflow/core/data/service/common.pb.h"
21 #include "tensorflow/core/data/service/worker_impl.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/platform/status_matchers.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/data_service.pb.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow/core/protobuf/service_config.pb.h"
30 
31 namespace tensorflow {
32 namespace data {
33 namespace {
34 
35 using ::tensorflow::testing::StatusIs;
36 using ::testing::HasSubstr;
37 
GetDefaultParams()38 DataServiceParams GetDefaultParams() {
39   DataServiceParams params;
40   params.dataset_id = "dataset_id";
41   params.processing_mode.set_sharding_policy(ProcessingModeDef::OFF);
42   params.address = "localhost";
43   params.protocol = "grpc";
44   params.data_transfer_protocol = "grpc";
45   params.metadata.set_cardinality(kUnknownCardinality);
46   return params;
47 }
48 
GetLocalWorker()49 std::shared_ptr<DataServiceWorkerImpl> GetLocalWorker() {
50   experimental::WorkerConfig config;
51   config.set_protocol("grpc");
52   config.set_dispatcher_address("localhost");
53   config.set_worker_address("localhost");
54   return std::make_shared<DataServiceWorkerImpl>(config);
55 }
56 
TEST(ValidateUtilsTest,DefaultParams)57 TEST(ValidateUtilsTest, DefaultParams) {
58   TF_EXPECT_OK(ValidateDataServiceParams(GetDefaultParams()));
59 }
60 
TEST(ValidateUtilsTest,LocalWorkerSuccess)61 TEST(ValidateUtilsTest, LocalWorkerSuccess) {
62   DataServiceParams params = GetDefaultParams();
63   LocalWorkers::Add("localhost", GetLocalWorker());
64   params.target_workers = TARGET_WORKERS_LOCAL;
65   TF_EXPECT_OK(ValidateDataServiceParams(params));
66   LocalWorkers::Remove("localhost");
67 }
68 
TEST(ValidateUtilsTest,NoLocalWorker)69 TEST(ValidateUtilsTest, NoLocalWorker) {
70   DataServiceParams params = GetDefaultParams();
71   params.target_workers = TARGET_WORKERS_LOCAL;
72   EXPECT_THAT(
73       ValidateDataServiceParams(params),
74       StatusIs(
75           error::INVALID_ARGUMENT,
76           HasSubstr(
77               "Local reads require local tf.data workers, but no local worker "
78               "is found.")));
79 }
80 
TEST(ValidateUtilsTest,NoLocalWorkerStaticSharding)81 TEST(ValidateUtilsTest, NoLocalWorkerStaticSharding) {
82   DataServiceParams params = GetDefaultParams();
83   params.processing_mode.set_sharding_policy(ProcessingModeDef::FILE_OR_DATA);
84   params.target_workers = TARGET_WORKERS_LOCAL;
85   EXPECT_THAT(
86       ValidateDataServiceParams(params),
87       StatusIs(
88           error::INVALID_ARGUMENT,
89           HasSubstr(
90               "Static sharding policy <FILE_OR_DATA> requires local tf.data "
91               "workers, but no local worker is found.")));
92 }
93 
TEST(ValidateUtilsTest,LocalReadDisallowsCoordinatedRead)94 TEST(ValidateUtilsTest, LocalReadDisallowsCoordinatedRead) {
95   DataServiceParams params = GetDefaultParams();
96   LocalWorkers::Add("localhost", GetLocalWorker());
97   params.num_consumers = 1;
98   params.consumer_index = 0;
99   params.target_workers = TARGET_WORKERS_LOCAL;
100   EXPECT_THAT(
101       ValidateDataServiceParams(params),
102       StatusIs(error::INVALID_ARGUMENT,
103                HasSubstr("Coordinated reads require non-local workers, but "
104                          "`target_workers` is \"LOCAL\".")));
105   LocalWorkers::Remove("localhost");
106 }
107 
TEST(ValidateUtilsTest,CrossTrainerCacheSuccess)108 TEST(ValidateUtilsTest, CrossTrainerCacheSuccess) {
109   DataServiceParams params = GetDefaultParams();
110   params.job_name = "job_name";
111   params.repetition = 1;
112   params.metadata.set_cardinality(kInfiniteCardinality);
113   params.cross_trainer_cache_options.emplace();
114   params.cross_trainer_cache_options->set_trainer_id("trainer ID");
115   TF_EXPECT_OK(ValidateDataServiceParams(params));
116 }
117 
TEST(ValidateUtilsTest,CrossTrainerCacheRequiresJobName)118 TEST(ValidateUtilsTest, CrossTrainerCacheRequiresJobName) {
119   DataServiceParams params = GetDefaultParams();
120   params.repetition = 1;
121   params.metadata.set_cardinality(kInfiniteCardinality);
122   params.cross_trainer_cache_options.emplace();
123   params.cross_trainer_cache_options->set_trainer_id("trainer ID");
124   EXPECT_THAT(
125       ValidateDataServiceParams(params),
126       StatusIs(
127           error::INVALID_ARGUMENT,
128           "Cross-trainer caching requires named jobs. Got empty `job_name`."));
129 }
130 
TEST(ValidateUtilsTest,CrossTrainerCacheRequiresInfiniteDataset)131 TEST(ValidateUtilsTest, CrossTrainerCacheRequiresInfiniteDataset) {
132   DataServiceParams params = GetDefaultParams();
133   params.job_name = "job_name";
134   params.repetition = 1;
135   params.metadata.set_cardinality(10);
136   params.cross_trainer_cache_options.emplace();
137   params.cross_trainer_cache_options->set_trainer_id("trainer ID");
138   EXPECT_THAT(ValidateDataServiceParams(params),
139               StatusIs(error::INVALID_ARGUMENT,
140                        HasSubstr("Cross-trainer caching requires the input "
141                                  "dataset to be infinite.")));
142 }
143 
TEST(ValidateUtilsTest,CrossTrainerCacheDisallowsRepetition)144 TEST(ValidateUtilsTest, CrossTrainerCacheDisallowsRepetition) {
145   DataServiceParams params = GetDefaultParams();
146   params.job_name = "job_name";
147   params.repetition = 5;
148   params.metadata.set_cardinality(kInfiniteCardinality);
149   params.cross_trainer_cache_options.emplace();
150   params.cross_trainer_cache_options->set_trainer_id("trainer ID");
151   EXPECT_THAT(
152       ValidateDataServiceParams(params),
153       StatusIs(
154           error::INVALID_ARGUMENT,
155           HasSubstr(
156               "Cross-trainer caching requires infinite datasets and disallows "
157               "multiple repetitions of the same dataset.")));
158 }
159 
TEST(ValidateUtilsTest,CrossTrainerCacheDisallowsCoordinatedRead)160 TEST(ValidateUtilsTest, CrossTrainerCacheDisallowsCoordinatedRead) {
161   DataServiceParams params = GetDefaultParams();
162   params.job_name = "job_name";
163   params.repetition = 1;
164   params.num_consumers = 1;
165   params.consumer_index = 0;
166   params.metadata.set_cardinality(kInfiniteCardinality);
167   params.cross_trainer_cache_options.emplace();
168   params.cross_trainer_cache_options->set_trainer_id("trainer ID");
169   EXPECT_THAT(
170       ValidateDataServiceParams(params),
171       StatusIs(
172           error::INVALID_ARGUMENT,
173           HasSubstr(
174               "Cross-trainer caching does not support coordinated reads.")));
175 }
176 }  // namespace
177 }  // namespace data
178 }  // namespace tensorflow
179