xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/bfloat16_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Tests for bfloat16 helper."""
17
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import variable_scope
21from tensorflow.python.platform import test
22from tensorflow.python.tpu import bfloat16
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import variables
26
27
28class BFloat16ScopeTest(test.TestCase):
29
30  def testDefaultScopeName(self):
31    """Test if name for the variable scope is propagated correctly."""
32    with bfloat16.bfloat16_scope() as bf:
33      self.assertEqual(bf.name, "")
34
35  def testCustomScopeName(self):
36    """Test if custom name for the variable scope is propagated correctly."""
37    name = 'bfloat16'
38    with bfloat16.bfloat16_scope('bfloat16') as bf:
39      self.assertEqual(bf.name, name)
40
41  def testVariableName(self):
42    """Test if custom name for the variable scope is propagated correctly."""
43    g = ops.Graph()
44    with g.as_default():
45      a = variables.Variable(2.2, name='var_a')
46      b = variables.Variable(3.3, name='var_b')
47      d = variables.Variable(4.4, name='var_b')
48      with g.name_scope('scope1'):
49        with bfloat16.bfloat16_scope('bf16'):
50          a = math_ops.cast(a, dtypes.bfloat16)
51          b = math_ops.cast(b, dtypes.bfloat16)
52          c = math_ops.add(a, b, name='addition')
53        with bfloat16.bfloat16_scope():
54          d = math_ops.cast(d, dtypes.bfloat16)
55          math_ops.add(c, d, name='addition')
56
57    g_ops = g.get_operations()
58    ops_name = []
59    for op in g_ops:
60      ops_name.append(str(op.name))
61
62    self.assertIn('scope1/bf16/addition', ops_name)
63    self.assertIn('scope1/bf16/Cast', ops_name)
64    self.assertIn('scope1/addition', ops_name)
65    self.assertIn('scope1/Cast', ops_name)
66
67  @test_util.run_deprecated_v1
68  def testRequestedDType(self):
69    """Test if requested dtype is honored in the getter.
70    """
71    with bfloat16.bfloat16_scope() as scope:
72      v1 = variable_scope.get_variable("v1", [])
73      self.assertEqual(v1.dtype.base_dtype, dtypes.float32)
74      v2 = variable_scope.get_variable("v2", [], dtype=dtypes.bfloat16)
75      self.assertEqual(v2.dtype.base_dtype, dtypes.bfloat16)
76      self.assertEqual([dtypes.float32, dtypes.float32],
77                       [v.dtype.base_dtype for v in scope.global_variables()])
78
79
80if __name__ == "__main__":
81  test.main()
82