1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/finalize_dataset_op.h"
16
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/model.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h"
24 #include "tensorflow/core/kernels/data/model_dataset_op.h"
25 #include "tensorflow/core/kernels/data/optimize_dataset_op.h"
26
27 namespace tensorflow {
28 namespace data {
29
30 /* static */ constexpr const char* const FinalizeDatasetOp::kDatasetType;
31 /* static */ constexpr const char* const FinalizeDatasetOp::kInputDataset;
32 /* static */ constexpr const char* const FinalizeDatasetOp::kOutputTypes;
33 /* static */ constexpr const char* const FinalizeDatasetOp::kOutputShapes;
34 /* static */ constexpr const char* const FinalizeDatasetOp::kHasCapturedRef;
35
36 namespace {
37
GetModelDatasetParams(const Options & options,model::AutotuneAlgorithm * algorithm,int64_t * cpu_budget,int64_t * ram_budget)38 void GetModelDatasetParams(const Options& options,
39 model::AutotuneAlgorithm* algorithm,
40 int64_t* cpu_budget, int64_t* ram_budget) {
41 *algorithm = model::AutotuneAlgorithm::HILL_CLIMB;
42 *cpu_budget = options.autotune_options().cpu_budget();
43 *ram_budget = options.autotune_options().ram_budget();
44 }
45
MakeDatasetHelper(OpKernelContext * ctx,bool has_captured_ref,DatasetBase * input,DatasetBase ** output)46 void MakeDatasetHelper(OpKernelContext* ctx, bool has_captured_ref,
47 DatasetBase* input, DatasetBase** output) {
48 *output = input;
49 input->Ref();
50 const Options& options = input->options();
51 if (ShouldConfigureMaxIntraOpParallelism(options)) {
52 experimental::MaxIntraOpParallelismDatasetOp::MakeDatasetFromOptions(
53 ctx, input, options.threading_options().max_intra_op_parallelism(),
54 output);
55 input->Unref();
56 input = *output;
57 }
58 if (ShouldUsePrivateThreadPool(options)) {
59 experimental::PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(
60 ctx, input, options.threading_options().private_threadpool_size(),
61 output);
62 input->Unref();
63 input = *output;
64 }
65 if (ShouldUseAutotuning(options)) {
66 model::AutotuneAlgorithm algorithm;
67 int64_t cpu_budget;
68 int64_t ram_budget;
69 GetModelDatasetParams(options, &algorithm, &cpu_budget, &ram_budget);
70 ModelDatasetOp::MakeDatasetFromOptions(ctx, input, algorithm, cpu_budget,
71 ram_budget, output);
72 input->Unref();
73 input = *output;
74 }
75 absl::flat_hash_set<tstring> optimizations_enabled;
76 absl::flat_hash_set<tstring> optimizations_disabled;
77 absl::flat_hash_set<tstring> optimizations_default;
78 GetOptimizations(options, &optimizations_enabled, &optimizations_disabled,
79 &optimizations_default);
80 if (ShouldApplyOptimizations(options, optimizations_enabled,
81 optimizations_default)) {
82 if (has_captured_ref &&
83 (!optimizations_enabled.empty() || !optimizations_default.empty())) {
84 LOG(WARNING)
85 << "tf.data graph rewrites are not compatible with reference "
86 "variables. The following rewrites will be disabled: "
87 << absl::StrJoin(optimizations_enabled, ", ") << ", "
88 << absl::StrJoin(optimizations_default, ", ") << ". "
89 << "To enable rewrites, use resource variables instead by calling "
90 "`tf.enable_resource_variables()` at the start of the program.";
91 } else {
92 auto optimization_configs = CreateGraphRewriteConfigs(options);
93 OptimizeDatasetOp::MakeDatasetFromOptions(
94 ctx, input, optimizations_enabled, optimizations_disabled,
95 optimizations_default, optimization_configs, output);
96 input->Unref();
97 input = *output;
98 }
99 }
100 }
101
102 } // namespace
103
FinalizeDatasetOp(OpKernelConstruction * ctx)104 FinalizeDatasetOp::FinalizeDatasetOp(OpKernelConstruction* ctx)
105 : UnaryDatasetOpKernel(ctx) {
106 if (ctx->HasAttr(kHasCapturedRef)) {
107 OP_REQUIRES_OK(ctx, ctx->GetAttr(kHasCapturedRef, &has_captured_ref_));
108 } else {
109 has_captured_ref_ = false;
110 }
111 }
112
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)113 void FinalizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
114 DatasetBase** output) {
115 MakeDatasetHelper(ctx, has_captured_ref_, input, output);
116 }
117
118 namespace {
119 REGISTER_KERNEL_BUILDER(Name("FinalizeDataset").Device(DEVICE_CPU).Priority(2),
120 FinalizeDatasetOp);
121 REGISTER_KERNEL_BUILDER(Name("FinalizeDataset")
122 .Device(DEVICE_GPU)
123 .HostMemory("input_dataset")
124 .HostMemory("handle")
125 .Priority(1),
126 FinalizeDatasetNoopOp);
127 } // namespace
128 } // namespace data
129 } // namespace tensorflow
130