xref: /aosp_15_r20/external/tensorflow/tensorflow/python/compiler/tensorrt/trt_convert_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to test TF-TensorRT integration."""
16
17import gc
18import os
19import re
20import tempfile
21from unittest import mock
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
27from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance  # pylint: disable=g-importing-member
28from tensorflow.core.framework import graph_pb2
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.python.compiler.tensorrt import trt_convert
31from tensorflow.python.compiler.tensorrt.test import test_utils
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import config
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import graph_util
37from tensorflow.python.framework import importer
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.framework import test_util
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import gen_resource_variable_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import test
47from tensorflow.python.saved_model import builder
48from tensorflow.python.saved_model import load
49from tensorflow.python.saved_model import loader
50from tensorflow.python.saved_model import loader_impl
51from tensorflow.python.saved_model import save
52from tensorflow.python.saved_model import save_options
53from tensorflow.python.saved_model import signature_constants
54from tensorflow.python.saved_model import signature_def_utils
55from tensorflow.python.saved_model import tag_constants
56from tensorflow.python.saved_model import utils
57from tensorflow.python.tools import saved_model_utils
58from tensorflow.python.trackable import autotrackable
59from tensorflow.python.util.lazy_loader import LazyLoader
60
61_SAVED_MODEL_SIGNATURE_KEY = "mypredict"
62
63gen_trt_ops = LazyLoader(
64    "gen_trt_ops", globals(),
65    "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
66
67
68class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
69  """Class to test Tensorflow-TensorRT integration python API."""
70
71  # Use a small max_workspace_size for tests so they don't consume too much GPU
72  # memory.
73  _TRT_MAX_WORKSPACE_SIZE_BYTES = (
74      trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES)
75
76  def mkdtemp(self):
77    return tempfile.mkdtemp(dir=self.get_temp_dir())
78
79  def testTRTEngineInstanceAvailable(self):
80    # test if we can access the TRTEngineInstance protobuf
81    assert hasattr(TRTEngineInstance(), "serialized_engine")
82
83  def _GetConfigProto(self, rewriter_config=None):
84    """Get ConfigProto for session creation."""
85    config = config_pb2.ConfigProto(
86        gpu_options=config_pb2.GPUOptions(allow_growth=True))
87    if rewriter_config:
88      config.graph_options.rewrite_options.CopyFrom(rewriter_config)
89    return config
90
91  @classmethod
92  def _GetGraph(cls, inp1, inp2, var):
93    """Get the graph for testing."""
94    # The graph computes: inp1^2 + inp1*var + inp1 + inp2 + var
95    add = inp1 + var
96    mul = inp1 * add
97    add = mul + add
98    add = add + inp2
99    out = array_ops.identity(add, name="output")
100    return out
101
102  def _GetModelForV2(self):
103
104    class SimpleModel(autotrackable.AutoTrackable):
105
106      def __init__(self):
107        self.v = None
108
109      @def_function.function(input_signature=[
110          tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
111          tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
112      ])
113      def run(self, inp1, inp2):
114        if self.v is None:
115          self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
116        return TrtConvertTest._GetGraph(inp1, inp2, self.v)
117
118    return SimpleModel()
119
120  def _GetGraphForV1(self, device):
121
122    def _GraphFn():
123      inp1 = array_ops.placeholder(
124          dtype=dtypes.float32, shape=[None, 1, 1], name="input1")
125      inp2 = array_ops.placeholder(
126          dtype=dtypes.float32, shape=[None, 1, 1], name="input2")
127      var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
128      out = TrtConvertTest._GetGraph(inp1, inp2, var)
129      return g, var, inp1, inp2, out
130
131    g = ops.Graph()
132    with g.as_default():
133      if device:
134        with g.device(device):
135          return _GraphFn()
136      return _GraphFn()
137
138  def _GetGraphDefForV1(self, device):
139    """Get the graph def for testing."""
140    g, var, _, _, _ = self._GetGraphForV1(device)
141    with self.session(graph=g, config=self._GetConfigProto()) as sess:
142      sess.run(var.initializer)
143      graph_def = graph_util.convert_variables_to_constants(
144          sess, g.as_graph_def(add_shapes=True), ["output"])
145    node_name_to_op = {node.name: node.op for node in graph_def.node}
146    self.assertEqual(
147        {
148            "v1": "Const",
149            "add/ReadVariableOp": "Identity",
150            "input1": "Placeholder",
151            "input2": "Placeholder",
152            "add": "AddV2",
153            "mul": "Mul",
154            "add_1": "AddV2",
155            "add_2": "AddV2",
156            "output": "Identity"
157        }, node_name_to_op)
158    return graph_def
159
160  def _WriteInputSavedModelForV1(self, input_saved_model_dir, device):
161    """Write the saved model as an input for testing."""
162    g, var, inp1, inp2, out = self._GetGraphForV1(device)
163    signature_def = signature_def_utils.build_signature_def(
164        inputs={
165            "myinput1": utils.build_tensor_info(inp1),
166            "myinput2": utils.build_tensor_info(inp2)
167        },
168        outputs={"myoutput": utils.build_tensor_info(out)},
169        method_name=signature_constants.PREDICT_METHOD_NAME)
170    saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
171    with self.session(graph=g, config=self._GetConfigProto()) as sess:
172      sess.run(var.initializer)
173      saved_model_builder.add_meta_graph_and_variables(
174          sess, [tag_constants.SERVING],
175          signature_def_map={_SAVED_MODEL_SIGNATURE_KEY: signature_def})
176    saved_model_builder.save()
177
178  def _ConvertGraphV1(self,
179                      output_saved_model_dir=None,
180                      need_calibration=False,
181                      max_batch_size=1,
182                      minimum_segment_size=3,
183                      is_dynamic_op=False,
184                      maximum_cached_engines=1,
185                      device=None):
186    """Helper method to convert a GraphDef or SavedModel using TF-TRT."""
187    input_saved_model_dir = None
188    if output_saved_model_dir:
189      input_saved_model_dir = self.mkdtemp()
190      self._WriteInputSavedModelForV1(input_saved_model_dir, device)
191
192    # Calibration requires dynamic_op.
193    if need_calibration:
194      is_dynamic_op = True
195
196    # For dynamic_op, the converter requires the unused max_batch_size=None.
197    if is_dynamic_op:
198      max_batch_size = None
199
200    converter = trt_convert.TrtGraphConverter(
201        input_saved_model_dir=input_saved_model_dir,
202        input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
203        input_graph_def=None
204        if input_saved_model_dir else self._GetGraphDefForV1(device),
205        nodes_denylist=None if input_saved_model_dir else ["output"],
206        max_batch_size=max_batch_size,
207        max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
208        precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration
209                        else trt_convert.TrtPrecisionMode.FP32),
210        minimum_segment_size=minimum_segment_size,
211        is_dynamic_op=is_dynamic_op,
212        maximum_cached_engines=maximum_cached_engines)
213    output_graph_def = converter.convert()
214
215    if need_calibration:
216
217      class CalibrationData(object):
218
219        def __init__(self):
220          self._data = 0
221
222        def next(self):
223          self._data += 1
224          return {"input1:0": [[[self._data]]], "input2:0": [[[self._data]]]}
225
226      output_graph_def = converter.calibrate(
227          fetch_names=["output:0"],
228          num_runs=10,
229          feed_dict_fn=CalibrationData().next)
230
231    if output_saved_model_dir is not None:
232      converter.save(output_saved_model_dir=output_saved_model_dir)
233    return output_graph_def
234
235  # Remove the graph sequence number prefix from the name only if the name has
236  # a prefix TRTEngineOp_n_.
237  def _MayRemoveGraphSequenceNumber(self, name):
238    prefix = re.search(r"TRTEngineOp_\d{3,}_", name)
239    if prefix and name.startswith(prefix.group(0)):
240      parts = name.split("_", maxsplit=2)
241      assert len(parts) == 3
242      return parts[0] + "_" + parts[2]
243    return name
244
245  # Return the unique TRTEngineOp in the given graph def.
246  def _GetUniqueTRTEngineOp(self, graph_def):
247    trt_engine_nodes = [
248        node for node in graph_def.node if node.op == "TRTEngineOp"
249    ]
250    assert len(trt_engine_nodes) == 1
251    return trt_engine_nodes[0]
252
253  def _TestTrtGraphConverter(self,
254                             device,
255                             output_saved_model_dir=None,
256                             need_calibration=False,
257                             is_dynamic_op=False):
258    """General method to test trt_convert.TrtGraphConverter()."""
259    output_graph_def = self._ConvertGraphV1(
260        output_saved_model_dir=output_saved_model_dir,
261        need_calibration=need_calibration,
262        is_dynamic_op=is_dynamic_op,
263        device=device)
264    graph_defs_to_verify = [output_graph_def]
265
266    if output_saved_model_dir:
267      saved_model_graph_def = saved_model_utils.get_meta_graph_def(
268          output_saved_model_dir, tag_constants.SERVING).graph_def
269      self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
270      graph_defs_to_verify.append(saved_model_graph_def)
271
272    for graph_def in graph_defs_to_verify:
273      node_name_to_op = {
274          self._MayRemoveGraphSequenceNumber(node.name): node.op
275          for node in graph_def.node
276      }
277      if device is not None and device.startswith("/CPU:"):
278        self.assertEqual(
279            {
280                "add": "AddV2",
281                "v1": "Const",
282                "add_1": "AddV2",
283                "add_2": "AddV2",
284                "input1": "Placeholder",
285                "input2": "Placeholder",
286                "mul": "Mul",
287                "output": "Identity"
288            }, node_name_to_op)
289      else:
290        self.assertEqual(
291            {
292                "input1": "Placeholder",
293                "input2": "Placeholder",
294                "TRTEngineOp_000": "TRTEngineOp",
295                "output": "Identity"
296            }, node_name_to_op)
297
298      if need_calibration:
299        trt_engine_nodes = [
300            node for node in graph_def.node if node.op == "TRTEngineOp"
301        ]
302        if device is not None and device.startswith("/CPU:"):
303          self.assertEmpty(trt_engine_nodes)
304          return
305
306        self.assertNotEmpty(trt_engine_nodes)
307        for node in trt_engine_nodes:
308          self.assertTrue(len(node.attr["calibration_data"].s))
309        # Run the calibrated graph.
310        # TODO(laigd): consider having some input where the answer is different.
311        with ops.Graph().as_default():
312          importer.import_graph_def(graph_def, name="")
313          with self.session(config=self._GetConfigProto()) as sess:
314            for test_data in range(10):
315              self.assertEqual((test_data + 1.0)**2 + test_data,
316                               sess.run(
317                                   "output:0",
318                                   feed_dict={
319                                       "input1:0": [[[test_data]]],
320                                       "input2:0": [[[test_data]]]
321                                   }))
322
323  @parameterized.named_parameters([
324      ("NoDeviceAssignment", None),
325      ("GPU", "/GPU:0"),
326      ("CPU", "/CPU:0"),
327  ])
328  @test_util.deprecated_graph_mode_only
329  def testTrtGraphConverter_OfflineConversion(self, device):
330    """Test case for trt_convert.TrtGraphConverter()."""
331
332    for need_calibration in [False, True]:
333      # Use GraphDef as input.
334      self._TestTrtGraphConverter(device)
335
336      # Use SavedModel as input.
337      self._TestTrtGraphConverter(
338          device,
339          output_saved_model_dir=self.mkdtemp(),
340          need_calibration=need_calibration)
341
342  @parameterized.named_parameters([
343      ("NoDeviceAssignment", None),
344      ("GPU", "/device:GPU:0"),
345      ("CPU", "/device:CPU:0"),
346  ])
347  @test_util.deprecated_graph_mode_only
348  def testTrtGraphConverter_OnlineConversion(self, device):
349    """Test case for TF-TRT conversion using Grappler directly."""
350
351    conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
352        precision_mode=trt_convert.TrtPrecisionMode.FP32)
353    config = self._GetConfigProto(
354        rewriter_config=trt_convert.get_tensorrt_rewriter_config(
355            conversion_params,
356            is_dynamic_op=False,
357            max_batch_size=1,
358            is_v2=False))
359
360    with ops.Graph().as_default():
361      # Online conversion requires a frozen graph, so we reuse inp1 as the var
362      # argument.
363      inp1 = array_ops.placeholder(
364          dtype=dtypes.float32, shape=[None, 1, 1], name="input1")
365      inp2 = array_ops.placeholder(
366          dtype=dtypes.float32, shape=[None, 1, 1], name="input2")
367      if device:
368        with ops.device(device):
369          TrtConvertTest._GetGraph(inp1, inp2, inp1)
370      else:
371        TrtConvertTest._GetGraph(inp1, inp2, inp1)
372      with self.session(config=config) as sess:
373        self._TestRun(sess, batch_size=1)
374
375  def _CreateConverterV2(
376      self,
377      input_saved_model_dir,
378      input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
379      max_workspace_size_bytes=10 << 20,  # Use a smaller workspace.
380      precision_mode=trt_convert.TrtPrecisionMode.FP32,
381      maximum_cached_engines=2,
382      allow_build_at_runtime=True):
383    return trt_convert.TrtGraphConverterV2(
384        input_saved_model_dir=input_saved_model_dir,
385        input_saved_model_signature_key=input_saved_model_signature_key,
386        max_workspace_size_bytes=max_workspace_size_bytes,
387        precision_mode=precision_mode,
388        maximum_cached_engines=maximum_cached_engines,
389        allow_build_at_runtime=allow_build_at_runtime)
390
391  def _CheckTrtOps(self, concrete_func, check_fn=None, num_engines=1):
392    graph_def = concrete_func.graph.as_graph_def()
393    trt_op_names = []
394    for node in graph_def.node:
395      if node.op == "TRTEngineOp":
396        trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
397        if check_fn:
398          check_fn(node)
399    for func in graph_def.library.function:
400      for node in func.node_def:
401        if node.op == "TRTEngineOp":
402          trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
403          if check_fn:
404            check_fn(node)
405    self.assertLen(trt_op_names, num_engines)
406
407  def _RandomInput(self, shape, dtype=np.float32):
408    inp1 = np.random.random_sample(shape).astype(dtype)
409    inp2 = np.random.random_sample(shape).astype(dtype)
410    return inp1, inp2
411
412  @test_util.run_v2_only
413  def testTrtGraphConverter_DynamicConversion_v2(self):
414    """Test case for trt_convert.TrtGraphConverter()."""
415
416    np_input1, np_input2 = self._RandomInput([4, 1, 1])
417
418    # Create a model and save it.
419    input_saved_model_dir = self.mkdtemp()
420    root = self._GetModelForV2()
421    expected_output = root.run(np_input1, np_input2)
422    save.save(root, input_saved_model_dir,
423              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
424
425    # Run TRT conversion.
426    converter = self._CreateConverterV2(input_saved_model_dir)
427    converter.convert()
428
429    # Verify the converted GraphDef and ConcreteFunction.
430    self._CheckTrtOps(converter._converted_func)  # pylint: disable=protected-access
431
432    trt_engine_name = self._GetUniqueTRTEngineOp(
433        converter._converted_graph_def).name
434
435    # Save the converted model without any TRT engine cache.
436    output_saved_model_dir = self.mkdtemp()
437    converter.save(output_saved_model_dir)
438    unexpected_asset_file = os.path.join(
439        output_saved_model_dir,
440        "assets/trt-serialized-engine." + trt_engine_name)
441    self.assertFalse(os.path.exists(unexpected_asset_file))
442
443    # Run the converted function to populate the engine cache.
444    def _InputFn():
445      yield np_input1, np_input2
446
447    converter.build(input_fn=_InputFn)
448
449    # Save the converted model again with serialized engine cache.
450    output_saved_model_dir = self.mkdtemp()
451    converter.save(output_saved_model_dir)
452    expected_asset_file = os.path.join(
453        output_saved_model_dir,
454        "assets/trt-serialized-engine." + trt_engine_name)
455    self.assertTrue(os.path.exists(expected_asset_file))
456    self.assertTrue(os.path.getsize(expected_asset_file))
457
458    del converter
459    gc.collect()  # Force GC to destroy the TRT engine cache.
460
461    # Load and verify the converted model.
462    #
463    # TODO(laigd): the name of the new input_signature of the
464    # `root_with_trt.run` function is empty string (originally was None),
465    # investigate why.
466    root_with_trt = load.load(output_saved_model_dir)
467    # TODO(laigd): `root_with_trt.run` is still using the original graph without
468    # trt. Consider changing that.
469    # self._CheckTrtOps(root_with_trt.run.get_concrete_function())
470    converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
471    self._CheckTrtOps(converted_signature)
472    output_with_trt = converted_signature(
473        inp1=ops.convert_to_tensor(np_input1),
474        inp2=ops.convert_to_tensor(np_input2))
475    # The output of running the converted signature is a dict due to
476    # compatibility reasons with V1 SavedModel signature mechanism.
477    self.assertAllClose(
478        expected_output,
479        list(output_with_trt.values())[0],
480        atol=1e-6,
481        rtol=1e-6)
482
483    del root_with_trt
484    gc.collect()  # Force GC to destroy the TRT engine cache.
485
486  @test_util.run_v2_only
487  def testTrtGraphConverter_ShapeOp_Int32InputOutput_v2(self):
488    """Testing ShapeOp and int32 values as engine input and output."""
489
490    class ShapeOpModel(autotrackable.AutoTrackable):
491
492      def __init__(self):
493        self.v = None
494
495      @def_function.function(input_signature=[
496          tensor_spec.TensorSpec(shape=[None, None], dtype=dtypes.float32)
497      ])
498      def run(self, x):
499        q = x + 1
500        q_shape = array_ops.shape(q)
501        # Add an OP that is not supported by TF-TRT. This allows TF-TRT to build
502        # two engines. The first engine produces an int32 output and the second
503        # engines has an int32 input and an int32 output.
504        q = math_ops.cumsum(q_shape)
505        q = q * 2
506        return array_ops.identity(q, name="output")
507
508    np_input = np.random.random_sample([5, 3]).astype(np.float32)
509
510    def _InputFunc():
511      yield (np_input,)
512
513    # Create the SavedModel.
514    root = ShapeOpModel()
515    expected_output = root.run(np_input)
516    input_saved_model_dir = self.mkdtemp()
517    save.save(root, input_saved_model_dir, signatures=root.run)
518
519    # Convert the graph to TF-TRT.
520    conv_params = trt_convert.TrtConversionParams(minimum_segment_size=2)
521    converter = trt_convert.TrtGraphConverterV2(
522        input_saved_model_dir=input_saved_model_dir,
523        use_dynamic_shape=True,
524        **conv_params._asdict())
525    converter.convert()
526
527    # Build the graph with the input generator. This runs the TRTEngineOp native
528    # segment.
529    converter.build(_InputFunc)
530    output_saved_model_dir = self.mkdtemp()
531    converter.save(output_saved_model_dir)
532
533    root_with_trt = load.load(output_saved_model_dir)
534    converted_signature = root_with_trt.signatures["serving_default"]
535    # Check that the graph is converted to two TRTEngineOps.
536    self._CheckTrtOps(converted_signature, num_engines=2)
537    # Run the graph.
538    output_with_trt = converted_signature(x=ops.convert_to_tensor(np_input))
539    # Check the result of the run.
540    self.assertAllClose(expected_output, list(output_with_trt.values())[0])
541
542  @test_util.run_v2_only
543  def testTrtGraphConverter_Int8Conversion_v2(self):
544
545    np_input1, np_input2 = self._RandomInput([4, 1, 1])
546
547    # Create a model and save it.
548    input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
549    root = self._GetModelForV2()
550    expected_output = root.run(np_input1, np_input2)
551    save.save(root, input_saved_model_dir,
552              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
553
554    # Run TRT conversion.
555    converter = self._CreateConverterV2(
556        input_saved_model_dir,
557        precision_mode=trt_convert.TrtPrecisionMode.INT8,
558        maximum_cached_engines=3)
559
560    # Convert and perform INT8 calibration
561    def _CalibrationInputFn():
562      yield np_input1, np_input2
563
564    converter.convert(calibration_input_fn=_CalibrationInputFn)
565
566    trt_engine_name = self._GetUniqueTRTEngineOp(
567        converter._converted_graph_def).name
568
569    def _CheckFn(node):
570      self.assertTrue(len(node.attr["calibration_data"].s), node.name)
571
572    # Verify the converted GraphDef.
573    self._CheckTrtOps(converter._converted_func, _CheckFn)  # pylint: disable=protected-access
574
575    # Build another engine with different batch size.
576    def _InputFn():
577      yield self._RandomInput([5, 1, 1])
578
579    converter.build(input_fn=_InputFn)
580
581    # Save the converted model.
582    # TODO(laigd): check that it should contain two engines.
583    output_saved_model_dir = self.mkdtemp()
584    converter.save(output_saved_model_dir)
585    expected_asset_file = os.path.join(
586        output_saved_model_dir,
587        "assets/trt-serialized-engine." + trt_engine_name)
588    self.assertTrue(os.path.exists(expected_asset_file))
589    self.assertTrue(os.path.getsize(expected_asset_file))
590
591    del converter
592    gc.collect()  # Force GC to destroy the TRT engine cache.
593
594    # Load and verify the converted model.
595    root_with_trt = load.load(output_saved_model_dir)
596    converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
597    self._CheckTrtOps(converted_signature, _CheckFn)
598    output_with_trt = converted_signature(
599        inp1=ops.convert_to_tensor(np_input1),
600        inp2=ops.convert_to_tensor(np_input2))
601    self.assertEqual(1, len(output_with_trt))
602    # The output of running the converted signature is a dict due to
603    # compatibility reasons with V1 SavedModel signature mechanism.
604    self.assertAllClose(
605        expected_output,
606        list(output_with_trt.values())[0],
607        atol=1e-6,
608        rtol=1e-6)
609
610    # Run with an input of different batch size. It should build a new engine
611    # using calibration table.
612    # TODO(laigd): check that it should contain three engines.
613    np_input1, np_input2 = self._RandomInput([6, 1, 1])
614    converted_signature(
615        inp1=ops.convert_to_tensor(np_input1),
616        inp2=ops.convert_to_tensor(np_input2))
617
618    del root_with_trt
619    gc.collect()  # Force GC to destroy the TRT engine cache.
620
621  @test_util.run_v2_only
622  def testTrtGraphConverter_DestroyEngineCache(self):
623    """Test case for trt_convert.TrtGraphConverter()."""
624
625    np_input1, np_input2 = self._RandomInput([4, 1, 1])
626
627    # Create a model and save it.
628    input_saved_model_dir = self.mkdtemp()
629    root = self._GetModelForV2()
630    save.save(root, input_saved_model_dir,
631              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
632
633    # Run TRT conversion.
634    converter = self._CreateConverterV2(input_saved_model_dir)
635    converter.convert()
636
637    trt_engine_name = self._GetUniqueTRTEngineOp(
638        converter._converted_graph_def).name
639
640    def _InputFn():
641      yield np_input1, np_input2
642
643    converter.build(input_fn=_InputFn)  # Populate the TRT engine cache.
644    output_saved_model_dir = self.mkdtemp()
645    converter.save(output_saved_model_dir)
646
647    def _DestroyCache():
648      with ops.device("GPU:0"):
649        handle = gen_trt_ops.create_trt_resource_handle(
650            resource_name=trt_engine_name)
651        gen_resource_variable_ops.destroy_resource_op(
652            handle, ignore_lookup_error=False)
653
654    with self.assertRaisesRegex(errors.NotFoundError,
655                                r"Resource .* does not exist."):
656      _DestroyCache()
657
658    # Load the converted model and make sure the engine cache is populated by
659    # default.
660    root = load.load(output_saved_model_dir)
661    _DestroyCache()
662    with self.assertRaisesRegex(errors.NotFoundError,
663                                r"Resource .* does not exist."):
664      _DestroyCache()
665
666    # Load the converted model again and make sure the engine cache is destroyed
667    # when the model goes out of scope.
668    root = load.load(output_saved_model_dir)
669    del root
670    gc.collect()  # Force GC to destroy the TRT engine cache.
671    with self.assertRaisesRegex(errors.NotFoundError,
672                                r"Resource .* does not exist."):
673      _DestroyCache()
674
675  def _CompareSavedModel(self, model_class):
676    signature_key = "serving_default"
677
678    def _GetModelPaths(model_class):
679      input_saved_model_dir = self.mkdtemp()
680      root = model_class()
681      save.save(root, input_saved_model_dir)
682
683      converter = self._CreateConverterV2(
684          input_saved_model_dir, input_saved_model_signature_key=signature_key)
685      converter.convert()
686      output_saved_model_dir = self.mkdtemp()
687      converter.save(output_saved_model_dir)
688      return input_saved_model_dir, output_saved_model_dir
689
690    def _GetSignatureDef(export_dir):
691      saved_model_proto = loader_impl.parse_saved_model(export_dir)
692      self.assertEqual(1, len(saved_model_proto.meta_graphs))
693      meta_graph = saved_model_proto.meta_graphs[0]
694      self.assertIn(signature_key, meta_graph.signature_def)
695      return meta_graph.signature_def[signature_key]
696
697    def _CompareSignatureDef(original_def, converted_def, is_input):
698      endpoints = original_def.inputs if is_input else original_def.outputs
699      converted_endpoints = (
700          converted_def.inputs if is_input else converted_def.outputs)
701      self.assertEqual(set(endpoints.keys()), set(converted_endpoints.keys()))
702      for key in endpoints:
703        original_input = endpoints[key]
704        converted_input = converted_endpoints[key]
705        self.assertEqual(original_input.name, converted_input.name)
706        self.assertEqual(original_input.dtype, converted_input.dtype)
707        self.assertEqual(
708            tensor_shape.TensorShape(original_input.tensor_shape).as_list(),
709            tensor_shape.TensorShape(converted_input.tensor_shape).as_list())
710
711    def _GetStructuredOutputs(export_dir):
712      root = load.load(export_dir)
713      return root.signatures[signature_key].structured_outputs
714
715    saved_model_path, converted_saved_model_path = _GetModelPaths(model_class)
716    original_def = _GetSignatureDef(saved_model_path)
717    converted_def = _GetSignatureDef(converted_saved_model_path)
718    self.assertEqual(original_def.method_name, converted_def.method_name)
719    _CompareSignatureDef(original_def, converted_def, True)
720    _CompareSignatureDef(original_def, converted_def, False)
721
722    self.assertEqual(
723        _GetStructuredOutputs(saved_model_path),
724        _GetStructuredOutputs(converted_saved_model_path))
725
726  @test_util.run_v2_only
727  def testRetainSignatureInfo_NoInputs(self):
728
729    class _Model(autotrackable.AutoTrackable):
730
731      @def_function.function(input_signature=[])
732      def run(self):
733        return array_ops.constant(1.0)
734
735    self._CompareSavedModel(_Model)
736
737  @test_util.run_v2_only
738  def testRetainSignatureInfo_OneInput(self):
739
740    class _Model(autotrackable.AutoTrackable):
741
742      @def_function.function(input_signature=[
743          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32)
744      ])
745      def run(self, inp):
746        return inp + inp * inp
747
748    self._CompareSavedModel(_Model)
749
750  @test_util.run_v2_only
751  def testRetainSignatureInfo_TwoInputs(self):
752
753    class _Model(autotrackable.AutoTrackable):
754
755      @def_function.function(input_signature=[
756          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32),
757          tensor_spec.TensorSpec(shape=[None, 2], dtype=dtypes.float32)
758      ])
759      def run(self, inp1, inp2):
760        return inp1 + inp2 * inp2
761
762    self._CompareSavedModel(_Model)
763
764  @test_util.run_v2_only
765  def testRetainSignatureInfo_OneOutputSignatureKey(self):
766
767    class _Model(autotrackable.AutoTrackable):
768
769      @def_function.function(input_signature=[])
770      def run(self):
771        return {"my_output": array_ops.constant(1.0)}
772
773    self._CompareSavedModel(_Model)
774
775  @test_util.run_v2_only
776  def testRetainSignatureInfo_TwoOutputSignatureKeys(self):
777
778    class _Model(autotrackable.AutoTrackable):
779
780      @def_function.function(input_signature=[
781          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32)
782      ])
783      def run(self, inp):
784        # Here the keys are not ordered lexicographically on purpose.
785        return {
786            "output_b": array_ops.constant(1.0),
787            "output_a": inp + inp * inp
788        }
789
790    self._CompareSavedModel(_Model)
791
792  def _TestRun(self, sess, batch_size):
793    result = sess.run(
794        "output:0",
795        feed_dict={
796            "input1:0": [[[1.0]]] * batch_size,
797            "input2:0": [[[1.0]]] * batch_size
798        })
799    self.assertAllEqual([[[5.0]]] * batch_size, result)
800
801  @parameterized.named_parameters([
802      ("LargeSegmentSize", 7),
803      ("NoMainGraphConversionSegmentSize", -1),
804  ])
805  @test_util.deprecated_graph_mode_only
806  def testTrtGraphConverter_MinimumSegmentSize(self, minimum_segment_size):
807    output_graph_def = self._ConvertGraphV1(
808        minimum_segment_size=minimum_segment_size)
809    node_name_to_op = {node.name: node.op for node in output_graph_def.node}
810    self.assertEqual(
811        {
812            "v1": "Const",
813            "input1": "Placeholder",
814            "input2": "Placeholder",
815            "add": "AddV2",
816            "mul": "Mul",
817            "add_1": "AddV2",
818            "add_2": "AddV2",
819            "output": "Identity"
820        }, node_name_to_op)
821
822  @test_util.deprecated_graph_mode_only
823  def testTrtGraphConverter_DynamicOp(self):
824
825    output_saved_model_dir = self.mkdtemp()
826    output_graph_def = self._ConvertGraphV1(
827        output_saved_model_dir=output_saved_model_dir,
828        is_dynamic_op=True,
829        maximum_cached_engines=2)
830
831    # Test the output GraphDef.
832    with ops.Graph().as_default():
833      importer.import_graph_def(output_graph_def, name="")
834      with self.session(config=self._GetConfigProto()) as sess:
835        # Run with batch size 1, a new engine is created and cached.
836        self._TestRun(sess, 1)
837        # Run with batch size 2, a new engine is created and cached.
838        self._TestRun(sess, 2)
839        # Run with batch size 3, since the number of cached engines has reached
840        # the max, it should evict an old engine and create a new one.
841        self._TestRun(sess, 3)
842
843    # Test the output SavedModel
844    with ops.Graph().as_default():
845      with self.session(config=self._GetConfigProto()) as sess:
846        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
847        # Run with batch size 1, a new engine is created and cached.
848        self._TestRun(sess, 1)
849        # Run with batch size 2, a new engine is created and cached.
850        self._TestRun(sess, 2)
851        # Run with batch size 3, since the number of cached engines has reached
852        # the max, it should evict an old engine and create a new one.
853        self._TestRun(sess, 3)
854
855  @test_util.deprecated_graph_mode_only
856  def testTrtGraphConverter_StaticOp(self):
857
858    output_saved_model_dir = self.mkdtemp()
859    output_graph_def = self._ConvertGraphV1(
860        output_saved_model_dir=output_saved_model_dir, maximum_cached_engines=1)
861
862    # Test the output GraphDef.
863    with ops.Graph().as_default():
864      importer.import_graph_def(output_graph_def, name="")
865      with self.session(config=self._GetConfigProto()) as sess:
866        # Run with batch size 1, the default engine embedded in the graphdef
867        # will be used.
868        self._TestRun(sess, 1)
869        # Run with batch size 2, which exceed the max_batch_size, it should try
870        # to fall back to TF function.
871        self._TestRun(sess, 2)
872
873    # Test the output SavedModel
874    with ops.Graph().as_default():
875      with self.session(config=self._GetConfigProto()) as sess:
876        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
877        # Run with batch size 1, the default engine embedded in the graphdef
878        # will be used.
879        self._TestRun(sess, 1)
880        # Run with batch size 2, which exceed the max_batch_size, it should try
881        # to fall back to TF function.
882        self._TestRun(sess, 2)
883
884  @test_util.run_v2_only
885  def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self):
886    np_input1, np_input2 = self._RandomInput([4, 1, 1])
887
888    # Create a model and save it.
889    input_saved_model_dir = self.mkdtemp()
890    root = self._GetModelForV2()
891    save.save(root, input_saved_model_dir,
892              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
893
894    def _InputFn():
895      yield np_input1, np_input2
896
897    # Run TRT conversion
898    converter = self._CreateConverterV2(
899        input_saved_model_dir, max_workspace_size_bytes=1 << 20)
900    converter.convert()
901
902    os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
903    os.environ["TF_TRT_ABORT_CUDA_ENGINE_BUILD"] = "True"
904    with self.assertRaisesRegex(
905        errors.AbortedError,
906        r"User disallowed engine native segment execution"):
907      try:
908        converter.build(input_fn=_InputFn)
909      finally:
910        # Always reset the environment variable.
911        os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
912        os.environ["TF_TRT_ABORT_CUDA_ENGINE_BUILD"] = "False"
913
914    converter.build(input_fn=_InputFn)
915
916  @parameterized.parameters((True, True), (True, False), (False, True),
917                            (False, False))
918  @test_util.run_v2_only
919  def testTrtGraphConverter_AllowBuildAtRuntime(self, build_offline,
920                                                allow_build_at_runtime):
921    if not is_tensorrt_enabled():
922      return
923
924    # Create a model and save it.
925    input_saved_model_dir = self.mkdtemp()
926    root = self._GetModelForV2()
927    save.save(root, input_saved_model_dir,
928              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
929
930    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
931    np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
932
933    def _InputFn():
934      yield np_input1, np_input2
935
936    # Run TRT conversion and request an unreasonably large workspace.
937    converter = self._CreateConverterV2(
938        input_saved_model_dir, allow_build_at_runtime=allow_build_at_runtime)
939    converter.convert()
940    if build_offline:
941      converter.build(input_fn=_InputFn)
942    # Output saved model dir.
943    output_saved_model_dir = self.mkdtemp()
944    converter.save(output_saved_model_dir)
945
946    saved_model_loaded = load.load(
947        output_saved_model_dir, tags=[tag_constants.SERVING])
948    graph_func = saved_model_loaded.signatures[_SAVED_MODEL_SIGNATURE_KEY]
949
950    # Checks the TrtEngineOp(s) have the correct attribute(s).
951    def _CheckFn(node):
952      self.assertEqual(node.attr["_allow_build_at_runtime"].b,
953                       allow_build_at_runtime)
954
955    self._CheckTrtOps(graph_func, _CheckFn)
956    # If the engine was not build offline and the user set not to build at
957    # runtime and not to run native segments. Then, it will report an error.
958    if not build_offline and not allow_build_at_runtime:
959      with self.assertRaisesRegex(
960          errors.AbortedError,
961          r"User disallowed engine native segment execution"):
962        try:
963          os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
964          graph_func(inp1=np_input1, inp2=np_input2)
965        finally:
966          os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
967    else:
968      output = graph_func(inp1=np_input1, inp2=np_input2)["output_0"]
969      self.assertEqual(output.shape, (4, 1, 1))
970      self.assertAllClose(
971          np.asarray([5.0, 5.0, 5.0, 5.0]).reshape([4, 1, 1]), output)
972
973  @test_util.run_v2_only
974  def testBackwardCompatibility(self):
975    """Load and execute a model that was saved in TF2.0."""
976
977    model_dir = test.test_src_dir_path(
978        "python/compiler/tensorrt/test/testdata/tftrt_2.0_saved_model")
979    saved_model_loaded = load.load(model_dir, tags=[tag_constants.SERVING])
980    graph_func = saved_model_loaded.signatures[
981        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
982
983    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
984    np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
985    output = graph_func(input1=np_input1, input2=np_input2)["output_0"]
986
987    self.assertEqual(output.shape, (4, 1, 1))
988    self.assertAllClose(
989        np.asarray([5.0, 5.0, 5.0, 5.0]).reshape([4, 1, 1]), output)
990
991  @parameterized.named_parameters([
992      ("SaveGPUSpecificEngine", True),
993      ("WithoutSaveGPUSpecificEngine", False),
994  ])
995  @test_util.run_v2_only
996  def testTrtGraphConverter_SaveGPUSpecificEngine(self, save_engine_flag):
997    """Test case for trt_convert.TrtGraphConverter()."""
998
999    np_input1, np_input2 = self._RandomInput([4, 1, 1])
1000
1001    # Create a model and save it.
1002    input_saved_model_dir = self.mkdtemp()
1003    root = self._GetModelForV2()
1004    save.save(root, input_saved_model_dir,
1005              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
1006
1007    # Run TRT conversion.
1008    converter = self._CreateConverterV2(
1009        input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.INT8)
1010
1011    # Run the converted function to populate the engine cache.
1012    def CalibrationFn():
1013      yield np_input1, np_input2
1014
1015    converter.convert(calibration_input_fn=CalibrationFn)
1016
1017    # Verify the converted GraphDef and ConcreteFunction.
1018    self._CheckTrtOps(converter._converted_func)
1019
1020    trt_engine_name = self._GetUniqueTRTEngineOp(
1021        converter._converted_graph_def).name
1022
1023    # Save the converted model with or without any TRT engine cache
1024    # based on the value of save_engine_flag.
1025    output_saved_model_dir = self.mkdtemp()
1026
1027    converter.save(
1028        output_saved_model_dir, save_gpu_specific_engines=save_engine_flag)
1029
1030    expected_asset_file = os.path.join(
1031        output_saved_model_dir,
1032        "assets/trt-serialized-engine." + trt_engine_name)
1033
1034    self.assertTrue(os.path.exists(expected_asset_file))
1035    if save_engine_flag:
1036      # engine is saved so we expect engine data
1037      self.assertTrue(os.path.getsize(expected_asset_file))
1038    else:
1039      # engine is not saved so files should be empty
1040      self.assertFalse(os.path.getsize(expected_asset_file))
1041
1042    del converter
1043    gc.collect()  # Force GC to destroy the TRT engine cache.
1044
1045  @test_util.run_v2_only
1046  def testTrtGraphConverterV2_SaveWithOptions(self):
1047    """Test to make sure that save method respects options kwarg."""
1048
1049    # Create a model and save it.
1050    input_saved_model_dir = self.mkdtemp()
1051    root = self._GetModelForV2()
1052    save.save(root, input_saved_model_dir,
1053              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
1054
1055    # Run TRT conversion.
1056    converter = self._CreateConverterV2(input_saved_model_dir)
1057    converter.convert()
1058
1059    # Patch save function with mock.
1060    with mock.patch.object(trt_convert, "save") as mock_save:
1061      mock_save.save = mock.MagicMock()
1062      # Save converted model with options.
1063      output_saved_model_dir = self.mkdtemp()
1064      options = save_options.SaveOptions(save_debug_info=True)
1065      converter.save(output_saved_model_dir, options=options)
1066
1067      # Assert that the saved_model.save function was called with the given
1068      # save_options by TrtGraphConverterV2.save method.
1069      mock_save.save.assert_called_once_with(
1070          mock.ANY, mock.ANY, mock.ANY, options=options)
1071
1072  @parameterized.named_parameters([
1073      ("NoDeviceAssignment", None),
1074      ("GPU1", "GPU:1"),
1075  ])
1076  @test_util.run_v2_only
1077  def testTrtGraphConverter_DevicePlacement(self, device_id):
1078    """Test case for trt_convert.TrtGraphConverter()."""
1079
1080    gpus = config.list_physical_devices("GPU")
1081    if len(gpus) < 2:
1082      self.skipTest("Expected at least 2 GPUs but found {} GPUs".format(
1083          len(gpus)))
1084
1085    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
1086    np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
1087
1088    # Create a model and save it.
1089    input_saved_model_dir = self.mkdtemp()
1090    root = self._GetModelForV2()
1091    save.save(root, input_saved_model_dir,
1092              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
1093
1094    converter = self._CreateConverterV2(
1095        input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)
1096
1097    converted_model = None
1098    # Specify device on which converted model should be placed
1099    with ops.device(device_id):
1100      converted_model = converter.convert()
1101
1102    # Verify that TRT engine op has the correct device.
1103    self._CheckTrtOps(converter._converted_func)
1104
1105    actual_device_id = self._GetUniqueTRTEngineOp(
1106        converter._converted_graph_def).device
1107
1108    expected_device_id = None
1109    if device_id is not None:
1110      expected_device_id = device_id
1111    else:
1112      expected_device_id = "GPU:0"
1113
1114    self.assertTrue(expected_device_id.lower() in actual_device_id.lower())
1115
1116    del converter
1117    gc.collect()  # Force GC to destroy the TRT engine cache.
1118
1119  @test_util.run_v2_only
1120  def testTrtGraphConverter_DevicePlacementOnCPU(self):
1121    """Test case for trt_convert.TrtGraphConverter()."""
1122
1123    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
1124    np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
1125
1126    # Create a model and save it.
1127    input_saved_model_dir = self.mkdtemp()
1128    root = self._GetModelForV2()
1129    save.save(root, input_saved_model_dir,
1130              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
1131
1132    # Run TRT conversion.
1133    converter = self._CreateConverterV2(
1134        input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)
1135
1136    converted_model = None
1137    # Specify device on which converted model should be placed
1138    with self.assertRaisesRegex(ValueError, r"Specified device is not a GPU"):
1139      with ops.device("CPU"):
1140        converted_model = converter.convert()
1141
1142    del converter
1143    gc.collect()  # Force GC to destroy the TRT engine cache.
1144
1145  def _TestVariableHelper(self, variable_op, tf_model_name, tftrt_model_name,
1146                          output_name):
1147    """Helper with the common code of variable converter tests."""
1148
1149    model_dir = test.test_src_dir_path(
1150        "python/compiler/tensorrt/test/testdata/" + tf_model_name)
1151    trt_model_dir = os.path.join(self.mkdtemp(), tftrt_model_name)
1152
1153    # Load and convert the TF model.
1154    conv_params = trt_convert.TrtConversionParams(
1155        precision_mode="FP16",
1156        minimum_segment_size=3,
1157        max_workspace_size_bytes=10 << 20,
1158        maximum_cached_engines=1)
1159    with test_utils.experimental_feature_scope("disable_graph_freezing"):
1160      converter = trt_convert.TrtGraphConverterV2(
1161          input_saved_model_dir=model_dir,
1162          conversion_params=conv_params,
1163          use_dynamic_shape=True,
1164          dynamic_shape_profile_strategy="Optimal")
1165    converter.convert()
1166
1167    # Build and save the converted model.
1168    input_shapes = [[(4, 1, 1), (4, 1, 1)]]
1169
1170    def _InputFn():
1171      for shapes in input_shapes:
1172        # return a list of input tensors
1173        yield [np.ones(shape=shape).astype(np.float32) for shape in shapes]
1174
1175    converter.build(_InputFn)
1176    converter.save(trt_model_dir)
1177
1178    # Load the converted model.
1179    saved_model_loaded = load.load(trt_model_dir, tags=[tag_constants.SERVING])
1180    graph_func = saved_model_loaded.signatures[
1181        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
1182
1183    # Check that there is one segment and that the 2 variables are in it.
1184    graph_def = graph_func.graph.as_graph_def()
1185    engines = []
1186    for lib_function in graph_def.library.function:
1187      if re.search(r"TRTEngineOp_\d+_\d+_native_segment",
1188                   lib_function.signature.name):
1189        node_ops = [node.op for node in lib_function.node_def]
1190        engines.append(node_ops)
1191    self.assertLen(engines, 1)
1192    self.assertEqual(engines[0].count(variable_op), 2)
1193
1194    # Run the function and check the output.
1195    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
1196    np_input2 = ops.convert_to_tensor(2. *
1197                                      np.ones([4, 1, 1]).astype(np.float32))
1198    output = graph_func(input1=np_input1, input2=np_input2)[output_name]
1199    self.assertEqual(output.shape, (4, 1, 1))
1200    self.assertAllClose(
1201        np.asarray([42., 42., 42., 42.]).reshape([4, 1, 1]), output)
1202
1203  @test_util.run_v2_only
1204  def testVariableV2(self):
1205    """Test conversion of VariableV2 nodes."""
1206
1207    self._TestVariableHelper("VariableV2", "tf_variablev2_saved_model",
1208                             "tftrt_variablev2_saved_model", "output")
1209
1210  @test_util.run_v2_only
1211  def testReadVariableOp(self):
1212    """Test conversion of ReadVariableOp nodes."""
1213
1214    self._TestVariableHelper("ReadVariableOp", "tf_readvariableop_saved_model",
1215                             "tftrt_readvariableop_saved_model", "output_0")
1216
1217if __name__ == "__main__" and is_tensorrt_enabled():
1218  test.main()
1219