xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/reduce_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/reduce_dataset_op.h"
17 
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/data/root_dataset.h"
20 #include "tensorflow/core/platform/resource.h"
21 #include "tensorflow/core/profiler/lib/traceme.h"
22 
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 
27 const char kOutputShapes[] = "output_shapes";
28 const char kOutputTypes[] = "output_types";
29 
30 }  // namespace
31 
ReduceDatasetOp(OpKernelConstruction * ctx)32 ReduceDatasetOp::ReduceDatasetOp(OpKernelConstruction* ctx)
33     : HybridAsyncOpKernel(ctx, "tf_data_reduce_dataset") {
34   FunctionMetadata::Params params;
35   OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
36                                    &params.use_inter_op_parallelism));
37   params.use_default_device = false;
38   OP_REQUIRES_OK(ctx,
39                  FunctionMetadata::Create(ctx, "f", params, &func_metadata_));
40   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
41   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
42 }
43 
DoCompute(OpKernelContext * ctx)44 Status ReduceDatasetOp::DoCompute(OpKernelContext* ctx) {
45   profiler::TraceMe traceme(
46       [&] {
47         return profiler::TraceMeEncode("ReduceDatasetOp::DoCompute",
48                                        {{"id", ctx->step_id()}});
49       },
50       profiler::kInfo);
51   tensorflow::ResourceTagger tag(kTFDataResourceTag,
52                                  ctx->op_kernel().type_string());
53   DatasetBase* dataset;
54   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
55   OpInputList inputs;
56   TF_RETURN_IF_ERROR(ctx->input_list("initial_state", &inputs));
57   std::vector<Tensor> state(inputs.begin(), inputs.end());
58 
59   std::unique_ptr<CapturedFunction> captured_func;
60   TF_RETURN_IF_ERROR(CapturedFunction::Create(
61       ctx, func_metadata_, "other_arguments", &captured_func));
62 
63   IteratorContext::Params params(ctx);
64   auto function_handle_cache =
65       std::make_unique<FunctionHandleCache>(params.flr);
66   params.function_handle_cache = function_handle_cache.get();
67   ResourceMgr resource_mgr;
68   params.resource_mgr = &resource_mgr;
69   CancellationManager cancellation_manager(ctx->cancellation_manager());
70   params.cancellation_manager = &cancellation_manager;
71 
72   IteratorContext iter_ctx(std::move(params));
73   std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
74   TF_RETURN_IF_ERROR(
75       captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
76 
77   std::unique_ptr<IteratorBase> iterator;
78   if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
79     DatasetBase* finalized_dataset = nullptr;
80     TF_RETURN_IF_ERROR(FinalizeDataset(ctx, dataset, &finalized_dataset));
81     core::ScopedUnref unref(finalized_dataset);
82     TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
83         &iter_ctx, /*parent=*/nullptr, "ReduceIterator", &iterator));
84   } else {
85     TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr,
86                                              "ReduceIterator", &iterator));
87   }
88 
89   // Iterate through the input dataset.
90   while (true) {
91     if (ctx->cancellation_manager()->IsCancelled()) {
92       return errors::Cancelled("Operation was cancelled");
93     }
94     std::vector<Tensor> next_input_element;
95     bool end_of_input;
96     TF_RETURN_IF_ERROR(
97         iterator->GetNext(&iter_ctx, &next_input_element, &end_of_input));
98     if (end_of_input) {
99       break;
100     }
101 
102     // Run the reduce function to update the current state.
103     std::vector<Tensor> args;
104     args.reserve(state.size() + next_input_element.size());
105     std::copy(state.begin(), state.end(), std::back_inserter(args));
106     std::copy(next_input_element.begin(), next_input_element.end(),
107               std::back_inserter(args));
108 
109     std::vector<Tensor> reduce_func_output;
110     TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
111         &iter_ctx, std::move(args), &reduce_func_output, /*node=*/nullptr));
112     if (reduce_func_output.size() != state.size()) {
113       return errors::InvalidArgument(
114           "The number of components of the initial state and the "
115           "reduce "
116           "function output does not match. (initial_state=",
117           state.size(), ", output=", reduce_func_output.size(), ").");
118     }
119     std::swap(reduce_func_output, state);
120   }
121 
122   TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
123   TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
124   for (size_t i = 0; i < state.size(); ++i) {
125     ctx->set_output(i, state[i]);
126   }
127   return OkStatus();
128 }
129 
130 namespace {
131 
132 REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
133                         ReduceDatasetOp);
134 REGISTER_INPUT_COLOCATION_EXEMPTION("ReduceDataset");
135 
136 }  // namespace
137 }  // namespace data
138 }  // namespace tensorflow
139