xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/finalize_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/finalize_dataset_op.h"
16 
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/model.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h"
24 #include "tensorflow/core/kernels/data/model_dataset_op.h"
25 #include "tensorflow/core/kernels/data/optimize_dataset_op.h"
26 
27 namespace tensorflow {
28 namespace data {
29 
30 /* static */ constexpr const char* const FinalizeDatasetOp::kDatasetType;
31 /* static */ constexpr const char* const FinalizeDatasetOp::kInputDataset;
32 /* static */ constexpr const char* const FinalizeDatasetOp::kOutputTypes;
33 /* static */ constexpr const char* const FinalizeDatasetOp::kOutputShapes;
34 /* static */ constexpr const char* const FinalizeDatasetOp::kHasCapturedRef;
35 
36 namespace {
37 
GetModelDatasetParams(const Options & options,model::AutotuneAlgorithm * algorithm,int64_t * cpu_budget,int64_t * ram_budget)38 void GetModelDatasetParams(const Options& options,
39                            model::AutotuneAlgorithm* algorithm,
40                            int64_t* cpu_budget, int64_t* ram_budget) {
41   *algorithm = model::AutotuneAlgorithm::HILL_CLIMB;
42   *cpu_budget = options.autotune_options().cpu_budget();
43   *ram_budget = options.autotune_options().ram_budget();
44 }
45 
MakeDatasetHelper(OpKernelContext * ctx,bool has_captured_ref,DatasetBase * input,DatasetBase ** output)46 void MakeDatasetHelper(OpKernelContext* ctx, bool has_captured_ref,
47                        DatasetBase* input, DatasetBase** output) {
48   *output = input;
49   input->Ref();
50   const Options& options = input->options();
51   if (ShouldConfigureMaxIntraOpParallelism(options)) {
52     experimental::MaxIntraOpParallelismDatasetOp::MakeDatasetFromOptions(
53         ctx, input, options.threading_options().max_intra_op_parallelism(),
54         output);
55     input->Unref();
56     input = *output;
57   }
58   if (ShouldUsePrivateThreadPool(options)) {
59     experimental::PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(
60         ctx, input, options.threading_options().private_threadpool_size(),
61         output);
62     input->Unref();
63     input = *output;
64   }
65   if (ShouldUseAutotuning(options)) {
66     model::AutotuneAlgorithm algorithm;
67     int64_t cpu_budget;
68     int64_t ram_budget;
69     GetModelDatasetParams(options, &algorithm, &cpu_budget, &ram_budget);
70     ModelDatasetOp::MakeDatasetFromOptions(ctx, input, algorithm, cpu_budget,
71                                            ram_budget, output);
72     input->Unref();
73     input = *output;
74   }
75   absl::flat_hash_set<tstring> optimizations_enabled;
76   absl::flat_hash_set<tstring> optimizations_disabled;
77   absl::flat_hash_set<tstring> optimizations_default;
78   GetOptimizations(options, &optimizations_enabled, &optimizations_disabled,
79                    &optimizations_default);
80   if (ShouldApplyOptimizations(options, optimizations_enabled,
81                                optimizations_default)) {
82     if (has_captured_ref &&
83         (!optimizations_enabled.empty() || !optimizations_default.empty())) {
84       LOG(WARNING)
85           << "tf.data graph rewrites are not compatible with reference "
86              "variables. The following rewrites will be disabled: "
87           << absl::StrJoin(optimizations_enabled, ", ") << ", "
88           << absl::StrJoin(optimizations_default, ", ") << ". "
89           << "To enable rewrites, use resource variables instead by calling "
90              "`tf.enable_resource_variables()` at the start of the program.";
91     } else {
92       auto optimization_configs = CreateGraphRewriteConfigs(options);
93       OptimizeDatasetOp::MakeDatasetFromOptions(
94           ctx, input, optimizations_enabled, optimizations_disabled,
95           optimizations_default, optimization_configs, output);
96       input->Unref();
97       input = *output;
98     }
99   }
100 }
101 
102 }  // namespace
103 
FinalizeDatasetOp(OpKernelConstruction * ctx)104 FinalizeDatasetOp::FinalizeDatasetOp(OpKernelConstruction* ctx)
105     : UnaryDatasetOpKernel(ctx) {
106   if (ctx->HasAttr(kHasCapturedRef)) {
107     OP_REQUIRES_OK(ctx, ctx->GetAttr(kHasCapturedRef, &has_captured_ref_));
108   } else {
109     has_captured_ref_ = false;
110   }
111 }
112 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)113 void FinalizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
114                                     DatasetBase** output) {
115   MakeDatasetHelper(ctx, has_captured_ref_, input, output);
116 }
117 
118 namespace {
119 REGISTER_KERNEL_BUILDER(Name("FinalizeDataset").Device(DEVICE_CPU).Priority(2),
120                         FinalizeDatasetOp);
121 REGISTER_KERNEL_BUILDER(Name("FinalizeDataset")
122                             .Device(DEVICE_GPU)
123                             .HostMemory("input_dataset")
124                             .HostMemory("handle")
125                             .Priority(1),
126                         FinalizeDatasetNoopOp);
127 }  // namespace
128 }  // namespace data
129 }  // namespace tensorflow
130