xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/function_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/function_ops.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/gradients.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/memory_types.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/full_type.pb.h"
29 #include "tensorflow/core/framework/full_type_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/tracing.h"
35 #include "tensorflow/core/profiler/lib/traceme.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 
38 namespace tensorflow {
39 
40 static constexpr const char* const kGradientOp =
41     FunctionLibraryDefinition::kGradientOp;
42 
ArgOp(OpKernelConstruction * ctx)43 ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
44   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
45   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
46 }
47 
Compute(OpKernelContext * ctx)48 void ArgOp::Compute(OpKernelContext* ctx) {
49   auto frame = ctx->call_frame();
50   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
51   const Tensor* val;
52 
53   auto validate_type = [this](const Tensor& val) {
54     if (val.dtype() == dtype_) {
55       return OkStatus();
56     } else {
57       return errors::InvalidArgument("Type mismatch: actual ",
58                                      DataTypeString(val.dtype()),
59                                      " vs. expect ", DataTypeString(dtype_));
60     }
61   };
62 
63   if (frame->CanConsumeArg(index_)) {
64     Tensor val;
65     frame->ConsumeArg(index_, &val);
66     OP_REQUIRES_OK(ctx, validate_type(val));
67     ctx->set_output(0, std::move(val));
68   } else {
69     OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
70     OP_REQUIRES_OK(ctx, validate_type(*val));
71     ctx->set_output(0, *val);
72   }
73 }
74 
RetvalOp(OpKernelConstruction * ctx)75 RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
76   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
77   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
78 }
79 
Compute(OpKernelContext * ctx)80 void RetvalOp::Compute(OpKernelContext* ctx) {
81   const Tensor& val = ctx->input(0);
82   OP_REQUIRES(ctx, val.dtype() == dtype_,
83               errors::InvalidArgument("Type mismatch: actual ",
84                                       DataTypeString(val.dtype()),
85                                       " vs. expect ", DataTypeString(dtype_)));
86   auto frame = ctx->call_frame();
87   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
88   OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
89 }
90 
91 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
92 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
93 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
94 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
95 
96 // TPU ops are only registered when they are required as part of the larger
97 // TPU runtime, and does not need to be registered when selective registration
98 // is turned on.
99 REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
100 
101 #define REGISTER(type)     \
102   REGISTER_KERNEL_BUILDER( \
103       Name(kArgOp).Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), ArgOp);
104 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER);
105 TF_CALL_QUANTIZED_TYPES(REGISTER);
106 TF_CALL_bool(REGISTER);
107 
108 REGISTER_KERNEL_BUILDER(
109     Name(kDeviceArgOp).Device(DEVICE_DEFAULT).TypeConstraint<int32>("T"),
110     ArgOp);
111 
112 REGISTER_KERNEL_BUILDER(Name(kArgOp)
113                             .Device(DEVICE_DEFAULT)
114                             .HostMemory("output")
115                             .TypeConstraint<int32>("T"),
116                         ArgOp);
117 #undef REGISTER
118 
119 REGISTER_KERNEL_BUILDER(Name(kArgOp)
120                             .Device(DEVICE_DEFAULT)
121                             .HostMemory("output")
122                             .TypeConstraint<ResourceHandle>("T"),
123                         ArgOp);
124 
125 REGISTER_KERNEL_BUILDER(Name(kArgOp)
126                             .Device(DEVICE_DEFAULT)
127                             .HostMemory("output")
128                             .TypeConstraint<tstring>("T"),
129                         ArgOp);
130 
131 REGISTER_KERNEL_BUILDER(
132     Name(kArgOp).Device(DEVICE_DEFAULT).TypeConstraint<Variant>("T"), ArgOp);
133 
134 #define REGISTER(type)                                               \
135   REGISTER_KERNEL_BUILDER(                                           \
136       Name(kRetOp).Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \
137       RetvalOp);
138 
139 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER);
140 TF_CALL_QUANTIZED_TYPES(REGISTER);
141 TF_CALL_qint16(REGISTER);
142 TF_CALL_quint16(REGISTER);
143 REGISTER(Variant);
144 TF_CALL_bool(REGISTER);
145 
146 REGISTER_KERNEL_BUILDER(Name(kRetOp)
147                             .Device(DEVICE_DEFAULT)
148                             .HostMemory("input")
149                             .TypeConstraint<int32>("T"),
150                         RetvalOp);
151 REGISTER_KERNEL_BUILDER(
152     Name(kDeviceRetOp).Device(DEVICE_DEFAULT).TypeConstraint<int32>("T"),
153     RetvalOp);
154 
155 REGISTER_KERNEL_BUILDER(Name(kRetOp)
156                             .Device(DEVICE_DEFAULT)
157                             .TypeConstraint<ResourceHandle>("T")
158                             .HostMemory("input"),
159                         RetvalOp);
160 
161 REGISTER_KERNEL_BUILDER(Name(kRetOp)
162                             .Device(DEVICE_DEFAULT)
163                             .TypeConstraint<tstring>("T")
164                             .HostMemory("input"),
165                         RetvalOp);
166 
167 #undef REGISTER
168 
169 class PassOn : public OpKernel {
170  public:
PassOn(OpKernelConstruction * ctx)171   explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
172     OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
173                 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
174                                  " vs. ", ctx->num_outputs()));
175     for (int i = 0; i < ctx->num_inputs(); ++i) {
176       OP_REQUIRES(
177           ctx, input_type(i) == output_type(i),
178           errors::Internal("Input and output types for position ", i,
179                            " do not match: ", DataTypeString(input_type(i)),
180                            " vs. ", DataTypeString(output_type(i))));
181     }
182   }
183 
Compute(OpKernelContext * ctx)184   void Compute(OpKernelContext* ctx) override {
185     for (int i = 0; i < ctx->num_inputs(); ++i) {
186       ctx->set_output(i, ctx->input(i));
187     }
188   }
189 };
190 
191 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
192 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
193 
194 #define REGISTER_DEFAULT_KERNELS(type)                                       \
195   REGISTER_KERNEL_BUILDER(                                                   \
196       Name("_ListToArray").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \
197       PassOn);                                                               \
198   REGISTER_KERNEL_BUILDER(                                                   \
199       Name("_ArrayToList").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \
200       PassOn);
201 
202 REGISTER_DEFAULT_KERNELS(Eigen::half);
203 REGISTER_DEFAULT_KERNELS(float);
204 REGISTER_DEFAULT_KERNELS(double);
205 
206 #undef REGISTER_DEFAULT_KERNELS
207 
208 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
209                             .Device(DEVICE_DEFAULT)
210                             .HostMemory("input")
211                             .HostMemory("output")
212                             .TypeConstraint<int32>("T"),
213                         PassOn);
214 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
215                             .Device(DEVICE_DEFAULT)
216                             .HostMemory("input")
217                             .HostMemory("output")
218                             .TypeConstraint<int32>("T"),
219                         PassOn);
220 
221 class SymbolicGradientOp : public AsyncOpKernel {
222  public:
SymbolicGradientOp(OpKernelConstruction * ctx)223   explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
224 
~SymbolicGradientOp()225   ~SymbolicGradientOp() override {}
226 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)227   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
228     FunctionLibraryRuntime* lib = ctx->function_library();
229     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
230                       errors::Internal("No function library is provided."),
231                       done);
232 
233     FunctionLibraryRuntime::Handle handle;
234     OP_REQUIRES_OK_ASYNC(
235         ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
236 
237     FunctionLibraryRuntime::Options opts;
238     opts.rendezvous = ctx->rendezvous();
239     opts.cancellation_manager = ctx->cancellation_manager();
240     opts.collective_executor = ctx->collective_executor();
241     opts.runner = ctx->runner();
242     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
243     opts.stats_collector = ctx->stats_collector();
244     opts.step_container = ctx->step_container();
245     std::vector<Tensor> args;
246     args.reserve(ctx->num_inputs());
247     for (int i = 0; i < ctx->num_inputs(); ++i) {
248       args.push_back(ctx->input(i));
249     }
250     std::vector<Tensor>* rets = new std::vector<Tensor>;
251     profiler::TraceMe trace_me("SymbolicGradientOp");
252     lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
253       if (!status.ok()) {
254         ctx->SetStatus(status);
255       } else if (rets->size() != ctx->num_outputs()) {
256         ctx->SetStatus(errors::InvalidArgument(
257             "SymGrad expects to return ", ctx->num_outputs(),
258             " tensor(s), but get ", rets->size(), " tensor(s) instead."));
259       } else {
260         for (size_t i = 0; i < rets->size(); ++i) {
261           ctx->set_output(i, std::move((*rets)[i]));
262         }
263       }
264       delete rets;
265       done();
266     });
267   }
268 
269  private:
270   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
271 };
272 
273 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
274                         SymbolicGradientOp);
275 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_DEFAULT),
276                         SymbolicGradientOp);
277 
RemoteCallOp(OpKernelConstruction * ctx)278 RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx)
279     : AsyncOpKernel(ctx), return_type_(ctx->def().experimental_type()) {
280   OP_REQUIRES_OK(ctx,
281                  ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
282   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
283   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
284 }
285 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)286 void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
287   FunctionLibraryRuntime* lib = ctx->function_library();
288   OP_REQUIRES_ASYNC(ctx, lib != nullptr,
289                     errors::Internal("No function library is provided."), done);
290 
291   const string& source_device = lib->device()->name();
292   const Tensor* target;
293   OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
294 
295   FunctionTarget function_target;
296   OP_REQUIRES_OK_ASYNC(
297       ctx,
298       DeviceNameUtils::CanonicalizeDeviceName(
299           target->scalar<tstring>()(), source_device, &function_target.first),
300       done);
301   function_target.second = lib;
302 
303   const string& target_device = function_target.first;
304   const string& func_name = func_.name();
305 
306   FunctionLibraryRuntime::Handle handle;
307   {
308     mutex_lock l(mu_);
309     auto cached_entry = handle_cache_.find(function_target);
310     if (cached_entry != handle_cache_.end()) {
311       handle = cached_entry->second;
312     } else {
313       VLOG(1) << "Instantiating " << func_name << " on " << target_device;
314       profiler::TraceMe activity(
315           [&] {
316             return strings::StrCat("RemoteCall: Instantiate: ", func_name,
317                                    " on ", target_device);
318           },
319           profiler::TraceMeLevel::kInfo);
320       FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
321       const auto* config = (ctx->function_library())
322                                ? ctx->function_library()->config_proto()
323                                : nullptr;
324       if (config) {
325         instantiate_opts.config_proto = *config;
326       }
327       instantiate_opts.target = target_device;
328       OP_REQUIRES_OK_ASYNC(ctx,
329                            lib->Instantiate(func_name, AttrSlice(&func_.attr()),
330                                             instantiate_opts, &handle),
331                            done);
332       auto insert_result = handle_cache_.insert({function_target, handle});
333       CHECK(insert_result.second) << "Insert unsuccessful.";
334       VLOG(1) << "Instantiated " << func_name << " on " << target_device
335               << ", resulting in handle: " << handle << " flr: " << lib;
336     }
337   }
338 
339   OpInputList arguments;
340   OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
341 
342   FunctionLibraryRuntime::Options opts;
343   opts.runner = nullptr;  // Use default runner at remote device.
344   opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
345   opts.source_device = source_device;
346   if (opts.source_device != target_device) {
347     opts.remote_execution = true;
348   }
349   opts.create_rendezvous = true;
350   CancellationManager* cancel_mgr = nullptr;
351   if (ctx->cancellation_manager() != nullptr) {
352     cancel_mgr = new CancellationManager(ctx->cancellation_manager());
353   }
354   opts.cancellation_manager = cancel_mgr;
355   opts.collective_executor = ctx->collective_executor();
356   std::vector<Tensor> args(arguments.begin(), arguments.end());
357   opts.args_alloc_attrs.reserve(input_dtypes_.size());
358   for (const auto& dtype : input_dtypes_) {
359     AllocatorAttributes arg_alloc_attrs;
360     arg_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
361     opts.args_alloc_attrs.push_back(arg_alloc_attrs);
362   }
363   opts.rets_alloc_attrs.reserve(output_dtypes_.size());
364   DCHECK(!return_type_.IsInitialized() ||
365          (return_type_.type_id() == TFT_UNSET) ||
366          (output_dtypes_.size() == return_type_.args_size()))
367       << "RemoteCall op has a full type information for "
368       << return_type_.args_size() << " outputs but the number of outputs is "
369       << output_dtypes_.size();
370   for (const auto& dtype : output_dtypes_) {
371     AllocatorAttributes ret_alloc_attrs;
372     bool on_host = DataTypeAlwaysOnHost(dtype);
373     if (return_type_.IsInitialized() && (return_type_.type_id() != TFT_UNSET)) {
374       DCHECK(return_type_.type_id() == TFT_PRODUCT)
375           << return_type_.DebugString();
376       FullTypeDef ftd = full_type::GetArgDefaultUnset(
377           return_type_, opts.rets_alloc_attrs.size());
378       if (full_type::IsHostMemoryType(ftd)) {
379         on_host = true;
380       }
381       VLOG(5) << "FulltypeDef for RemoteCall output="
382               << opts.rets_alloc_attrs.size()
383               << ", IsHostMemoryType=" << full_type::IsHostMemoryType(ftd)
384               << ":\n"
385               << ftd.DebugString();
386     }
387     ret_alloc_attrs.set_on_host(on_host);
388     opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
389   }
390   auto* rets = new std::vector<Tensor>;
391   VLOG(1) << "Running " << func_name << " on " << target_device
392           << " with handle: " << handle;
393   profiler::TraceMe trace_me(
394       [&] {
395         return profiler::TraceMeEncode(
396             "RemoteCallOp",
397             {{"func_name", func_name}, {"device", target_device}});
398       },
399       profiler::TraceMeLevel::kInfo);
400   lib->Run(
401       opts, handle, args, rets,
402       [rets, done = std::move(done), func_name, ctx, cancel_mgr,
403        target_device = std::move(function_target.first)](const Status& status) {
404         profiler::TraceMe activity(
405             [&] {
406               return profiler::TraceMeEncode(
407                   "RemoteCallOpDone",
408                   {{"func_name", func_name}, {"device", target_device}});
409             },
410             profiler::TraceMeLevel::kInfo);
411         if (!status.ok()) {
412           ctx->SetStatus(status);
413         } else {
414           for (size_t i = 0; i < rets->size(); ++i) {
415             ctx->set_output(i, std::move((*rets)[i]));
416           }
417         }
418         delete cancel_mgr;
419         delete rets;
420         done();
421       });
422 }
423 
TraceString(const OpKernelContext & ctx,bool verbose) const424 string RemoteCallOp::TraceString(const OpKernelContext& ctx,
425                                  bool verbose) const {
426   string trace_string = profiler::TraceMeOp(
427       strings::StrCat(name_view(), "__", func_.name()), type_string_view());
428   if (verbose) {
429     string shape = ShapeTraceString(ctx);
430     if (!shape.empty()) {
431       trace_string =
432           profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
433     }
434   }
435   return trace_string;
436 }
437 
438 REGISTER_KERNEL_BUILDER(
439     Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
440 REGISTER_KERNEL_BUILDER(
441     Name("RemoteCall").Device(DEVICE_DEFAULT).HostMemory("target"),
442     RemoteCallOp);
443 }  // namespace tensorflow
444