xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/trt_convert_api.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/trt_convert_api.h"
17 
18 #include <iostream>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/cc/tools/freeze_saved_model.h"
24 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/grappler/clusters/cluster.h"
30 #include "tensorflow/core/grappler/clusters/single_machine.h"
31 #include "tensorflow/core/grappler/clusters/utils.h"
32 #include "tensorflow/core/grappler/devices.h"
33 #include "tensorflow/core/grappler/grappler_item.h"
34 #include "tensorflow/core/grappler/grappler_item_builder.h"
35 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/protobuf/config.pb.h"
40 #include "tensorflow/core/protobuf/meta_graph.pb.h"
41 #include "tensorflow/core/public/session.h"
42 
43 #if GOOGLE_CUDA && GOOGLE_TENSORRT
44 
45 namespace tensorflow {
46 
47 namespace tensorrt {
48 namespace {
49 
50 // Creates and provisions a new cluster. The caller must call Shutdown before
51 // the cluster is destroyed.
NewCluster(grappler::Cluster ** cluster)52 Status NewCluster(grappler::Cluster** cluster) {
53   int num_cpu_cores = grappler::GetNumAvailableLogicalCPUCores();
54   int num_gpus = grappler::GetNumAvailableGPUs();
55   int timeout_s = 60 * 10;
56   *cluster = new grappler::SingleMachine(timeout_s, num_cpu_cores, num_gpus);
57   (*cluster)->DisableDetailedStats(true);
58   (*cluster)->AllowSoftPlacement(true);
59   (*cluster)->SetNumWarmupSteps(10);
60   TF_RETURN_IF_ERROR((*cluster)->Provision());
61   return Status::OK();
62 }
63 
RunGrappler(const MetaGraphDef & meta_graph_def,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const ConfigProto & config_proto,grappler::Cluster * cluster,GraphDef * out_graph_def)64 Status RunGrappler(const MetaGraphDef& meta_graph_def,
65                    const std::vector<std::string>& input_names,
66                    const std::vector<std::string>& output_names,
67                    const ConfigProto& config_proto, grappler::Cluster* cluster,
68                    GraphDef* out_graph_def) {
69   grappler::ItemConfig item_config;
70 
71   for (const string& name : input_names) {
72     item_config.feed_nodes.insert(name);
73   }
74   for (const string& name : output_names) {
75     item_config.fetch_nodes.insert(name);
76   }
77 
78   std::unique_ptr<grappler::GrapplerItem> item =
79       grappler::GrapplerItemFromMetaGraphDef("tf_graph", meta_graph_def,
80                                              item_config);
81   if (!item) {
82     return tensorflow::errors::Internal(
83         "Failed to create grappler item from MetaGraphDef.");
84   }
85 
86   tensorflow::DeviceBase* cpu_device = nullptr;
87   TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
88       std::move(*item), config_proto, cpu_device, cluster, out_graph_def));
89   VLOG(2) << "Grappler finished\n";
90   return Status::OK();
91 }
92 
ImportGraphDefToSession(Session * session,const GraphDef & graph_def,const string & prefix)93 Status ImportGraphDefToSession(Session* session, const GraphDef& graph_def,
94                                const string& prefix) {
95   ImportGraphDefOptions opts;
96   opts.prefix = prefix;
97   Graph graph(OpRegistry::Global());
98   TF_RETURN_IF_ERROR(ImportGraphDef(opts, graph_def, &graph, nullptr));
99   GraphDef new_graph_def;
100   graph.ToGraphDef(&new_graph_def);
101   TF_RETURN_IF_ERROR(session->Extend(new_graph_def));
102   return Status::OK();
103 }
104 
GetTrtRewriterConfig(const TfTrtConversionParams & params,const GraphDef & frozen_graph_def,RewriterConfig * opt_config)105 Status GetTrtRewriterConfig(const TfTrtConversionParams& params,
106                             const GraphDef& frozen_graph_def,
107                             RewriterConfig* opt_config) {
108   opt_config->set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE);
109   opt_config->set_min_graph_nodes(-1);  // do not skip small graphs
110 
111   // Turn off remapping.
112   opt_config->set_remapping(RewriterConfig_Toggle::RewriterConfig_Toggle_OFF);
113 
114   // If the graph has QDQ nodes, then we need to disable folding of the
115   // QDQ with constants. Otherwise, the conversion will not work corectly.
116   // Ideally, we do this after segmentation and outlining of TRT regions to
117   // functions, but we currently lack that capability. Disabling QDQ-const
118   // folding doesn't matter if you don't have QDQ nodes, so we always enable
119   // this.
120   opt_config->set_experimental_disable_folding_quantization_emulation(
121       IS_TRT_VERSION_GE(8, 0, 0, 0));
122 
123   // Initial transformations before TensorRTOptimizer is called
124   opt_config->add_optimizers("function");
125   opt_config->add_optimizers("constfold");
126   opt_config->add_optimizers("layout");
127   opt_config->add_optimizers("constfold");
128 
129   // Parameters for TensorRTOptimizer
130   auto trt_optimizer = opt_config->add_custom_optimizers();
131   trt_optimizer->set_name("TensorRTOptimizer");
132 
133   auto trt_parameter_map = trt_optimizer->mutable_parameter_map();
134   (*trt_parameter_map)["is_dynamic_op"].set_b(true);
135   (*trt_parameter_map)["minimum_segment_size"].set_i(
136       params.minimum_segment_size);
137   string prec_string;
138   TF_RETURN_IF_ERROR(
139       TrtPrecisionModeToName(params.precision_mode, &prec_string));
140   (*trt_parameter_map)["precision_mode"].set_s(prec_string);
141   (*trt_parameter_map)["max_batch_size"].set_i(1);
142   (*trt_parameter_map)["max_workspace_size_bytes"].set_i(
143       params.max_workspace_size_bytes);
144   (*trt_parameter_map)["max_cached_engines"].set_i(params.max_cached_engines);
145   (*trt_parameter_map)["use_calibration"].set_b(params.use_calibration);
146   (*trt_parameter_map)["profile_strategy"].set_s(
147       ProfileStrategyToName(params.profile_strategy));
148   (*trt_parameter_map)["use_implicit_batch"].set_b(!params.use_dynamic_shape);
149   (*trt_parameter_map)["_allow_build_at_runtime"].set_b(
150       params.allow_build_at_runtime);
151   return Status::OK();
152 }
153 
154 // Runs TRTOptimizer grappler pass.
RunTfTrt(const MetaGraphDef & meta_graph_def,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const RewriterConfig & rewriter_config,GraphDef * segmented_graph_def)155 Status RunTfTrt(const MetaGraphDef& meta_graph_def,
156                 const std::vector<std::string>& input_names,
157                 const std::vector<std::string>& output_names,
158                 const RewriterConfig& rewriter_config,
159                 GraphDef* segmented_graph_def) {
160   ConfigProto config_proto;
161   config_proto.mutable_graph_options()->mutable_rewrite_options()->CopyFrom(
162       rewriter_config);
163 
164   VLOG(4) << "Setting up Grappler parameters\n" << config_proto.DebugString();
165   std::unique_ptr<grappler::Cluster> cluster;
166   grappler::Cluster* p_cluster;
167   mutex mu_cluster;  // There can be only one provisioned cluster per process.
168   mutex_lock lock(mu_cluster);
169   TF_RETURN_IF_ERROR(NewCluster(&p_cluster));
170   cluster.reset(p_cluster);
171   TF_RETURN_IF_ERROR(RunGrappler(meta_graph_def, input_names, output_names,
172                                  config_proto, cluster.get(),
173                                  segmented_graph_def));
174   TF_RETURN_IF_ERROR(cluster->Shutdown());
175   return Status::OK();
176 }
177 
178 // Sets the _profile_generation mode attribute of all TRTEngineOp nodes in the
179 // graph to mode.
SetProfileGenerationMode(GraphDef * graph_def,bool mode)180 Status SetProfileGenerationMode(GraphDef* graph_def, bool mode) {
181   VLOG(3) << "Setting _profile_generation_mode=" << mode;
182   std::string op{"TRTEngineOp"};
183   for (auto& node : *(graph_def->mutable_node())) {
184     if (!op.compare(node.op())) {
185       auto* attr = node.mutable_attr();
186       AttrValue profile_generation_mode;
187       profile_generation_mode.set_b(mode);
188       (*attr)["_profile_generation_mode"] = profile_generation_mode;
189     }
190   }
191   return Status::OK();
192 }
193 
RunSession(Session * session,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const std::vector<Tensor> & input_tensors,string prefix="")194 Status RunSession(Session* session, const std::vector<std::string>& input_names,
195                   const std::vector<std::string>& output_names,
196                   const std::vector<Tensor>& input_tensors,
197                   string prefix = "") {
198   TRT_ENSURE(!input_names.empty());
199   TRT_ENSURE(!output_names.empty());
200   TRT_ENSURE(!input_tensors.empty());
201 
202   std::vector<std::pair<std::string, tensorflow::Tensor>> input_pairs;
203   std::vector<std::string> prefixed_output_names;
204   auto prefixed_name = [](std::string prefix, std::string name) {
205     return prefix.size() > 0 ? absl::StrJoin({prefix, name}, "/") : name;
206   };
207   for (int i = 0; i < input_names.size(); i++) {
208     input_pairs.push_back(
209         {prefixed_name(prefix, input_names.at(i)), input_tensors.at(i)});
210   }
211   for (int i = 0; i < output_names.size(); i++) {
212     prefixed_output_names.push_back(prefixed_name(prefix, output_names.at(i)));
213   }
214   std::vector<tensorflow::Tensor> output_tensors;
215   for (int i = 0; i < output_names.size(); i++) {
216     output_tensors.push_back({});
217   }
218   VLOG(3) << "TF-TRT Build mode: running inference\n";
219   TF_RETURN_IF_ERROR(
220       session->Run(input_pairs, prefixed_output_names, {}, &output_tensors));
221   return Status::OK();
222 }
223 
224 // Runs the model to create the engines. In dynamic shape mode, before creating
225 // the engines, we provide shapes to define optimization profiles.
Build(GraphDef & segmented_graph_def,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const std::vector<std::vector<tensorflow::Tensor>> & inputs,Session * session,const TfTrtConversionParams params)226 Status Build(GraphDef& segmented_graph_def,
227              const std::vector<std::string>& input_names,
228              const std::vector<std::string>& output_names,
229              const std::vector<std::vector<tensorflow::Tensor>>& inputs,
230              Session* session, const TfTrtConversionParams params) {
231   VLOG(2) << "Building the model";
232   bool need_collect_profiles = params.use_dynamic_shape && inputs.size() > 1;
233   if (need_collect_profiles) {
234     TF_RETURN_IF_ERROR(SetProfileGenerationMode(&segmented_graph_def, true));
235   }
236   TF_RETURN_IF_ERROR(session->Create(segmented_graph_def));
237   string prefix = "";
238   if (need_collect_profiles) {
239     for (auto const& input : inputs) {
240       TF_RETURN_IF_ERROR(RunSession(session, input_names, output_names, input));
241     }
242     prefix = "TrtBuildStep";
243     TF_RETURN_IF_ERROR(SetProfileGenerationMode(&segmented_graph_def, false));
244     VLOG(3) << "Importing graph with _profile_generation_mode disabled";
245     TF_RETURN_IF_ERROR(
246         ImportGraphDefToSession(session, segmented_graph_def, prefix));
247   }
248   TF_RETURN_IF_ERROR(
249       RunSession(session, input_names, output_names, *inputs.begin(), prefix));
250   return Status::OK();
251 }
252 
253 // Returns the resource manager associated with the node.
GetResourceManager(const NodeDef & node,Session * session,ResourceMgr ** rm)254 Status GetResourceManager(const NodeDef& node, Session* session,
255                           ResourceMgr** rm) {
256   const DeviceMgr* device_mgr;
257   TF_RETURN_IF_ERROR(session->LocalDeviceManager(&device_mgr));
258   Device* device;
259   string device_name = node.device().empty()
260                            ? "/job:localhost/replica:0/task:0/device:GPU:0"
261                            : node.device();
262   TF_RETURN_IF_ERROR(device_mgr->LookupDevice(device_name, &device));
263   *rm = device->resource_manager();
264   return Status::OK();
265 }
266 
267 // Looks up the cache resurce associated with the TRT node.
GetEngineCacheResource(const NodeDef & node,Session * session,TRTEngineCacheResource ** resource)268 Status GetEngineCacheResource(const NodeDef& node, Session* session,
269                               TRTEngineCacheResource** resource) {
270   ResourceMgr* rm;
271   TF_RETURN_IF_ERROR(GetResourceManager(node, session, &rm));
272 
273   absl::string_view resource_name = node.name();
274   size_t last_slash = resource_name.find_last_of('/');
275   if (last_slash != absl::string_view::npos) {
276     resource_name.remove_prefix(last_slash + 1);
277   }
278   const std::string container(kTfTrtContainerName);
279   *resource = nullptr;
280   TF_RETURN_IF_ERROR(
281       rm->Lookup(container, std::string(resource_name), resource));
282   if (resource == nullptr || (*resource)->cache_.size() == 0) {
283     return errors::Internal("Engine cache not found for", resource_name);
284   }
285   return Status::OK();
286 }
287 
288 // Looks up the engine from the engine cache, and serializes the engine.
ReadSerializedEngine(const NodeDef & node,Session * session,TrtUniquePtrType<nvinfer1::IHostMemory> * engine_data)289 Status ReadSerializedEngine(
290     const NodeDef& node, Session* session,
291     TrtUniquePtrType<nvinfer1::IHostMemory>* engine_data) {
292   TRTEngineCacheResource* resource;
293   TF_RETURN_IF_ERROR(GetEngineCacheResource(node, session, &resource));
294   core::ScopedUnref unref_cache_res(resource);
295   if (resource->cache_.size() > 1) {
296     return errors::Internal(
297         "Multiple engines found, but we can only serialize one");
298   }
299   const std::unique_ptr<EngineContext>& engine =
300       resource->cache_.begin()->second;
301   if (!engine) {
302     return errors::Internal("Engine not found for", node.name());
303   }
304 
305   if (engine->GetCudaEngine()) {
306     // Serialize the engine.
307     engine_data->reset(engine->GetCudaEngine()->serialize());
308   } else {
309     LOG(WARNING) << "Engine cache contains nullptr";
310   }
311 
312   return Status::OK();
313 }
314 
315 // Saves the TRT engines as attributes of the TRTEngineOp nodes.
ConvertToStaticEngine(const GraphDef graph_def,GraphDef * static_graph_def,Session * session)316 Status ConvertToStaticEngine(const GraphDef graph_def,
317                              GraphDef* static_graph_def, Session* session) {
318   static_graph_def->CopyFrom(graph_def);
319   VLOG(1) << "Saving TRT engines as static engine";
320   std::string op{"TRTEngineOp"};
321   for (auto& node : *(static_graph_def->mutable_node())) {
322     if (!op.compare(node.op())) {
323       VLOG(2) << "Saving TRT engine for " << node.name()
324               << ", device: " << node.device();
325       TrtUniquePtrType<nvinfer1::IHostMemory> engine_data;
326       TF_RETURN_IF_ERROR(ReadSerializedEngine(node, session, &engine_data));
327       auto* attr = node.mutable_attr();
328       AttrValue static_engine;
329       static_engine.set_b(true);
330       AttrValue engine_string;
331       if (engine_data) {
332         engine_string.set_s(engine_data->data(), engine_data->size());
333       }
334       (*attr)["static_engine"] = static_engine;
335       (*attr)["serialized_segment"] = engine_string;
336     }
337   }
338   return Status::OK();
339 }
340 
ValidateConversionParams(const TfTrtConversionParams & p,int n_inputs)341 Status ValidateConversionParams(const TfTrtConversionParams& p, int n_inputs) {
342   if (p.precision_mode == TrtPrecisionMode::INT8 && p.use_calibration) {
343     return errors::InvalidArgument(
344         "Calibration not yet implemented through the C++ interface. Please use "
345         "our Python API for calibration.");
346   }
347   if (p.convert_to_static_engine && n_inputs == 0) {
348     return errors::InvalidArgument(
349         "TRT Engine needs to be built before we can convert it to static "
350         "engine. Please provide input data to build the model.");
351   }
352   if (!p.convert_to_static_engine && n_inputs >= 0) {
353     // After the conversion, the session that was used to build the engines
354     // will be destroyed. If we do not convert the engine to static engine,
355     // then we loose the engines.
356     //
357     // TODO(tfeher): Provide a way to save dynamic engines and remove this
358     // warning.
359     LOG(WARNING)
360         << "Skipping build mode because we cannot save the "
361            "engines. Use convert_to_static_engines=true conversion "
362            "parameter to enable build mode and save the engines in the graph.";
363   }
364   if (!p.allow_build_at_runtime && n_inputs == 0) {
365     LOG(WARNING)
366         << "TRT will not be used since allow_build_at_runtime is disabled and "
367            "no inputs are provided to build during conversion.";
368   }
369   return Status::OK();
370 }
371 
372 // Returns configuration used during the build step session run.
GetSessionConfg()373 tensorflow::SessionOptions GetSessionConfg() {
374   // We also need to disable constant folding because we already ran constant
375   // folding and may have prevented quantization operation folding on purpose.
376   tensorflow::SessionOptions opts;
377   auto* rewriter_opts =
378       opts.config.mutable_graph_options()->mutable_rewrite_options();
379   rewriter_opts->set_experimental_disable_folding_quantization_emulation(true);
380 
381   // It seems  that we need to disable the optimizer entirely to prevent the
382   // folding.
383   rewriter_opts->set_disable_meta_optimizer(true);
384   return opts;
385 }
386 
387 }  // namespace
388 
ConvertAndBuild(const GraphDef & frozen_graph_def,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<std::vector<tensorflow::Tensor>> & inputs,const TfTrtConversionParams & conv_params)389 StatusOr<GraphDef> ConvertAndBuild(
390     const GraphDef& frozen_graph_def, const std::vector<string>& input_names,
391     const std::vector<string>& output_names,
392     const std::vector<std::vector<tensorflow::Tensor>>& inputs,
393     const TfTrtConversionParams& conv_params) {
394   TF_RETURN_IF_ERROR(ValidateConversionParams(conv_params, inputs.size()));
395   MetaGraphDef meta_graph;
396   meta_graph.mutable_graph_def()->CopyFrom(frozen_graph_def);
397 
398   RewriterConfig rewriter_config;
399   TF_RETURN_IF_ERROR(
400       GetTrtRewriterConfig(conv_params, frozen_graph_def, &rewriter_config));
401 
402   GraphDef segmented_graph_def;
403   TF_RETURN_IF_ERROR(RunTfTrt(meta_graph, input_names, output_names,
404                               rewriter_config, &segmented_graph_def));
405 
406   GraphDef output;
407 
408   if (inputs.size() > 0 && conv_params.convert_to_static_engine) {
409     // The TRTOptimization pass has inserted placeholder TRTEngineOps. Here we
410     // trigger conversion by inferring the graph.
411     std::unique_ptr<tensorflow::Session> session(
412         tensorflow::NewSession(GetSessionConfg()));
413     if (!session.get()) {
414       return errors::Internal("Failed to create build session");
415     }
416 
417     TF_RETURN_IF_ERROR(Build(segmented_graph_def, input_names, output_names,
418                              inputs, session.get(), conv_params));
419 
420     TF_RETURN_IF_ERROR(
421         ConvertToStaticEngine(segmented_graph_def, &output, session.get()));
422   } else {
423     output.CopyFrom(segmented_graph_def);
424   }
425   VLOG(1) << "TF-TRT conversion finished";
426   return output;
427 }
428 
InlineFunctions(const MetaGraphDef & meta_graph_def,GraphDef * out_graph_def)429 Status InlineFunctions(const MetaGraphDef& meta_graph_def,
430                        GraphDef* out_graph_def) {
431   ConfigProto config_proto;
432   auto opt_config =
433       config_proto.mutable_graph_options()->mutable_rewrite_options();
434 
435   opt_config->set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE);
436   opt_config->set_min_graph_nodes(-1);  // do not skip small graphs
437   opt_config->add_optimizers("function");
438 
439   TF_RETURN_IF_ERROR(RunGrappler(meta_graph_def, {}, {}, config_proto, nullptr,
440                                  out_graph_def));
441 
442   VLOG(2) << "Graph is inlined";
443   return Status::OK();
444 }
445 
446 // Freezes the graph. It is assumed that the functions are inlined and the
447 // variables are initialized.
FreezeGraph(SavedModelBundle & bundle,MetaGraphDef * frozen_meta_graph)448 Status FreezeGraph(SavedModelBundle& bundle, MetaGraphDef* frozen_meta_graph) {
449   std::unordered_set<std::string> inputs;
450   std::unordered_set<std::string> outputs;
451   GraphDef frozen_graph_def;
452   TF_RETURN_IF_ERROR(
453       FreezeSavedModel(bundle, &frozen_graph_def, &inputs, &outputs));
454 
455   frozen_meta_graph->CopyFrom(bundle.meta_graph_def);
456   GraphDef* gdef = frozen_meta_graph->mutable_graph_def();
457   gdef->CopyFrom(frozen_graph_def);
458 
459   VLOG(2) << "Graph frozen";
460   return Status::OK();
461 }
462 
463 // Returns the name of nodes listed in the signature definition.
GetNodeNames(const google::protobuf::Map<std::string,tensorflow::TensorInfo> & signature)464 std::vector<std::string> GetNodeNames(
465     const google::protobuf::Map<std::string, tensorflow::TensorInfo>& signature) {
466   std::vector<std::string> names;
467   for (auto const& item : signature) {
468     absl::string_view name = item.second.name();
469     // Remove tensor suffix like ":0".
470     size_t last_colon = name.find_last_of(':');
471     if (last_colon != absl::string_view::npos) {
472       name.remove_suffix(name.size() - last_colon);
473     }
474     names.push_back(std::string(name));
475   }
476   return names;
477 }
478 
ConvertAndBuild(SavedModelBundle * bundle,const std::string & signature_key,const std::vector<std::vector<tensorflow::Tensor>> & inputs,const TfTrtConversionParams & conversion_params)479 StatusOr<GraphDef> ConvertAndBuild(
480     SavedModelBundle* bundle, const std::string& signature_key,
481     const std::vector<std::vector<tensorflow::Tensor>>& inputs,
482     const TfTrtConversionParams& conversion_params) {
483   // Inline the functions.
484   GraphDef inlined_graph_def;
485   TF_RETURN_IF_ERROR(
486       InlineFunctions(bundle->meta_graph_def, &inlined_graph_def));
487 
488   // Replace the graph_def with the inlined graph. Note that bundle->session
489   // still has the original graph.
490   bundle->meta_graph_def.mutable_graph_def()->CopyFrom(inlined_graph_def);
491 
492   // Freeze variables.
493   MetaGraphDef frozen_meta_graph;
494   TF_RETURN_IF_ERROR(FreezeGraph(*bundle, &frozen_meta_graph));
495 
496   // Convert.
497   auto signature_map = bundle->GetSignatures();
498   const tensorflow::SignatureDef& signature = signature_map[signature_key];
499   std::vector<std::string> input_names = GetNodeNames(signature.inputs());
500   std::vector<std::string> output_names = GetNodeNames(signature.outputs());
501   return ConvertAndBuild(frozen_meta_graph.graph_def(), input_names,
502                          output_names, inputs, conversion_params);
503 }
504 
505 }  // namespace tensorrt
506 }  // namespace tensorflow
507 
508 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
509