1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests and benchmarks for the trace_type module.""" 16 17import collections 18import timeit 19 20from absl.testing import parameterized 21 22from tensorflow.core.function import trace_type 23from tensorflow.core.function.trace_type import default_types 24from tensorflow.python.compat import v2_compat 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.eager import function 28from tensorflow.python.framework import combinations 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import variables 35from tensorflow.python.ops.ragged import ragged_tensor 36from tensorflow.python.platform import test 37from tensorflow.python.types import trace 38 39 40class TestAttr: 41 """Helps test attrs collections.""" 42 43 def __init__(self, name): 44 self.name = name 45 46 47class TestAttrsClass: 48 """Helps test attrs collections.""" 49 50 __attrs_attrs__ = (TestAttr('a'), TestAttr('b')) 51 52 def __init__(self, a, b): 53 self.a = a 54 self.b = b 55 56 def __eq__(self, other): 57 return isinstance( 58 other, TestAttrsClass) and self.a == other.a and self.b == other.b 59 60 61class DummyGenericClass: 62 """Helps test memory leaks for GenericType.""" 63 pass 64 65 66class CacheKeyGenerationTest(test.TestCase, parameterized.TestCase): 67 68 @combinations.generate(combinations.combine(mode=['eager'])) 69 def testIteratorAliasing(self): 70 it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])) 71 it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])) 72 73 self.assertEqual( 74 trace_type.from_object((it1, it1)), 75 trace_type.from_object((it2, it2))) 76 self.assertEqual( 77 trace_type.from_object((it1, it2)), 78 trace_type.from_object((it2, it1))) 79 self.assertNotEqual( 80 trace_type.from_object((it1, it1)), 81 trace_type.from_object((it1, it2))) 82 83 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 84 def testIteratorTypesImplementTracing(self): 85 self.assertTrue( 86 issubclass(iterator_ops.OwnedIterator, trace.SupportsTracingProtocol)) 87 self.assertTrue( 88 issubclass(iterator_ops.IteratorSpec, trace.SupportsTracingProtocol)) 89 90 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 91 def testCompositeAndSpec(self): 92 composite_tensor = ragged_tensor.RaggedTensor.from_row_splits( 93 values=[1, 2, 3], row_splits=[0, 2, 3]) 94 spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32) 95 96 self.assertEqual( 97 trace_type.from_object(composite_tensor), 98 trace_type.from_object(spec)) 99 100 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 101 def testVariableAliasing(self): 102 v1 = resource_variable_ops.ResourceVariable([1]) 103 v2 = resource_variable_ops.ResourceVariable([1]) 104 v3 = resource_variable_ops.ResourceVariable([1]) 105 all_unique = trace_type.from_object((v1, v2, v3)) 106 all_same = trace_type.from_object((v1, v1, v1)) 107 self.assertNotEqual(all_unique, all_same) 108 109 v3 = resource_variable_ops.ResourceVariable([2]) 110 v4 = resource_variable_ops.ResourceVariable([2]) 111 v5 = resource_variable_ops.ResourceVariable([2]) 112 all_unique_again = trace_type.from_object((v3, v4, v5)) 113 all_same_again = trace_type.from_object((v4, v4, v4)) 114 self.assertEqual(all_unique, all_unique_again) 115 self.assertEqual(all_same, all_same_again) 116 117 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 118 def testTensorEquality(self): 119 context = trace_type.InternalTracingContext() 120 tensor_a = array_ops.zeros([11, 3, 5], 121 dtype=dtypes.int32).__tf_tracing_type__(context) 122 tensor_b = array_ops.zeros([11, 4, 5], 123 dtype=dtypes.int32).__tf_tracing_type__(context) 124 tensor_c = array_ops.zeros( 125 [11, 3, 5], dtype=dtypes.float32).__tf_tracing_type__(context) 126 tensor_d = array_ops.ones([11, 3, 5], 127 dtype=dtypes.int32).__tf_tracing_type__(context) 128 129 self.assertNotEqual(tensor_a, tensor_b) 130 self.assertNotEqual(tensor_a, tensor_c) 131 self.assertNotEqual(tensor_b, tensor_c) 132 self.assertEqual(tensor_a, tensor_d) 133 134 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 135 def testTensorAndSpecEquality(self): 136 context = trace_type.InternalTracingContext() 137 tensor = array_ops.zeros([11, 3, 5], 138 dtype=dtypes.int32).__tf_tracing_type__(context) 139 spec = tensor_spec.TensorSpec( 140 [11, 3, 5], dtype=dtypes.int32).__tf_tracing_type__(context) 141 spec_with_name = tensor_spec.TensorSpec( 142 [11, 3, 5], dtype=dtypes.int32, 143 name='name').__tf_tracing_type__(context) 144 145 self.assertEqual(tensor, spec) 146 self.assertNotEqual(tensor, spec_with_name) 147 148 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 149 def testTensorShapeUnknown(self): 150 context = trace_type.InternalTracingContext() 151 spec_1 = tensor_spec.TensorSpec( 152 None, dtype=dtypes.int32).__tf_tracing_type__(context) 153 spec_2 = tensor_spec.TensorSpec( 154 None, dtype=dtypes.int32).__tf_tracing_type__(context) 155 self.assertEqual(spec_1, spec_2) 156 157 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 158 def testAttrsCacheKeyGeneration(self): 159 trace_a = trace_type.from_object(TestAttrsClass(1, 2)) 160 expected = default_types.Attrs.from_type_and_attributes( 161 TestAttrsClass, 162 (default_types.Literal(1), default_types.Literal(2))) 163 self.assertEqual(trace_a, expected) 164 self.assertTrue(trace_a.is_subtype_of(trace_a)) 165 166 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 167 def testTupleEquality(self): 168 trace_a = trace_type.from_object((1, 2, 3, 4)) 169 trace_b = trace_type.from_object((1, 2, 2, 4)) 170 trace_c = trace_type.from_object((1, 2, 3)) 171 trace_d = trace_type.from_object((1, 2, 3, 4)) 172 self.assertNotEqual(trace_a, trace_b) 173 self.assertNotEqual(trace_a, trace_c) 174 self.assertNotEqual(trace_b, trace_c) 175 self.assertEqual(trace_a, trace_d) 176 177 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 178 def testListEquality(self): 179 trace_a = trace_type.from_object([1, 2, 3, 4]) 180 trace_b = trace_type.from_object([1, 2, 2, 4]) 181 trace_c = trace_type.from_object([1, 2, 3]) 182 trace_d = trace_type.from_object([1, 2, 3, 4]) 183 self.assertNotEqual(trace_a, trace_b) 184 self.assertNotEqual(trace_a, trace_c) 185 self.assertNotEqual(trace_b, trace_c) 186 self.assertEqual(trace_a, trace_d) 187 188 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 189 def testDictEquality(self): 190 trace_a = trace_type.from_object({1: 2, 3: 4}) 191 trace_b = trace_type.from_object({1: 2, 3: 2}) 192 trace_c = trace_type.from_object({1: 2, 3: 0}) 193 trace_d = trace_type.from_object({3: 4, 1: 2}) 194 self.assertNotEqual(trace_a, trace_b) 195 self.assertNotEqual(trace_a, trace_c) 196 self.assertNotEqual(trace_b, trace_c) 197 self.assertEqual(trace_a, trace_d) 198 199 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 200 def testComplexStruct(self): 201 struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})} 202 trace_a = trace_type.from_object(struct) 203 trace_b = trace_type.from_object(struct) 204 self.assertEqual(trace_a, trace_b) 205 self.assertTrue(trace_a.is_subtype_of(trace_b)) 206 self.assertTrue(trace_b.is_subtype_of(trace_a)) 207 208 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 209 def testCustomUnequableTypeSucceeds(self): 210 211 class CustomUnequable: 212 213 def __eq__(self, o): 214 raise ValueError 215 216 def __hash__(self): 217 return 0 218 219 object_a = CustomUnequable() 220 object_b = CustomUnequable() 221 trace_a_1 = trace_type.from_object(object_a) 222 trace_a_2 = trace_type.from_object(object_a) 223 trace_b = trace_type.from_object(object_b) 224 self.assertEqual(trace_a_1, trace_a_2) 225 226 with self.assertRaises(ValueError): 227 trace_a_1.__eq__(trace_b) 228 229 del object_a 230 self.assertNotEqual(trace_a_1, trace_a_2) 231 self.assertNotEqual(trace_a_2, trace_a_1) 232 233 del object_b 234 self.assertNotEqual(trace_a_1, trace_a_2) 235 self.assertNotEqual(trace_a_2, trace_a_1) 236 237 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 238 def testCustomUnhashableTypeFailsGracefully(self): 239 240 class CustomUnhashable: 241 242 def __eq__(self, o): 243 return True 244 245 obj = CustomUnhashable() 246 with self.assertRaisesRegex( 247 TypeError, 248 r'could not be represented through the generic tracing type'): 249 trace_type.from_object(obj) 250 251 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 252 def testGetPlaceholderValue(self): 253 composite_value = [1, 2, (3, [4, 5]), {6: [7]}, TestAttrsClass(8, (10, 11))] 254 composite_type = trace_type.from_object(composite_value) 255 placeholder_value = composite_type._placeholder_value() 256 self.assertEqual(composite_value, placeholder_value) 257 258 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 259 def testWrappedNamedTuple(self): 260 ActualType = collections.namedtuple('ActualType', ['a', 'b', 'c']) 261 262 class MockWrapper(tuple): 263 # Generated through trackable data structures: 264 # //tensorflow/python/training/tracking/data_structures.py 265 # With design pattern similar to Python functools: 266 # https://docs.python.org/3/library/functools.html?highlight=__wrapped__#functools.update_wrapper 267 __wrapped__ = ActualType(1, 2, 3) 268 269 self.assertEqual( 270 trace_type.from_object(MockWrapper()), 271 trace_type.from_object(ActualType(1, 2, 3))) 272 273 274class CacheKeyMemoryTest(test.TestCase): 275 276 @test_util.assert_no_new_pyobjects_executing_eagerly 277 def testGeneric(self): 278 trace_type.from_object(1) 279 trace_type.from_object(DummyGenericClass()) 280 281 @test_util.assert_no_new_pyobjects_executing_eagerly 282 def testTensor(self): 283 tensor = array_ops.zeros([10]) 284 trace_type.from_object(tensor) 285 286 @test_util.assert_no_new_pyobjects_executing_eagerly 287 def testTuple(self): 288 trace_type.from_object((1, 2, 3)) 289 290 @test_util.assert_no_new_pyobjects_executing_eagerly 291 def testDict(self): 292 trace_type.from_object({1: 1, 2: 2, 3: 3}) 293 294 @test_util.assert_no_new_pyobjects_executing_eagerly 295 def testList(self): 296 trace_type.from_object([1, 2, 3]) 297 298 @test_util.assert_no_new_pyobjects_executing_eagerly 299 def testAttrs(self): 300 trace_type.from_object(TestAttrsClass(1, 2)) 301 302 303class CacheKeyGenerationBenchmark(test.Benchmark): 304 305 def benchmarkTensor(self): 306 shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]] 307 tensors = [] 308 for s in shapes: 309 tensors.append(array_ops.zeros(s)) 310 311 def encode_tensors(tensors): 312 trace_type.from_object(tensors) 313 314 iterations = 100000 315 t = timeit.timeit(lambda: encode_tensors(tensors), number=iterations) 316 self.report_benchmark( 317 name='tensor_cache_key_generation', 318 iters=iterations, 319 wall_time=t, 320 metrics=[{ 321 'name': 'tensor_cache_key_generation_avg_ms', 322 'value': t / iterations * 1000 323 }]) 324 325 def benchmarkTensorSpec(self): 326 shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]] 327 tensor_specs = [] 328 for s in shapes: 329 tensor_specs.append(tensor_spec.TensorSpec(s, dtypes.int32)) 330 331 def encode_tensor_specs(tensor_specs): 332 trace_type.from_object(tensor_specs) 333 334 iterations = 100000 335 t = timeit.timeit( 336 lambda: encode_tensor_specs(tensor_specs), number=iterations) 337 self.report_benchmark( 338 name='tensor_spec_cache_key_generation', 339 iters=iterations, 340 wall_time=t, 341 metrics=[{ 342 'name': 'tensor_spec_cache_key_generation_avg_ms', 343 'value': t / iterations * 1000 344 }]) 345 346 def benchmarkVariable(self): 347 var_list = [ 348 variables.Variable(1.0), 349 variables.Variable(1), 350 variables.Variable([1]) 351 ] 352 353 def encode_variables(var_list): 354 trace_type.from_object(var_list) 355 356 iterations = 10000 357 t = timeit.timeit(lambda: encode_variables(var_list), number=iterations) 358 self.report_benchmark( 359 name='variable_cache_key_generation', 360 iters=iterations, 361 wall_time=t, 362 metrics=[{ 363 'name': 'variable_cache_key_generation_avg_ms', 364 'value': t / iterations * 1000 365 }]) 366 367 def benchmarkCacheKeyLookup(self): 368 369 @function.defun 370 def defined(t): 371 return t 372 373 call_arg_list = [ 374 1, 375 array_ops.zeros([5, 13]), 376 array_ops.zeros([9, 22, 24]), 377 array_ops.zeros([5, 13, 2]) 378 ] 379 380 for c in call_arg_list: 381 defined(c) 382 383 lookup_call_arg = array_ops.zeros([5, 13]) 384 385 iterations = 10000 386 t = timeit.timeit(stmt=lambda: defined(lookup_call_arg), number=iterations) 387 388 self.report_benchmark( 389 name='cache_key_lookup', 390 iters=iterations, 391 wall_time=t, 392 metrics=[{ 393 'name': 'cache_key_lookup_avg_ms', 394 'value': t / iterations * 1000 395 }]) 396 397 def benchmarkNestedStruct(self): 398 struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})} 399 400 def encode_struct(struct): 401 trace_type.from_object(struct) 402 403 iterations = 100000 404 t = timeit.timeit(lambda: encode_struct(struct), number=iterations) 405 self.report_benchmark( 406 name='nested_struct_cache_key_generation', 407 iters=iterations, 408 wall_time=t, 409 metrics=[{ 410 'name': 'nested_struct_cache_key_generation_avg_ms', 411 'value': t / iterations * 1000 412 }]) 413 414 def benchmarkFunctionInvocation(self): 415 struct = (variables.Variable(1.0), array_ops.zeros([5, 13]), { 416 'tensor': array_ops.zeros([5, 20]), 417 'variable': variables.Variable(1.0) 418 }) 419 420 @function.defun 421 def defined(t): 422 return t 423 424 defined(struct) # Get it traced and cached. 425 426 iterations = 10000 427 t = timeit.timeit(lambda: defined(struct), number=iterations) 428 self.report_benchmark( 429 name='function_invocation', 430 iters=iterations, 431 wall_time=t, 432 metrics=[{ 433 'name': 'function_invocation_time_avg_ms', 434 'value': t / iterations * 1000 435 }]) 436 437if __name__ == '__main__': 438 v2_compat.enable_v2_behavior() 439 test.main() 440