1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/reduce_dataset_op.h"
17
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/data/root_dataset.h"
20 #include "tensorflow/core/platform/resource.h"
21 #include "tensorflow/core/profiler/lib/traceme.h"
22
23 namespace tensorflow {
24 namespace data {
25 namespace {
26
27 const char kOutputShapes[] = "output_shapes";
28 const char kOutputTypes[] = "output_types";
29
30 } // namespace
31
ReduceDatasetOp(OpKernelConstruction * ctx)32 ReduceDatasetOp::ReduceDatasetOp(OpKernelConstruction* ctx)
33 : HybridAsyncOpKernel(ctx, "tf_data_reduce_dataset") {
34 FunctionMetadata::Params params;
35 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
36 ¶ms.use_inter_op_parallelism));
37 params.use_default_device = false;
38 OP_REQUIRES_OK(ctx,
39 FunctionMetadata::Create(ctx, "f", params, &func_metadata_));
40 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
41 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
42 }
43
DoCompute(OpKernelContext * ctx)44 Status ReduceDatasetOp::DoCompute(OpKernelContext* ctx) {
45 profiler::TraceMe traceme(
46 [&] {
47 return profiler::TraceMeEncode("ReduceDatasetOp::DoCompute",
48 {{"id", ctx->step_id()}});
49 },
50 profiler::kInfo);
51 tensorflow::ResourceTagger tag(kTFDataResourceTag,
52 ctx->op_kernel().type_string());
53 DatasetBase* dataset;
54 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
55 OpInputList inputs;
56 TF_RETURN_IF_ERROR(ctx->input_list("initial_state", &inputs));
57 std::vector<Tensor> state(inputs.begin(), inputs.end());
58
59 std::unique_ptr<CapturedFunction> captured_func;
60 TF_RETURN_IF_ERROR(CapturedFunction::Create(
61 ctx, func_metadata_, "other_arguments", &captured_func));
62
63 IteratorContext::Params params(ctx);
64 auto function_handle_cache =
65 std::make_unique<FunctionHandleCache>(params.flr);
66 params.function_handle_cache = function_handle_cache.get();
67 ResourceMgr resource_mgr;
68 params.resource_mgr = &resource_mgr;
69 CancellationManager cancellation_manager(ctx->cancellation_manager());
70 params.cancellation_manager = &cancellation_manager;
71
72 IteratorContext iter_ctx(std::move(params));
73 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
74 TF_RETURN_IF_ERROR(
75 captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
76
77 std::unique_ptr<IteratorBase> iterator;
78 if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
79 DatasetBase* finalized_dataset = nullptr;
80 TF_RETURN_IF_ERROR(FinalizeDataset(ctx, dataset, &finalized_dataset));
81 core::ScopedUnref unref(finalized_dataset);
82 TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
83 &iter_ctx, /*parent=*/nullptr, "ReduceIterator", &iterator));
84 } else {
85 TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr,
86 "ReduceIterator", &iterator));
87 }
88
89 // Iterate through the input dataset.
90 while (true) {
91 if (ctx->cancellation_manager()->IsCancelled()) {
92 return errors::Cancelled("Operation was cancelled");
93 }
94 std::vector<Tensor> next_input_element;
95 bool end_of_input;
96 TF_RETURN_IF_ERROR(
97 iterator->GetNext(&iter_ctx, &next_input_element, &end_of_input));
98 if (end_of_input) {
99 break;
100 }
101
102 // Run the reduce function to update the current state.
103 std::vector<Tensor> args;
104 args.reserve(state.size() + next_input_element.size());
105 std::copy(state.begin(), state.end(), std::back_inserter(args));
106 std::copy(next_input_element.begin(), next_input_element.end(),
107 std::back_inserter(args));
108
109 std::vector<Tensor> reduce_func_output;
110 TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
111 &iter_ctx, std::move(args), &reduce_func_output, /*node=*/nullptr));
112 if (reduce_func_output.size() != state.size()) {
113 return errors::InvalidArgument(
114 "The number of components of the initial state and the "
115 "reduce "
116 "function output does not match. (initial_state=",
117 state.size(), ", output=", reduce_func_output.size(), ").");
118 }
119 std::swap(reduce_func_output, state);
120 }
121
122 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
123 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
124 for (size_t i = 0; i < state.size(); ++i) {
125 ctx->set_output(i, state[i]);
126 }
127 return OkStatus();
128 }
129
130 namespace {
131
132 REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
133 ReduceDatasetOp);
134 REGISTER_INPUT_COLOCATION_EXEMPTION("ReduceDataset");
135
136 } // namespace
137 } // namespace data
138 } // namespace tensorflow
139