1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <vector>
18
19 #include "absl/memory/memory.h"
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/graph_optimizer.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/graph_to_functiondef.h"
36 #include "tensorflow/core/framework/node_def_builder.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/grappler/clusters/utils.h"
41 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
42 #include "tensorflow/core/lib/core/refcount.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/stream_executor.h"
48 #include "tensorflow/core/platform/thread_annotations.h"
49 #include "tensorflow/core/platform/types.h"
50 #include "tensorflow/core/profiler/lib/traceme.h"
51 #include "tensorflow/core/util/env_var.h"
52
53 #if GOOGLE_CUDA && GOOGLE_TENSORRT
54 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
55 #include "third_party/tensorrt/NvInfer.h"
56
57 namespace tensorflow {
58 namespace tensorrt {
59 namespace {
60 Logger& logger = *Logger::GetLogger();
61 using absl::StrAppend;
62 using absl::StrCat;
63 using ::nvinfer1::IRuntime;
64
65 #define LOG_FIRST_FEW_WARNING_WITH_PREFIX \
66 LOG_FIRST_N(WARNING, 5) << "TF-TRT Warning: "
67
68 // Allocates device memory for an execution context to execute a TensorRT
69 // engine and records the relevant information for deallocating the memory when
70 // the engine finishes execution.
71 class ContextDeviceMemory {
72 public:
ContextDeviceMemory()73 ContextDeviceMemory()
74 : execution_context_(nullptr),
75 device_memory_allocator_(nullptr),
76 device_memory_(nullptr) {}
77
~ContextDeviceMemory()78 ~ContextDeviceMemory() {
79 if (device_memory_) {
80 device_memory_allocator_->free(device_memory_);
81 }
82 }
83
AllocateDeviceMemory(nvinfer1::IExecutionContext * execution_context,TRTBaseAllocator * device_memory_allocator,size_t device_memory_size)84 Status AllocateDeviceMemory(nvinfer1::IExecutionContext* execution_context,
85 TRTBaseAllocator* device_memory_allocator,
86 size_t device_memory_size) {
87 execution_context_ = execution_context;
88 device_memory_allocator_ = device_memory_allocator;
89 device_memory_ = nullptr;
90 VLOG(2) << "Device memory size for TensorRT engine " << device_memory_size;
91 if (device_memory_size > 0) {
92 device_memory_ = device_memory_allocator_->allocate(
93 device_memory_size,
94 /*unused alignment=*/0, /*flags=*/0);
95 if (device_memory_ == nullptr) {
96 return errors::InvalidArgument(
97 "Out of GPU memory for execution context");
98 }
99 }
100 {
101 tensorflow::profiler::TraceMe activity(
102 "setDeviceMemory", tensorflow::profiler::TraceMeLevel::kInfo);
103 execution_context_->setDeviceMemory(device_memory_);
104 }
105 return Status::OK();
106 }
107
108 private:
109 nvinfer1::IExecutionContext* execution_context_;
110 TRTBaseAllocator* device_memory_allocator_;
111 void* device_memory_;
112 };
113
114 // Macros for asynchronous execution, such as OP_REQUIRES_OK_ASYNC requires an
115 // object with operator (). Provides such an object with a noop operator()
116 // because we don't need such macros to invoke the DoneCallback for the
117 // TRTEngineOp.
118 struct DummyAsyncHelper {
operator ()tensorflow::tensorrt::__anon58f8ceb50111::DummyAsyncHelper119 void operator()() {}
120 };
121
122 // A helper class to call the DoneCallback for the TRTEngineOp when the object
123 // is destructed to support asynchronous of the native segment and TRT engines
124 // for the TRTEngineOp.
125 class AsyncHelper : public core::RefCounted {
126 public:
AsyncHelper(AsyncOpKernel::DoneCallback done)127 AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
128
~AsyncHelper()129 ~AsyncHelper() override { done_(); }
130
131 private:
132 AsyncOpKernel::DoneCallback done_;
133 };
134
135 } // end anonymous namespace
136
137 // This OP can construct TRTEngine on the fly and if construction of engine
138 // fails, executes equivalent subgraph as a TensorFlow function.
139 class TRTEngineOp : public AsyncOpKernel {
140 public:
141 explicit TRTEngineOp(OpKernelConstruction* context);
142
143 void ComputeAsync(OpKernelContext* context,
144 AsyncOpKernel::DoneCallback done) override;
145
146 private:
147 // Executes calibration asynchronously.
148 void ExecuteCalibration(OpKernelContext* ctx,
149 TRTEngineCacheResource* cache_res,
150 AsyncHelper* async_helper);
151
152 // Constructs a function handle for the segment of the TRTEngineOp.
153 StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
154 FunctionLibraryRuntime* lib, const string& device_name,
155 bool allow_soft_placement = false, size_t num_inputs = 0,
156 size_t num_outputs = 0);
157
158 // Imports the GraphDef for the segment of the TRTEngineOp to
159 // segment_graph_def_.
160 Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
161 const string& device_name);
162
163 // Executes the native segment as function Op asynchronously.
164 void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* async_helper);
165
166 // Allocates the device memory for the execution context and enqueues the
167 // TensorRT engine for execution. Also deallocates the device memory. Returns
168 // whether we need to retry by running the native segment.
169 Status ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context,
170 int trt_context_idx,
171 const TrtShapeOptimizationProfile& profiles,
172 TRTBaseAllocator* allocator);
173
174 // Allocates necessary resources for calibration.
175 Status AllocateCalibrationResources(OpKernelContext* ctx,
176 TRTEngineCacheResource* cache_res);
177
178 Status GetEngineCacheResource(OpKernelContext* ctx,
179 TRTEngineCacheResource** cache_res);
180
181 // Returns a pair of 1) An EngineContext object that is compatible with the
182 // input and 2) The index of the IExecutionContext compatible with the input.
183 // If a cuda engine for the given input shapes can't be found, returns
184 // (nullptr, 0) to allow native engine execution. Returns an error code for
185 // any problem that would prevent both TensorRT engine exceution and native
186 // segment execution.
187 StatusOr<std::pair<EngineContext*, int>> GetEngine(
188 const std::vector<TensorShape>& input_concrete_shapes,
189 OpKernelContext* ctx, TRTEngineCacheResource* cache_resource);
190
191 // Builds and returns a cuda engine for the input shapes. If building the
192 // engine fails, enters a dummy entry into the cache_resource cache so we
193 // don't continually try to build the same failing engine.
194 StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> BuildEngine(
195 const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
196 bool use_calibration, TRTInt8Calibrator* calibrator,
197 TRTEngineCacheResource* cache_resource, OpKernelContext* ctx);
198
199 // Verify that the input shapes are consistent and can be handled by this op.
200 Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
201
202 std::vector<string> input_nodes_;
203 std::vector<string> output_nodes_;
204
205 // serialized protobuf segment or trt engine depending on static_engine_ flag.
206 string serialized_segment_;
207
208 // The function for TF native execution of the segment.
209 NameAttrList func_;
210
211 // GraphDef representation of the segment.
212 GraphDef segment_graph_def_;
213
214 // Engine Precision mode.
215 TrtPrecisionMode precision_mode_;
216
217 // Whether engine is constructed during the conversion or needs to be
218 // constructed from protobuf segment.
219 bool static_engine_;
220
221 // Whether to calibrate INT8 engine.
222 bool calibration_mode_;
223
224 // Whether to use implicit batch dimension for TensorRT.
225 bool use_implicit_batch_;
226
227 // Whether to collect optimization profiles for TensorRT, only used when
228 // use_implicit_batch_=false.
229 bool profile_generation_mode_;
230
231 // Optimization profile generation strategy.
232 ProfileStrategy profile_strategy_;
233
234 // Whether the TRTEngineOp has any input with unknown dimensions.
235 bool has_dynamic_shape_input_;
236
237 // Whether to build TensorRT engines at runtime.
238 bool allow_build_at_runtime_;
239
240 // Whether to allow soft placement when the graph is executed with native
241 // TensorFlow.
242 bool allow_soft_placement_;
243
244 // Maximum number of cached engines.
245 int max_cached_engines_;
246
247 int64 workspace_size_;
248 mutex engine_mutex_;
249 FunctionLibraryRuntime::Handle native_execution_func_handle_;
250
251 // The finalized calibrator for inference.
252 std::unique_ptr<TRTInt8Calibrator> calibrator_;
253
254 // If true, create calibration graph for INT8 mode. Otherwise, we are using
255 // user-provided quantization ranges.
256 bool use_calibration_;
257
258 tensorflow::grappler::Cluster* cluster_;
259
260 // Array of all input shapes, collected from the input_shapes attribute when
261 // constructing the TRTEngineOp. The input_shapes attribute is set during
262 // graph conversion time. This data is used to retrieve which input dimensions
263 // could be unknown. During inference time this information is not available
264 // otherwise (all shapes are known (concrete) shapes when we run inference).
265 std::vector<PartialTensorShape> input_partial_shapes_;
266 // Shapes, excluding resource inputs.
267 std::vector<PartialTensorShape> input_partial_shapes_filtered_;
268
269 // The TF node can have more inputs than the TRT engine: resource inputs are
270 // saved as weight in the engine, instead of passing that as engine input.
271 // Input mask is true for those TF input that are TRT engine inputs.
272 std::vector<bool> input_mask_;
273
274 // Whether to use explicit precision (QDQ) mode.
275 bool use_explicit_precision_;
276 };
277
278 #define TYPECASE(dt, X, Y) \
279 case dt: { \
280 return (void*)X->flat<EnumToDataType<dt>::Type>().data(); \
281 }
282
GetTensorAddress(const Tensor * tensor_ptr)283 void* GetTensorAddress(const Tensor* tensor_ptr) {
284 auto tensor_type = tensor_ptr->dtype();
285 switch (tensor_type) {
286 TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
287 TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
288 TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
289 TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
290 default: {
291 LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
292 return nullptr;
293 }
294 }
295 }
296
FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime * flib_runtime,GraphDef * graph_def)297 static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
298 FunctionLibraryRuntime* flib_runtime,
299 GraphDef* graph_def) {
300 const FunctionLibraryDefinition* flib_def =
301 flib_runtime->GetFunctionLibraryDefinition();
302 const FunctionBody* fbody;
303 fbody = flib_runtime->GetFunctionBody(handle);
304 if (!fbody) {
305 return errors::Internal(
306 "Function body is null when converting from FuncDef to GraphDef.");
307 }
308 std::unique_ptr<Graph> graph(new Graph(flib_def));
309 CopyGraph(*fbody->graph, graph.get());
310
311 auto replace_name = [](const char* const prefix, string* name) {
312 if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
313 name->replace(0, strlen(prefix), prefix);
314 return true;
315 }
316 return false;
317 };
318 graph->ToGraphDef(graph_def);
319 // GraphToFunctionDef() will convert all the node names to lowercase.
320 for (auto& node : *graph_def->mutable_node()) {
321 if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
322 if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
323 // Instantiation of the function will append _RetVal to the node name,
324 // need to remove it for backward compatibility.
325 const char* const suffix_to_remove = "_RetVal";
326 if (absl::EndsWith(node.name(), suffix_to_remove)) {
327 node.mutable_name()->erase(node.name().size() -
328 strlen(suffix_to_remove));
329 }
330 }
331 }
332 for (auto& input : *node.mutable_input()) {
333 if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
334 replace_name(IONamePrefixes::kOutputPHName, &input);
335 }
336 }
337 }
338 return Status::OK();
339 }
340
ConstructFunctionHandle(FunctionLibraryRuntime * lib,const string & device_name,bool allow_soft_placement,size_t num_inputs,size_t num_outputs)341 StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
342 FunctionLibraryRuntime* lib, const string& device_name,
343 bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
344 VLOG(1) << "Constructing function handle";
345 if (lib == nullptr) {
346 return errors::Internal("Context function library is null");
347 }
348 FunctionLibraryRuntime::InstantiateOptions inst_ops;
349 inst_ops.state_handle = "";
350 inst_ops.target = device_name;
351 if (allow_soft_placement) {
352 const FunctionDef* fdef =
353 lib->GetFunctionLibraryDefinition()->Find(func_.name());
354 if (!fdef) {
355 return errors::Internal(
356 StrCat("Can't find FunctionDef for ", func_.name()));
357 }
358 bool ints_on_device =
359 fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
360 fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
361 // kIntsOnDeviceAttr is not compatible with is_multi_device_function which
362 // is needed to support allow_soft_placement.
363 if (ints_on_device) {
364 LOG_FIRST_FEW_WARNING_WITH_PREFIX
365 << "Function " << name()
366 << " has attribute kIntsOnDeviceAttr=true "
367 "and will be executed natively with allow_soft_placement=false. "
368 "If this is a problem, please re-generate your SavedModel with "
369 "the TF-TRT runtime you are using.";
370 } else {
371 inst_ops.is_multi_device_function = true;
372 inst_ops.input_devices.resize(num_inputs, device_name);
373 inst_ops.output_devices.resize(num_outputs, device_name);
374 inst_ops.config_proto.set_allow_soft_placement(true);
375 }
376 }
377 FunctionLibraryRuntime::Handle func_handle;
378 Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
379 inst_ops, &func_handle);
380 if (status.ok()) {
381 return func_handle;
382 }
383 return status;
384 }
385
ImportSegmentGraphDef(FunctionLibraryRuntime * lib,const string & device_name)386 Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
387 const string& device_name) {
388 TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
389 ConstructFunctionHandle(lib, device_name));
390 return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
391 }
392
TRTEngineOp(OpKernelConstruction * context)393 TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
394 : AsyncOpKernel(context) {
395 // read serialized_engine
396 OP_REQUIRES_OK(context,
397 context->GetAttr("serialized_segment", &serialized_segment_));
398 OP_REQUIRES_OK(context,
399 context->GetAttr("workspace_size_bytes", &workspace_size_));
400 OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
401
402 VLOG(1) << "Constructing " << name();
403 string precision_string;
404 OP_REQUIRES_OK(context,
405 context->GetAttr("precision_mode", &precision_string));
406 string calibration_data;
407 OP_REQUIRES_OK(context,
408 context->GetAttr("calibration_data", &calibration_data));
409 OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
410 OP_REQUIRES(context, !func_.name().empty(),
411 errors::InvalidArgument(
412 "The TF function for the TRT segment could not be empty"));
413 OP_REQUIRES_OK(context,
414 TrtPrecisionModeFromName(precision_string, &precision_mode_));
415 OP_REQUIRES_OK(context,
416 context->GetAttr("use_calibration", &use_calibration_));
417 OP_REQUIRES_OK(context,
418 context->GetAttr("input_shapes", &input_partial_shapes_));
419 auto status =
420 context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_);
421 if (status.code() == tensorflow::error::NOT_FOUND) {
422 VLOG(2) << "Not found _allow_build_at_runtime in "
423 << context->device()->name()
424 << ", thus setting _allow_build_at_runtime=true";
425 allow_build_at_runtime_ = true;
426 } else {
427 OP_REQUIRES_OK(context, status);
428 }
429
430 // Get a mask of non-resource inputs.
431 std::vector<DataType> in_types;
432 input_mask_.resize(input_partial_shapes_.size());
433 OP_REQUIRES_OK(context, context->GetAttr("InT", &in_types));
434 for (int i = 0; i < input_mask_.size(); i++) {
435 input_mask_[i] = (in_types[i] != DataType::DT_RESOURCE);
436 }
437
438 // Filter the shapes to exclude resources.
439 for (int i = 0; i < input_partial_shapes_.size(); i++) {
440 if (input_mask_[i]) {
441 input_partial_shapes_filtered_.push_back(input_partial_shapes_[i]);
442 }
443 }
444
445 status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
446 if (status.code() == tensorflow::error::NOT_FOUND) {
447 allow_soft_placement_ = true;
448 } else {
449 OP_REQUIRES_OK(context, status);
450 }
451
452 status = context->GetAttr("use_explicit_precision", &use_explicit_precision_);
453 if (!status.ok()) {
454 use_explicit_precision_ = false;
455 }
456
457 native_execution_func_handle_ = kInvalidHandle;
458 if (!static_engine_) {
459 OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
460 context->device()->name()));
461 }
462 // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
463 // backward compatibility reasons. Remove it once all known users switch to
464 // 2.0.
465 calibration_mode_ =
466 (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
467 calibration_data.empty());
468 if (!calibration_data.empty()) {
469 calibrator_.reset(new TRTInt8Calibrator(calibration_data));
470 calibration_data.resize(0);
471 }
472 OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
473 &max_cached_engines_));
474
475 status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_);
476 if (status.code() == tensorflow::error::NOT_FOUND) {
477 VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name()
478 << ", thus setting _use_implicit_batch=true";
479 use_implicit_batch_ = true;
480 }
481
482 status =
483 context->GetAttr("_profile_generation_mode", &profile_generation_mode_);
484 if (status.code() == tensorflow::error::NOT_FOUND) {
485 VLOG(2) << "Not found _profile_generation_mode in "
486 << context->device()->name()
487 << ", thus setting _profile_generation_mode=false";
488 profile_generation_mode_ = false;
489 }
490 if (static_engine_) {
491 if (profile_generation_mode_) profile_generation_mode_ = false;
492 }
493 if (use_implicit_batch_) {
494 OP_REQUIRES(context, !profile_generation_mode_,
495 errors::InvalidArgument(
496 "profile_generation_mode_=true is only supported if "
497 "use_implicit_batch=false"));
498 if (input_partial_shapes_.empty()) {
499 VLOG(1) << "Attribute input_shapes is not set. This happens probably "
500 << "because you are using a model that is already converted "
501 << "to TensorRT with a previous version of TF-TRT (i.e. includes "
502 << "TRTEngineOp in graph). This is not an error. If you convert "
503 << "the original model again to TensorRT, the attributes "
504 << "input_shapes will be set automatically.";
505 }
506 } else {
507 OP_REQUIRES(
508 context, !input_partial_shapes_.empty(),
509 errors::InvalidArgument(
510 "Explicit batch mode requires attribute input_shapes to be set."
511 "If you are using a model that was converted to TensorRT by a "
512 "previous version of TF-TRT, (i.e. includes TRTEngineOp in graph "
513 "without the input_shapes attribute), then you need to convert the "
514 "original model again to TensorRT in order to set the attribute "
515 "input_shapes."));
516
517 string profile_strategy_name;
518 status = context->GetAttr("profile_strategy", &profile_strategy_name);
519 if (status.code() == tensorflow::error::NOT_FOUND) {
520 VLOG(2) << "Not found strategy in " << context->device()->name()
521 << ", thus setting profile_strategy='Range'";
522 profile_strategy_ = ProfileStrategy::kRange;
523 } else {
524 OP_REQUIRES_OK(context, ProfileStrategyFromName(profile_strategy_name,
525 &profile_strategy_));
526 }
527 }
528 has_dynamic_shape_input_ = absl::c_any_of(
529 input_partial_shapes_filtered_,
530 [](PartialTensorShape shape) { return !shape.IsFullyDefined(); });
531 VLOG(2) << "TRTEngineOp has_dynamic_shape_input_: "
532 << has_dynamic_shape_input_;
533 }
534
535 // Copies input tensor ctx->input(i) (which is in device memory) to the host,
536 // and place the resulting host tensor to the back of native_inputs.
CopyToHostAsync(OpKernelContext * ctx,std::vector<Tensor> * native_inputs,int i,const cudaStream_t stream)537 Status CopyToHostAsync(OpKernelContext* ctx, std::vector<Tensor>* native_inputs,
538 int i, const cudaStream_t stream) {
539 // The TRTEngineOp has all ctx->inputs on the device. In contrast, the
540 // native segment expects to find int32 inputs on the host. We copy int32
541 // inputs from device to host.
542
543 AllocatorAttributes allocator_attr;
544 allocator_attr.set_on_host(true);
545 Tensor t;
546 TF_RETURN_IF_ERROR(ctx->allocate_temp(
547 ctx->input_dtype(i), ctx->input(i).shape(), &t, allocator_attr));
548 native_inputs->push_back(t);
549 const Tensor& gpu_tensor = ctx->input(i);
550 auto ret = cudaMemcpyAsync(
551 t.flat<int32>().data(), gpu_tensor.flat<int32>().data(),
552 t.NumElements() * sizeof(int32), cudaMemcpyDeviceToHost, stream);
553 if (ret != 0) {
554 return errors::Internal("Could not copy tensor for native segment input");
555 }
556 return Status::OK();
557 }
558
559 // Copies native_tensor, which is in host memory to ctx->output(t), which is in
560 // device memory.
CopyToDeviceAsync(OpKernelContext * ctx,const Tensor & native_tensor,int t,cudaStream_t stream)561 Status CopyToDeviceAsync(OpKernelContext* ctx, const Tensor& native_tensor,
562 int t, cudaStream_t stream) {
563 Tensor* gpu_tensor;
564 TF_RETURN_IF_ERROR(
565 ctx->allocate_output(t, native_tensor.shape(), &gpu_tensor));
566 auto ret = cudaMemcpyAsync(gpu_tensor->flat<int32>().data(),
567 native_tensor.flat<int32>().data(),
568 native_tensor.NumElements() * sizeof(int32),
569 cudaMemcpyHostToDevice, stream);
570 if (ret != 0) {
571 return errors::Internal("Could not copy tensor for native segment output");
572 }
573 return Status::OK();
574 }
575
ExecuteNativeSegment(OpKernelContext * ctx,AsyncHelper * async_helper)576 void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
577 AsyncHelper* async_helper) {
578 tensorflow::profiler::TraceMe activity(
579 "TRTEngineOp::ExecuteNativeSegment",
580 tensorflow::profiler::TraceMeLevel::kInfo);
581 std::vector<Tensor> native_inputs;
582 std::vector<Tensor>* native_outputs = new std::vector<Tensor>();
583 DummyAsyncHelper dummy_async_helper;
584 if (native_execution_func_handle_ == kInvalidHandle) {
585 StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
586 ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
587 allow_soft_placement_, ctx->num_inputs(),
588 ctx->num_outputs());
589 OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), dummy_async_helper);
590 native_execution_func_handle_ = *status_or_handle;
591 }
592
593 auto lib = ctx->function_library();
594 FunctionLibraryRuntime::Options opts;
595 opts.rendezvous = ctx->rendezvous();
596 opts.cancellation_manager = ctx->cancellation_manager();
597 opts.runner = ctx->runner();
598 native_inputs.reserve(ctx->num_inputs());
599 int n_copies = 0;
600 const cudaStream_t* stream = CHECK_NOTNULL(
601 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
602 ->stream()
603 ->implementation()
604 ->GpuStreamMemberHack()));
605 for (int i = 0; i < ctx->num_inputs(); i++) {
606 if (ctx->input_dtype(i) != DT_INT32) {
607 native_inputs.push_back(ctx->input(i));
608 } else {
609 OP_REQUIRES_OK_ASYNC(ctx,
610 CopyToHostAsync(ctx, &native_inputs, i, *stream),
611 dummy_async_helper);
612 n_copies++;
613 }
614 }
615 if (n_copies > 0) {
616 // If we have any int32 tensors, then wait until data is copied to host.
617 cudaStreamSynchronize(*stream);
618 }
619 VLOG(1) << "Executing native segment: " << name();
620 // Increment the reference count of the async_helper by 1. When the native
621 // segment finishes execution asynchronously, we decrement the reference
622 // count of the object.
623 async_helper->Ref();
624 lib->Run(
625 opts, native_execution_func_handle_, native_inputs, native_outputs,
626 [this, ctx, native_outputs, async_helper, stream](const Status& s) {
627 core::ScopedUnref sc(async_helper);
628 DummyAsyncHelper dummy_async_helper;
629 std::unique_ptr<std::vector<Tensor>> outputs_wrapper(native_outputs);
630 OP_REQUIRES_OK_ASYNC(ctx, s, dummy_async_helper);
631 VLOG(1) << "Native Segment completed";
632 int n_copies = 0;
633 for (size_t t = 0; t < native_outputs->size(); ++t) {
634 if (native_outputs->at(t).dtype() == DT_INT32) {
635 OP_REQUIRES_OK_ASYNC(
636 ctx, CopyToDeviceAsync(ctx, native_outputs->at(t), t, *stream),
637 dummy_async_helper);
638 n_copies++;
639 } else {
640 ctx->set_output(t, native_outputs->at(t));
641 }
642 }
643 if (n_copies > 0) {
644 cudaStreamSynchronize(*stream);
645 }
646 });
647 }
648
ExecuteCalibration(OpKernelContext * ctx,TRTEngineCacheResource * cache_res,AsyncHelper * async_helper)649 void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
650 TRTEngineCacheResource* cache_res,
651 AsyncHelper* async_helper) {
652 tensorflow::profiler::TraceMe activity(
653 "TRTEngineOp::ExecuteCalibration",
654 tensorflow::profiler::TraceMeLevel::kInfo);
655 VLOG(1) << "Executing TRT calibration: " << name();
656 DummyAsyncHelper dummy_async_helper;
657
658 CalibrationContext* calib_ctx = cache_res->calib_ctx_.get();
659 const int num_inputs = ctx->num_inputs();
660 // TODO(laigd): need to check that input shape matches.
661 // Pass input data to calibrator
662 std::unordered_map<string, void*> input_data;
663 for (int i = 0; i < num_inputs; i++) {
664 const Tensor& t = ctx->input(i);
665 void* data_address = GetTensorAddress(&t);
666 OP_REQUIRES_ASYNC(ctx, data_address,
667 errors::InvalidArgument(
668 "Unsupported data type encountered in input ", i),
669 dummy_async_helper);
670 // Check the allocated buffer is sufficient for input
671 const auto device_tensor = &calib_ctx->device_tensors_.at(i);
672 CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
673 input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
674 }
675 VLOG(2) << "Filled map for sending";
676 // Copied from gpu_kernel_helper.h as the header can only be used in *.cu.cc
677 // files.
678 const cudaStream_t* stream = CHECK_NOTNULL(
679 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
680 ->stream()
681 ->implementation()
682 ->GpuStreamMemberHack()));
683 // TRTInt8Calibrator::setBatch will wait until TRTInt8Calibrator::getBatch is
684 // called before proceeding with feeding the calibration data to the
685 // calibrator. It returns true if the calibration data is accepted and
686 // returns false if calibration is terminated due to errors.
687 //
688 // If TRTInt8Calibrator::getBatch is never called, which could happen if
689 // there is any problem in building the cuda engine for calibration inside
690 // TensorRT, then the TRTInt8Calibrator::setBatch call here will hang until
691 // TRTInt8Calibrator::setDone is called by the calibration thread in
692 // AllocateCalibrationResources.
693 //
694 // In both of the above cases, setBatch here returns a boolean value to
695 // indicate the result of the calibration process.
696 if (!calib_ctx->calibrator_->setBatch(input_data, *stream)) {
697 VLOG(2) << "Failed to feed calibration data";
698 } else {
699 VLOG(2) << "Passed calibration data";
700 }
701 ExecuteNativeSegment(ctx, async_helper);
702 }
703
VerifyInputShapes(const std::vector<TensorShape> & input_concrete_shapes)704 Status TRTEngineOp::VerifyInputShapes(
705 const std::vector<TensorShape>& input_concrete_shapes) {
706 if (input_concrete_shapes.empty()) {
707 return errors::InvalidArgument("Input shapes are empty, for ", name());
708 }
709
710 if (input_partial_shapes_filtered_.empty()) {
711 if (!use_implicit_batch_) {
712 return errors::InvalidArgument(
713 "Explicit batch mode requires input_partial_shapes_ ",
714 "to contain the dynamic input shapes to TRTEngineOp");
715 }
716 // If the graph was converted with an earlier version of TF-TRT, it can
717 // happen that the input_partial_shapes_ vector is not set (see
718 // input_shapes attribute handling in the TRTEngineOp constructor).
719 // In implicit batch mode it is allowed to have empty input_partial_shapes_,
720 // since it is only required in explicit batch mode (see the input_shapes
721 // attribute of ConvertGraphDefToEngine in TRTEngineOp::GetEngine.
722 } else {
723 // Additional consistency checks if input_partial_shapes_ is present.
724 const string error_msg = StrCat(
725 "Input shapes do not match input partial shapes stored in graph, for ",
726 name(), ": ", DebugString(input_concrete_shapes),
727 " != ", DebugString(input_partial_shapes_filtered_));
728 if (input_concrete_shapes.size() != input_partial_shapes_filtered_.size()) {
729 return errors::InvalidArgument(error_msg);
730 }
731 for (int i = 0; i < input_concrete_shapes.size(); i++) {
732 if (input_concrete_shapes[i].dims() !=
733 input_partial_shapes_filtered_[i].dims()) {
734 return errors::InvalidArgument(error_msg);
735 }
736 }
737 for (int i = 0; i < input_concrete_shapes.size(); i++) {
738 for (int d = 0; d < input_concrete_shapes[i].dims(); d++) {
739 if (input_partial_shapes_filtered_[i].dim_size(d) != -1) {
740 if (input_concrete_shapes[i].dim_size(d) !=
741 input_partial_shapes_filtered_[i].dim_size(d)) {
742 return errors::InvalidArgument(error_msg);
743 }
744 }
745 }
746 }
747 }
748
749 if (use_implicit_batch_) {
750 if (input_concrete_shapes[0].dims() < 1) {
751 return errors::InvalidArgument(
752 "Input shapes contain scalar, for ", name(), ": ",
753 TensorShapeUtils::ShapeListString(input_concrete_shapes));
754 }
755 const int batch_size = input_concrete_shapes[0].dim_size(0);
756 if (batch_size < 1) {
757 return errors::InvalidArgument(
758 "Incorrect batch dimension, for ", name(), ": ",
759 TensorShapeUtils::ShapeListString(input_concrete_shapes));
760 }
761 for (const TensorShape& shape : input_concrete_shapes) {
762 if (batch_size != shape.dim_size(0)) {
763 return errors::InvalidArgument(
764 "Input shapes are inconsistent on the batch dimension, for ",
765 name(), ": ",
766 TensorShapeUtils::ShapeListString(input_concrete_shapes));
767 }
768 }
769 }
770 return Status::OK();
771 }
772
AllowEngineNativeSegmentExecution()773 static bool AllowEngineNativeSegmentExecution() {
774 bool value;
775 Status status =
776 ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
777 /*default_value=*/true, &value);
778 if (!status.ok()) {
779 LOG(ERROR) << status;
780 }
781 return value;
782 }
783
ComputeAsync(OpKernelContext * ctx,AsyncOpKernel::DoneCallback done)784 void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
785 AsyncOpKernel::DoneCallback done) {
786 tensorflow::profiler::TraceMe activity(
787 "TRTEngineOp::ComputeAsync", tensorflow::profiler::TraceMeLevel::kInfo);
788
789 // Invoke DoneCallback when this object is destructed, which could be after
790 // this routine finishes execution, in particular, when native segment is
791 // executed.
792 auto async_helper = new AsyncHelper(done);
793 core::ScopedUnref sc(async_helper);
794
795 // For all async execution macros, use this object as there is no need to call
796 // DoneCallback from those macros.
797 DummyAsyncHelper dummy_async_helper;
798
799 // Get TRT resource.
800 TRTEngineCacheResource* cache_res = nullptr;
801 OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res),
802 dummy_async_helper);
803 core::ScopedUnref unref_cache_res(cache_res);
804
805 // Get shapes of inputs to engine.
806 std::vector<TensorShape> input_concrete_shapes;
807 input_concrete_shapes.reserve(ctx->num_inputs());
808 std::vector<TensorShape> input_concrete_shapes_filtered;
809 for (int i = 0; i < ctx->num_inputs(); ++i) {
810 input_concrete_shapes.push_back(ctx->input(i).shape());
811 if (ctx->input(i).dtype() != DataType::DT_RESOURCE) {
812 input_concrete_shapes_filtered.push_back(ctx->input(i).shape());
813 }
814 }
815
816 /// TODO(lsugy): fix case of engine with only resource inputs.
817 Status verify_input_shape_status =
818 VerifyInputShapes(input_concrete_shapes_filtered);
819 // TODO(bixia): Fix the segmentation.
820 if (!verify_input_shape_status.ok()) {
821 LOG_FIRST_FEW_WARNING_WITH_PREFIX
822 << "Running native segment for" << name()
823 << " due to failure in verifying input shapes: "
824 << verify_input_shape_status.error_message();
825 ExecuteNativeSegment(ctx, async_helper);
826 return;
827 }
828
829 if (!use_implicit_batch_ &&
830 (has_dynamic_shape_input_ || cache_res->profiles_.HasShapeTensor())) {
831 OP_REQUIRES_OK_ASYNC(ctx, cache_res->profiles_.CollectShapeValues(ctx),
832 dummy_async_helper);
833 cache_res->profiles_.SetInputMask(input_mask_);
834 if (profile_generation_mode_) {
835 // Collecting new shapes for profiles can be only done once. After the
836 // shapes are converted to TRT profiles, no shapes can be collected
837 // anymore.
838 OP_REQUIRES_ASYNC(ctx, cache_res->profiles_.GetNumProfiles() == 0,
839 errors::Unimplemented("Cannot collect new shapes when "
840 "profiles are already created."),
841 dummy_async_helper);
842 // Just collect the input shape info and return. The shapes are used to
843 // generate optimization profiles during engine creation.
844 cache_res->profiles_.AddShape(input_concrete_shapes);
845 VLOG(1) << "Native segment is used during collecting shapes for profiles";
846 ExecuteNativeSegment(ctx, async_helper);
847 return;
848 } else if (cache_res->profiles_.GetNumProfiles() == 0 && !static_engine_) {
849 // Add current shape if we did not collect any shapes so far.
850 if (!cache_res->profiles_.HasShape()) {
851 cache_res->profiles_.AddShape(input_concrete_shapes);
852 }
853 // Create profiles out of collected shapes during profile generation.
854 cache_res->profiles_.InitProfiles(input_partial_shapes_,
855 profile_strategy_);
856 }
857 }
858
859 // Run calibration if in int8+calibration mode.
860 // * Logic in TF 1.x:
861 // - During conversion: calibration_mode_ is true and cache size is 0, so it
862 // will run calibration.
863 // - During inference: calibration_data will be set, so calibration_mode_
864 // is false and it won't trigger calibration.
865 // * Logic in TF 2.0:
866 // - During conversion: similar to 1.x.
867 // - During inference: calibration_data will still be empty, but cache will
868 // contain the calibrated engine, so it won't trigger calibration.
869 //
870 // TODO(laigd): consider the following alternatives:
871 // 1. Serialize the state (calibration or inference) using
872 // TRTEngineInstance proto (or a new proto), so we know which mode we're
873 // in and don't run calibration during inference (which is invalid).
874 // 2. Reuse the calibration_data attribute or use a new attribute in the
875 // NodeDef to indicate whether it's in calibration mode.
876 if (calibration_mode_ && cache_res->cache_.size() == 0) {
877 if (!cache_res->calib_ctx_) {
878 // TODO(laigd): better encapsulation.
879 mutex_lock lock(engine_mutex_);
880 if (!cache_res->calib_ctx_) {
881 // Add profiles if we are in dynamic shape mode.
882 if (!use_implicit_batch_ && (has_dynamic_shape_input_ ||
883 cache_res->profiles_.HasShapeTensor())) {
884 cache_res->profiles_.InitCalibProfile(input_concrete_shapes);
885 }
886 OP_REQUIRES_OK_ASYNC(ctx, AllocateCalibrationResources(ctx, cache_res),
887 dummy_async_helper);
888 }
889 }
890 // TODO(laigd): check that the input shapes match the shapes of the
891 // persistent tensor in the calibration resource.
892 ExecuteCalibration(ctx, cache_res, async_helper);
893 return;
894 }
895
896 StatusOr<std::pair<EngineContext*, int>> status =
897 GetEngine(input_concrete_shapes, ctx, cache_res);
898 OP_REQUIRES_OK_ASYNC(ctx, status.status(), dummy_async_helper);
899
900 EngineContext* engine_context = status.ValueOrDie().first;
901 int trt_context_idx = status.ValueOrDie().second;
902 auto may_execute_native_segment = [&] {
903 if (!AllowEngineNativeSegmentExecution()) {
904 ctx->CtxFailure(
905 errors::Aborted("User disallowed engine native segment execution"));
906 return false;
907 }
908 return true;
909 };
910 if (!engine_context->GetCudaEngine()) {
911 LOG_FIRST_FEW_WARNING_WITH_PREFIX
912 << "Engine retrieval for input shapes: "
913 << TensorShapeUtils::ShapeListString(input_concrete_shapes)
914 << " failed. Running native segment for " << name();
915 if (may_execute_native_segment()) {
916 ExecuteNativeSegment(ctx, async_helper);
917 }
918 return;
919 }
920 Status stat =
921 ExecuteTrtEngine(ctx, engine_context, trt_context_idx,
922 cache_res->profiles_, cache_res->allocator_.get());
923 if (stat.ok()) return;
924
925 LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
926 << " Retrying with native segment for "
927 << name();
928 if (!may_execute_native_segment()) {
929 return;
930 }
931 // Release any outputs that are allocated, ExecuteNativeSegment will
932 // re-allocate them and fail if they are currently allocated.
933 // The Tensor pointer in the returned TensorValue must be explicitly
934 // deleted.
935 for (int i = 0; i < ctx->num_outputs(); i++) {
936 delete ctx->release_output(i).tensor;
937 }
938 ExecuteNativeSegment(ctx, async_helper);
939 }
940
ExecuteTrtEngine(OpKernelContext * ctx,EngineContext * engine_context,int trt_context_idx,const TrtShapeOptimizationProfile & profiles,TRTBaseAllocator * allocator)941 Status TRTEngineOp::ExecuteTrtEngine(
942 OpKernelContext* ctx, EngineContext* engine_context, int trt_context_idx,
943 const TrtShapeOptimizationProfile& profiles, TRTBaseAllocator* allocator) {
944 tensorflow::profiler::TraceMe activity(
945 "TRTEngineOp::ExecuteTrtEngine",
946 tensorflow::profiler::TraceMeLevel::kInfo);
947 VLOG(1) << "Executing TRT engine: " << name();
948 nvinfer1::ICudaEngine* cuda_engine = engine_context->GetCudaEngine();
949
950 if (VLOG_IS_ON(2)) {
951 VLOG(2) << " Network name: " << cuda_engine->getName();
952 VLOG(2) << " Activation size: " << engine_context->GetDeviceMemorySize()
953 << " bytes";
954 #if !IS_TRT_VERSION_GE(8, 0, 0, 0)
955 // getWorkspaceSize() is deprecated as of TRT 8
956 VLOG(2) << " Workspace size: " << cuda_engine->getWorkspaceSize()
957 << " bytes";
958 #endif // #if !IS_TRT_VERSION_GE(8, 0, 0, 0)
959 VLOG(2) << " Datatype of " << cuda_engine->getNbBindings()
960 << " inputs/outputs";
961 string binding_types = "";
962 for (int i = 0; i < cuda_engine->getNbBindings(); i++) {
963 binding_types += " " + string(cuda_engine->getBindingName(i)) + ": " +
964 DebugString(cuda_engine->getBindingDataType(i)) + "\n";
965 }
966 VLOG(2) << binding_types;
967 }
968
969 const int num_binding = cuda_engine->getNbBindings();
970 std::vector<void*> buffers(num_binding);
971
972 // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex
973 // for it.
974 mutex_lock lock(engine_context->mu);
975 nvinfer1::IExecutionContext* execution_context;
976 bool has_device_memory;
977 TF_RETURN_IF_ERROR(engine_context->GetExecutionContext(
978 trt_context_idx, &execution_context, &has_device_memory));
979
980 if (VLOG_IS_ON(2)) {
981 VLOG(2) << "Selected execution context: " << trt_context_idx;
982 }
983 const int num_batch =
984 use_implicit_batch_ ? ctx->input(0).shape().dim_size(0) : 0;
985
986 TF_RETURN_IF_ERROR(SetTrtEngineInputs(
987 cuda_engine, execution_context, trt_context_idx, buffers,
988 use_implicit_batch_, num_batch, profiles, ctx));
989
990 TF_RETURN_IF_ERROR(SetTrtEngineOutputs(cuda_engine, execution_context,
991 trt_context_idx, buffers,
992 use_implicit_batch_, num_batch, ctx));
993
994 // Copied from gpu_kernel_helper.h as the header can only be used in *.cu.cc
995 // files.
996 const cudaStream_t* stream = CHECK_NOTNULL(
997 reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
998 ->stream()
999 ->implementation()
1000 ->GpuStreamMemberHack()));
1001
1002 ContextDeviceMemory context_device_memory;
1003 if (!has_device_memory) {
1004 tensorflow::profiler::TraceMe activity(
1005 "TRTEngineOp::AllocateDeviceMemory",
1006 tensorflow::profiler::TraceMeLevel::kInfo);
1007 // Allocate device memory for the TensorRT engine execution. The device
1008 // memory will be released when context_device_memory goes out of scope.
1009 TF_RETURN_IF_ERROR(context_device_memory.AllocateDeviceMemory(
1010 execution_context, allocator, engine_context->GetDeviceMemorySize()));
1011 }
1012 // Enqueue the TensorRT engine for execution.
1013 return TrtEnqueue(execution_context, buffers, *stream, use_implicit_batch_,
1014 num_batch);
1015 }
1016
GetEngineCacheResource(OpKernelContext * ctx,TRTEngineCacheResource ** cache_res)1017 Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
1018 TRTEngineCacheResource** cache_res) {
1019 tensorflow::profiler::TraceMe activity(
1020 "TRTEngineOp::GetEngineCachResource",
1021 tensorflow::profiler::TraceMeLevel::kInfo);
1022 // Canonicalize the op name by removing the scopes if any. This is mainly
1023 // because in TFv2, the function graph can be instantiated in various ways and
1024 // it'll insert scope names to the name of the TRTEngineOps, which will result
1025 // in many different engine caches if we use the instantiated op name
1026 // directly, but we still want all of them share the same cache (if they were
1027 // representing the same subgraph).
1028 absl::string_view resource_name = name();
1029 size_t last_slash = resource_name.find_last_of('/');
1030 if (last_slash != absl::string_view::npos) {
1031 resource_name.remove_prefix(last_slash + 1);
1032 }
1033
1034 // Get engine cache.
1035 return ctx->resource_manager()->LookupOrCreate(
1036 std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
1037 {[this, ctx](TRTEngineCacheResource** cr) -> Status {
1038 *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
1039 return Status::OK();
1040 }});
1041 }
1042
BuildEngine(const std::vector<TensorShape> & input_concrete_shapes,int batch_size,bool use_calibration,TRTInt8Calibrator * calibrator,TRTEngineCacheResource * cache_resource,OpKernelContext * ctx)1043 StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> TRTEngineOp::BuildEngine(
1044 const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
1045 bool use_calibration, TRTInt8Calibrator* calibrator,
1046 TRTEngineCacheResource* cache_resource, OpKernelContext* ctx) {
1047 TRT_ENSURE(cache_resource);
1048 TRT_ENSURE(ctx);
1049 // Use concrete shapes for implicit batch mode and partial shapes for
1050 // explicit batch mode.
1051 bool use_concrete_shapes =
1052 use_implicit_batch_ || cache_resource->profiles_.IsStaticCompatible();
1053 const std::vector<PartialTensorShape>& conversion_input_shapes =
1054 use_concrete_shapes
1055 ? std::vector<PartialTensorShape>(input_concrete_shapes.begin(),
1056 input_concrete_shapes.end())
1057 : input_partial_shapes_;
1058
1059 VLOG(1) << "Building a new TensorRT engine for " << name()
1060 << " with input shapes: " << DebugString(conversion_input_shapes);
1061
1062 std::unordered_map<string, tensorflow::DeviceProperties> device_map;
1063 DeviceNameUtils::ParsedName full_parsed_name;
1064 DeviceNameUtils::ParseFullName(ctx->device()->name(), &full_parsed_name);
1065 device_map.emplace(ctx->device()->name(),
1066 grappler::GetDeviceInfo(full_parsed_name));
1067 tensorflow::grappler::VirtualCluster cluster(device_map);
1068
1069 TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
1070 auto status = convert::ConvertGraphDefToEngine(
1071 segment_graph_def_, ctx, precision_mode_, batch_size, workspace_size_,
1072 conversion_input_shapes, &logger, cache_resource->allocator_.get(),
1073 calibrator, &engine, use_calibration, use_implicit_batch_, nullptr,
1074 &cache_resource->profiles_, name(), use_explicit_precision_, &cluster);
1075 if (!status.ok()) {
1076 LOG_FIRST_FEW_WARNING_WITH_PREFIX
1077 << "Engine creation for " << name() << " failed. "
1078 << "The native segment will be used instead. "
1079 << "Reason: " << status;
1080 // Store an empty engine in the cache for these input shapes so we don't try
1081 // to build the same failing engine again.
1082 cache_resource->cache_.emplace(input_concrete_shapes,
1083 std::make_unique<EngineContext>());
1084 return status;
1085 }
1086 return engine;
1087 }
1088
GetEngine(const std::vector<TensorShape> & input_concrete_shapes,OpKernelContext * ctx,TRTEngineCacheResource * cache_res)1089 StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
1090 const std::vector<TensorShape>& input_concrete_shapes, OpKernelContext* ctx,
1091 TRTEngineCacheResource* cache_res) {
1092 static EngineContext empty_context;
1093 tensorflow::profiler::TraceMe activity(
1094 "TRTEngineOp::GetEngine", tensorflow::profiler::TraceMeLevel::kInfo);
1095 mutex_lock lock(engine_mutex_);
1096 // Using first input to get batch size is reliable - VerifyInputShapes()
1097 // guarantees that the first input is not a scalar. As such we can always use
1098 // the first input to get the batch size for implicit batch mode. For explicit
1099 // batch mode, this value is not used.
1100 const int batch_size = input_concrete_shapes[0].dim_size(0);
1101 // TODO(Tamas): remove the need for batch_size in explicit_batch mode
1102 auto& cache = cache_res->cache_;
1103 auto allocator = cache_res->allocator_.get();
1104 if (allocator == nullptr) {
1105 return std::pair<EngineContext*, int>(&empty_context, 0);
1106 }
1107
1108 // Handle the static engine case. For static engines, the cache will have a
1109 // single element containing the only engine.
1110 if (static_engine_) {
1111 if (cache.size()) {
1112 // TODO(laigd): need a better shape compatibility check for the case where
1113 // implicit batch is disabled.
1114 if (!use_implicit_batch_ ||
1115 AreShapesCompatible(input_concrete_shapes, cache.begin()->first)) {
1116 int profile_id = 0;
1117 if (!use_implicit_batch_)
1118 profile_id =
1119 cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1120 if (profile_id != -1) {
1121 return std::pair<EngineContext*, int>(cache.begin()->second.get(),
1122 profile_id);
1123 }
1124 }
1125 return std::pair<EngineContext*, int>(&empty_context, 0);
1126 }
1127
1128 TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
1129 infer->setGpuAllocator(allocator);
1130 // Need to initialize plugins in order to deserialize engines that contain
1131 // plugins.
1132 MaybeInitializeTrtPlugins(&logger);
1133 TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
1134 infer->deserializeCudaEngine(serialized_segment_.c_str(),
1135 serialized_segment_.size(), nullptr));
1136 int profile_id = 0;
1137 if (static_engine && !use_implicit_batch_) {
1138 // load profiles
1139 std::vector<ExecutionContext> exec_contexts;
1140 TF_RETURN_IF_ERROR(cache_res->profiles_.RestoreProfiles(
1141 static_engine.get(), ctx->num_inputs()));
1142 TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
1143 static_engine.get(), &exec_contexts));
1144 cache.emplace(input_concrete_shapes,
1145 std::make_unique<EngineContext>(std::move(static_engine),
1146 std::move(exec_contexts)));
1147 VLOG(1) << "Added new engine to cache of " << name()
1148 << ". Cache size: " << cache.size();
1149 // Query which profile of the new engine matches the actual input.
1150 profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1151 if (profile_id == -1) {
1152 return std::pair<EngineContext*, int>(&empty_context, 0);
1153 }
1154 EngineContext* engine_context = cache_res->GetEngineContext(profile_id);
1155 return std::pair<EngineContext*, int>(engine_context, profile_id);
1156 }
1157
1158 if (!static_engine) {
1159 if (!allow_build_at_runtime_) {
1160 // Store an empty engine in the cache so we don't try to load the same
1161 // failing engine again.
1162 cache.emplace(input_concrete_shapes, std::make_unique<EngineContext>());
1163 return std::pair<EngineContext*, int>(&empty_context, 0);
1164 }
1165 if (segment_graph_def_.node().empty()) {
1166 Status status = ImportSegmentGraphDef(ctx->function_library(),
1167 ctx->device()->name());
1168 if (!status.ok()) {
1169 LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
1170 << name() << " failed. "
1171 << "Reason: " << status;
1172 }
1173 }
1174 auto result = BuildEngine(input_concrete_shapes, batch_size,
1175 /*use_calibration=*/false,
1176 /*calibrator=*/nullptr, cache_res, ctx);
1177 if (!result.ok()) {
1178 return std::pair<EngineContext*, int>(&empty_context, 0);
1179 }
1180 static_engine = std::move(result.ValueOrDie());
1181 }
1182
1183 auto raw_static_engine = static_engine.get();
1184 std::vector<TensorShape> engine_input_shapes(input_concrete_shapes);
1185
1186 int max_batch_size = 1;
1187 if (use_implicit_batch_) {
1188 max_batch_size = raw_static_engine->getMaxBatchSize();
1189 // Static engine will have max_batch_size for batch size so that all
1190 // inputs will map to this single engine.
1191 for (int i = 0; i < engine_input_shapes.size(); i++) {
1192 engine_input_shapes[i].set_dim(0, max_batch_size);
1193 }
1194 }
1195
1196 ExecutionContext context = ExecutionContext::Create(raw_static_engine);
1197 // TODO(laigd): here we assume engine_input_shapes matches the actual input
1198 // shapes of the engine, we should verify that.
1199 cache.emplace(engine_input_shapes,
1200 std::make_unique<EngineContext>(std::move(static_engine),
1201 std::move(context)));
1202 // Runtime is safe to delete after engine creation
1203 VLOG(1) << "Size of serialized TRT engine: "
1204 << serialized_segment_.capacity();
1205 string tmp;
1206 // Swap with temporary empty string to deallocate the CPU memory.
1207 serialized_segment_.swap(tmp);
1208 if (use_implicit_batch_ && (max_batch_size < batch_size)) {
1209 return std::pair<EngineContext*, int>(&empty_context, 0);
1210 }
1211 return std::pair<EngineContext*, int>(cache.at(engine_input_shapes).get(),
1212 0);
1213 } // static_engine_
1214
1215 int profile_id = -1;
1216 if (!use_implicit_batch_) {
1217 profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1218 // Since all profiles are already created at this point, finding no
1219 // compatible profiles results in falling back to native TF.
1220 if (profile_id == -1) {
1221 return std::pair<EngineContext*, int>(&empty_context, 0);
1222 }
1223 }
1224
1225 EngineContext* engine_contexts;
1226 if (use_implicit_batch_) {
1227 engine_contexts = cache_res->GetEngineContext(input_concrete_shapes);
1228 } else {
1229 engine_contexts = cache_res->GetEngineContext(profile_id);
1230 }
1231
1232 // If cache does not have a compatible engine then create a new engine.
1233 if (engine_contexts == nullptr) {
1234 if (!allow_build_at_runtime_) {
1235 LOG_FIRST_FEW_WARNING_WITH_PREFIX
1236 << "Found no engine in cache matching input shapes. "
1237 << "Not building a new engine because "
1238 << "allow_build_at_runtime=False. "
1239 << "The native segment will be used instead.";
1240 // Store an empty engine in the cache for these input shapes so we don't
1241 // try to build the same failing engine again.
1242 cache.emplace(input_concrete_shapes, std::make_unique<EngineContext>());
1243 return std::pair<EngineContext*, int>(&empty_context, 0);
1244 }
1245
1246 // Up to this point, calibrator_ can never be empty, since otherwise it
1247 // means calibration_mode_ is true and this path won't get executed.
1248 auto result =
1249 BuildEngine(input_concrete_shapes, batch_size, use_calibration_,
1250 calibrator_.get(), cache_res, ctx);
1251 if (!result.ok()) {
1252 return std::pair<EngineContext*, int>(&empty_context, 0);
1253 }
1254 TrtUniquePtrType<nvinfer1::ICudaEngine> engine =
1255 std::move(result.ValueOrDie());
1256 std::vector<ExecutionContext> exec_contexts;
1257 TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
1258 engine.get(), &exec_contexts));
1259 cache.emplace(input_concrete_shapes,
1260 std::make_unique<EngineContext>(std::move(engine),
1261 std::move(exec_contexts)));
1262 VLOG(1) << "Added new engine to cache of " << name()
1263 << ". Cache size: " << cache.size();
1264 engine_contexts = cache.at(input_concrete_shapes).get();
1265 // Query which profile of the new engine matches the actual input.
1266 profile_id = cache_res->profiles_.GetProfileNumber(input_concrete_shapes);
1267 }
1268 return std::pair<EngineContext*, int>(engine_contexts,
1269 use_implicit_batch_ ? 0 : profile_id);
1270 }
1271
1272 // TODO(hinsu): Move this allocation to CalibrationContext constructor, if
1273 // possible.
AllocateCalibrationResources(OpKernelContext * ctx,TRTEngineCacheResource * cache_res)1274 Status TRTEngineOp::AllocateCalibrationResources(
1275 OpKernelContext* ctx, TRTEngineCacheResource* cache_res) {
1276 cache_res->calib_ctx_ = std::make_unique<CalibrationContext>();
1277 auto* cres = cache_res->calib_ctx_.get();
1278
1279 // Get the input shapes.
1280 /// TODO(lsugy): support INT8 calibration in non-frozen mode.
1281 const int batch_size = ctx->input(0).dim_size(0);
1282 const int num_inputs = ctx->num_inputs();
1283 std::vector<TensorShape> shapes;
1284 cres->device_tensors_.resize(num_inputs);
1285 VLOG(1) << "Constructing calibrator";
1286 for (int i = 0; i < num_inputs; i++) {
1287 // allocate workspace on device for inputs
1288 const Tensor& t = ctx->input(i);
1289 shapes.emplace_back(t.shape());
1290 TF_RETURN_IF_ERROR(
1291 ctx->allocate_temp(t.dtype(), t.shape(), &cres->device_tensors_.at(i)));
1292 CHECK_EQ(t.TotalBytes(), // Crash OK
1293 (cres->device_tensors_.at(i)).TotalBytes());
1294 void* device_address = GetTensorAddress(&cres->device_tensors_.at(i));
1295 if (device_address == nullptr) {
1296 return errors::InvalidArgument(
1297 "Unsupported data type encountered in input ", i);
1298 }
1299 cres->device_buffers_.emplace(
1300 StrCat(IONamePrefixes::kInputPHName, i),
1301 std::pair<void*, size_t>(device_address,
1302 cres->device_tensors_.at(i).TotalBytes()));
1303 }
1304 cres->calibrator_.reset(
1305 new TRTInt8Calibrator(cres->device_buffers_, batch_size, name()));
1306 const int platform_device_id =
1307 ctx->device()->tensorflow_accelerator_device_info()->gpu_id;
1308 if (platform_device_id < 0) {
1309 LOG(ERROR) << "Can't get gpu_device_info from context->device()";
1310 return errors::InvalidArgument(
1311 "Context->device doesn't contain device info!");
1312 }
1313
1314 bool use_concrete_shapes =
1315 use_implicit_batch_ || cache_res->profiles_.IsStaticCompatible();
1316 const std::vector<PartialTensorShape>& conversion_input_shapes =
1317 use_concrete_shapes
1318 ? std::vector<PartialTensorShape>(shapes.begin(), shapes.end())
1319 : input_partial_shapes_;
1320
1321 cache_res->Ref();
1322 string platform_device_name = ctx->device()->name();
1323 cres->thr_.reset(new std::thread([this, cres, shapes, conversion_input_shapes,
1324 platform_device_id, platform_device_name,
1325 cache_res, ctx]() {
1326 core::ScopedUnref sc(cache_res);
1327
1328 VLOG(1) << "Starting calibration thread on device " << platform_device_id
1329 << ", Calibration Resource @ " << cres;
1330 auto err = cudaSetDevice(platform_device_id);
1331 if (err != cudaSuccess) {
1332 // TODO(aaroey): should return error here.
1333 LOG(ERROR) << "Couldn't set cuda device to " << platform_device_id
1334 << " in calibration thread";
1335 }
1336
1337 std::unordered_map<string, tensorflow::DeviceProperties> device_map;
1338 DeviceNameUtils::ParsedName full_parsed_name;
1339 DeviceNameUtils::ParseFullName(platform_device_name, &full_parsed_name);
1340 device_map.emplace(platform_device_name,
1341 grappler::GetDeviceInfo(full_parsed_name));
1342 tensorflow::grappler::VirtualCluster cluster(device_map);
1343
1344 // ConvertGraphDefToEngine() will try to build the engine. This thread
1345 // will loop inside buildCudaEngine() consuming the calibration data
1346 // that is set by the TF op, and drive the builder until calibrator
1347 // returns false. Engine is discarded after calibration table is
1348 // generated
1349 //
1350 // TODO(aaroey): maybe setting the max batch size using the python
1351 // calibration wrapper class.
1352 auto s = convert::ConvertGraphDefToEngine(
1353 this->segment_graph_def_, ctx, TrtPrecisionMode::INT8,
1354 cres->calibrator_->getBatchSize(), this->workspace_size_,
1355 conversion_input_shapes, &cache_res->GetLogger(),
1356 cache_res->allocator_.get(), cres->calibrator_.get(), &cres->engine_,
1357 /*use_calibration=*/true, this->use_implicit_batch_,
1358 /*convert_successfully=*/nullptr,
1359 /*profiles=*/&cache_res->profiles_, name(),
1360 /*use_explicit_precision=*/use_explicit_precision_,
1361 /*cluster=*/&cluster);
1362 if (!s.ok()) {
1363 LOG(ERROR) << "Calibration failed: " << s;
1364 cres->calibrator_->setDone(); // Ignore further pushes
1365 cache_res->cache_.emplace(shapes, std::make_unique<EngineContext>());
1366 } else {
1367 // Transfer the ownership of the engine to the engine cache, so we can
1368 // dump it out during conversion for TF 2.0.
1369 mutex_lock lock(this->engine_mutex_);
1370 this->calibrator_ = std::move(cres->calibrator_);
1371 if (!use_implicit_batch_ &&
1372 (has_dynamic_shape_input_ || cache_res->profiles_.HasShapeTensor())) {
1373 std::vector<ExecutionContext> exec_contexts;
1374 auto calib_result = cache_res->profiles_.CreateExecutionContexts(
1375 cres->engine_.get(), &exec_contexts);
1376 cache_res->cache_.emplace(
1377 shapes, std::make_unique<EngineContext>(std::move(cres->engine_),
1378 std::move(exec_contexts)));
1379 } else {
1380 ExecutionContext context =
1381 ExecutionContext::Create(cres->engine_.get());
1382 cache_res->cache_.emplace(
1383 shapes, std::make_unique<EngineContext>(std::move(cres->engine_),
1384 std::move(context)));
1385 }
1386 }
1387
1388 VLOG(1) << "Calibration loop terminated " << this->name();
1389 }));
1390 VLOG(1) << "initialized calibrator resource";
1391 return Status::OK();
1392 }
1393
1394 REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
1395
1396 } // namespace tensorrt
1397 } // namespace tensorflow
1398
1399 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
1400