1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
16
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/stringprintf.h"
24 #include "tensorflow/core/util/batch_util.h"
25
26 namespace tensorflow {
27 namespace data {
28
29 // See documentation in ../../ops/dataset_ops.cc for a high-level
30 // description of the following op.
31
32 /* static */ constexpr const char* const ShardDatasetOp::kDatasetType;
33 /* static */ constexpr const char* const ShardDatasetOp::kInputDataset;
34 /* static */ constexpr const char* const ShardDatasetOp::kNumShards;
35 /* static */ constexpr const char* const ShardDatasetOp::kIndex;
36 /* static */ constexpr const char* const ShardDatasetOp::kRequireNonEmpty;
37 /* static */ constexpr const char* const ShardDatasetOp::kOutputTypes;
38 /* static */ constexpr const char* const ShardDatasetOp::kOutputShapes;
39
40 constexpr char kInputImplEmpty[] = "input_impl_empty";
41 constexpr char kNextIndex[] = "next_index";
42
43 class ShardDatasetOp::Dataset : public DatasetBase {
44 public:
Dataset(OpKernelContext * ctx,int64_t num_shards,int64_t index,bool require_non_empty,const DatasetBase * input)45 Dataset(OpKernelContext* ctx, int64_t num_shards, int64_t index,
46 bool require_non_empty, const DatasetBase* input)
47 : DatasetBase(DatasetContext(ctx)),
48 num_shards_(num_shards),
49 index_(index),
50 input_(input),
51 require_non_empty_(require_non_empty),
52 traceme_metadata_(
53 {{"index", strings::Printf("%lld", static_cast<long long>(index))},
54 {"num_shards",
55 strings::Printf("%lld", static_cast<long long>(num_shards))}}) {
56 input_->Ref();
57 }
58
~Dataset()59 ~Dataset() override { input_->Unref(); }
60
MakeIteratorInternal(const string & prefix) const61 std::unique_ptr<IteratorBase> MakeIteratorInternal(
62 const string& prefix) const override {
63 return std::make_unique<Iterator>(Iterator::Params{
64 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
65 }
66
output_dtypes() const67 const DataTypeVector& output_dtypes() const override {
68 return input_->output_dtypes();
69 }
70
output_shapes() const71 const std::vector<PartialTensorShape>& output_shapes() const override {
72 return input_->output_shapes();
73 }
74
DebugString() const75 string DebugString() const override {
76 name_utils::DatasetDebugStringParams params;
77 params.set_args(num_shards_, index_);
78 return name_utils::DatasetDebugString(kDatasetType, params);
79 }
80
CardinalityInternal() const81 int64_t CardinalityInternal() const override {
82 int64_t n = input_->Cardinality();
83 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
84 return n;
85 }
86 return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
87 }
88
CardinalityInternal(CardinalityOptions options) const89 int64_t CardinalityInternal(CardinalityOptions options) const override {
90 int64_t n = input_->Cardinality(options);
91 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
92 return n;
93 }
94 return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
95 }
96
InputDatasets(std::vector<const DatasetBase * > * inputs) const97 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
98 inputs->push_back(input_);
99 return OkStatus();
100 }
101
CheckExternalState() const102 Status CheckExternalState() const override {
103 return input_->CheckExternalState();
104 }
105
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const106 Status Get(OpKernelContext* ctx, int64 index,
107 std::vector<Tensor>* out_tensors) const override {
108 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
109 return input_->Get(ctx, index_ + (num_shards_ * index), out_tensors);
110 }
111
112 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const113 Status AsGraphDefInternal(SerializationContext* ctx,
114 DatasetGraphDefBuilder* b,
115 Node** output) const override {
116 Node* input_graph_node = nullptr;
117 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
118 Node* num_shards = nullptr;
119 TF_RETURN_IF_ERROR(b->AddScalar(num_shards_, &num_shards));
120 Node* index = nullptr;
121 TF_RETURN_IF_ERROR(b->AddScalar(index_, &index));
122
123 AttrValue require_non_empty_attr;
124 b->BuildAttrValue(require_non_empty_, &require_non_empty_attr);
125
126 TF_RETURN_IF_ERROR(
127 b->AddDataset(this, {input_graph_node, num_shards, index},
128 {{kRequireNonEmpty, require_non_empty_attr}}, output));
129 return OkStatus();
130 }
131
132 private:
133 class Iterator : public DatasetIterator<Dataset> {
134 public:
Iterator(const Params & params)135 explicit Iterator(const Params& params)
136 : DatasetIterator<Dataset>(params), next_index_(0) {}
137
Initialize(IteratorContext * ctx)138 Status Initialize(IteratorContext* ctx) override {
139 if (dataset()->num_shards_ == kShardHint) {
140 return errors::FailedPrecondition(
141 "`tf.data.Dataset.shard(SHARD_HINT, ...)` can only be used in "
142 "`tf.distribute.Strategy.experimental_distribute_dataset()` with "
143 "`tf.data.experimental.AutoShardPolicy.HINT` policy, or tf.data "
144 "service with "
145 "`tf.data.experimental.service.ShardingPolicy.HINT` processing "
146 "mode.");
147 }
148 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
149 }
150
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)151 Status GetNextInternal(IteratorContext* ctx,
152 std::vector<Tensor>* out_tensors,
153 bool* end_of_sequence) override {
154 mutex_lock l(mu_);
155
156 *end_of_sequence = false;
157 if (!input_impl_) {
158 *end_of_sequence = true;
159 return OkStatus();
160 }
161
162 int num_to_skip =
163 (dataset()->index_ - next_index_) % dataset()->num_shards_;
164 if (num_to_skip < 0) {
165 num_to_skip += dataset()->num_shards_;
166 }
167 int num_skipped;
168 TF_RETURN_IF_ERROR(
169 input_impl_->Skip(ctx, num_to_skip, end_of_sequence, &num_skipped));
170 next_index_ += num_skipped;
171 if (*end_of_sequence) {
172 input_impl_.reset();
173 return OkStatus();
174 }
175
176 std::vector<Tensor> result;
177 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &result, end_of_sequence));
178 if (*end_of_sequence) {
179 input_impl_.reset();
180 return OkStatus();
181 }
182 next_index_++;
183
184 if (dataset()->require_non_empty_ &&
185 next_index_ < dataset()->num_shards_) {
186 int num_skipped;
187 Status s = input_impl_->Skip(ctx, dataset()->num_shards_ - next_index_,
188 end_of_sequence, &num_skipped);
189 if (*end_of_sequence || errors::IsOutOfRange(s)) {
190 // `dataset()->require_non_empty_` implies that this transformation
191 // was introduced by auto_sharding rewrite, so it's acceptable
192 // produce an error message that assumes auto-sharding context.
193 return errors::InvalidArgument(
194 "Could not apply FILE based sharding: the dataset only has ",
195 next_index_, " file(s), which is not enough for the required ",
196 dataset()->num_shards_,
197 " shards/workers."
198 "If you are using datasets with distribution strategy, "
199 "consider setting the auto sharding policy to either DATA or "
200 "OFF using the `experimental_distribute.auto_shard_policy` option"
201 "of `tf.data.Options()`. Or, split your input files into a "
202 "larger number of small files such that number of files is "
203 "greater than number of shards/workers.");
204 } else if (!s.ok()) {
205 return s;
206 }
207
208 next_index_ = dataset()->num_shards_;
209 }
210
211 *out_tensors = std::move(result);
212 return OkStatus();
213 }
214
215 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const216 std::shared_ptr<model::Node> CreateNode(
217 IteratorContext* ctx, model::Node::Args args) const override {
218 return model::MakeKnownRatioNode(std::move(args), dataset()->num_shards_);
219 }
220
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)221 Status SaveInternal(SerializationContext* ctx,
222 IteratorStateWriter* writer) override {
223 mutex_lock l(mu_);
224 if (!input_impl_) {
225 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
226 } else {
227 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
228 TF_RETURN_IF_ERROR(
229 writer->WriteScalar(full_name(kNextIndex), next_index_));
230 }
231 return OkStatus();
232 }
233
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)234 Status RestoreInternal(IteratorContext* ctx,
235 IteratorStateReader* reader) override {
236 mutex_lock l(mu_);
237 if (!reader->Contains(full_name(kInputImplEmpty))) {
238 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
239 TF_RETURN_IF_ERROR(
240 reader->ReadScalar(full_name(kNextIndex), &next_index_));
241 } else {
242 input_impl_.reset();
243 }
244 return OkStatus();
245 }
246
GetTraceMeMetadata() const247 TraceMeMetadata GetTraceMeMetadata() const override {
248 return dataset()->traceme_metadata_;
249 }
250
251 private:
252 mutex mu_;
253 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
254 int64_t next_index_ TF_GUARDED_BY(mu_);
255 };
256
257 const int64_t num_shards_;
258 const int64_t index_;
259 const DatasetBase* const input_;
260 const bool require_non_empty_;
261 const TraceMeMetadata traceme_metadata_;
262 };
263
ShardDatasetOp(OpKernelConstruction * ctx)264 ShardDatasetOp::ShardDatasetOp(OpKernelConstruction* ctx)
265 : UnaryDatasetOpKernel(ctx) {
266 OP_REQUIRES_OK(ctx, ctx->GetAttr(kRequireNonEmpty, &require_non_empty_));
267 }
268
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)269 void ShardDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
270 DatasetBase** output) {
271 int64_t index = 0;
272 int64_t num_shards = 0;
273
274 OP_REQUIRES_OK(ctx,
275 ParseScalarArgument<int64_t>(ctx, kNumShards, &num_shards));
276 OP_REQUIRES(
277 ctx, num_shards > 0 || num_shards == kShardHint,
278 errors::InvalidArgument("Number of shards must be greater than zero "
279 "(currently num_shards = ",
280 num_shards, ")."));
281
282 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kIndex, &index));
283 OP_REQUIRES(
284 ctx, (index >= 0 && index < num_shards) || num_shards == kShardHint,
285 errors::InvalidArgument("Index must be between 0 and ", num_shards - 1,
286 " (currently index = ", index, ")."));
287
288 *output = new Dataset(ctx, num_shards, index, require_non_empty_, input);
289 }
290
291 namespace {
292 REGISTER_KERNEL_BUILDER(Name("ShardDataset").Device(DEVICE_CPU),
293 ShardDatasetOp);
294 } // namespace
295 } // namespace data
296 } // namespace tensorflow
297