xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/remapper_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Grappler Remapper."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from absl.testing import parameterized
24
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.python.client import session
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import init_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn
36from tensorflow.python.ops import nn_ops
37from tensorflow.python.ops import random_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import sysconfig
40from tensorflow.python.platform import test
41from tensorflow.python.util import _pywrap_utils
42
43
44def _input(shape):
45  """Generates an input of a given shape."""
46  return variables.Variable(random_ops.truncated_normal(shape, seed=0))
47
48
49def _weight(shape):
50  """Generates a weight of a given shape."""
51  # Note that the lambda is needed to allow construction inside loops.
52  return variables.Variable(lambda: init_ops.glorot_uniform_initializer(seed=0)
53                            (shape))
54
55
56def _bias(shape):
57  """Generates a bias of a given shape."""
58  return constant_op.constant(0.1, shape=shape)
59
60
61def _get_config(remapping_on=False):
62  """Returns a CongfigProto with remapper optimizer on/off."""
63  rewrite_config = rewriter_config_pb2.RewriterConfig(
64      remapping=rewriter_config_pb2.RewriterConfig
65      .ON if remapping_on else rewriter_config_pb2.RewriterConfig.OFF)
66  rewrite_config.min_graph_nodes = -1
67  graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_config)
68  config = config_pb2.ConfigProto(graph_options=graph_options)
69  return config
70
71
72class RemapperTest(test.TestCase, parameterized.TestCase):
73  """Tests the Grappler remapper optimizer."""
74
75  def setUp(self):
76    super(RemapperTest, self).setUp()
77    # Gelu fusion on GPU requires cublasLt
78    os.environ['TF_USE_CUBLASLT'] = '1'
79    # Conv runtime fusion on GPU requires cuDNN frontend APIs.
80    os.environ['TF_CUDNN_USE_FRONTEND'] = '1'
81    os.environ['TF_CUDNN_USE_RUNTIME_FUSION'] = '1'
82
83  def maybe_skip_test(self, mode):
84    if mode == 'cuda':
85      # It seems the windows os cannot correctly query the cuda_version.
86      # TODO(kaixih@nvidia): Remove this when it works.
87      if os.name == 'nt':
88        self.skipTest("This test doesn't support Windows")
89
90      # The cublaslt matmul with gelu epilog is only supported since cuda 11.4.
91      if not test.is_gpu_available(cuda_only=True):
92        self.skipTest('This test requires GPU.')
93      cuda_version_str = sysconfig.get_build_info().get('cuda_version', '0.0')
94      cuda_version = tuple([int(x) for x in cuda_version_str.split('.')])
95      if cuda_version < (11, 4):
96        self.skipTest('This test requires CUDA >= 11.4.')
97
98    if mode == 'mkl' and not test_util.IsMklEnabled():
99      self.skipTest('MKL is not enabled.')
100
101  def _VerifyValues(self, model_fn, use_low_precision, fused_op, epilog_ops):
102    run_options = config_pb2.RunOptions(output_partition_graphs=True)
103    metadata = config_pb2.RunMetadata()
104    # Compute reference value.
105    config = _get_config(remapping_on=False)
106    with session.Session(config=config) as sess:
107      sess.run(variables.global_variables_initializer())
108      output_ref = sess.run(
109          model_fn, options=run_options, run_metadata=metadata)
110    # Compute output with fusion.
111    config = _get_config(remapping_on=True)
112    with session.Session(config=config) as sess:
113      sess.run(variables.global_variables_initializer())
114      output_val = sess.run(
115          model_fn, options=run_options, run_metadata=metadata)
116      graph = metadata.partition_graphs[0]
117
118    # Graph should contain fused op.
119    found_fused_op = False
120    for node in graph.node:
121      if node.op in fused_op:
122        fused_ops = node.attr['fused_ops'].list.s
123        ops_matched = len(fused_ops) >= 1 and len(fused_ops) == len(epilog_ops)
124        for op_a, op_b in zip(fused_ops, epilog_ops):
125          if op_a != op_b:
126            ops_matched = False
127            break
128        found_fused_op = ops_matched
129        break
130    self.assertTrue(found_fused_op)
131
132    # Computed output value should be close to reference value.
133    tol = 1e-2 if use_low_precision else 1e-5
134    self.assertAllClose(output_ref, output_val, atol=tol, rtol=tol)
135
136    return graph
137
138  @parameterized.parameters(['cuda', 'mkl'])
139  @test_util.run_deprecated_v1
140  @test_util.disable_xla('This test does not pass with XLA')
141  def test_matmul_biasadd_gelu_fusion(self, mode):
142    """Test MatMul+BiasAdd+Gelu fusion."""
143    self.maybe_skip_test(mode)
144    data_types = [dtypes.float32]
145    if mode == 'cuda':
146      data_types.append(dtypes.float16)
147    elif mode == 'mkl':
148      data_types.append(dtypes.bfloat16)
149
150    is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
151
152    m, n, k = (3, 3, 4)  # Matrix dimensions
153    for precision in data_types:
154      for approximate in (False, True):
155        # Gelu exact (approximate=False) is not supported with bfloat16
156        # precision since no support for Erf with bfloat16 data type.
157        # TODO(intel-tf): Enable gelu exact with bfloat16, when Erf op is
158        # supported with bfloat16.
159        if precision == dtypes.bfloat16:
160          if not (approximate and is_bf16_supported):
161            continue
162
163        # TODO(kaixih@nvidia): Enable gelu exact when Erf op is supported with
164        # cublaslt.
165        if mode == 'cuda' and not approximate:
166          continue
167
168        device = '/device:GPU:0' if mode == 'cuda' else '/device:CPU:0'
169        # Create MatMul + BiasAdd + Gelu graph
170        ops.reset_default_graph()
171        with ops.device(device):
172          x = _input([m, k])
173          w = _weight([k, n])
174          b = _bias([n])
175          x = math_ops.cast(x, precision)
176          w = math_ops.cast(w, precision)
177          b = math_ops.cast(b, precision)
178          y = math_ops.matmul(x, w)
179          z = nn.bias_add(y, b)
180          out = nn.gelu(z, approximate=approximate)
181
182        gelu_type = b'GeluApproximate' if approximate else b'GeluExact'
183        epilog_ops = [b'BiasAdd', gelu_type]
184        fused_op = ['_MklNativeFusedMatMul', '_MklFusedMatMul', '_FusedMatMul']
185        graph = self._VerifyValues(out, precision != dtypes.float32, fused_op,
186                                   epilog_ops)
187
188  @test_util.run_deprecated_v1
189  @test_util.disable_xla('This test does not pass with XLA')
190  def test_conv2d_biasadd_act_fusion(self):
191    """Test Conv2D+BiasAdd+Relu fusion."""
192    if not test_util.is_gpu_available():
193      self.skipTest('No GPU available')
194
195    N, H, W, C = (5, 3, 3, 8)  # pylint: disable=invalid-name
196    # The runtime fusion requires the output dims to be 32-bit aligned.
197    self.assertEqual(C % 2, 0)
198
199    act_fns = [nn.relu]
200    act_names = [b'Relu']
201
202    if test_util.is_gpu_available(
203        cuda_only=True, min_cuda_compute_capability=(8, 0)):
204      act_fns += [nn.elu, nn.relu6, nn.leaky_relu]
205      act_names += [b'Elu', b'Relu6', b'LeakyRelu']
206
207    for precision in ('float16', 'float32'):
208      for act_fn, act_name in zip(act_fns, act_names):
209        use_fp16 = precision == 'float16'
210        # The runtime fusion (when the activation is not relu) only supports
211        # fp16 at this moment.
212        if not use_fp16 and act_name != b'Relu':
213          continue
214
215        ops.reset_default_graph()
216        x_shape = [N, C, H, W]
217        x_format, b_format = ('NCHW', 'NC..')
218        if use_fp16:
219          x_shape = [N, H, W, C]
220          x_format, b_format = ('NHWC', 'N..C')
221
222        x = _input(x_shape)
223        w = _weight([2, 2, C, C])
224        b = _bias([C])
225
226        if use_fp16:
227          x = math_ops.cast(x, dtypes.float16)
228          w = math_ops.cast(w, dtypes.float16)
229          b = math_ops.cast(b, dtypes.float16)
230
231        y = nn_ops.conv2d(
232            x, w, strides=(1, 1), padding='SAME', data_format=x_format)
233        z = nn.bias_add(y, b, data_format=b_format)
234        out = act_fn(z)
235        out = array_ops.identity(out)
236
237        epilog_ops = [b'BiasAdd', act_name]
238        fused_op = ['_FusedConv2D']
239        graph = self._VerifyValues(out, use_fp16, fused_op, epilog_ops)
240
241  @test_util.run_deprecated_v1
242  @test_util.disable_xla('This test does not pass with XLA')
243  def test_two_conv2d_fusions(self):
244    """Test two Conv2D patterns and only the second is fusable."""
245    if not test_util.is_gpu_available(
246        cuda_only=True, min_cuda_compute_capability=(8, 0)):
247      self.skipTest('No GPU with compute compatibility >= 8.0 available')
248
249    N, H, W, C = (5, 3, 3, 8)  # pylint: disable=invalid-name
250
251    ops.reset_default_graph()
252    x_shape = [N, C, H, W]
253    x_format, b_format = ('NCHW', 'NC..')
254
255    x = _input(x_shape)
256    w = _weight([2, 2, C, C])
257    b = _bias([C])
258
259    y = nn_ops.conv2d(
260        x, w, strides=(1, 1), padding='SAME', data_format=x_format)
261    y = nn.bias_add(y, b, data_format=b_format)
262    y = nn.leaky_relu(y)
263    y = nn_ops.conv2d(
264        y, w, strides=(1, 1), padding='SAME', data_format=x_format)
265    y = nn.bias_add(y, b, data_format=b_format)
266    y = nn.relu(y)
267    out = array_ops.identity(y)
268
269    # The first Conv-BiasAdd-LeakyRelu is not fusable because cuDNN requires
270    # fp16 for this pattern. The second Conv-BiasAdd-Relu is fusable.
271    epilog_ops = [b'BiasAdd', b'Relu']
272    fused_op = ['_FusedConv2D']
273    self._VerifyValues(out, False, fused_op, epilog_ops)
274
275
276if __name__ == '__main__':
277  test.main()
278