1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for bfloat16 helper.""" 17 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import variable_scope 21from tensorflow.python.platform import test 22from tensorflow.python.tpu import bfloat16 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import variables 26 27 28class BFloat16ScopeTest(test.TestCase): 29 30 def testDefaultScopeName(self): 31 """Test if name for the variable scope is propagated correctly.""" 32 with bfloat16.bfloat16_scope() as bf: 33 self.assertEqual(bf.name, "") 34 35 def testCustomScopeName(self): 36 """Test if custom name for the variable scope is propagated correctly.""" 37 name = 'bfloat16' 38 with bfloat16.bfloat16_scope('bfloat16') as bf: 39 self.assertEqual(bf.name, name) 40 41 def testVariableName(self): 42 """Test if custom name for the variable scope is propagated correctly.""" 43 g = ops.Graph() 44 with g.as_default(): 45 a = variables.Variable(2.2, name='var_a') 46 b = variables.Variable(3.3, name='var_b') 47 d = variables.Variable(4.4, name='var_b') 48 with g.name_scope('scope1'): 49 with bfloat16.bfloat16_scope('bf16'): 50 a = math_ops.cast(a, dtypes.bfloat16) 51 b = math_ops.cast(b, dtypes.bfloat16) 52 c = math_ops.add(a, b, name='addition') 53 with bfloat16.bfloat16_scope(): 54 d = math_ops.cast(d, dtypes.bfloat16) 55 math_ops.add(c, d, name='addition') 56 57 g_ops = g.get_operations() 58 ops_name = [] 59 for op in g_ops: 60 ops_name.append(str(op.name)) 61 62 self.assertIn('scope1/bf16/addition', ops_name) 63 self.assertIn('scope1/bf16/Cast', ops_name) 64 self.assertIn('scope1/addition', ops_name) 65 self.assertIn('scope1/Cast', ops_name) 66 67 @test_util.run_deprecated_v1 68 def testRequestedDType(self): 69 """Test if requested dtype is honored in the getter. 70 """ 71 with bfloat16.bfloat16_scope() as scope: 72 v1 = variable_scope.get_variable("v1", []) 73 self.assertEqual(v1.dtype.base_dtype, dtypes.float32) 74 v2 = variable_scope.get_variable("v2", [], dtype=dtypes.bfloat16) 75 self.assertEqual(v2.dtype.base_dtype, dtypes.bfloat16) 76 self.assertEqual([dtypes.float32, dtypes.float32], 77 [v.dtype.base_dtype for v in scope.global_variables()]) 78 79 80if __name__ == "__main__": 81 test.main() 82