xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/memory/memory.h"
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/graph_optimizer.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/graph_to_functiondef.h"
36 #include "tensorflow/core/framework/node_def_builder.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/grappler/clusters/utils.h"
41 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
42 #include "tensorflow/core/lib/core/refcount.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/stream_executor.h"
48 #include "tensorflow/core/platform/thread_annotations.h"
49 #include "tensorflow/core/platform/types.h"
50 #include "tensorflow/core/profiler/lib/traceme.h"
51 #include "tensorflow/core/util/env_var.h"
52 
53 #if GOOGLE_CUDA && GOOGLE_TENSORRT
54 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
55 #include "third_party/tensorrt/NvInfer.h"
56 
57 namespace tensorflow {
58 namespace tensorrt {
59 namespace {
60 Logger& logger = *Logger::GetLogger();
61 using absl::StrAppend;
62 using absl::StrCat;
63 using ::nvinfer1::IRuntime;
64 
65 #define LOG_FIRST_FEW_WARNING_WITH_PREFIX \
66   LOG_FIRST_N(WARNING, 5) << "TF-TRT Warning: "
67 
68 // Allocates device memory for an execution context to execute a TensorRT
69 // engine and records the relevant information for deallocating the memory when
70 // the engine finishes execution.
71 class ContextDeviceMemory {
72  public:
ContextDeviceMemory()73   ContextDeviceMemory()
74       : execution_context_(nullptr),
75         device_memory_allocator_(nullptr),
76         device_memory_(nullptr) {}
77 
~ContextDeviceMemory()78   ~ContextDeviceMemory() {
79     if (device_memory_) {
80       device_memory_allocator_->free(device_memory_);
81     }
82   }
83 
AllocateDeviceMemory(nvinfer1::IExecutionContext * execution_context,TRTBaseAllocator * device_memory_allocator,size_t device_memory_size)84   Status AllocateDeviceMemory(nvinfer1::IExecutionContext* execution_context,
85                               TRTBaseAllocator* device_memory_allocator,
86                               size_t device_memory_size) {
87     execution_context_ = execution_context;
88     device_memory_allocator_ = device_memory_allocator;
89     device_memory_ = nullptr;
90     VLOG(2) << "Device memory size for TensorRT engine " << device_memory_size;
91     if (device_memory_size > 0) {
92       device_memory_ = device_memory_allocator_->allocate(
93           device_memory_size,
94           /*unused alignment=*/0, /*flags=*/0);
95       if (device_memory_ == nullptr) {
96         return errors::InvalidArgument(
97             "Out of GPU memory for execution context");
98       }
99     }
100     {
101       tensorflow::profiler::TraceMe activity(
102           "setDeviceMemory", tensorflow::profiler::TraceMeLevel::kInfo);
103       execution_context_->setDeviceMemory(device_memory_);
104     }
105     return Status::OK();
106   }
107 
108  private:
109   nvinfer1::IExecutionContext* execution_context_;
110   TRTBaseAllocator* device_memory_allocator_;
111   void* device_memory_;
112 };
113 
114 // Macros for asynchronous execution, such as OP_REQUIRES_OK_ASYNC requires an
115 // object with operator (). Provides such an object with a noop operator()
116 // because we don't need such macros to invoke the DoneCallback for the
117 // TRTEngineOp.
118 struct DummyAsyncHelper {
operator ()tensorflow::tensorrt::__anon58f8ceb50111::DummyAsyncHelper119   void operator()() {}
120 };
121 
122 // A helper class to call the DoneCallback for the TRTEngineOp when the object
123 // is destructed to support asynchronous of the native segment and TRT engines
124 // for the TRTEngineOp.
125 class AsyncHelper : public core::RefCounted {
126  public:
AsyncHelper(AsyncOpKernel::DoneCallback done)127   AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
128 
~AsyncHelper()129   ~AsyncHelper() override { done_(); }
130 
131  private:
132   AsyncOpKernel::DoneCallback done_;
133 };
134 
135 }  // end anonymous namespace
136 
137 //  This OP can construct TRTEngine on the fly and if construction of engine
138 //  fails, executes equivalent subgraph as a TensorFlow function.
139 class TRTEngineOp : public AsyncOpKernel {
140  public:
141   explicit TRTEngineOp(OpKernelConstruction* context);
142 
143   void ComputeAsync(OpKernelContext* context,
144                     AsyncOpKernel::DoneCallback done) override;
145 
146  private:
147   // Executes calibration asynchronously.
148   void ExecuteCalibration(OpKernelContext* ctx,
149                           TRTEngineCacheResource* cache_res,
150                           AsyncHelper* async_helper);
151 
152   // Constructs a function handle for the segment of the TRTEngineOp.
153   StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
154       FunctionLibraryRuntime* lib, const string& device_name,
155       bool allow_soft_placement = false, size_t num_inputs = 0,
156       size_t num_outputs = 0);
157 
158   // Imports the GraphDef for the segment of the TRTEngineOp to
159   // segment_graph_def_.
160   Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
161                                const string& device_name);
162 
163   // Executes the native segment as function Op  asynchronously.
164   void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* async_helper);
165 
166   // Allocates the device memory for the execution context and enqueues the
167   // TensorRT engine for execution. Also deallocates the device memory. Returns
168   // whether we need to retry by running the native segment.
169   Status ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context,
170                           int trt_context_idx,
171                           const TrtShapeOptimizationProfile& profiles,
172                           TRTBaseAllocator* allocator);
173 
174   // Allocates necessary resources for calibration.
175   Status AllocateCalibrationResources(OpKernelContext* ctx,
176                                       TRTEngineCacheResource* cache_res);
177 
178   Status GetEngineCacheResource(OpKernelContext* ctx,
179                                 TRTEngineCacheResource** cache_res);
180 
181   // Returns a pair of 1) An EngineContext object that is compatible with the
182   // input and 2) The index of the IExecutionContext compatible with the input.
183   // If a cuda engine for the given input shapes can't be found, returns
184   // (nullptr, 0) to allow native engine execution. Returns an error code for
185   // any problem that would prevent both TensorRT engine exceution and native
186   // segment execution.
187   StatusOr<std::pair<EngineContext*, int>> GetEngine(
188       const std::vector<TensorShape>& input_concrete_shapes,
189       OpKernelContext* ctx, TRTEngineCacheResource* cache_resource);
190 
191   // Builds and returns a cuda engine for the input shapes. If building the
192   // engine fails, enters a dummy entry into the cache_resource cache so we
193   // don't continually try to build the same failing engine.
194   StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> BuildEngine(
195       const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
196       bool use_calibration, TRTInt8Calibrator* calibrator,
197       TRTEngineCacheResource* cache_resource, OpKernelContext* ctx);
198 
199   // Verify that the input shapes are consistent and can be handled by this op.
200   Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
201 
202   std::vector<string> input_nodes_;
203   std::vector<string> output_nodes_;
204 
205   // serialized protobuf segment or trt engine depending on static_engine_ flag.
206   string serialized_segment_;
207 
208   // The function for TF native execution of the segment.
209   NameAttrList func_;
210 
211   // GraphDef representation of the segment.
212   GraphDef segment_graph_def_;
213 
214   // Engine Precision mode.
215   TrtPrecisionMode precision_mode_;
216 
217   // Whether engine is constructed during the conversion or needs to be
218   // constructed from protobuf segment.
219   bool static_engine_;
220 
221   // Whether to calibrate INT8 engine.
222   bool calibration_mode_;
223 
224   // Whether to use implicit batch dimension for TensorRT.
225   bool use_implicit_batch_;
226 
227   // Whether to collect optimization profiles for TensorRT, only used when
228   // use_implicit_batch_=false.
229   bool profile_generation_mode_;
230 
231   // Optimization profile generation strategy.
232   ProfileStrategy profile_strategy_;
233 
234   // Whether the TRTEngineOp has any input with unknown dimensions.
235   bool has_dynamic_shape_input_;
236 
237   // Whether to build TensorRT engines at runtime.
238   bool allow_build_at_runtime_;
239 
240   // Whether to allow soft placement when the graph is executed with native
241   // TensorFlow.
242   bool allow_soft_placement_;
243 
244   // Maximum number of cached engines.
245   int max_cached_engines_;
246 
247   int64 workspace_size_;
248   mutex engine_mutex_;
249   FunctionLibraryRuntime::Handle native_execution_func_handle_;
250 
251   // The finalized calibrator for inference.
252   std::unique_ptr<TRTInt8Calibrator> calibrator_;
253 
254   // If true, create calibration graph for INT8 mode. Otherwise, we are using
255   // user-provided quantization ranges.
256   bool use_calibration_;
257 
258   tensorflow::grappler::Cluster* cluster_;
259 
260   // Array of all input shapes, collected from the input_shapes attribute when
261   // constructing the TRTEngineOp. The input_shapes attribute is set during
262   // graph conversion time. This data is used to retrieve which input dimensions
263   // could be unknown. During inference time this information is not available
264   // otherwise (all shapes are known (concrete) shapes when we run inference).
265   std::vector<PartialTensorShape> input_partial_shapes_;
266   // Shapes, excluding resource inputs.
267   std::vector<PartialTensorShape> input_partial_shapes_filtered_;
268 
269   // The TF node can have more inputs than the TRT engine: resource inputs are
270   // saved as weight in the engine, instead of passing that as engine input.
271   // Input mask is true for those TF input that are TRT engine inputs.
272   std::vector<bool> input_mask_;
273 
274   // Whether to use explicit precision (QDQ) mode.
275   bool use_explicit_precision_;
276 };
277 
278 #define TYPECASE(dt, X, Y)                                    \
279   case dt: {                                                  \
280     return (void*)X->flat<EnumToDataType<dt>::Type>().data(); \
281   }
282 
GetTensorAddress(const Tensor * tensor_ptr)283 void* GetTensorAddress(const Tensor* tensor_ptr) {
284   auto tensor_type = tensor_ptr->dtype();
285   switch (tensor_type) {
286     TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
287     TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
288     TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
289     TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
290     default: {
291       LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
292       return nullptr;
293     }
294   }
295 }
296 
FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime * flib_runtime,GraphDef * graph_def)297 static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
298                                     FunctionLibraryRuntime* flib_runtime,
299                                     GraphDef* graph_def) {
300   const FunctionLibraryDefinition* flib_def =
301       flib_runtime->GetFunctionLibraryDefinition();
302   const FunctionBody* fbody;
303   fbody = flib_runtime->GetFunctionBody(handle);
304   if (!fbody) {
305     return errors::Internal(
306         "Function body is null when converting from FuncDef to GraphDef.");
307   }
308   std::unique_ptr<Graph> graph(new Graph(flib_def));
309   CopyGraph(*fbody->graph, graph.get());
310 
311   auto replace_name = [](const char* const prefix, string* name) {
312     if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
313       name->replace(0, strlen(prefix), prefix);
314       return true;
315     }
316     return false;
317   };
318   graph->ToGraphDef(graph_def);
319   // GraphToFunctionDef() will convert all the node names to lowercase.
320   for (auto& node : *graph_def->mutable_node()) {
321     if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
322       if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
323         // Instantiation of the function will append _RetVal to the node name,
324         // need to remove it for backward compatibility.
325         const char* const suffix_to_remove = "_RetVal";
326         if (absl::EndsWith(node.name(), suffix_to_remove)) {
327           node.mutable_name()->erase(node.name().size() -
328                                      strlen(suffix_to_remove));
329         }
330       }
331     }
332     for (auto& input : *node.mutable_input()) {
333       if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
334         replace_name(IONamePrefixes::kOutputPHName, &input);
335       }
336     }
337   }
338   return Status::OK();
339 }
340 
ConstructFunctionHandle(FunctionLibraryRuntime * lib,const string & device_name,bool allow_soft_placement,size_t num_inputs,size_t num_outputs)341 StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
342     FunctionLibraryRuntime* lib, const string& device_name,
343     bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
344   VLOG(1) << "Constructing function handle";
345   if (lib == nullptr) {
346     return errors::Internal("Context function library is null");
347   }
348   FunctionLibraryRuntime::InstantiateOptions inst_ops;
349   inst_ops.state_handle = "";
350   inst_ops.target = device_name;
351   if (allow_soft_placement) {
352     const FunctionDef* fdef =
353         lib->GetFunctionLibraryDefinition()->Find(func_.name());
354     if (!fdef) {
355       return errors::Internal(
356           StrCat("Can't find FunctionDef for ", func_.name()));
357     }
358     bool ints_on_device =
359         fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
360         fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
361     // kIntsOnDeviceAttr is not compatible with is_multi_device_function which
362     // is needed to support allow_soft_placement.
363     if (ints_on_device) {
364       LOG_FIRST_FEW_WARNING_WITH_PREFIX
365           << "Function " << name()
366           << " has attribute kIntsOnDeviceAttr=true "
367              "and will be executed natively with allow_soft_placement=false. "
368              "If this is a problem, please re-generate your SavedModel with "
369              "the TF-TRT runtime you are using.";
370     } else {
371       inst_ops.is_multi_device_function = true;
372       inst_ops.input_devices.resize(num_inputs, device_name);
373       inst_ops.output_devices.resize(num_outputs, device_name);
374       inst_ops.config_proto.set_allow_soft_placement(true);
375     }
376   }
377   FunctionLibraryRuntime::Handle func_handle;
378   Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
379                                    inst_ops, &func_handle);
380   if (status.ok()) {
381     return func_handle;
382   }
383   return status;
384 }
385 
ImportSegmentGraphDef(FunctionLibraryRuntime * lib,const string & device_name)386 Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
387                                           const string& device_name) {
388   TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
389                       ConstructFunctionHandle(lib, device_name));
390   return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
391 }
392 
TRTEngineOp(OpKernelConstruction * context)393 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
394     : AsyncOpKernel(context) {
395   // read serialized_engine
396   OP_REQUIRES_OK(context,
397                  context->GetAttr("serialized_segment", &serialized_segment_));
398   OP_REQUIRES_OK(context,
399                  context->GetAttr("workspace_size_bytes", &workspace_size_));
400   OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
401 
402   VLOG(1) << "Constructing " << name();
403   string precision_string;
404   OP_REQUIRES_OK(context,
405                  context->GetAttr("precision_mode", &precision_string));
406   string calibration_data;
407   OP_REQUIRES_OK(context,
408                  context->GetAttr("calibration_data", &calibration_data));
409   OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
410   OP_REQUIRES(context, !func_.name().empty(),
411               errors::InvalidArgument(
412                   "The TF function for the TRT segment could not be empty"));
413   OP_REQUIRES_OK(context,
414                  TrtPrecisionModeFromName(precision_string, &precision_mode_));
415   OP_REQUIRES_OK(context,
416                  context->GetAttr("use_calibration", &use_calibration_));
417   OP_REQUIRES_OK(context,
418                  context->GetAttr("input_shapes", &input_partial_shapes_));
419   auto status =
420       context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_);
421   if (status.code() == tensorflow::error::NOT_FOUND) {
422     VLOG(2) << "Not found _allow_build_at_runtime in "
423             << context->device()->name()
424             << ", thus setting _allow_build_at_runtime=true";
425     allow_build_at_runtime_ = true;
426   } else {
427     OP_REQUIRES_OK(context, status);
428   }
429 
430   // Get a mask of non-resource inputs.
431   std::vector<DataType> in_types;
432   input_mask_.resize(input_partial_shapes_.size());
433   OP_REQUIRES_OK(context, context->GetAttr("InT", &in_types));
434   for (int i = 0; i < input_mask_.size(); i++) {
435     input_mask_[i] = (in_types[i] != DataType::DT_RESOURCE);
436   }
437 
438   // Filter the shapes to exclude resources.
439   for (int i = 0; i < input_partial_shapes_.size(); i++) {
440     if (input_mask_[i]) {
441       input_partial_shapes_filtered_.push_back(input_partial_shapes_[i]);
442     }
443   }
444 
445   status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
446   if (status.code() == tensorflow::error::NOT_FOUND) {
447     allow_soft_placement_ = true;
448   } else {
449     OP_REQUIRES_OK(context, status);
450   }
451 
452   status = context->GetAttr("use_explicit_precision", &use_explicit_precision_);
453   if (!status.ok()) {
454     use_explicit_precision_ = false;
455   }
456 
457   native_execution_func_handle_ = kInvalidHandle;
458   if (!static_engine_) {
459     OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
460                                                   context->device()->name()));
461   }
462   // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
463   // backward compatibility reasons. Remove it once all known users switch to
464   // 2.0.
465   calibration_mode_ =
466       (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
467        calibration_data.empty());
468   if (!calibration_data.empty()) {
469     calibrator_.reset(new TRTInt8Calibrator(calibration_data));
470     calibration_data.resize(0);
471   }
472   OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
473                                            &max_cached_engines_));
474 
475   status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_);
476   if (status.code() == tensorflow::error::NOT_FOUND) {
477     VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name()
478             << ", thus setting _use_implicit_batch=true";
479     use_implicit_batch_ = true;
480   }
481 
482   status =
483       context->GetAttr("_profile_generation_mode", &profile_generation_mode_);
484   if (status.code() == tensorflow::error::NOT_FOUND) {
485     VLOG(2) << "Not found _profile_generation_mode in "
486             << context->device()->name()
487             << ", thus setting _profile_generation_mode=false";
488     profile_generation_mode_ = false;
489   }
490   if (static_engine_) {
491     if (profile_generation_mode_) profile_generation_mode_ = false;
492   }
493   if (use_implicit_batch_) {
494     OP_REQUIRES(context, !profile_generation_mode_,
495                 errors::InvalidArgument(
496                     "profile_generation_mode_=true is only supported if "
497                     "use_implicit_batch=false"));
498     if (input_partial_shapes_.empty()) {
499       VLOG(1) << "Attribute input_shapes is not set. This happens probably "
500               << "because you are using a model that is already converted "
501               << "to TensorRT with a previous version of TF-TRT (i.e. includes "
502               << "TRTEngineOp in graph). This is not an error. If you convert "
503               << "the original model again to TensorRT, the attributes "
504               << "input_shapes will be set automatically.";
505     }
506   } else {
507     OP_REQUIRES(
508         context, !input_partial_shapes_.empty(),
509         errors::InvalidArgument(
510             "Explicit batch mode requires attribute input_shapes to be set."
511             "If you are using a model that was converted to TensorRT by a "
512             "previous version of TF-TRT, (i.e. includes TRTEngineOp in graph "
513             "without the input_shapes attribute), then you need to convert the "
514             "original model again to TensorRT in order to set the attribute "
515             "input_shapes."));
516 
517     string profile_strategy_name;
518     status = context->GetAttr("profile_strategy", &profile_strategy_name);
519     if (status.code() == tensorflow::error::NOT_FOUND) {
520       VLOG(2) << "Not found strategy in " << context->device()->name()
521               << ", thus setting profile_strategy='Range'";
522       profile_strategy_ = ProfileStrategy::kRange;
523     } else {
524       OP_REQUIRES_OK(context, ProfileStrategyFromName(profile_strategy_name,
525                                                       &profile_strategy_));
526     }
527   }
528   has_dynamic_shape_input_ = absl::c_any_of(
529       input_partial_shapes_filtered_,
530       [](PartialTensorShape shape) { return !shape.IsFullyDefined(); });
531   VLOG(2) << "TRTEngineOp has_dynamic_shape_input_: "
532           << has_dynamic_shape_input_;
533 }
534 
535 // Copies input tensor ctx->input(i) (which is in device memory) to the host,
536 // and place the resulting host tensor to the back of native_inputs.
CopyToHostAsync(OpKernelContext * ctx,std::vector<Tensor> * native_inputs,int i,const cudaStream_t stream)537 Status CopyToHostAsync(OpKernelContext* ctx, std::vector<Tensor>* native_inputs,
538                        int i, const cudaStream_t stream) {
539   // The TRTEngineOp has all ctx->inputs on the device. In contrast, the
540   // native segment expects to find int32 inputs on the host. We copy int32
541   // inputs from device to host.
542 
543   AllocatorAttributes allocator_attr;
544   allocator_attr.set_on_host(true);
545   Tensor t;
546   TF_RETURN_IF_ERROR(ctx->allocate_temp(
547       ctx->input_dtype(i), ctx->input(i).shape(), &t, allocator_attr));
548   native_inputs->push_back(t);
549   const Tensor& gpu_tensor = ctx->input(i);
550   auto ret = cudaMemcpyAsync(
551       t.flat<int32>().data(), gpu_tensor.flat<int32>().data(),
552       t.NumElements() * sizeof(int32), cudaMemcpyDeviceToHost, stream);
553   if (ret != 0) {
554     return errors::Internal("Could not copy tensor for native segment input");
555   }
556   return Status::OK();
557 }
558 
559 // Copies native_tensor, which is in host memory to ctx->output(t), which is in
560 // device memory.
CopyToDeviceAsync(OpKernelContext * ctx,const Tensor & native_tensor,int t,cudaStream_t stream)561 Status CopyToDeviceAsync(OpKernelContext* ctx, const Tensor& native_tensor,
562                          int t, cudaStream_t stream) {
563   Tensor* gpu_tensor;
564   TF_RETURN_IF_ERROR(
565       ctx->allocate_output(t, native_tensor.shape(), &gpu_tensor));
566   auto ret = cudaMemcpyAsync(gpu_tensor->flat<int32>().data(),
567                              native_tensor.flat<int32>().data(),
568                              native_tensor.NumElements() * sizeof(int32),
569                              cudaMemcpyHostToDevice, stream);
570   if (ret != 0) {
571     return errors::Internal("Could not copy tensor for native segment output");
572   }
573   return Status::OK();
574 }
575 
ExecuteNativeSegment(OpKernelContext * ctx,AsyncHelper * async_helper)576 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
577                                        AsyncHelper* async_helper) {
578   tensorflow::profiler::TraceMe activity(
579       "TRTEngineOp::ExecuteNativeSegment",
580       tensorflow::profiler::TraceMeLevel::kInfo);
581   std::vector<Tensor> native_inputs;
582   std::vector<Tensor>* native_outputs = new std::vector<Tensor>();
583   DummyAsyncHelper dummy_async_helper;
584   if (native_execution_func_handle_ == kInvalidHandle) {
585     StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
586         ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
587                                 allow_soft_placement_, ctx->num_inputs(),
588                                 ctx->num_outputs());
589     OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), dummy_async_helper);
590     native_execution_func_handle_ = *status_or_handle;
591   }
592 
593   auto lib = ctx->function_library();
594   FunctionLibraryRuntime::Options opts;
595   opts.rendezvous = ctx->rendezvous();
596   opts.cancellation_manager = ctx->cancellation_manager();
597   opts.runner = ctx->runner();
598   native_inputs.reserve(ctx->num_inputs());
599   int n_copies = 0;
600   const cudaStream_t* stream = CHECK_NOTNULL(
601       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
602                                                 ->stream()
603                                                 ->implementation()
604                                                 ->GpuStreamMemberHack()));
605   for (int i = 0; i < ctx->num_inputs(); i++) {
606     if (ctx->input_dtype(i) != DT_INT32) {
607       native_inputs.push_back(ctx->input(i));
608     } else {
609       OP_REQUIRES_OK_ASYNC(ctx,
610                            CopyToHostAsync(ctx, &native_inputs, i, *stream),
611                            dummy_async_helper);
612       n_copies++;
613     }
614   }
615   if (n_copies > 0) {
616     // If we have any int32 tensors, then wait until data is copied to host.
617     cudaStreamSynchronize(*stream);
618   }
619   VLOG(1) << "Executing native segment: " << name();
620   // Increment the reference count of the async_helper by 1. When the native
621   // segment finishes execution asynchronously, we decrement the reference
622   // count of the object.
623   async_helper->Ref();
624   lib->Run(
625       opts, native_execution_func_handle_, native_inputs, native_outputs,
626       [this, ctx, native_outputs, async_helper, stream](const Status& s) {
627         core::ScopedUnref sc(async_helper);
628         DummyAsyncHelper dummy_async_helper;
629         std::unique_ptr<std::vector<Tensor>> outputs_wrapper(native_outputs);
630         OP_REQUIRES_OK_ASYNC(ctx, s, dummy_async_helper);
631         VLOG(1) << "Native Segment completed";
632         int n_copies = 0;
633         for (size_t t = 0; t < native_outputs->size(); ++t) {
634           if (native_outputs->at(t).dtype() == DT_INT32) {
635             OP_REQUIRES_OK_ASYNC(
636                 ctx, CopyToDeviceAsync(ctx, native_outputs->at(t), t, *stream),
637                 dummy_async_helper);
638             n_copies++;
639           } else {
640             ctx->set_output(t, native_outputs->at(t));
641           }
642         }
643         if (n_copies > 0) {
644           cudaStreamSynchronize(*stream);
645         }
646       });
647 }
648 
ExecuteCalibration(OpKernelContext * ctx,TRTEngineCacheResource * cache_res,AsyncHelper * async_helper)649 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
650                                      TRTEngineCacheResource* cache_res,
651                                      AsyncHelper* async_helper) {
652   tensorflow::profiler::TraceMe activity(
653       "TRTEngineOp::ExecuteCalibration",
654       tensorflow::profiler::TraceMeLevel::kInfo);
655   VLOG(1) << "Executing TRT calibration: " << name();
656   DummyAsyncHelper dummy_async_helper;
657 
658   CalibrationContext* calib_ctx = cache_res->calib_ctx_.get();
659   const int num_inputs = ctx->num_inputs();
660   // TODO(laigd): need to check that input shape matches.
661   // Pass input data to calibrator
662   std::unordered_map<string, void*> input_data;
663   for (int i = 0; i < num_inputs; i++) {
664     const Tensor& t = ctx->input(i);
665     void* data_address = GetTensorAddress(&t);
666     OP_REQUIRES_ASYNC(ctx, data_address,
667                       errors::InvalidArgument(
668                           "Unsupported data type encountered in input ", i),
669                       dummy_async_helper);
670     // Check the allocated buffer is sufficient for input
671     const auto device_tensor = &calib_ctx->device_tensors_.at(i);
672     CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
673     input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
674   }
675   VLOG(2) << "Filled map for sending";
676   // Copied from gpu_kernel_helper.h as the header can only be used in *.cu.cc
677   // files.
678   const cudaStream_t* stream = CHECK_NOTNULL(
679       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
680                                                 ->stream()
681                                                 ->implementation()
682                                                 ->GpuStreamMemberHack()));
683   // TRTInt8Calibrator::setBatch will wait until TRTInt8Calibrator::getBatch is
684   // called before proceeding with feeding the calibration data to the
685   // calibrator. It returns true if the calibration data is accepted and
686   // returns false if calibration is terminated due to errors.
687   //
688   // If TRTInt8Calibrator::getBatch is never called, which could happen if
689   // there is any problem in building the cuda engine for calibration inside
690   // TensorRT, then the TRTInt8Calibrator::setBatch call here will hang until
691   // TRTInt8Calibrator::setDone is called by the calibration thread in
692   // AllocateCalibrationResources.
693   //
694   // In both of the above cases, setBatch here returns a boolean value to
695   // indicate the result of the calibration process.
696   if (!calib_ctx->calibrator_->setBatch(input_data, *stream)) {
697     VLOG(2) << "Failed to feed calibration data";
698   } else {
699     VLOG(2) << "Passed calibration data";
700   }
701   ExecuteNativeSegment(ctx, async_helper);
702 }
703 
VerifyInputShapes(const std::vector<TensorShape> & input_concrete_shapes)704 Status TRTEngineOp::VerifyInputShapes(
705     const std::vector<TensorShape>& input_concrete_shapes) {
706   if (input_concrete_shapes.empty()) {
707     return errors::InvalidArgument("Input shapes are empty, for ", name());
708   }
709 
710   if (input_partial_shapes_filtered_.empty()) {
711     if (!use_implicit_batch_) {
712       return errors::InvalidArgument(
713           "Explicit batch mode requires input_partial_shapes_ ",
714           "to contain the dynamic input shapes to TRTEngineOp");
715     }
716     // If the graph was converted with an earlier version of TF-TRT, it can
717     // happen that the input_partial_shapes_ vector is not set (see
718     // input_shapes attribute handling in the TRTEngineOp constructor).
719     // In implicit batch mode it is allowed to have empty input_partial_shapes_,
720     // since it is only required in explicit batch mode (see the input_shapes
721     // attribute of ConvertGraphDefToEngine in TRTEngineOp::GetEngine.
722   } else {
723     // Additional consistency checks if input_partial_shapes_ is present.
724     const string error_msg = StrCat(
725         "Input shapes do not match input partial shapes stored in graph, for ",
726         name(), ": ", DebugString(input_concrete_shapes),
727         " != ", DebugString(input_partial_shapes_filtered_));
728     if (input_concrete_shapes.size() != input_partial_shapes_filtered_.size()) {
729       return errors::InvalidArgument(error_msg);
730     }
731     for (int i = 0; i < input_concrete_shapes.size(); i++) {
732       if (input_concrete_shapes[i].dims() !=
733           input_partial_shapes_filtered_[i].dims()) {
734         return errors::InvalidArgument(error_msg);
735       }
736     }
737     for (int i = 0; i < input_concrete_shapes.size(); i++) {
738       for (int d = 0; d < input_concrete_shapes[i].dims(); d++) {
739         if (input_partial_shapes_filtered_[i].dim_size(d) != -1) {
740           if (input_concrete_shapes[i].dim_size(d) !=
741               input_partial_shapes_filtered_[i].dim_size(d)) {
742             return errors::InvalidArgument(error_msg);
743           }
744         }
745       }
746     }
747   }
748 
749   if (use_implicit_batch_) {
750     if (input_concrete_shapes[0].dims() < 1) {
751       return errors::InvalidArgument(
752           "Input shapes contain scalar, for ", name(), ": ",
753           TensorShapeUtils::ShapeListString(input_concrete_shapes));
754     }
755     const int batch_size = input_concrete_shapes[0].dim_size(0);
756     if (batch_size < 1) {
757       return errors::InvalidArgument(
758           "Incorrect batch dimension, for ", name(), ": ",
759           TensorShapeUtils::ShapeListString(input_concrete_shapes));
760     }
761     for (const TensorShape& shape : input_concrete_shapes) {
762       if (batch_size != shape.dim_size(0)) {
763         return errors::InvalidArgument(
764             "Input shapes are inconsistent on the batch dimension, for ",
765             name(), ": ",
766             TensorShapeUtils::ShapeListString(input_concrete_shapes));
767       }
768     }
769   }
770   return Status::OK();
771 }
772 
AllowEngineNativeSegmentExecution()773 static bool AllowEngineNativeSegmentExecution() {
774   bool value;
775   Status status =
776       ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
777                          /*default_value=*/true, &value);
778   if (!status.ok()) {
779     LOG(ERROR) << status;
780   }
781   return value;
782 }
783 
ComputeAsync(OpKernelContext * ctx,AsyncOpKernel::DoneCallback done)784 void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
785                                AsyncOpKernel::DoneCallback done) {
786   tensorflow::profiler::TraceMe activity(
787       "TRTEngineOp::ComputeAsync", tensorflow::profiler::TraceMeLevel::kInfo);
788 
789   // Invoke DoneCallback when this object is destructed, which could be after
790   // this routine finishes execution, in particular, when native segment is
791   // executed.
792   auto async_helper = new AsyncHelper(done);
793   core::ScopedUnref sc(async_helper);
794 
795   // For all async execution macros, use this object as there is no need to call
796   // DoneCallback from those macros.
797   DummyAsyncHelper dummy_async_helper;
798 
799   // Get TRT resource.
800   TRTEngineCacheResource* cache_res = nullptr;
801   OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res),
802                        dummy_async_helper);
803   core::ScopedUnref unref_cache_res(cache_res);
804 
805   // Get shapes of inputs to engine.
806   std::vector<TensorShape> input_concrete_shapes;
807   input_concrete_shapes.reserve(ctx->num_inputs());
808   std::vector<TensorShape> input_concrete_shapes_filtered;
809   for (int i = 0; i < ctx->num_inputs(); ++i) {
810     input_concrete_shapes.push_back(ctx->input(i).shape());
811     if (ctx->input(i).dtype() != DataType::DT_RESOURCE) {
812       input_concrete_shapes_filtered.push_back(ctx->input(i).shape());
813     }
814   }
815 
816   /// TODO(lsugy): fix case of engine with only resource inputs.
817   Status verify_input_shape_status =
818       VerifyInputShapes(input_concrete_shapes_filtered);
819   // TODO(bixia): Fix the segmentation.
820   if (!verify_input_shape_status.ok()) {
821     LOG_FIRST_FEW_WARNING_WITH_PREFIX
822         << "Running native segment for" << name()
823         << " due to failure in verifying input shapes: "
824         << verify_input_shape_status.error_message();
825     ExecuteNativeSegment(ctx, async_helper);
826     return;
827   }
828 
829   if (!use_implicit_batch_ &&
830       (has_dynamic_shape_input_ || cache_res->profiles_.HasShapeTensor())) {
831     OP_REQUIRES_OK_ASYNC(ctx, cache_res->profiles_.CollectShapeValues(ctx),
832                          dummy_async_helper);
833     cache_res->profiles_.SetInputMask(input_mask_);
834     if (profile_generation_mode_) {
835       // Collecting new shapes for profiles can be only done once. After the
836       // shapes are converted to TRT profiles, no shapes can be collected
837       // anymore.
838       OP_REQUIRES_ASYNC(ctx, cache_res->profiles_.GetNumProfiles() == 0,
839                         errors::Unimplemented("Cannot collect new shapes when "
840                                               "profiles are already created."),
841                         dummy_async_helper);
842       // Just collect the input shape info and return. The shapes are used to
843       // generate optimization profiles during engine creation.
844       cache_res->profiles_.AddShape(input_concrete_shapes);
845       VLOG(1) << "Native segment is used during collecting shapes for profiles";
846       ExecuteNativeSegment(ctx, async_helper);
847       return;
848     } else if (cache_res->profiles_.GetNumProfiles() == 0 && !static_engine_) {
849       // Add current shape if we did not collect any shapes so far.
850       if (!cache_res->profiles_.HasShape()) {
851         cache_res->profiles_.AddShape(input_concrete_shapes);
852       }
853       // Create profiles out of collected shapes during profile generation.
854       cache_res->profiles_.InitProfiles(input_partial_shapes_,
855                                         profile_strategy_);
856     }
857   }
858 
859   // Run calibration if in int8+calibration mode.
860   // * Logic in TF 1.x:
861   //   - During conversion: calibration_mode_ is true and cache size is 0, so it
862   //     will run calibration.
863   //   - During inference: calibration_data will be set, so calibration_mode_
864   //     is false and it won't trigger calibration.
865   // * Logic in TF 2.0:
866   //   - During conversion: similar to 1.x.
867   //   - During inference: calibration_data will still be empty, but cache will
868   //     contain the calibrated engine, so it won't trigger calibration.
869   //
870   // TODO(laigd): consider the following alternatives:
871   // 1. Serialize the state (calibration or inference) using
872   //    TRTEngineInstance proto (or a new proto), so we know which mode we're
873   //    in and don't run calibration during inference (which is invalid).
874   // 2. Reuse the calibration_data attribute or use a new attribute in the
875   //    NodeDef to indicate whether it's in calibration mode.
876   if (calibration_mode_ && cache_res->cache_.size() == 0) {
877     if (!cache_res->calib_ctx_) {
878       // TODO(laigd): better encapsulation.
879       mutex_lock lock(engine_mutex_);
880       if (!cache_res->calib_ctx_) {
881         // Add profiles if we are in dynamic shape mode.
882         if (!use_implicit_batch_ && (has_dynamic_shape_input_ ||
883                                      cache_res->profiles_.HasShapeTensor())) {
884           cache_res->profiles_.InitCalibProfile(input_concrete_shapes);
885         }
886         OP_REQUIRES_OK_ASYNC(ctx, AllocateCalibrationResources(ctx, cache_res),
887                              dummy_async_helper);
888       }
889     }
890     // TODO(laigd): check that the input shapes match the shapes of the
891     // persistent tensor in the calibration resource.
892     ExecuteCalibration(ctx, cache_res, async_helper);
893     return;
894   }
895 
896   StatusOr<std::pair<EngineContext*, int>> status =
897       GetEngine(input_concrete_shapes, ctx, cache_res);
898   OP_REQUIRES_OK_ASYNC(ctx, status.status(), dummy_async_helper);
899 
900   EngineContext* engine_context = status.ValueOrDie().first;
901   int trt_context_idx = status.ValueOrDie().second;
902   auto may_execute_native_segment = [&] {
903     if (!AllowEngineNativeSegmentExecution()) {
904       ctx->CtxFailure(
905           errors::Aborted("User disallowed engine native segment execution"));
906       return false;
907     }
908     return true;
909   };
910   if (!engine_context->GetCudaEngine()) {
911     LOG_FIRST_FEW_WARNING_WITH_PREFIX
912         << "Engine retrieval for input shapes: "
913         << TensorShapeUtils::ShapeListString(input_concrete_shapes)
914         << " failed. Running native segment for " << name();
915     if (may_execute_native_segment()) {
916       ExecuteNativeSegment(ctx, async_helper);
917     }
918     return;
919   }
920   Status stat =
921       ExecuteTrtEngine(ctx, engine_context, trt_context_idx,
922                        cache_res->profiles_, cache_res->allocator_.get());
923   if (stat.ok()) return;
924 
925   LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
926                                     << " Retrying with native segment for "
927                                     << name();
928   if (!may_execute_native_segment()) {
929     return;
930   }
931   // Release any outputs that are allocated, ExecuteNativeSegment will
932   // re-allocate them and fail if they are currently allocated.
933   // The Tensor pointer in the returned TensorValue must be explicitly
934   // deleted.
935   for (int i = 0; i < ctx->num_outputs(); i++) {
936     delete ctx->release_output(i).tensor;
937   }
938   ExecuteNativeSegment(ctx, async_helper);
939 }
940 
ExecuteTrtEngine(OpKernelContext * ctx,EngineContext * engine_context,int trt_context_idx,const TrtShapeOptimizationProfile & profiles,TRTBaseAllocator * allocator)941 Status TRTEngineOp::ExecuteTrtEngine(
942     OpKernelContext* ctx, EngineContext* engine_context, int trt_context_idx,
943     const TrtShapeOptimizationProfile& profiles, TRTBaseAllocator* allocator) {
944   tensorflow::profiler::TraceMe activity(
945       "TRTEngineOp::ExecuteTrtEngine",
946       tensorflow::profiler::TraceMeLevel::kInfo);
947   VLOG(1) << "Executing TRT engine: " << name();
948   nvinfer1::ICudaEngine* cuda_engine = engine_context->GetCudaEngine();
949 
950   if (VLOG_IS_ON(2)) {
951     VLOG(2) << "  Network name: " << cuda_engine->getName();
952     VLOG(2) << "  Activation size: " << engine_context->GetDeviceMemorySize()
953             << " bytes";
954 #if !IS_TRT_VERSION_GE(8, 0, 0, 0)
955     // getWorkspaceSize() is deprecated as of TRT 8
956     VLOG(2) << "  Workspace size: " << cuda_engine->getWorkspaceSize()
957             << " bytes";
958 #endif  // #if !IS_TRT_VERSION_GE(8, 0, 0, 0)
959     VLOG(2) << "  Datatype of " << cuda_engine->getNbBindings()
960             << " inputs/outputs";
961     string binding_types = "";
962     for (int i = 0; i < cuda_engine->getNbBindings(); i++) {
963       binding_types += "    " + string(cuda_engine->getBindingName(i)) + ": " +
964                        DebugString(cuda_engine->getBindingDataType(i)) + "\n";
965     }
966     VLOG(2) << binding_types;
967   }
968 
969   const int num_binding = cuda_engine->getNbBindings();
970   std::vector<void*> buffers(num_binding);
971 
972   // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex
973   // for it.
974   mutex_lock lock(engine_context->mu);
975   nvinfer1::IExecutionContext* execution_context;
976   bool has_device_memory;
977   TF_RETURN_IF_ERROR(engine_context->GetExecutionContext(
978       trt_context_idx, &execution_context, &has_device_memory));
979 
980   if (VLOG_IS_ON(2)) {
981     VLOG(2) << "Selected execution context: " << trt_context_idx;
982   }
983   const int num_batch =
984       use_implicit_batch_ ? ctx->input(0).shape().dim_size(0) : 0;
985 
986   TF_RETURN_IF_ERROR(SetTrtEngineInputs(
987       cuda_engine, execution_context, trt_context_idx, buffers,
988       use_implicit_batch_, num_batch, profiles, ctx));
989 
990   TF_RETURN_IF_ERROR(SetTrtEngineOutputs(cuda_engine, execution_context,
991                                          trt_context_idx, buffers,
992                                          use_implicit_batch_, num_batch, ctx));
993 
994   // Copied from gpu_kernel_helper.h as the header can only be used in *.cu.cc
995   // files.
996   const cudaStream_t* stream = CHECK_NOTNULL(
997       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
998                                                 ->stream()
999                                                 ->implementation()
1000                                                 ->GpuStreamMemberHack()));
1001 
1002   ContextDeviceMemory context_device_memory;
1003   if (!has_device_memory) {
1004     tensorflow::profiler::TraceMe activity(
1005         "TRTEngineOp::AllocateDeviceMemory",
1006         tensorflow::profiler::TraceMeLevel::kInfo);
1007     // Allocate device memory for the TensorRT engine execution. The device
1008     // memory will be released when context_device_memory goes out of scope.
1009     TF_RETURN_IF_ERROR(context_device_memory.AllocateDeviceMemory(
1010         execution_context, allocator, engine_context->GetDeviceMemorySize()));
1011   }
1012   // Enqueue the TensorRT engine for execution.
1013   return TrtEnqueue(execution_context, buffers, *stream, use_implicit_batch_,
1014                     num_batch);
1015 }
1016 
GetEngineCacheResource(OpKernelContext * ctx,TRTEngineCacheResource ** cache_res)1017 Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
1018                                            TRTEngineCacheResource** cache_res) {
1019   tensorflow::profiler::TraceMe activity(
1020       "TRTEngineOp::GetEngineCachResource",
1021       tensorflow::profiler::TraceMeLevel::kInfo);
1022   // Canonicalize the op name by removing the scopes if any. This is mainly
1023   // because in TFv2, the function graph can be instantiated in various ways and
1024   // it'll insert scope names to the name of the TRTEngineOps, which will result
1025   // in many different engine caches if we use the instantiated op name
1026   // directly, but we still want all of them share the same cache (if they were
1027   // representing the same subgraph).
1028   absl::string_view resource_name = name();
1029   size_t last_slash = resource_name.find_last_of('/');
1030   if (last_slash != absl::string_view::npos) {
1031     resource_name.remove_prefix(last_slash + 1);
1032   }
1033 
1034   // Get engine cache.
1035   return ctx->resource_manager()->LookupOrCreate(
1036       std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
1037       {[this, ctx](TRTEngineCacheResource** cr) -> Status {
1038         *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
1039         return Status::OK();
1040       }});
1041 }
1042 
BuildEngine(const std::vector<TensorShape> & input_concrete_shapes,int batch_size,bool use_calibration,TRTInt8Calibrator * calibrator,TRTEngineCacheResource * cache_resource,OpKernelContext * ctx)1043 StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> TRTEngineOp::BuildEngine(
1044     const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
1045     bool use_calibration, TRTInt8Calibrator* calibrator,
1046     TRTEngineCacheResource* cache_resource, OpKernelContext* ctx) {
1047   TRT_ENSURE(cache_resource);
1048   TRT_ENSURE(ctx);
1049   // Use concrete shapes for implicit batch mode and partial shapes for
1050   // explicit batch mode.
1051   bool use_concrete_shapes =
1052       use_implicit_batch_ || cache_resource->profiles_.IsStaticCompatible();
1053   const std::vector<PartialTensorShape>& conversion_input_shapes =
1054       use_concrete_shapes
1055           ? std::vector<PartialTensorShape>(input_concrete_shapes.begin(),
1056                                             input_concrete_shapes.end())
1057           : input_partial_shapes_;
1058 
1059   VLOG(1) << "Building a new TensorRT engine for " << name()
1060           << " with input shapes: " << DebugString(conversion_input_shapes);
1061 
1062   std::unordered_map<string, tensorflow::DeviceProperties> device_map;
1063   DeviceNameUtils::ParsedName full_parsed_name;
1064   DeviceNameUtils::ParseFullName(ctx->device()->name(), &full_parsed_name);
1065   device_map.emplace(ctx->device()->name(),
1066                      grappler::GetDeviceInfo(full_parsed_name));
1067   tensorflow::grappler::VirtualCluster cluster(device_map);
1068 
1069   TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
1070   auto status = convert::ConvertGraphDefToEngine(
1071       segment_graph_def_, ctx, precision_mode_, batch_size, workspace_size_,
1072       conversion_input_shapes, &logger, cache_resource->allocator_.get(),
1073       calibrator, &engine, use_calibration, use_implicit_batch_, nullptr,
1074       &cache_resource->profiles_, name(), use_explicit_precision_, &cluster);
1075   if (!status.ok()) {
1076     LOG_FIRST_FEW_WARNING_WITH_PREFIX
1077         << "Engine creation for " << name() << " failed. "
1078         << "The native segment will be used instead. "
1079         << "Reason: " << status;
1080     // Store an empty engine in the cache for these input shapes so we don't try
1081     // to build the same failing engine again.
1082     cache_resource->cache_.emplace(input_concrete_shapes,
1083                                    std::make_unique<EngineContext>());
1084     return status;
1085   }
1086   return engine;
1087 }
1088 
GetEngine(const std::vector<TensorShape> & input_concrete_shapes,OpKernelContext * ctx,TRTEngineCacheResource * cache_res)1089 StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
1090     const std::vector<TensorShape>& input_concrete_shapes, OpKernelContext* ctx,
1091     TRTEngineCacheResource* cache_res) {
1092   static EngineContext empty_context;
1093   tensorflow::profiler::TraceMe activity(
1094       "TRTEngineOp::GetEngine", tensorflow::profiler::TraceMeLevel::kInfo);
1095   mutex_lock lock(engine_mutex_);
1096   // Using first input to get batch size is reliable - VerifyInputShapes()
1097   // guarantees that the first input is not a scalar. As such we can always use
1098   // the first input to get the batch size for implicit batch mode. For explicit
1099   // batch mode, this value is not used.
1100   const int batch_size = input_concrete_shapes[0].dim_size(0);
1101   // TODO(Tamas): remove the need for batch_size in explicit_batch mode
1102   auto& cache = cache_res->cache_;
1103   auto allocator = cache_res->allocator_.get();
1104   if (allocator == nullptr) {
1105     return std::pair<EngineContext*, int>(&empty_context, 0);
1106   }
1107 
1108   // Handle the static engine case. For static engines, the cache will have a
1109   // single element containing the only engine.
1110   if (static_engine_) {
1111     if (cache.size()) {
1112       // TODO(laigd): need a better shape compatibility check for the case where
1113       // implicit batch is disabled.
1114       if (!use_implicit_batch_ ||
1115           AreShapesCompatible(input_concrete_shapes, cache.begin()->first)) {
1116         int profile_id = 0;
1117         if (!use_implicit_batch_)
1118           profile_id =
1119               cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1120         if (profile_id != -1) {
1121           return std::pair<EngineContext*, int>(cache.begin()->second.get(),
1122                                                 profile_id);
1123         }
1124       }
1125       return std::pair<EngineContext*, int>(&empty_context, 0);
1126     }
1127 
1128     TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
1129     infer->setGpuAllocator(allocator);
1130     // Need to initialize plugins in order to deserialize engines that contain
1131     // plugins.
1132     MaybeInitializeTrtPlugins(&logger);
1133     TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
1134         infer->deserializeCudaEngine(serialized_segment_.c_str(),
1135                                      serialized_segment_.size(), nullptr));
1136     int profile_id = 0;
1137     if (static_engine && !use_implicit_batch_) {
1138       // load profiles
1139       std::vector<ExecutionContext> exec_contexts;
1140       TF_RETURN_IF_ERROR(cache_res->profiles_.RestoreProfiles(
1141           static_engine.get(), ctx->num_inputs()));
1142       TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
1143           static_engine.get(), &exec_contexts));
1144       cache.emplace(input_concrete_shapes,
1145                     std::make_unique<EngineContext>(std::move(static_engine),
1146                                                     std::move(exec_contexts)));
1147       VLOG(1) << "Added new engine to cache of " << name()
1148               << ". Cache size: " << cache.size();
1149       // Query which profile of the new engine matches the actual input.
1150       profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1151       if (profile_id == -1) {
1152         return std::pair<EngineContext*, int>(&empty_context, 0);
1153       }
1154       EngineContext* engine_context = cache_res->GetEngineContext(profile_id);
1155       return std::pair<EngineContext*, int>(engine_context, profile_id);
1156     }
1157 
1158     if (!static_engine) {
1159       if (!allow_build_at_runtime_) {
1160         // Store an empty engine in the cache so we don't try to load the same
1161         // failing engine again.
1162         cache.emplace(input_concrete_shapes, std::make_unique<EngineContext>());
1163         return std::pair<EngineContext*, int>(&empty_context, 0);
1164       }
1165       if (segment_graph_def_.node().empty()) {
1166         Status status = ImportSegmentGraphDef(ctx->function_library(),
1167                                               ctx->device()->name());
1168         if (!status.ok()) {
1169           LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
1170                                             << name() << " failed. "
1171                                             << "Reason: " << status;
1172         }
1173       }
1174       auto result = BuildEngine(input_concrete_shapes, batch_size,
1175                                 /*use_calibration=*/false,
1176                                 /*calibrator=*/nullptr, cache_res, ctx);
1177       if (!result.ok()) {
1178         return std::pair<EngineContext*, int>(&empty_context, 0);
1179       }
1180       static_engine = std::move(result.ValueOrDie());
1181     }
1182 
1183     auto raw_static_engine = static_engine.get();
1184     std::vector<TensorShape> engine_input_shapes(input_concrete_shapes);
1185 
1186     int max_batch_size = 1;
1187     if (use_implicit_batch_) {
1188       max_batch_size = raw_static_engine->getMaxBatchSize();
1189       // Static engine will have max_batch_size for batch size so that all
1190       // inputs will map to this single engine.
1191       for (int i = 0; i < engine_input_shapes.size(); i++) {
1192         engine_input_shapes[i].set_dim(0, max_batch_size);
1193       }
1194     }
1195 
1196     ExecutionContext context = ExecutionContext::Create(raw_static_engine);
1197     // TODO(laigd): here we assume engine_input_shapes matches the actual input
1198     // shapes of the engine, we should verify that.
1199     cache.emplace(engine_input_shapes,
1200                   std::make_unique<EngineContext>(std::move(static_engine),
1201                                                   std::move(context)));
1202     // Runtime is safe to delete after engine creation
1203     VLOG(1) << "Size of serialized TRT engine: "
1204             << serialized_segment_.capacity();
1205     string tmp;
1206     // Swap with temporary empty string to deallocate the CPU memory.
1207     serialized_segment_.swap(tmp);
1208     if (use_implicit_batch_ && (max_batch_size < batch_size)) {
1209       return std::pair<EngineContext*, int>(&empty_context, 0);
1210     }
1211     return std::pair<EngineContext*, int>(cache.at(engine_input_shapes).get(),
1212                                           0);
1213   }  // static_engine_
1214 
1215   int profile_id = -1;
1216   if (!use_implicit_batch_) {
1217     profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1218     // Since all profiles are already created at this point, finding no
1219     // compatible profiles results in falling back to native TF.
1220     if (profile_id == -1) {
1221       return std::pair<EngineContext*, int>(&empty_context, 0);
1222     }
1223   }
1224 
1225   EngineContext* engine_contexts;
1226   if (use_implicit_batch_) {
1227     engine_contexts = cache_res->GetEngineContext(input_concrete_shapes);
1228   } else {
1229     engine_contexts = cache_res->GetEngineContext(profile_id);
1230   }
1231 
1232   // If cache does not have a compatible engine then create a new engine.
1233   if (engine_contexts == nullptr) {
1234     if (!allow_build_at_runtime_) {
1235       LOG_FIRST_FEW_WARNING_WITH_PREFIX
1236           << "Found no engine in cache matching input shapes. "
1237           << "Not building a new engine because "
1238           << "allow_build_at_runtime=False. "
1239           << "The native segment will be used instead.";
1240       // Store an empty engine in the cache for these input shapes so we don't
1241       // try to build the same failing engine again.
1242       cache.emplace(input_concrete_shapes, std::make_unique<EngineContext>());
1243       return std::pair<EngineContext*, int>(&empty_context, 0);
1244     }
1245 
1246     // Up to this point, calibrator_ can never be empty, since otherwise it
1247     // means calibration_mode_ is true and this path won't get executed.
1248     auto result =
1249         BuildEngine(input_concrete_shapes, batch_size, use_calibration_,
1250                     calibrator_.get(), cache_res, ctx);
1251     if (!result.ok()) {
1252       return std::pair<EngineContext*, int>(&empty_context, 0);
1253     }
1254     TrtUniquePtrType<nvinfer1::ICudaEngine> engine =
1255         std::move(result.ValueOrDie());
1256     std::vector<ExecutionContext> exec_contexts;
1257     TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
1258         engine.get(), &exec_contexts));
1259     cache.emplace(input_concrete_shapes,
1260                   std::make_unique<EngineContext>(std::move(engine),
1261                                                   std::move(exec_contexts)));
1262     VLOG(1) << "Added new engine to cache of " << name()
1263             << ". Cache size: " << cache.size();
1264     engine_contexts = cache.at(input_concrete_shapes).get();
1265     // Query which profile of the new engine matches the actual input.
1266     profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1267   }
1268   return std::pair<EngineContext*, int>(engine_contexts,
1269                                         use_implicit_batch_ ? 0 : profile_id);
1270 }
1271 
1272 // TODO(hinsu): Move this allocation to CalibrationContext constructor, if
1273 // possible.
AllocateCalibrationResources(OpKernelContext * ctx,TRTEngineCacheResource * cache_res)1274 Status TRTEngineOp::AllocateCalibrationResources(
1275     OpKernelContext* ctx, TRTEngineCacheResource* cache_res) {
1276   cache_res->calib_ctx_ = std::make_unique<CalibrationContext>();
1277   auto* cres = cache_res->calib_ctx_.get();
1278 
1279   // Get the input shapes.
1280   /// TODO(lsugy): support INT8 calibration in non-frozen mode.
1281   const int batch_size = ctx->input(0).dim_size(0);
1282   const int num_inputs = ctx->num_inputs();
1283   std::vector<TensorShape> shapes;
1284   cres->device_tensors_.resize(num_inputs);
1285   VLOG(1) << "Constructing calibrator";
1286   for (int i = 0; i < num_inputs; i++) {
1287     // allocate workspace on device for inputs
1288     const Tensor& t = ctx->input(i);
1289     shapes.emplace_back(t.shape());
1290     TF_RETURN_IF_ERROR(
1291         ctx->allocate_temp(t.dtype(), t.shape(), &cres->device_tensors_.at(i)));
1292     CHECK_EQ(t.TotalBytes(),  // Crash OK
1293              (cres->device_tensors_.at(i)).TotalBytes());
1294     void* device_address = GetTensorAddress(&cres->device_tensors_.at(i));
1295     if (device_address == nullptr) {
1296       return errors::InvalidArgument(
1297           "Unsupported data type encountered in input ", i);
1298     }
1299     cres->device_buffers_.emplace(
1300         StrCat(IONamePrefixes::kInputPHName, i),
1301         std::pair<void*, size_t>(device_address,
1302                                  cres->device_tensors_.at(i).TotalBytes()));
1303   }
1304   cres->calibrator_.reset(
1305       new TRTInt8Calibrator(cres->device_buffers_, batch_size, name()));
1306   const int platform_device_id =
1307       ctx->device()->tensorflow_accelerator_device_info()->gpu_id;
1308   if (platform_device_id < 0) {
1309     LOG(ERROR) << "Can't get gpu_device_info from context->device()";
1310     return errors::InvalidArgument(
1311         "Context->device doesn't contain device info!");
1312   }
1313 
1314   bool use_concrete_shapes =
1315       use_implicit_batch_ || cache_res->profiles_.IsStaticCompatible();
1316   const std::vector<PartialTensorShape>& conversion_input_shapes =
1317       use_concrete_shapes
1318           ? std::vector<PartialTensorShape>(shapes.begin(), shapes.end())
1319           : input_partial_shapes_;
1320 
1321   cache_res->Ref();
1322   string platform_device_name = ctx->device()->name();
1323   cres->thr_.reset(new std::thread([this, cres, shapes, conversion_input_shapes,
1324                                     platform_device_id, platform_device_name,
1325                                     cache_res, ctx]() {
1326     core::ScopedUnref sc(cache_res);
1327 
1328     VLOG(1) << "Starting calibration thread on device " << platform_device_id
1329             << ", Calibration Resource @ " << cres;
1330     auto err = cudaSetDevice(platform_device_id);
1331     if (err != cudaSuccess) {
1332       // TODO(aaroey): should return error here.
1333       LOG(ERROR) << "Couldn't set cuda device to " << platform_device_id
1334                  << " in calibration thread";
1335     }
1336 
1337     std::unordered_map<string, tensorflow::DeviceProperties> device_map;
1338     DeviceNameUtils::ParsedName full_parsed_name;
1339     DeviceNameUtils::ParseFullName(platform_device_name, &full_parsed_name);
1340     device_map.emplace(platform_device_name,
1341                        grappler::GetDeviceInfo(full_parsed_name));
1342     tensorflow::grappler::VirtualCluster cluster(device_map);
1343 
1344     // ConvertGraphDefToEngine() will try to build the engine. This thread
1345     // will loop inside buildCudaEngine() consuming the calibration data
1346     // that is set by the TF op, and drive the builder until calibrator
1347     // returns false. Engine is discarded after calibration table is
1348     // generated
1349     //
1350     // TODO(aaroey): maybe setting the max batch size using the python
1351     // calibration wrapper class.
1352     auto s = convert::ConvertGraphDefToEngine(
1353         this->segment_graph_def_, ctx, TrtPrecisionMode::INT8,
1354         cres->calibrator_->getBatchSize(), this->workspace_size_,
1355         conversion_input_shapes, &cache_res->GetLogger(),
1356         cache_res->allocator_.get(), cres->calibrator_.get(), &cres->engine_,
1357         /*use_calibration=*/true, this->use_implicit_batch_,
1358         /*convert_successfully=*/nullptr,
1359         /*profiles=*/&cache_res->profiles_, name(),
1360         /*use_explicit_precision=*/use_explicit_precision_,
1361         /*cluster=*/&cluster);
1362     if (!s.ok()) {
1363       LOG(ERROR) << "Calibration failed: " << s;
1364       cres->calibrator_->setDone();  // Ignore further pushes
1365       cache_res->cache_.emplace(shapes, std::make_unique<EngineContext>());
1366     } else {
1367       // Transfer the ownership of the engine to the engine cache, so we can
1368       // dump it out during conversion for TF 2.0.
1369       mutex_lock lock(this->engine_mutex_);
1370       this->calibrator_ = std::move(cres->calibrator_);
1371       if (!use_implicit_batch_ &&
1372           (has_dynamic_shape_input_ || cache_res->profiles_.HasShapeTensor())) {
1373         std::vector<ExecutionContext> exec_contexts;
1374         auto calib_result = cache_res->profiles_.CreateExecutionContexts(
1375             cres->engine_.get(), &exec_contexts);
1376         cache_res->cache_.emplace(
1377             shapes, std::make_unique<EngineContext>(std::move(cres->engine_),
1378                                                     std::move(exec_contexts)));
1379       } else {
1380         ExecutionContext context =
1381             ExecutionContext::Create(cres->engine_.get());
1382         cache_res->cache_.emplace(
1383             shapes, std::make_unique<EngineContext>(std::move(cres->engine_),
1384                                                     std::move(context)));
1385       }
1386     }
1387 
1388     VLOG(1) << "Calibration loop terminated " << this->name();
1389   }));
1390   VLOG(1) << "initialized calibrator resource";
1391   return Status::OK();
1392 }
1393 
1394 REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
1395 
1396 }  // namespace tensorrt
1397 }  // namespace tensorflow
1398 
1399 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
1400