1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/trt_convert_api.h"
17
18 #include <iostream>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/cc/tools/freeze_saved_model.h"
24 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/grappler/clusters/cluster.h"
30 #include "tensorflow/core/grappler/clusters/single_machine.h"
31 #include "tensorflow/core/grappler/clusters/utils.h"
32 #include "tensorflow/core/grappler/devices.h"
33 #include "tensorflow/core/grappler/grappler_item.h"
34 #include "tensorflow/core/grappler/grappler_item_builder.h"
35 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/protobuf/config.pb.h"
40 #include "tensorflow/core/protobuf/meta_graph.pb.h"
41 #include "tensorflow/core/public/session.h"
42
43 #if GOOGLE_CUDA && GOOGLE_TENSORRT
44
45 namespace tensorflow {
46
47 namespace tensorrt {
48 namespace {
49
50 // Creates and provisions a new cluster. The caller must call Shutdown before
51 // the cluster is destroyed.
NewCluster(grappler::Cluster ** cluster)52 Status NewCluster(grappler::Cluster** cluster) {
53 int num_cpu_cores = grappler::GetNumAvailableLogicalCPUCores();
54 int num_gpus = grappler::GetNumAvailableGPUs();
55 int timeout_s = 60 * 10;
56 *cluster = new grappler::SingleMachine(timeout_s, num_cpu_cores, num_gpus);
57 (*cluster)->DisableDetailedStats(true);
58 (*cluster)->AllowSoftPlacement(true);
59 (*cluster)->SetNumWarmupSteps(10);
60 TF_RETURN_IF_ERROR((*cluster)->Provision());
61 return Status::OK();
62 }
63
RunGrappler(const MetaGraphDef & meta_graph_def,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const ConfigProto & config_proto,grappler::Cluster * cluster,GraphDef * out_graph_def)64 Status RunGrappler(const MetaGraphDef& meta_graph_def,
65 const std::vector<std::string>& input_names,
66 const std::vector<std::string>& output_names,
67 const ConfigProto& config_proto, grappler::Cluster* cluster,
68 GraphDef* out_graph_def) {
69 grappler::ItemConfig item_config;
70
71 for (const string& name : input_names) {
72 item_config.feed_nodes.insert(name);
73 }
74 for (const string& name : output_names) {
75 item_config.fetch_nodes.insert(name);
76 }
77
78 std::unique_ptr<grappler::GrapplerItem> item =
79 grappler::GrapplerItemFromMetaGraphDef("tf_graph", meta_graph_def,
80 item_config);
81 if (!item) {
82 return tensorflow::errors::Internal(
83 "Failed to create grappler item from MetaGraphDef.");
84 }
85
86 tensorflow::DeviceBase* cpu_device = nullptr;
87 TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
88 std::move(*item), config_proto, cpu_device, cluster, out_graph_def));
89 VLOG(2) << "Grappler finished\n";
90 return Status::OK();
91 }
92
ImportGraphDefToSession(Session * session,const GraphDef & graph_def,const string & prefix)93 Status ImportGraphDefToSession(Session* session, const GraphDef& graph_def,
94 const string& prefix) {
95 ImportGraphDefOptions opts;
96 opts.prefix = prefix;
97 Graph graph(OpRegistry::Global());
98 TF_RETURN_IF_ERROR(ImportGraphDef(opts, graph_def, &graph, nullptr));
99 GraphDef new_graph_def;
100 graph.ToGraphDef(&new_graph_def);
101 TF_RETURN_IF_ERROR(session->Extend(new_graph_def));
102 return Status::OK();
103 }
104
GetTrtRewriterConfig(const TfTrtConversionParams & params,const GraphDef & frozen_graph_def,RewriterConfig * opt_config)105 Status GetTrtRewriterConfig(const TfTrtConversionParams& params,
106 const GraphDef& frozen_graph_def,
107 RewriterConfig* opt_config) {
108 opt_config->set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE);
109 opt_config->set_min_graph_nodes(-1); // do not skip small graphs
110
111 // Turn off remapping.
112 opt_config->set_remapping(RewriterConfig_Toggle::RewriterConfig_Toggle_OFF);
113
114 // If the graph has QDQ nodes, then we need to disable folding of the
115 // QDQ with constants. Otherwise, the conversion will not work corectly.
116 // Ideally, we do this after segmentation and outlining of TRT regions to
117 // functions, but we currently lack that capability. Disabling QDQ-const
118 // folding doesn't matter if you don't have QDQ nodes, so we always enable
119 // this.
120 opt_config->set_experimental_disable_folding_quantization_emulation(
121 IS_TRT_VERSION_GE(8, 0, 0, 0));
122
123 // Initial transformations before TensorRTOptimizer is called
124 opt_config->add_optimizers("function");
125 opt_config->add_optimizers("constfold");
126 opt_config->add_optimizers("layout");
127 opt_config->add_optimizers("constfold");
128
129 // Parameters for TensorRTOptimizer
130 auto trt_optimizer = opt_config->add_custom_optimizers();
131 trt_optimizer->set_name("TensorRTOptimizer");
132
133 auto trt_parameter_map = trt_optimizer->mutable_parameter_map();
134 (*trt_parameter_map)["is_dynamic_op"].set_b(true);
135 (*trt_parameter_map)["minimum_segment_size"].set_i(
136 params.minimum_segment_size);
137 string prec_string;
138 TF_RETURN_IF_ERROR(
139 TrtPrecisionModeToName(params.precision_mode, &prec_string));
140 (*trt_parameter_map)["precision_mode"].set_s(prec_string);
141 (*trt_parameter_map)["max_batch_size"].set_i(1);
142 (*trt_parameter_map)["max_workspace_size_bytes"].set_i(
143 params.max_workspace_size_bytes);
144 (*trt_parameter_map)["max_cached_engines"].set_i(params.max_cached_engines);
145 (*trt_parameter_map)["use_calibration"].set_b(params.use_calibration);
146 (*trt_parameter_map)["profile_strategy"].set_s(
147 ProfileStrategyToName(params.profile_strategy));
148 (*trt_parameter_map)["use_implicit_batch"].set_b(!params.use_dynamic_shape);
149 (*trt_parameter_map)["_allow_build_at_runtime"].set_b(
150 params.allow_build_at_runtime);
151 return Status::OK();
152 }
153
154 // Runs TRTOptimizer grappler pass.
RunTfTrt(const MetaGraphDef & meta_graph_def,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const RewriterConfig & rewriter_config,GraphDef * segmented_graph_def)155 Status RunTfTrt(const MetaGraphDef& meta_graph_def,
156 const std::vector<std::string>& input_names,
157 const std::vector<std::string>& output_names,
158 const RewriterConfig& rewriter_config,
159 GraphDef* segmented_graph_def) {
160 ConfigProto config_proto;
161 config_proto.mutable_graph_options()->mutable_rewrite_options()->CopyFrom(
162 rewriter_config);
163
164 VLOG(4) << "Setting up Grappler parameters\n" << config_proto.DebugString();
165 std::unique_ptr<grappler::Cluster> cluster;
166 grappler::Cluster* p_cluster;
167 mutex mu_cluster; // There can be only one provisioned cluster per process.
168 mutex_lock lock(mu_cluster);
169 TF_RETURN_IF_ERROR(NewCluster(&p_cluster));
170 cluster.reset(p_cluster);
171 TF_RETURN_IF_ERROR(RunGrappler(meta_graph_def, input_names, output_names,
172 config_proto, cluster.get(),
173 segmented_graph_def));
174 TF_RETURN_IF_ERROR(cluster->Shutdown());
175 return Status::OK();
176 }
177
178 // Sets the _profile_generation mode attribute of all TRTEngineOp nodes in the
179 // graph to mode.
SetProfileGenerationMode(GraphDef * graph_def,bool mode)180 Status SetProfileGenerationMode(GraphDef* graph_def, bool mode) {
181 VLOG(3) << "Setting _profile_generation_mode=" << mode;
182 std::string op{"TRTEngineOp"};
183 for (auto& node : *(graph_def->mutable_node())) {
184 if (!op.compare(node.op())) {
185 auto* attr = node.mutable_attr();
186 AttrValue profile_generation_mode;
187 profile_generation_mode.set_b(mode);
188 (*attr)["_profile_generation_mode"] = profile_generation_mode;
189 }
190 }
191 return Status::OK();
192 }
193
RunSession(Session * session,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const std::vector<Tensor> & input_tensors,string prefix="")194 Status RunSession(Session* session, const std::vector<std::string>& input_names,
195 const std::vector<std::string>& output_names,
196 const std::vector<Tensor>& input_tensors,
197 string prefix = "") {
198 TRT_ENSURE(!input_names.empty());
199 TRT_ENSURE(!output_names.empty());
200 TRT_ENSURE(!input_tensors.empty());
201
202 std::vector<std::pair<std::string, tensorflow::Tensor>> input_pairs;
203 std::vector<std::string> prefixed_output_names;
204 auto prefixed_name = [](std::string prefix, std::string name) {
205 return prefix.size() > 0 ? absl::StrJoin({prefix, name}, "/") : name;
206 };
207 for (int i = 0; i < input_names.size(); i++) {
208 input_pairs.push_back(
209 {prefixed_name(prefix, input_names.at(i)), input_tensors.at(i)});
210 }
211 for (int i = 0; i < output_names.size(); i++) {
212 prefixed_output_names.push_back(prefixed_name(prefix, output_names.at(i)));
213 }
214 std::vector<tensorflow::Tensor> output_tensors;
215 for (int i = 0; i < output_names.size(); i++) {
216 output_tensors.push_back({});
217 }
218 VLOG(3) << "TF-TRT Build mode: running inference\n";
219 TF_RETURN_IF_ERROR(
220 session->Run(input_pairs, prefixed_output_names, {}, &output_tensors));
221 return Status::OK();
222 }
223
224 // Runs the model to create the engines. In dynamic shape mode, before creating
225 // the engines, we provide shapes to define optimization profiles.
Build(GraphDef & segmented_graph_def,const std::vector<std::string> & input_names,const std::vector<std::string> & output_names,const std::vector<std::vector<tensorflow::Tensor>> & inputs,Session * session,const TfTrtConversionParams params)226 Status Build(GraphDef& segmented_graph_def,
227 const std::vector<std::string>& input_names,
228 const std::vector<std::string>& output_names,
229 const std::vector<std::vector<tensorflow::Tensor>>& inputs,
230 Session* session, const TfTrtConversionParams params) {
231 VLOG(2) << "Building the model";
232 bool need_collect_profiles = params.use_dynamic_shape && inputs.size() > 1;
233 if (need_collect_profiles) {
234 TF_RETURN_IF_ERROR(SetProfileGenerationMode(&segmented_graph_def, true));
235 }
236 TF_RETURN_IF_ERROR(session->Create(segmented_graph_def));
237 string prefix = "";
238 if (need_collect_profiles) {
239 for (auto const& input : inputs) {
240 TF_RETURN_IF_ERROR(RunSession(session, input_names, output_names, input));
241 }
242 prefix = "TrtBuildStep";
243 TF_RETURN_IF_ERROR(SetProfileGenerationMode(&segmented_graph_def, false));
244 VLOG(3) << "Importing graph with _profile_generation_mode disabled";
245 TF_RETURN_IF_ERROR(
246 ImportGraphDefToSession(session, segmented_graph_def, prefix));
247 }
248 TF_RETURN_IF_ERROR(
249 RunSession(session, input_names, output_names, *inputs.begin(), prefix));
250 return Status::OK();
251 }
252
253 // Returns the resource manager associated with the node.
GetResourceManager(const NodeDef & node,Session * session,ResourceMgr ** rm)254 Status GetResourceManager(const NodeDef& node, Session* session,
255 ResourceMgr** rm) {
256 const DeviceMgr* device_mgr;
257 TF_RETURN_IF_ERROR(session->LocalDeviceManager(&device_mgr));
258 Device* device;
259 string device_name = node.device().empty()
260 ? "/job:localhost/replica:0/task:0/device:GPU:0"
261 : node.device();
262 TF_RETURN_IF_ERROR(device_mgr->LookupDevice(device_name, &device));
263 *rm = device->resource_manager();
264 return Status::OK();
265 }
266
267 // Looks up the cache resurce associated with the TRT node.
GetEngineCacheResource(const NodeDef & node,Session * session,TRTEngineCacheResource ** resource)268 Status GetEngineCacheResource(const NodeDef& node, Session* session,
269 TRTEngineCacheResource** resource) {
270 ResourceMgr* rm;
271 TF_RETURN_IF_ERROR(GetResourceManager(node, session, &rm));
272
273 absl::string_view resource_name = node.name();
274 size_t last_slash = resource_name.find_last_of('/');
275 if (last_slash != absl::string_view::npos) {
276 resource_name.remove_prefix(last_slash + 1);
277 }
278 const std::string container(kTfTrtContainerName);
279 *resource = nullptr;
280 TF_RETURN_IF_ERROR(
281 rm->Lookup(container, std::string(resource_name), resource));
282 if (resource == nullptr || (*resource)->cache_.size() == 0) {
283 return errors::Internal("Engine cache not found for", resource_name);
284 }
285 return Status::OK();
286 }
287
288 // Looks up the engine from the engine cache, and serializes the engine.
ReadSerializedEngine(const NodeDef & node,Session * session,TrtUniquePtrType<nvinfer1::IHostMemory> * engine_data)289 Status ReadSerializedEngine(
290 const NodeDef& node, Session* session,
291 TrtUniquePtrType<nvinfer1::IHostMemory>* engine_data) {
292 TRTEngineCacheResource* resource;
293 TF_RETURN_IF_ERROR(GetEngineCacheResource(node, session, &resource));
294 core::ScopedUnref unref_cache_res(resource);
295 if (resource->cache_.size() > 1) {
296 return errors::Internal(
297 "Multiple engines found, but we can only serialize one");
298 }
299 const std::unique_ptr<EngineContext>& engine =
300 resource->cache_.begin()->second;
301 if (!engine) {
302 return errors::Internal("Engine not found for", node.name());
303 }
304
305 if (engine->GetCudaEngine()) {
306 // Serialize the engine.
307 engine_data->reset(engine->GetCudaEngine()->serialize());
308 } else {
309 LOG(WARNING) << "Engine cache contains nullptr";
310 }
311
312 return Status::OK();
313 }
314
315 // Saves the TRT engines as attributes of the TRTEngineOp nodes.
ConvertToStaticEngine(const GraphDef graph_def,GraphDef * static_graph_def,Session * session)316 Status ConvertToStaticEngine(const GraphDef graph_def,
317 GraphDef* static_graph_def, Session* session) {
318 static_graph_def->CopyFrom(graph_def);
319 VLOG(1) << "Saving TRT engines as static engine";
320 std::string op{"TRTEngineOp"};
321 for (auto& node : *(static_graph_def->mutable_node())) {
322 if (!op.compare(node.op())) {
323 VLOG(2) << "Saving TRT engine for " << node.name()
324 << ", device: " << node.device();
325 TrtUniquePtrType<nvinfer1::IHostMemory> engine_data;
326 TF_RETURN_IF_ERROR(ReadSerializedEngine(node, session, &engine_data));
327 auto* attr = node.mutable_attr();
328 AttrValue static_engine;
329 static_engine.set_b(true);
330 AttrValue engine_string;
331 if (engine_data) {
332 engine_string.set_s(engine_data->data(), engine_data->size());
333 }
334 (*attr)["static_engine"] = static_engine;
335 (*attr)["serialized_segment"] = engine_string;
336 }
337 }
338 return Status::OK();
339 }
340
ValidateConversionParams(const TfTrtConversionParams & p,int n_inputs)341 Status ValidateConversionParams(const TfTrtConversionParams& p, int n_inputs) {
342 if (p.precision_mode == TrtPrecisionMode::INT8 && p.use_calibration) {
343 return errors::InvalidArgument(
344 "Calibration not yet implemented through the C++ interface. Please use "
345 "our Python API for calibration.");
346 }
347 if (p.convert_to_static_engine && n_inputs == 0) {
348 return errors::InvalidArgument(
349 "TRT Engine needs to be built before we can convert it to static "
350 "engine. Please provide input data to build the model.");
351 }
352 if (!p.convert_to_static_engine && n_inputs >= 0) {
353 // After the conversion, the session that was used to build the engines
354 // will be destroyed. If we do not convert the engine to static engine,
355 // then we loose the engines.
356 //
357 // TODO(tfeher): Provide a way to save dynamic engines and remove this
358 // warning.
359 LOG(WARNING)
360 << "Skipping build mode because we cannot save the "
361 "engines. Use convert_to_static_engines=true conversion "
362 "parameter to enable build mode and save the engines in the graph.";
363 }
364 if (!p.allow_build_at_runtime && n_inputs == 0) {
365 LOG(WARNING)
366 << "TRT will not be used since allow_build_at_runtime is disabled and "
367 "no inputs are provided to build during conversion.";
368 }
369 return Status::OK();
370 }
371
372 // Returns configuration used during the build step session run.
GetSessionConfg()373 tensorflow::SessionOptions GetSessionConfg() {
374 // We also need to disable constant folding because we already ran constant
375 // folding and may have prevented quantization operation folding on purpose.
376 tensorflow::SessionOptions opts;
377 auto* rewriter_opts =
378 opts.config.mutable_graph_options()->mutable_rewrite_options();
379 rewriter_opts->set_experimental_disable_folding_quantization_emulation(true);
380
381 // It seems that we need to disable the optimizer entirely to prevent the
382 // folding.
383 rewriter_opts->set_disable_meta_optimizer(true);
384 return opts;
385 }
386
387 } // namespace
388
ConvertAndBuild(const GraphDef & frozen_graph_def,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<std::vector<tensorflow::Tensor>> & inputs,const TfTrtConversionParams & conv_params)389 StatusOr<GraphDef> ConvertAndBuild(
390 const GraphDef& frozen_graph_def, const std::vector<string>& input_names,
391 const std::vector<string>& output_names,
392 const std::vector<std::vector<tensorflow::Tensor>>& inputs,
393 const TfTrtConversionParams& conv_params) {
394 TF_RETURN_IF_ERROR(ValidateConversionParams(conv_params, inputs.size()));
395 MetaGraphDef meta_graph;
396 meta_graph.mutable_graph_def()->CopyFrom(frozen_graph_def);
397
398 RewriterConfig rewriter_config;
399 TF_RETURN_IF_ERROR(
400 GetTrtRewriterConfig(conv_params, frozen_graph_def, &rewriter_config));
401
402 GraphDef segmented_graph_def;
403 TF_RETURN_IF_ERROR(RunTfTrt(meta_graph, input_names, output_names,
404 rewriter_config, &segmented_graph_def));
405
406 GraphDef output;
407
408 if (inputs.size() > 0 && conv_params.convert_to_static_engine) {
409 // The TRTOptimization pass has inserted placeholder TRTEngineOps. Here we
410 // trigger conversion by inferring the graph.
411 std::unique_ptr<tensorflow::Session> session(
412 tensorflow::NewSession(GetSessionConfg()));
413 if (!session.get()) {
414 return errors::Internal("Failed to create build session");
415 }
416
417 TF_RETURN_IF_ERROR(Build(segmented_graph_def, input_names, output_names,
418 inputs, session.get(), conv_params));
419
420 TF_RETURN_IF_ERROR(
421 ConvertToStaticEngine(segmented_graph_def, &output, session.get()));
422 } else {
423 output.CopyFrom(segmented_graph_def);
424 }
425 VLOG(1) << "TF-TRT conversion finished";
426 return output;
427 }
428
InlineFunctions(const MetaGraphDef & meta_graph_def,GraphDef * out_graph_def)429 Status InlineFunctions(const MetaGraphDef& meta_graph_def,
430 GraphDef* out_graph_def) {
431 ConfigProto config_proto;
432 auto opt_config =
433 config_proto.mutable_graph_options()->mutable_rewrite_options();
434
435 opt_config->set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE);
436 opt_config->set_min_graph_nodes(-1); // do not skip small graphs
437 opt_config->add_optimizers("function");
438
439 TF_RETURN_IF_ERROR(RunGrappler(meta_graph_def, {}, {}, config_proto, nullptr,
440 out_graph_def));
441
442 VLOG(2) << "Graph is inlined";
443 return Status::OK();
444 }
445
446 // Freezes the graph. It is assumed that the functions are inlined and the
447 // variables are initialized.
FreezeGraph(SavedModelBundle & bundle,MetaGraphDef * frozen_meta_graph)448 Status FreezeGraph(SavedModelBundle& bundle, MetaGraphDef* frozen_meta_graph) {
449 std::unordered_set<std::string> inputs;
450 std::unordered_set<std::string> outputs;
451 GraphDef frozen_graph_def;
452 TF_RETURN_IF_ERROR(
453 FreezeSavedModel(bundle, &frozen_graph_def, &inputs, &outputs));
454
455 frozen_meta_graph->CopyFrom(bundle.meta_graph_def);
456 GraphDef* gdef = frozen_meta_graph->mutable_graph_def();
457 gdef->CopyFrom(frozen_graph_def);
458
459 VLOG(2) << "Graph frozen";
460 return Status::OK();
461 }
462
463 // Returns the name of nodes listed in the signature definition.
GetNodeNames(const google::protobuf::Map<std::string,tensorflow::TensorInfo> & signature)464 std::vector<std::string> GetNodeNames(
465 const google::protobuf::Map<std::string, tensorflow::TensorInfo>& signature) {
466 std::vector<std::string> names;
467 for (auto const& item : signature) {
468 absl::string_view name = item.second.name();
469 // Remove tensor suffix like ":0".
470 size_t last_colon = name.find_last_of(':');
471 if (last_colon != absl::string_view::npos) {
472 name.remove_suffix(name.size() - last_colon);
473 }
474 names.push_back(std::string(name));
475 }
476 return names;
477 }
478
ConvertAndBuild(SavedModelBundle * bundle,const std::string & signature_key,const std::vector<std::vector<tensorflow::Tensor>> & inputs,const TfTrtConversionParams & conversion_params)479 StatusOr<GraphDef> ConvertAndBuild(
480 SavedModelBundle* bundle, const std::string& signature_key,
481 const std::vector<std::vector<tensorflow::Tensor>>& inputs,
482 const TfTrtConversionParams& conversion_params) {
483 // Inline the functions.
484 GraphDef inlined_graph_def;
485 TF_RETURN_IF_ERROR(
486 InlineFunctions(bundle->meta_graph_def, &inlined_graph_def));
487
488 // Replace the graph_def with the inlined graph. Note that bundle->session
489 // still has the original graph.
490 bundle->meta_graph_def.mutable_graph_def()->CopyFrom(inlined_graph_def);
491
492 // Freeze variables.
493 MetaGraphDef frozen_meta_graph;
494 TF_RETURN_IF_ERROR(FreezeGraph(*bundle, &frozen_meta_graph));
495
496 // Convert.
497 auto signature_map = bundle->GetSignatures();
498 const tensorflow::SignatureDef& signature = signature_map[signature_key];
499 std::vector<std::string> input_names = GetNodeNames(signature.inputs());
500 std::vector<std::string> output_names = GetNodeNames(signature.outputs());
501 return ConvertAndBuild(frozen_meta_graph.graph_def(), input_names,
502 output_names, inputs, conversion_params);
503 }
504
505 } // namespace tensorrt
506 } // namespace tensorflow
507
508 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
509