xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/shard_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
16 
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/stringprintf.h"
24 #include "tensorflow/core/util/batch_util.h"
25 
26 namespace tensorflow {
27 namespace data {
28 
29 // See documentation in ../../ops/dataset_ops.cc for a high-level
30 // description of the following op.
31 
32 /* static */ constexpr const char* const ShardDatasetOp::kDatasetType;
33 /* static */ constexpr const char* const ShardDatasetOp::kInputDataset;
34 /* static */ constexpr const char* const ShardDatasetOp::kNumShards;
35 /* static */ constexpr const char* const ShardDatasetOp::kIndex;
36 /* static */ constexpr const char* const ShardDatasetOp::kRequireNonEmpty;
37 /* static */ constexpr const char* const ShardDatasetOp::kOutputTypes;
38 /* static */ constexpr const char* const ShardDatasetOp::kOutputShapes;
39 
40 constexpr char kInputImplEmpty[] = "input_impl_empty";
41 constexpr char kNextIndex[] = "next_index";
42 
43 class ShardDatasetOp::Dataset : public DatasetBase {
44  public:
Dataset(OpKernelContext * ctx,int64_t num_shards,int64_t index,bool require_non_empty,const DatasetBase * input)45   Dataset(OpKernelContext* ctx, int64_t num_shards, int64_t index,
46           bool require_non_empty, const DatasetBase* input)
47       : DatasetBase(DatasetContext(ctx)),
48         num_shards_(num_shards),
49         index_(index),
50         input_(input),
51         require_non_empty_(require_non_empty),
52         traceme_metadata_(
53             {{"index", strings::Printf("%lld", static_cast<long long>(index))},
54              {"num_shards",
55               strings::Printf("%lld", static_cast<long long>(num_shards))}}) {
56     input_->Ref();
57   }
58 
~Dataset()59   ~Dataset() override { input_->Unref(); }
60 
MakeIteratorInternal(const string & prefix) const61   std::unique_ptr<IteratorBase> MakeIteratorInternal(
62       const string& prefix) const override {
63     return std::make_unique<Iterator>(Iterator::Params{
64         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
65   }
66 
output_dtypes() const67   const DataTypeVector& output_dtypes() const override {
68     return input_->output_dtypes();
69   }
70 
output_shapes() const71   const std::vector<PartialTensorShape>& output_shapes() const override {
72     return input_->output_shapes();
73   }
74 
DebugString() const75   string DebugString() const override {
76     name_utils::DatasetDebugStringParams params;
77     params.set_args(num_shards_, index_);
78     return name_utils::DatasetDebugString(kDatasetType, params);
79   }
80 
CardinalityInternal() const81   int64_t CardinalityInternal() const override {
82     int64_t n = input_->Cardinality();
83     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
84       return n;
85     }
86     return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
87   }
88 
CardinalityInternal(CardinalityOptions options) const89   int64_t CardinalityInternal(CardinalityOptions options) const override {
90     int64_t n = input_->Cardinality(options);
91     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
92       return n;
93     }
94     return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
95   }
96 
InputDatasets(std::vector<const DatasetBase * > * inputs) const97   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
98     inputs->push_back(input_);
99     return OkStatus();
100   }
101 
CheckExternalState() const102   Status CheckExternalState() const override {
103     return input_->CheckExternalState();
104   }
105 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const106   Status Get(OpKernelContext* ctx, int64 index,
107              std::vector<Tensor>* out_tensors) const override {
108     TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
109     return input_->Get(ctx, index_ + (num_shards_ * index), out_tensors);
110   }
111 
112  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const113   Status AsGraphDefInternal(SerializationContext* ctx,
114                             DatasetGraphDefBuilder* b,
115                             Node** output) const override {
116     Node* input_graph_node = nullptr;
117     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
118     Node* num_shards = nullptr;
119     TF_RETURN_IF_ERROR(b->AddScalar(num_shards_, &num_shards));
120     Node* index = nullptr;
121     TF_RETURN_IF_ERROR(b->AddScalar(index_, &index));
122 
123     AttrValue require_non_empty_attr;
124     b->BuildAttrValue(require_non_empty_, &require_non_empty_attr);
125 
126     TF_RETURN_IF_ERROR(
127         b->AddDataset(this, {input_graph_node, num_shards, index},
128                       {{kRequireNonEmpty, require_non_empty_attr}}, output));
129     return OkStatus();
130   }
131 
132  private:
133   class Iterator : public DatasetIterator<Dataset> {
134    public:
Iterator(const Params & params)135     explicit Iterator(const Params& params)
136         : DatasetIterator<Dataset>(params), next_index_(0) {}
137 
Initialize(IteratorContext * ctx)138     Status Initialize(IteratorContext* ctx) override {
139       if (dataset()->num_shards_ == kShardHint) {
140         return errors::FailedPrecondition(
141             "`tf.data.Dataset.shard(SHARD_HINT, ...)` can only be used in "
142             "`tf.distribute.Strategy.experimental_distribute_dataset()` with "
143             "`tf.data.experimental.AutoShardPolicy.HINT` policy, or tf.data "
144             "service with "
145             "`tf.data.experimental.service.ShardingPolicy.HINT` processing "
146             "mode.");
147       }
148       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
149     }
150 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)151     Status GetNextInternal(IteratorContext* ctx,
152                            std::vector<Tensor>* out_tensors,
153                            bool* end_of_sequence) override {
154       mutex_lock l(mu_);
155 
156       *end_of_sequence = false;
157       if (!input_impl_) {
158         *end_of_sequence = true;
159         return OkStatus();
160       }
161 
162       int num_to_skip =
163           (dataset()->index_ - next_index_) % dataset()->num_shards_;
164       if (num_to_skip < 0) {
165         num_to_skip += dataset()->num_shards_;
166       }
167       int num_skipped;
168       TF_RETURN_IF_ERROR(
169           input_impl_->Skip(ctx, num_to_skip, end_of_sequence, &num_skipped));
170       next_index_ += num_skipped;
171       if (*end_of_sequence) {
172         input_impl_.reset();
173         return OkStatus();
174       }
175 
176       std::vector<Tensor> result;
177       TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &result, end_of_sequence));
178       if (*end_of_sequence) {
179         input_impl_.reset();
180         return OkStatus();
181       }
182       next_index_++;
183 
184       if (dataset()->require_non_empty_ &&
185           next_index_ < dataset()->num_shards_) {
186         int num_skipped;
187         Status s = input_impl_->Skip(ctx, dataset()->num_shards_ - next_index_,
188                                      end_of_sequence, &num_skipped);
189         if (*end_of_sequence || errors::IsOutOfRange(s)) {
190           // `dataset()->require_non_empty_` implies that this transformation
191           // was introduced by auto_sharding rewrite, so it's acceptable
192           // produce an error message that assumes auto-sharding context.
193           return errors::InvalidArgument(
194               "Could not apply FILE based sharding: the dataset only has ",
195               next_index_, " file(s), which is not enough for the required ",
196               dataset()->num_shards_,
197               " shards/workers."
198               "If you are using datasets with distribution strategy, "
199               "consider setting the auto sharding policy to either DATA or "
200               "OFF using the `experimental_distribute.auto_shard_policy` option"
201               "of `tf.data.Options()`. Or, split your input files into a "
202               "larger number of small files such that number of files is "
203               "greater than number of shards/workers.");
204         } else if (!s.ok()) {
205           return s;
206         }
207 
208         next_index_ = dataset()->num_shards_;
209       }
210 
211       *out_tensors = std::move(result);
212       return OkStatus();
213     }
214 
215    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const216     std::shared_ptr<model::Node> CreateNode(
217         IteratorContext* ctx, model::Node::Args args) const override {
218       return model::MakeKnownRatioNode(std::move(args), dataset()->num_shards_);
219     }
220 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)221     Status SaveInternal(SerializationContext* ctx,
222                         IteratorStateWriter* writer) override {
223       mutex_lock l(mu_);
224       if (!input_impl_) {
225         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
226       } else {
227         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
228         TF_RETURN_IF_ERROR(
229             writer->WriteScalar(full_name(kNextIndex), next_index_));
230       }
231       return OkStatus();
232     }
233 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)234     Status RestoreInternal(IteratorContext* ctx,
235                            IteratorStateReader* reader) override {
236       mutex_lock l(mu_);
237       if (!reader->Contains(full_name(kInputImplEmpty))) {
238         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
239         TF_RETURN_IF_ERROR(
240             reader->ReadScalar(full_name(kNextIndex), &next_index_));
241       } else {
242         input_impl_.reset();
243       }
244       return OkStatus();
245     }
246 
GetTraceMeMetadata() const247     TraceMeMetadata GetTraceMeMetadata() const override {
248       return dataset()->traceme_metadata_;
249     }
250 
251    private:
252     mutex mu_;
253     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
254     int64_t next_index_ TF_GUARDED_BY(mu_);
255   };
256 
257   const int64_t num_shards_;
258   const int64_t index_;
259   const DatasetBase* const input_;
260   const bool require_non_empty_;
261   const TraceMeMetadata traceme_metadata_;
262 };
263 
ShardDatasetOp(OpKernelConstruction * ctx)264 ShardDatasetOp::ShardDatasetOp(OpKernelConstruction* ctx)
265     : UnaryDatasetOpKernel(ctx) {
266   OP_REQUIRES_OK(ctx, ctx->GetAttr(kRequireNonEmpty, &require_non_empty_));
267 }
268 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)269 void ShardDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
270                                  DatasetBase** output) {
271   int64_t index = 0;
272   int64_t num_shards = 0;
273 
274   OP_REQUIRES_OK(ctx,
275                  ParseScalarArgument<int64_t>(ctx, kNumShards, &num_shards));
276   OP_REQUIRES(
277       ctx, num_shards > 0 || num_shards == kShardHint,
278       errors::InvalidArgument("Number of shards must be greater than zero "
279                               "(currently num_shards = ",
280                               num_shards, ")."));
281 
282   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kIndex, &index));
283   OP_REQUIRES(
284       ctx, (index >= 0 && index < num_shards) || num_shards == kShardHint,
285       errors::InvalidArgument("Index must be between 0 and ", num_shards - 1,
286                               " (currently index = ", index, ")."));
287 
288   *output = new Dataset(ctx, num_shards, index, require_non_empty_, input);
289 }
290 
291 namespace {
292 REGISTER_KERNEL_BUILDER(Name("ShardDataset").Device(DEVICE_CPU),
293                         ShardDatasetOp);
294 }  // namespace
295 }  // namespace data
296 }  // namespace tensorflow
297