1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to test TF-TensorRT integration.""" 16 17import gc 18import os 19import re 20import tempfile 21from unittest import mock 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled 27from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member 28from tensorflow.core.framework import graph_pb2 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.python.compiler.tensorrt import trt_convert 31from tensorflow.python.compiler.tensorrt.test import test_utils 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import config 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import graph_util 37from tensorflow.python.framework import importer 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.framework import tensor_spec 41from tensorflow.python.framework import test_util 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import gen_resource_variable_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import test 47from tensorflow.python.saved_model import builder 48from tensorflow.python.saved_model import load 49from tensorflow.python.saved_model import loader 50from tensorflow.python.saved_model import loader_impl 51from tensorflow.python.saved_model import save 52from tensorflow.python.saved_model import save_options 53from tensorflow.python.saved_model import signature_constants 54from tensorflow.python.saved_model import signature_def_utils 55from tensorflow.python.saved_model import tag_constants 56from tensorflow.python.saved_model import utils 57from tensorflow.python.tools import saved_model_utils 58from tensorflow.python.trackable import autotrackable 59from tensorflow.python.util.lazy_loader import LazyLoader 60 61_SAVED_MODEL_SIGNATURE_KEY = "mypredict" 62 63gen_trt_ops = LazyLoader( 64 "gen_trt_ops", globals(), 65 "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") 66 67 68class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): 69 """Class to test Tensorflow-TensorRT integration python API.""" 70 71 # Use a small max_workspace_size for tests so they don't consume too much GPU 72 # memory. 73 _TRT_MAX_WORKSPACE_SIZE_BYTES = ( 74 trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES) 75 76 def mkdtemp(self): 77 return tempfile.mkdtemp(dir=self.get_temp_dir()) 78 79 def testTRTEngineInstanceAvailable(self): 80 # test if we can access the TRTEngineInstance protobuf 81 assert hasattr(TRTEngineInstance(), "serialized_engine") 82 83 def _GetConfigProto(self, rewriter_config=None): 84 """Get ConfigProto for session creation.""" 85 config = config_pb2.ConfigProto( 86 gpu_options=config_pb2.GPUOptions(allow_growth=True)) 87 if rewriter_config: 88 config.graph_options.rewrite_options.CopyFrom(rewriter_config) 89 return config 90 91 @classmethod 92 def _GetGraph(cls, inp1, inp2, var): 93 """Get the graph for testing.""" 94 # The graph computes: inp1^2 + inp1*var + inp1 + inp2 + var 95 add = inp1 + var 96 mul = inp1 * add 97 add = mul + add 98 add = add + inp2 99 out = array_ops.identity(add, name="output") 100 return out 101 102 def _GetModelForV2(self): 103 104 class SimpleModel(autotrackable.AutoTrackable): 105 106 def __init__(self): 107 self.v = None 108 109 @def_function.function(input_signature=[ 110 tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32), 111 tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32) 112 ]) 113 def run(self, inp1, inp2): 114 if self.v is None: 115 self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32) 116 return TrtConvertTest._GetGraph(inp1, inp2, self.v) 117 118 return SimpleModel() 119 120 def _GetGraphForV1(self, device): 121 122 def _GraphFn(): 123 inp1 = array_ops.placeholder( 124 dtype=dtypes.float32, shape=[None, 1, 1], name="input1") 125 inp2 = array_ops.placeholder( 126 dtype=dtypes.float32, shape=[None, 1, 1], name="input2") 127 var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1") 128 out = TrtConvertTest._GetGraph(inp1, inp2, var) 129 return g, var, inp1, inp2, out 130 131 g = ops.Graph() 132 with g.as_default(): 133 if device: 134 with g.device(device): 135 return _GraphFn() 136 return _GraphFn() 137 138 def _GetGraphDefForV1(self, device): 139 """Get the graph def for testing.""" 140 g, var, _, _, _ = self._GetGraphForV1(device) 141 with self.session(graph=g, config=self._GetConfigProto()) as sess: 142 sess.run(var.initializer) 143 graph_def = graph_util.convert_variables_to_constants( 144 sess, g.as_graph_def(add_shapes=True), ["output"]) 145 node_name_to_op = {node.name: node.op for node in graph_def.node} 146 self.assertEqual( 147 { 148 "v1": "Const", 149 "add/ReadVariableOp": "Identity", 150 "input1": "Placeholder", 151 "input2": "Placeholder", 152 "add": "AddV2", 153 "mul": "Mul", 154 "add_1": "AddV2", 155 "add_2": "AddV2", 156 "output": "Identity" 157 }, node_name_to_op) 158 return graph_def 159 160 def _WriteInputSavedModelForV1(self, input_saved_model_dir, device): 161 """Write the saved model as an input for testing.""" 162 g, var, inp1, inp2, out = self._GetGraphForV1(device) 163 signature_def = signature_def_utils.build_signature_def( 164 inputs={ 165 "myinput1": utils.build_tensor_info(inp1), 166 "myinput2": utils.build_tensor_info(inp2) 167 }, 168 outputs={"myoutput": utils.build_tensor_info(out)}, 169 method_name=signature_constants.PREDICT_METHOD_NAME) 170 saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir) 171 with self.session(graph=g, config=self._GetConfigProto()) as sess: 172 sess.run(var.initializer) 173 saved_model_builder.add_meta_graph_and_variables( 174 sess, [tag_constants.SERVING], 175 signature_def_map={_SAVED_MODEL_SIGNATURE_KEY: signature_def}) 176 saved_model_builder.save() 177 178 def _ConvertGraphV1(self, 179 output_saved_model_dir=None, 180 need_calibration=False, 181 max_batch_size=1, 182 minimum_segment_size=3, 183 is_dynamic_op=False, 184 maximum_cached_engines=1, 185 device=None): 186 """Helper method to convert a GraphDef or SavedModel using TF-TRT.""" 187 input_saved_model_dir = None 188 if output_saved_model_dir: 189 input_saved_model_dir = self.mkdtemp() 190 self._WriteInputSavedModelForV1(input_saved_model_dir, device) 191 192 # Calibration requires dynamic_op. 193 if need_calibration: 194 is_dynamic_op = True 195 196 # For dynamic_op, the converter requires the unused max_batch_size=None. 197 if is_dynamic_op: 198 max_batch_size = None 199 200 converter = trt_convert.TrtGraphConverter( 201 input_saved_model_dir=input_saved_model_dir, 202 input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, 203 input_graph_def=None 204 if input_saved_model_dir else self._GetGraphDefForV1(device), 205 nodes_denylist=None if input_saved_model_dir else ["output"], 206 max_batch_size=max_batch_size, 207 max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES, 208 precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration 209 else trt_convert.TrtPrecisionMode.FP32), 210 minimum_segment_size=minimum_segment_size, 211 is_dynamic_op=is_dynamic_op, 212 maximum_cached_engines=maximum_cached_engines) 213 output_graph_def = converter.convert() 214 215 if need_calibration: 216 217 class CalibrationData(object): 218 219 def __init__(self): 220 self._data = 0 221 222 def next(self): 223 self._data += 1 224 return {"input1:0": [[[self._data]]], "input2:0": [[[self._data]]]} 225 226 output_graph_def = converter.calibrate( 227 fetch_names=["output:0"], 228 num_runs=10, 229 feed_dict_fn=CalibrationData().next) 230 231 if output_saved_model_dir is not None: 232 converter.save(output_saved_model_dir=output_saved_model_dir) 233 return output_graph_def 234 235 # Remove the graph sequence number prefix from the name only if the name has 236 # a prefix TRTEngineOp_n_. 237 def _MayRemoveGraphSequenceNumber(self, name): 238 prefix = re.search(r"TRTEngineOp_\d{3,}_", name) 239 if prefix and name.startswith(prefix.group(0)): 240 parts = name.split("_", maxsplit=2) 241 assert len(parts) == 3 242 return parts[0] + "_" + parts[2] 243 return name 244 245 # Return the unique TRTEngineOp in the given graph def. 246 def _GetUniqueTRTEngineOp(self, graph_def): 247 trt_engine_nodes = [ 248 node for node in graph_def.node if node.op == "TRTEngineOp" 249 ] 250 assert len(trt_engine_nodes) == 1 251 return trt_engine_nodes[0] 252 253 def _TestTrtGraphConverter(self, 254 device, 255 output_saved_model_dir=None, 256 need_calibration=False, 257 is_dynamic_op=False): 258 """General method to test trt_convert.TrtGraphConverter().""" 259 output_graph_def = self._ConvertGraphV1( 260 output_saved_model_dir=output_saved_model_dir, 261 need_calibration=need_calibration, 262 is_dynamic_op=is_dynamic_op, 263 device=device) 264 graph_defs_to_verify = [output_graph_def] 265 266 if output_saved_model_dir: 267 saved_model_graph_def = saved_model_utils.get_meta_graph_def( 268 output_saved_model_dir, tag_constants.SERVING).graph_def 269 self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef) 270 graph_defs_to_verify.append(saved_model_graph_def) 271 272 for graph_def in graph_defs_to_verify: 273 node_name_to_op = { 274 self._MayRemoveGraphSequenceNumber(node.name): node.op 275 for node in graph_def.node 276 } 277 if device is not None and device.startswith("/CPU:"): 278 self.assertEqual( 279 { 280 "add": "AddV2", 281 "v1": "Const", 282 "add_1": "AddV2", 283 "add_2": "AddV2", 284 "input1": "Placeholder", 285 "input2": "Placeholder", 286 "mul": "Mul", 287 "output": "Identity" 288 }, node_name_to_op) 289 else: 290 self.assertEqual( 291 { 292 "input1": "Placeholder", 293 "input2": "Placeholder", 294 "TRTEngineOp_000": "TRTEngineOp", 295 "output": "Identity" 296 }, node_name_to_op) 297 298 if need_calibration: 299 trt_engine_nodes = [ 300 node for node in graph_def.node if node.op == "TRTEngineOp" 301 ] 302 if device is not None and device.startswith("/CPU:"): 303 self.assertEmpty(trt_engine_nodes) 304 return 305 306 self.assertNotEmpty(trt_engine_nodes) 307 for node in trt_engine_nodes: 308 self.assertTrue(len(node.attr["calibration_data"].s)) 309 # Run the calibrated graph. 310 # TODO(laigd): consider having some input where the answer is different. 311 with ops.Graph().as_default(): 312 importer.import_graph_def(graph_def, name="") 313 with self.session(config=self._GetConfigProto()) as sess: 314 for test_data in range(10): 315 self.assertEqual((test_data + 1.0)**2 + test_data, 316 sess.run( 317 "output:0", 318 feed_dict={ 319 "input1:0": [[[test_data]]], 320 "input2:0": [[[test_data]]] 321 })) 322 323 @parameterized.named_parameters([ 324 ("NoDeviceAssignment", None), 325 ("GPU", "/GPU:0"), 326 ("CPU", "/CPU:0"), 327 ]) 328 @test_util.deprecated_graph_mode_only 329 def testTrtGraphConverter_OfflineConversion(self, device): 330 """Test case for trt_convert.TrtGraphConverter().""" 331 332 for need_calibration in [False, True]: 333 # Use GraphDef as input. 334 self._TestTrtGraphConverter(device) 335 336 # Use SavedModel as input. 337 self._TestTrtGraphConverter( 338 device, 339 output_saved_model_dir=self.mkdtemp(), 340 need_calibration=need_calibration) 341 342 @parameterized.named_parameters([ 343 ("NoDeviceAssignment", None), 344 ("GPU", "/device:GPU:0"), 345 ("CPU", "/device:CPU:0"), 346 ]) 347 @test_util.deprecated_graph_mode_only 348 def testTrtGraphConverter_OnlineConversion(self, device): 349 """Test case for TF-TRT conversion using Grappler directly.""" 350 351 conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( 352 precision_mode=trt_convert.TrtPrecisionMode.FP32) 353 config = self._GetConfigProto( 354 rewriter_config=trt_convert.get_tensorrt_rewriter_config( 355 conversion_params, 356 is_dynamic_op=False, 357 max_batch_size=1, 358 is_v2=False)) 359 360 with ops.Graph().as_default(): 361 # Online conversion requires a frozen graph, so we reuse inp1 as the var 362 # argument. 363 inp1 = array_ops.placeholder( 364 dtype=dtypes.float32, shape=[None, 1, 1], name="input1") 365 inp2 = array_ops.placeholder( 366 dtype=dtypes.float32, shape=[None, 1, 1], name="input2") 367 if device: 368 with ops.device(device): 369 TrtConvertTest._GetGraph(inp1, inp2, inp1) 370 else: 371 TrtConvertTest._GetGraph(inp1, inp2, inp1) 372 with self.session(config=config) as sess: 373 self._TestRun(sess, batch_size=1) 374 375 def _CreateConverterV2( 376 self, 377 input_saved_model_dir, 378 input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, 379 max_workspace_size_bytes=10 << 20, # Use a smaller workspace. 380 precision_mode=trt_convert.TrtPrecisionMode.FP32, 381 maximum_cached_engines=2, 382 allow_build_at_runtime=True): 383 return trt_convert.TrtGraphConverterV2( 384 input_saved_model_dir=input_saved_model_dir, 385 input_saved_model_signature_key=input_saved_model_signature_key, 386 max_workspace_size_bytes=max_workspace_size_bytes, 387 precision_mode=precision_mode, 388 maximum_cached_engines=maximum_cached_engines, 389 allow_build_at_runtime=allow_build_at_runtime) 390 391 def _CheckTrtOps(self, concrete_func, check_fn=None, num_engines=1): 392 graph_def = concrete_func.graph.as_graph_def() 393 trt_op_names = [] 394 for node in graph_def.node: 395 if node.op == "TRTEngineOp": 396 trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) 397 if check_fn: 398 check_fn(node) 399 for func in graph_def.library.function: 400 for node in func.node_def: 401 if node.op == "TRTEngineOp": 402 trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) 403 if check_fn: 404 check_fn(node) 405 self.assertLen(trt_op_names, num_engines) 406 407 def _RandomInput(self, shape, dtype=np.float32): 408 inp1 = np.random.random_sample(shape).astype(dtype) 409 inp2 = np.random.random_sample(shape).astype(dtype) 410 return inp1, inp2 411 412 @test_util.run_v2_only 413 def testTrtGraphConverter_DynamicConversion_v2(self): 414 """Test case for trt_convert.TrtGraphConverter().""" 415 416 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 417 418 # Create a model and save it. 419 input_saved_model_dir = self.mkdtemp() 420 root = self._GetModelForV2() 421 expected_output = root.run(np_input1, np_input2) 422 save.save(root, input_saved_model_dir, 423 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 424 425 # Run TRT conversion. 426 converter = self._CreateConverterV2(input_saved_model_dir) 427 converter.convert() 428 429 # Verify the converted GraphDef and ConcreteFunction. 430 self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access 431 432 trt_engine_name = self._GetUniqueTRTEngineOp( 433 converter._converted_graph_def).name 434 435 # Save the converted model without any TRT engine cache. 436 output_saved_model_dir = self.mkdtemp() 437 converter.save(output_saved_model_dir) 438 unexpected_asset_file = os.path.join( 439 output_saved_model_dir, 440 "assets/trt-serialized-engine." + trt_engine_name) 441 self.assertFalse(os.path.exists(unexpected_asset_file)) 442 443 # Run the converted function to populate the engine cache. 444 def _InputFn(): 445 yield np_input1, np_input2 446 447 converter.build(input_fn=_InputFn) 448 449 # Save the converted model again with serialized engine cache. 450 output_saved_model_dir = self.mkdtemp() 451 converter.save(output_saved_model_dir) 452 expected_asset_file = os.path.join( 453 output_saved_model_dir, 454 "assets/trt-serialized-engine." + trt_engine_name) 455 self.assertTrue(os.path.exists(expected_asset_file)) 456 self.assertTrue(os.path.getsize(expected_asset_file)) 457 458 del converter 459 gc.collect() # Force GC to destroy the TRT engine cache. 460 461 # Load and verify the converted model. 462 # 463 # TODO(laigd): the name of the new input_signature of the 464 # `root_with_trt.run` function is empty string (originally was None), 465 # investigate why. 466 root_with_trt = load.load(output_saved_model_dir) 467 # TODO(laigd): `root_with_trt.run` is still using the original graph without 468 # trt. Consider changing that. 469 # self._CheckTrtOps(root_with_trt.run.get_concrete_function()) 470 converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] 471 self._CheckTrtOps(converted_signature) 472 output_with_trt = converted_signature( 473 inp1=ops.convert_to_tensor(np_input1), 474 inp2=ops.convert_to_tensor(np_input2)) 475 # The output of running the converted signature is a dict due to 476 # compatibility reasons with V1 SavedModel signature mechanism. 477 self.assertAllClose( 478 expected_output, 479 list(output_with_trt.values())[0], 480 atol=1e-6, 481 rtol=1e-6) 482 483 del root_with_trt 484 gc.collect() # Force GC to destroy the TRT engine cache. 485 486 @test_util.run_v2_only 487 def testTrtGraphConverter_ShapeOp_Int32InputOutput_v2(self): 488 """Testing ShapeOp and int32 values as engine input and output.""" 489 490 class ShapeOpModel(autotrackable.AutoTrackable): 491 492 def __init__(self): 493 self.v = None 494 495 @def_function.function(input_signature=[ 496 tensor_spec.TensorSpec(shape=[None, None], dtype=dtypes.float32) 497 ]) 498 def run(self, x): 499 q = x + 1 500 q_shape = array_ops.shape(q) 501 # Add an OP that is not supported by TF-TRT. This allows TF-TRT to build 502 # two engines. The first engine produces an int32 output and the second 503 # engines has an int32 input and an int32 output. 504 q = math_ops.cumsum(q_shape) 505 q = q * 2 506 return array_ops.identity(q, name="output") 507 508 np_input = np.random.random_sample([5, 3]).astype(np.float32) 509 510 def _InputFunc(): 511 yield (np_input,) 512 513 # Create the SavedModel. 514 root = ShapeOpModel() 515 expected_output = root.run(np_input) 516 input_saved_model_dir = self.mkdtemp() 517 save.save(root, input_saved_model_dir, signatures=root.run) 518 519 # Convert the graph to TF-TRT. 520 conv_params = trt_convert.TrtConversionParams(minimum_segment_size=2) 521 converter = trt_convert.TrtGraphConverterV2( 522 input_saved_model_dir=input_saved_model_dir, 523 use_dynamic_shape=True, 524 **conv_params._asdict()) 525 converter.convert() 526 527 # Build the graph with the input generator. This runs the TRTEngineOp native 528 # segment. 529 converter.build(_InputFunc) 530 output_saved_model_dir = self.mkdtemp() 531 converter.save(output_saved_model_dir) 532 533 root_with_trt = load.load(output_saved_model_dir) 534 converted_signature = root_with_trt.signatures["serving_default"] 535 # Check that the graph is converted to two TRTEngineOps. 536 self._CheckTrtOps(converted_signature, num_engines=2) 537 # Run the graph. 538 output_with_trt = converted_signature(x=ops.convert_to_tensor(np_input)) 539 # Check the result of the run. 540 self.assertAllClose(expected_output, list(output_with_trt.values())[0]) 541 542 @test_util.run_v2_only 543 def testTrtGraphConverter_Int8Conversion_v2(self): 544 545 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 546 547 # Create a model and save it. 548 input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) 549 root = self._GetModelForV2() 550 expected_output = root.run(np_input1, np_input2) 551 save.save(root, input_saved_model_dir, 552 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 553 554 # Run TRT conversion. 555 converter = self._CreateConverterV2( 556 input_saved_model_dir, 557 precision_mode=trt_convert.TrtPrecisionMode.INT8, 558 maximum_cached_engines=3) 559 560 # Convert and perform INT8 calibration 561 def _CalibrationInputFn(): 562 yield np_input1, np_input2 563 564 converter.convert(calibration_input_fn=_CalibrationInputFn) 565 566 trt_engine_name = self._GetUniqueTRTEngineOp( 567 converter._converted_graph_def).name 568 569 def _CheckFn(node): 570 self.assertTrue(len(node.attr["calibration_data"].s), node.name) 571 572 # Verify the converted GraphDef. 573 self._CheckTrtOps(converter._converted_func, _CheckFn) # pylint: disable=protected-access 574 575 # Build another engine with different batch size. 576 def _InputFn(): 577 yield self._RandomInput([5, 1, 1]) 578 579 converter.build(input_fn=_InputFn) 580 581 # Save the converted model. 582 # TODO(laigd): check that it should contain two engines. 583 output_saved_model_dir = self.mkdtemp() 584 converter.save(output_saved_model_dir) 585 expected_asset_file = os.path.join( 586 output_saved_model_dir, 587 "assets/trt-serialized-engine." + trt_engine_name) 588 self.assertTrue(os.path.exists(expected_asset_file)) 589 self.assertTrue(os.path.getsize(expected_asset_file)) 590 591 del converter 592 gc.collect() # Force GC to destroy the TRT engine cache. 593 594 # Load and verify the converted model. 595 root_with_trt = load.load(output_saved_model_dir) 596 converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] 597 self._CheckTrtOps(converted_signature, _CheckFn) 598 output_with_trt = converted_signature( 599 inp1=ops.convert_to_tensor(np_input1), 600 inp2=ops.convert_to_tensor(np_input2)) 601 self.assertEqual(1, len(output_with_trt)) 602 # The output of running the converted signature is a dict due to 603 # compatibility reasons with V1 SavedModel signature mechanism. 604 self.assertAllClose( 605 expected_output, 606 list(output_with_trt.values())[0], 607 atol=1e-6, 608 rtol=1e-6) 609 610 # Run with an input of different batch size. It should build a new engine 611 # using calibration table. 612 # TODO(laigd): check that it should contain three engines. 613 np_input1, np_input2 = self._RandomInput([6, 1, 1]) 614 converted_signature( 615 inp1=ops.convert_to_tensor(np_input1), 616 inp2=ops.convert_to_tensor(np_input2)) 617 618 del root_with_trt 619 gc.collect() # Force GC to destroy the TRT engine cache. 620 621 @test_util.run_v2_only 622 def testTrtGraphConverter_DestroyEngineCache(self): 623 """Test case for trt_convert.TrtGraphConverter().""" 624 625 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 626 627 # Create a model and save it. 628 input_saved_model_dir = self.mkdtemp() 629 root = self._GetModelForV2() 630 save.save(root, input_saved_model_dir, 631 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 632 633 # Run TRT conversion. 634 converter = self._CreateConverterV2(input_saved_model_dir) 635 converter.convert() 636 637 trt_engine_name = self._GetUniqueTRTEngineOp( 638 converter._converted_graph_def).name 639 640 def _InputFn(): 641 yield np_input1, np_input2 642 643 converter.build(input_fn=_InputFn) # Populate the TRT engine cache. 644 output_saved_model_dir = self.mkdtemp() 645 converter.save(output_saved_model_dir) 646 647 def _DestroyCache(): 648 with ops.device("GPU:0"): 649 handle = gen_trt_ops.create_trt_resource_handle( 650 resource_name=trt_engine_name) 651 gen_resource_variable_ops.destroy_resource_op( 652 handle, ignore_lookup_error=False) 653 654 with self.assertRaisesRegex(errors.NotFoundError, 655 r"Resource .* does not exist."): 656 _DestroyCache() 657 658 # Load the converted model and make sure the engine cache is populated by 659 # default. 660 root = load.load(output_saved_model_dir) 661 _DestroyCache() 662 with self.assertRaisesRegex(errors.NotFoundError, 663 r"Resource .* does not exist."): 664 _DestroyCache() 665 666 # Load the converted model again and make sure the engine cache is destroyed 667 # when the model goes out of scope. 668 root = load.load(output_saved_model_dir) 669 del root 670 gc.collect() # Force GC to destroy the TRT engine cache. 671 with self.assertRaisesRegex(errors.NotFoundError, 672 r"Resource .* does not exist."): 673 _DestroyCache() 674 675 def _CompareSavedModel(self, model_class): 676 signature_key = "serving_default" 677 678 def _GetModelPaths(model_class): 679 input_saved_model_dir = self.mkdtemp() 680 root = model_class() 681 save.save(root, input_saved_model_dir) 682 683 converter = self._CreateConverterV2( 684 input_saved_model_dir, input_saved_model_signature_key=signature_key) 685 converter.convert() 686 output_saved_model_dir = self.mkdtemp() 687 converter.save(output_saved_model_dir) 688 return input_saved_model_dir, output_saved_model_dir 689 690 def _GetSignatureDef(export_dir): 691 saved_model_proto = loader_impl.parse_saved_model(export_dir) 692 self.assertEqual(1, len(saved_model_proto.meta_graphs)) 693 meta_graph = saved_model_proto.meta_graphs[0] 694 self.assertIn(signature_key, meta_graph.signature_def) 695 return meta_graph.signature_def[signature_key] 696 697 def _CompareSignatureDef(original_def, converted_def, is_input): 698 endpoints = original_def.inputs if is_input else original_def.outputs 699 converted_endpoints = ( 700 converted_def.inputs if is_input else converted_def.outputs) 701 self.assertEqual(set(endpoints.keys()), set(converted_endpoints.keys())) 702 for key in endpoints: 703 original_input = endpoints[key] 704 converted_input = converted_endpoints[key] 705 self.assertEqual(original_input.name, converted_input.name) 706 self.assertEqual(original_input.dtype, converted_input.dtype) 707 self.assertEqual( 708 tensor_shape.TensorShape(original_input.tensor_shape).as_list(), 709 tensor_shape.TensorShape(converted_input.tensor_shape).as_list()) 710 711 def _GetStructuredOutputs(export_dir): 712 root = load.load(export_dir) 713 return root.signatures[signature_key].structured_outputs 714 715 saved_model_path, converted_saved_model_path = _GetModelPaths(model_class) 716 original_def = _GetSignatureDef(saved_model_path) 717 converted_def = _GetSignatureDef(converted_saved_model_path) 718 self.assertEqual(original_def.method_name, converted_def.method_name) 719 _CompareSignatureDef(original_def, converted_def, True) 720 _CompareSignatureDef(original_def, converted_def, False) 721 722 self.assertEqual( 723 _GetStructuredOutputs(saved_model_path), 724 _GetStructuredOutputs(converted_saved_model_path)) 725 726 @test_util.run_v2_only 727 def testRetainSignatureInfo_NoInputs(self): 728 729 class _Model(autotrackable.AutoTrackable): 730 731 @def_function.function(input_signature=[]) 732 def run(self): 733 return array_ops.constant(1.0) 734 735 self._CompareSavedModel(_Model) 736 737 @test_util.run_v2_only 738 def testRetainSignatureInfo_OneInput(self): 739 740 class _Model(autotrackable.AutoTrackable): 741 742 @def_function.function(input_signature=[ 743 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32) 744 ]) 745 def run(self, inp): 746 return inp + inp * inp 747 748 self._CompareSavedModel(_Model) 749 750 @test_util.run_v2_only 751 def testRetainSignatureInfo_TwoInputs(self): 752 753 class _Model(autotrackable.AutoTrackable): 754 755 @def_function.function(input_signature=[ 756 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32), 757 tensor_spec.TensorSpec(shape=[None, 2], dtype=dtypes.float32) 758 ]) 759 def run(self, inp1, inp2): 760 return inp1 + inp2 * inp2 761 762 self._CompareSavedModel(_Model) 763 764 @test_util.run_v2_only 765 def testRetainSignatureInfo_OneOutputSignatureKey(self): 766 767 class _Model(autotrackable.AutoTrackable): 768 769 @def_function.function(input_signature=[]) 770 def run(self): 771 return {"my_output": array_ops.constant(1.0)} 772 773 self._CompareSavedModel(_Model) 774 775 @test_util.run_v2_only 776 def testRetainSignatureInfo_TwoOutputSignatureKeys(self): 777 778 class _Model(autotrackable.AutoTrackable): 779 780 @def_function.function(input_signature=[ 781 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32) 782 ]) 783 def run(self, inp): 784 # Here the keys are not ordered lexicographically on purpose. 785 return { 786 "output_b": array_ops.constant(1.0), 787 "output_a": inp + inp * inp 788 } 789 790 self._CompareSavedModel(_Model) 791 792 def _TestRun(self, sess, batch_size): 793 result = sess.run( 794 "output:0", 795 feed_dict={ 796 "input1:0": [[[1.0]]] * batch_size, 797 "input2:0": [[[1.0]]] * batch_size 798 }) 799 self.assertAllEqual([[[5.0]]] * batch_size, result) 800 801 @parameterized.named_parameters([ 802 ("LargeSegmentSize", 7), 803 ("NoMainGraphConversionSegmentSize", -1), 804 ]) 805 @test_util.deprecated_graph_mode_only 806 def testTrtGraphConverter_MinimumSegmentSize(self, minimum_segment_size): 807 output_graph_def = self._ConvertGraphV1( 808 minimum_segment_size=minimum_segment_size) 809 node_name_to_op = {node.name: node.op for node in output_graph_def.node} 810 self.assertEqual( 811 { 812 "v1": "Const", 813 "input1": "Placeholder", 814 "input2": "Placeholder", 815 "add": "AddV2", 816 "mul": "Mul", 817 "add_1": "AddV2", 818 "add_2": "AddV2", 819 "output": "Identity" 820 }, node_name_to_op) 821 822 @test_util.deprecated_graph_mode_only 823 def testTrtGraphConverter_DynamicOp(self): 824 825 output_saved_model_dir = self.mkdtemp() 826 output_graph_def = self._ConvertGraphV1( 827 output_saved_model_dir=output_saved_model_dir, 828 is_dynamic_op=True, 829 maximum_cached_engines=2) 830 831 # Test the output GraphDef. 832 with ops.Graph().as_default(): 833 importer.import_graph_def(output_graph_def, name="") 834 with self.session(config=self._GetConfigProto()) as sess: 835 # Run with batch size 1, a new engine is created and cached. 836 self._TestRun(sess, 1) 837 # Run with batch size 2, a new engine is created and cached. 838 self._TestRun(sess, 2) 839 # Run with batch size 3, since the number of cached engines has reached 840 # the max, it should evict an old engine and create a new one. 841 self._TestRun(sess, 3) 842 843 # Test the output SavedModel 844 with ops.Graph().as_default(): 845 with self.session(config=self._GetConfigProto()) as sess: 846 loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) 847 # Run with batch size 1, a new engine is created and cached. 848 self._TestRun(sess, 1) 849 # Run with batch size 2, a new engine is created and cached. 850 self._TestRun(sess, 2) 851 # Run with batch size 3, since the number of cached engines has reached 852 # the max, it should evict an old engine and create a new one. 853 self._TestRun(sess, 3) 854 855 @test_util.deprecated_graph_mode_only 856 def testTrtGraphConverter_StaticOp(self): 857 858 output_saved_model_dir = self.mkdtemp() 859 output_graph_def = self._ConvertGraphV1( 860 output_saved_model_dir=output_saved_model_dir, maximum_cached_engines=1) 861 862 # Test the output GraphDef. 863 with ops.Graph().as_default(): 864 importer.import_graph_def(output_graph_def, name="") 865 with self.session(config=self._GetConfigProto()) as sess: 866 # Run with batch size 1, the default engine embedded in the graphdef 867 # will be used. 868 self._TestRun(sess, 1) 869 # Run with batch size 2, which exceed the max_batch_size, it should try 870 # to fall back to TF function. 871 self._TestRun(sess, 2) 872 873 # Test the output SavedModel 874 with ops.Graph().as_default(): 875 with self.session(config=self._GetConfigProto()) as sess: 876 loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) 877 # Run with batch size 1, the default engine embedded in the graphdef 878 # will be used. 879 self._TestRun(sess, 1) 880 # Run with batch size 2, which exceed the max_batch_size, it should try 881 # to fall back to TF function. 882 self._TestRun(sess, 2) 883 884 @test_util.run_v2_only 885 def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self): 886 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 887 888 # Create a model and save it. 889 input_saved_model_dir = self.mkdtemp() 890 root = self._GetModelForV2() 891 save.save(root, input_saved_model_dir, 892 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 893 894 def _InputFn(): 895 yield np_input1, np_input2 896 897 # Run TRT conversion 898 converter = self._CreateConverterV2( 899 input_saved_model_dir, max_workspace_size_bytes=1 << 20) 900 converter.convert() 901 902 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" 903 os.environ["TF_TRT_ABORT_CUDA_ENGINE_BUILD"] = "True" 904 with self.assertRaisesRegex( 905 errors.AbortedError, 906 r"User disallowed engine native segment execution"): 907 try: 908 converter.build(input_fn=_InputFn) 909 finally: 910 # Always reset the environment variable. 911 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" 912 os.environ["TF_TRT_ABORT_CUDA_ENGINE_BUILD"] = "False" 913 914 converter.build(input_fn=_InputFn) 915 916 @parameterized.parameters((True, True), (True, False), (False, True), 917 (False, False)) 918 @test_util.run_v2_only 919 def testTrtGraphConverter_AllowBuildAtRuntime(self, build_offline, 920 allow_build_at_runtime): 921 if not is_tensorrt_enabled(): 922 return 923 924 # Create a model and save it. 925 input_saved_model_dir = self.mkdtemp() 926 root = self._GetModelForV2() 927 save.save(root, input_saved_model_dir, 928 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 929 930 np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 931 np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 932 933 def _InputFn(): 934 yield np_input1, np_input2 935 936 # Run TRT conversion and request an unreasonably large workspace. 937 converter = self._CreateConverterV2( 938 input_saved_model_dir, allow_build_at_runtime=allow_build_at_runtime) 939 converter.convert() 940 if build_offline: 941 converter.build(input_fn=_InputFn) 942 # Output saved model dir. 943 output_saved_model_dir = self.mkdtemp() 944 converter.save(output_saved_model_dir) 945 946 saved_model_loaded = load.load( 947 output_saved_model_dir, tags=[tag_constants.SERVING]) 948 graph_func = saved_model_loaded.signatures[_SAVED_MODEL_SIGNATURE_KEY] 949 950 # Checks the TrtEngineOp(s) have the correct attribute(s). 951 def _CheckFn(node): 952 self.assertEqual(node.attr["_allow_build_at_runtime"].b, 953 allow_build_at_runtime) 954 955 self._CheckTrtOps(graph_func, _CheckFn) 956 # If the engine was not build offline and the user set not to build at 957 # runtime and not to run native segments. Then, it will report an error. 958 if not build_offline and not allow_build_at_runtime: 959 with self.assertRaisesRegex( 960 errors.AbortedError, 961 r"User disallowed engine native segment execution"): 962 try: 963 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" 964 graph_func(inp1=np_input1, inp2=np_input2) 965 finally: 966 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" 967 else: 968 output = graph_func(inp1=np_input1, inp2=np_input2)["output_0"] 969 self.assertEqual(output.shape, (4, 1, 1)) 970 self.assertAllClose( 971 np.asarray([5.0, 5.0, 5.0, 5.0]).reshape([4, 1, 1]), output) 972 973 @test_util.run_v2_only 974 def testBackwardCompatibility(self): 975 """Load and execute a model that was saved in TF2.0.""" 976 977 model_dir = test.test_src_dir_path( 978 "python/compiler/tensorrt/test/testdata/tftrt_2.0_saved_model") 979 saved_model_loaded = load.load(model_dir, tags=[tag_constants.SERVING]) 980 graph_func = saved_model_loaded.signatures[ 981 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 982 983 np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 984 np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 985 output = graph_func(input1=np_input1, input2=np_input2)["output_0"] 986 987 self.assertEqual(output.shape, (4, 1, 1)) 988 self.assertAllClose( 989 np.asarray([5.0, 5.0, 5.0, 5.0]).reshape([4, 1, 1]), output) 990 991 @parameterized.named_parameters([ 992 ("SaveGPUSpecificEngine", True), 993 ("WithoutSaveGPUSpecificEngine", False), 994 ]) 995 @test_util.run_v2_only 996 def testTrtGraphConverter_SaveGPUSpecificEngine(self, save_engine_flag): 997 """Test case for trt_convert.TrtGraphConverter().""" 998 999 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 1000 1001 # Create a model and save it. 1002 input_saved_model_dir = self.mkdtemp() 1003 root = self._GetModelForV2() 1004 save.save(root, input_saved_model_dir, 1005 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 1006 1007 # Run TRT conversion. 1008 converter = self._CreateConverterV2( 1009 input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8) 1010 1011 # Run the converted function to populate the engine cache. 1012 def CalibrationFn(): 1013 yield np_input1, np_input2 1014 1015 converter.convert(calibration_input_fn=CalibrationFn) 1016 1017 # Verify the converted GraphDef and ConcreteFunction. 1018 self._CheckTrtOps(converter._converted_func) 1019 1020 trt_engine_name = self._GetUniqueTRTEngineOp( 1021 converter._converted_graph_def).name 1022 1023 # Save the converted model with or without any TRT engine cache 1024 # based on the value of save_engine_flag. 1025 output_saved_model_dir = self.mkdtemp() 1026 1027 converter.save( 1028 output_saved_model_dir, save_gpu_specific_engines=save_engine_flag) 1029 1030 expected_asset_file = os.path.join( 1031 output_saved_model_dir, 1032 "assets/trt-serialized-engine." + trt_engine_name) 1033 1034 self.assertTrue(os.path.exists(expected_asset_file)) 1035 if save_engine_flag: 1036 # engine is saved so we expect engine data 1037 self.assertTrue(os.path.getsize(expected_asset_file)) 1038 else: 1039 # engine is not saved so files should be empty 1040 self.assertFalse(os.path.getsize(expected_asset_file)) 1041 1042 del converter 1043 gc.collect() # Force GC to destroy the TRT engine cache. 1044 1045 @test_util.run_v2_only 1046 def testTrtGraphConverterV2_SaveWithOptions(self): 1047 """Test to make sure that save method respects options kwarg.""" 1048 1049 # Create a model and save it. 1050 input_saved_model_dir = self.mkdtemp() 1051 root = self._GetModelForV2() 1052 save.save(root, input_saved_model_dir, 1053 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 1054 1055 # Run TRT conversion. 1056 converter = self._CreateConverterV2(input_saved_model_dir) 1057 converter.convert() 1058 1059 # Patch save function with mock. 1060 with mock.patch.object(trt_convert, "save") as mock_save: 1061 mock_save.save = mock.MagicMock() 1062 # Save converted model with options. 1063 output_saved_model_dir = self.mkdtemp() 1064 options = save_options.SaveOptions(save_debug_info=True) 1065 converter.save(output_saved_model_dir, options=options) 1066 1067 # Assert that the saved_model.save function was called with the given 1068 # save_options by TrtGraphConverterV2.save method. 1069 mock_save.save.assert_called_once_with( 1070 mock.ANY, mock.ANY, mock.ANY, options=options) 1071 1072 @parameterized.named_parameters([ 1073 ("NoDeviceAssignment", None), 1074 ("GPU1", "GPU:1"), 1075 ]) 1076 @test_util.run_v2_only 1077 def testTrtGraphConverter_DevicePlacement(self, device_id): 1078 """Test case for trt_convert.TrtGraphConverter().""" 1079 1080 gpus = config.list_physical_devices("GPU") 1081 if len(gpus) < 2: 1082 self.skipTest("Expected at least 2 GPUs but found {} GPUs".format( 1083 len(gpus))) 1084 1085 np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 1086 np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 1087 1088 # Create a model and save it. 1089 input_saved_model_dir = self.mkdtemp() 1090 root = self._GetModelForV2() 1091 save.save(root, input_saved_model_dir, 1092 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 1093 1094 converter = self._CreateConverterV2( 1095 input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32) 1096 1097 converted_model = None 1098 # Specify device on which converted model should be placed 1099 with ops.device(device_id): 1100 converted_model = converter.convert() 1101 1102 # Verify that TRT engine op has the correct device. 1103 self._CheckTrtOps(converter._converted_func) 1104 1105 actual_device_id = self._GetUniqueTRTEngineOp( 1106 converter._converted_graph_def).device 1107 1108 expected_device_id = None 1109 if device_id is not None: 1110 expected_device_id = device_id 1111 else: 1112 expected_device_id = "GPU:0" 1113 1114 self.assertTrue(expected_device_id.lower() in actual_device_id.lower()) 1115 1116 del converter 1117 gc.collect() # Force GC to destroy the TRT engine cache. 1118 1119 @test_util.run_v2_only 1120 def testTrtGraphConverter_DevicePlacementOnCPU(self): 1121 """Test case for trt_convert.TrtGraphConverter().""" 1122 1123 np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 1124 np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 1125 1126 # Create a model and save it. 1127 input_saved_model_dir = self.mkdtemp() 1128 root = self._GetModelForV2() 1129 save.save(root, input_saved_model_dir, 1130 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 1131 1132 # Run TRT conversion. 1133 converter = self._CreateConverterV2( 1134 input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32) 1135 1136 converted_model = None 1137 # Specify device on which converted model should be placed 1138 with self.assertRaisesRegex(ValueError, r"Specified device is not a GPU"): 1139 with ops.device("CPU"): 1140 converted_model = converter.convert() 1141 1142 del converter 1143 gc.collect() # Force GC to destroy the TRT engine cache. 1144 1145 def _TestVariableHelper(self, variable_op, tf_model_name, tftrt_model_name, 1146 output_name): 1147 """Helper with the common code of variable converter tests.""" 1148 1149 model_dir = test.test_src_dir_path( 1150 "python/compiler/tensorrt/test/testdata/" + tf_model_name) 1151 trt_model_dir = os.path.join(self.mkdtemp(), tftrt_model_name) 1152 1153 # Load and convert the TF model. 1154 conv_params = trt_convert.TrtConversionParams( 1155 precision_mode="FP16", 1156 minimum_segment_size=3, 1157 max_workspace_size_bytes=10 << 20, 1158 maximum_cached_engines=1) 1159 with test_utils.experimental_feature_scope("disable_graph_freezing"): 1160 converter = trt_convert.TrtGraphConverterV2( 1161 input_saved_model_dir=model_dir, 1162 conversion_params=conv_params, 1163 use_dynamic_shape=True, 1164 dynamic_shape_profile_strategy="Optimal") 1165 converter.convert() 1166 1167 # Build and save the converted model. 1168 input_shapes = [[(4, 1, 1), (4, 1, 1)]] 1169 1170 def _InputFn(): 1171 for shapes in input_shapes: 1172 # return a list of input tensors 1173 yield [np.ones(shape=shape).astype(np.float32) for shape in shapes] 1174 1175 converter.build(_InputFn) 1176 converter.save(trt_model_dir) 1177 1178 # Load the converted model. 1179 saved_model_loaded = load.load(trt_model_dir, tags=[tag_constants.SERVING]) 1180 graph_func = saved_model_loaded.signatures[ 1181 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 1182 1183 # Check that there is one segment and that the 2 variables are in it. 1184 graph_def = graph_func.graph.as_graph_def() 1185 engines = [] 1186 for lib_function in graph_def.library.function: 1187 if re.search(r"TRTEngineOp_\d+_\d+_native_segment", 1188 lib_function.signature.name): 1189 node_ops = [node.op for node in lib_function.node_def] 1190 engines.append(node_ops) 1191 self.assertLen(engines, 1) 1192 self.assertEqual(engines[0].count(variable_op), 2) 1193 1194 # Run the function and check the output. 1195 np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 1196 np_input2 = ops.convert_to_tensor(2. * 1197 np.ones([4, 1, 1]).astype(np.float32)) 1198 output = graph_func(input1=np_input1, input2=np_input2)[output_name] 1199 self.assertEqual(output.shape, (4, 1, 1)) 1200 self.assertAllClose( 1201 np.asarray([42., 42., 42., 42.]).reshape([4, 1, 1]), output) 1202 1203 @test_util.run_v2_only 1204 def testVariableV2(self): 1205 """Test conversion of VariableV2 nodes.""" 1206 1207 self._TestVariableHelper("VariableV2", "tf_variablev2_saved_model", 1208 "tftrt_variablev2_saved_model", "output") 1209 1210 @test_util.run_v2_only 1211 def testReadVariableOp(self): 1212 """Test conversion of ReadVariableOp nodes.""" 1213 1214 self._TestVariableHelper("ReadVariableOp", "tf_readvariableop_saved_model", 1215 "tftrt_readvariableop_saved_model", "output_0") 1216 1217if __name__ == "__main__" and is_tensorrt_enabled(): 1218 test.main() 1219