xref: /aosp_15_r20/external/tensorflow/tensorflow/core/function/trace_type/trace_type_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests and benchmarks for the trace_type module."""
16
17import collections
18import timeit
19
20from absl.testing import parameterized
21
22from tensorflow.core.function import trace_type
23from tensorflow.core.function.trace_type import default_types
24from tensorflow.python.compat import v2_compat
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import iterator_ops
27from tensorflow.python.eager import function
28from tensorflow.python.framework import combinations
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variables
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.platform import test
37from tensorflow.python.types import trace
38
39
40class TestAttr:
41  """Helps test attrs collections."""
42
43  def __init__(self, name):
44    self.name = name
45
46
47class TestAttrsClass:
48  """Helps test attrs collections."""
49
50  __attrs_attrs__ = (TestAttr('a'), TestAttr('b'))
51
52  def __init__(self, a, b):
53    self.a = a
54    self.b = b
55
56  def __eq__(self, other):
57    return isinstance(
58        other, TestAttrsClass) and self.a == other.a and self.b == other.b
59
60
61class DummyGenericClass:
62  """Helps test memory leaks for GenericType."""
63  pass
64
65
66class CacheKeyGenerationTest(test.TestCase, parameterized.TestCase):
67
68  @combinations.generate(combinations.combine(mode=['eager']))
69  def testIteratorAliasing(self):
70    it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
71    it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
72
73    self.assertEqual(
74        trace_type.from_object((it1, it1)),
75        trace_type.from_object((it2, it2)))
76    self.assertEqual(
77        trace_type.from_object((it1, it2)),
78        trace_type.from_object((it2, it1)))
79    self.assertNotEqual(
80        trace_type.from_object((it1, it1)),
81        trace_type.from_object((it1, it2)))
82
83  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
84  def testIteratorTypesImplementTracing(self):
85    self.assertTrue(
86        issubclass(iterator_ops.OwnedIterator, trace.SupportsTracingProtocol))
87    self.assertTrue(
88        issubclass(iterator_ops.IteratorSpec, trace.SupportsTracingProtocol))
89
90  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
91  def testCompositeAndSpec(self):
92    composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
93        values=[1, 2, 3], row_splits=[0, 2, 3])
94    spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)
95
96    self.assertEqual(
97        trace_type.from_object(composite_tensor),
98        trace_type.from_object(spec))
99
100  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
101  def testVariableAliasing(self):
102    v1 = resource_variable_ops.ResourceVariable([1])
103    v2 = resource_variable_ops.ResourceVariable([1])
104    v3 = resource_variable_ops.ResourceVariable([1])
105    all_unique = trace_type.from_object((v1, v2, v3))
106    all_same = trace_type.from_object((v1, v1, v1))
107    self.assertNotEqual(all_unique, all_same)
108
109    v3 = resource_variable_ops.ResourceVariable([2])
110    v4 = resource_variable_ops.ResourceVariable([2])
111    v5 = resource_variable_ops.ResourceVariable([2])
112    all_unique_again = trace_type.from_object((v3, v4, v5))
113    all_same_again = trace_type.from_object((v4, v4, v4))
114    self.assertEqual(all_unique, all_unique_again)
115    self.assertEqual(all_same, all_same_again)
116
117  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
118  def testTensorEquality(self):
119    context = trace_type.InternalTracingContext()
120    tensor_a = array_ops.zeros([11, 3, 5],
121                               dtype=dtypes.int32).__tf_tracing_type__(context)
122    tensor_b = array_ops.zeros([11, 4, 5],
123                               dtype=dtypes.int32).__tf_tracing_type__(context)
124    tensor_c = array_ops.zeros(
125        [11, 3, 5], dtype=dtypes.float32).__tf_tracing_type__(context)
126    tensor_d = array_ops.ones([11, 3, 5],
127                              dtype=dtypes.int32).__tf_tracing_type__(context)
128
129    self.assertNotEqual(tensor_a, tensor_b)
130    self.assertNotEqual(tensor_a, tensor_c)
131    self.assertNotEqual(tensor_b, tensor_c)
132    self.assertEqual(tensor_a, tensor_d)
133
134  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
135  def testTensorAndSpecEquality(self):
136    context = trace_type.InternalTracingContext()
137    tensor = array_ops.zeros([11, 3, 5],
138                             dtype=dtypes.int32).__tf_tracing_type__(context)
139    spec = tensor_spec.TensorSpec(
140        [11, 3, 5], dtype=dtypes.int32).__tf_tracing_type__(context)
141    spec_with_name = tensor_spec.TensorSpec(
142        [11, 3, 5], dtype=dtypes.int32,
143        name='name').__tf_tracing_type__(context)
144
145    self.assertEqual(tensor, spec)
146    self.assertNotEqual(tensor, spec_with_name)
147
148  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
149  def testTensorShapeUnknown(self):
150    context = trace_type.InternalTracingContext()
151    spec_1 = tensor_spec.TensorSpec(
152        None, dtype=dtypes.int32).__tf_tracing_type__(context)
153    spec_2 = tensor_spec.TensorSpec(
154        None, dtype=dtypes.int32).__tf_tracing_type__(context)
155    self.assertEqual(spec_1, spec_2)
156
157  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
158  def testAttrsCacheKeyGeneration(self):
159    trace_a = trace_type.from_object(TestAttrsClass(1, 2))
160    expected = default_types.Attrs.from_type_and_attributes(
161        TestAttrsClass,
162        (default_types.Literal(1), default_types.Literal(2)))
163    self.assertEqual(trace_a, expected)
164    self.assertTrue(trace_a.is_subtype_of(trace_a))
165
166  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
167  def testTupleEquality(self):
168    trace_a = trace_type.from_object((1, 2, 3, 4))
169    trace_b = trace_type.from_object((1, 2, 2, 4))
170    trace_c = trace_type.from_object((1, 2, 3))
171    trace_d = trace_type.from_object((1, 2, 3, 4))
172    self.assertNotEqual(trace_a, trace_b)
173    self.assertNotEqual(trace_a, trace_c)
174    self.assertNotEqual(trace_b, trace_c)
175    self.assertEqual(trace_a, trace_d)
176
177  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
178  def testListEquality(self):
179    trace_a = trace_type.from_object([1, 2, 3, 4])
180    trace_b = trace_type.from_object([1, 2, 2, 4])
181    trace_c = trace_type.from_object([1, 2, 3])
182    trace_d = trace_type.from_object([1, 2, 3, 4])
183    self.assertNotEqual(trace_a, trace_b)
184    self.assertNotEqual(trace_a, trace_c)
185    self.assertNotEqual(trace_b, trace_c)
186    self.assertEqual(trace_a, trace_d)
187
188  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
189  def testDictEquality(self):
190    trace_a = trace_type.from_object({1: 2, 3: 4})
191    trace_b = trace_type.from_object({1: 2, 3: 2})
192    trace_c = trace_type.from_object({1: 2, 3: 0})
193    trace_d = trace_type.from_object({3: 4, 1: 2})
194    self.assertNotEqual(trace_a, trace_b)
195    self.assertNotEqual(trace_a, trace_c)
196    self.assertNotEqual(trace_b, trace_c)
197    self.assertEqual(trace_a, trace_d)
198
199  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
200  def testComplexStruct(self):
201    struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
202    trace_a = trace_type.from_object(struct)
203    trace_b = trace_type.from_object(struct)
204    self.assertEqual(trace_a, trace_b)
205    self.assertTrue(trace_a.is_subtype_of(trace_b))
206    self.assertTrue(trace_b.is_subtype_of(trace_a))
207
208  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
209  def testCustomUnequableTypeSucceeds(self):
210
211    class CustomUnequable:
212
213      def __eq__(self, o):
214        raise ValueError
215
216      def __hash__(self):
217        return 0
218
219    object_a = CustomUnequable()
220    object_b = CustomUnequable()
221    trace_a_1 = trace_type.from_object(object_a)
222    trace_a_2 = trace_type.from_object(object_a)
223    trace_b = trace_type.from_object(object_b)
224    self.assertEqual(trace_a_1, trace_a_2)
225
226    with self.assertRaises(ValueError):
227      trace_a_1.__eq__(trace_b)
228
229    del object_a
230    self.assertNotEqual(trace_a_1, trace_a_2)
231    self.assertNotEqual(trace_a_2, trace_a_1)
232
233    del object_b
234    self.assertNotEqual(trace_a_1, trace_a_2)
235    self.assertNotEqual(trace_a_2, trace_a_1)
236
237  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
238  def testCustomUnhashableTypeFailsGracefully(self):
239
240    class CustomUnhashable:
241
242      def __eq__(self, o):
243        return True
244
245    obj = CustomUnhashable()
246    with self.assertRaisesRegex(
247        TypeError,
248        r'could not be represented through the generic tracing type'):
249      trace_type.from_object(obj)
250
251  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
252  def testGetPlaceholderValue(self):
253    composite_value = [1, 2, (3, [4, 5]), {6: [7]}, TestAttrsClass(8, (10, 11))]
254    composite_type = trace_type.from_object(composite_value)
255    placeholder_value = composite_type._placeholder_value()
256    self.assertEqual(composite_value, placeholder_value)
257
258  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
259  def testWrappedNamedTuple(self):
260    ActualType = collections.namedtuple('ActualType', ['a', 'b', 'c'])
261
262    class MockWrapper(tuple):
263      # Generated through trackable data structures:
264      # //tensorflow/python/training/tracking/data_structures.py
265      # With design pattern similar to Python functools:
266      # https://docs.python.org/3/library/functools.html?highlight=__wrapped__#functools.update_wrapper
267      __wrapped__ = ActualType(1, 2, 3)
268
269    self.assertEqual(
270        trace_type.from_object(MockWrapper()),
271        trace_type.from_object(ActualType(1, 2, 3)))
272
273
274class CacheKeyMemoryTest(test.TestCase):
275
276  @test_util.assert_no_new_pyobjects_executing_eagerly
277  def testGeneric(self):
278    trace_type.from_object(1)
279    trace_type.from_object(DummyGenericClass())
280
281  @test_util.assert_no_new_pyobjects_executing_eagerly
282  def testTensor(self):
283    tensor = array_ops.zeros([10])
284    trace_type.from_object(tensor)
285
286  @test_util.assert_no_new_pyobjects_executing_eagerly
287  def testTuple(self):
288    trace_type.from_object((1, 2, 3))
289
290  @test_util.assert_no_new_pyobjects_executing_eagerly
291  def testDict(self):
292    trace_type.from_object({1: 1, 2: 2, 3: 3})
293
294  @test_util.assert_no_new_pyobjects_executing_eagerly
295  def testList(self):
296    trace_type.from_object([1, 2, 3])
297
298  @test_util.assert_no_new_pyobjects_executing_eagerly
299  def testAttrs(self):
300    trace_type.from_object(TestAttrsClass(1, 2))
301
302
303class CacheKeyGenerationBenchmark(test.Benchmark):
304
305  def benchmarkTensor(self):
306    shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
307    tensors = []
308    for s in shapes:
309      tensors.append(array_ops.zeros(s))
310
311    def encode_tensors(tensors):
312      trace_type.from_object(tensors)
313
314    iterations = 100000
315    t = timeit.timeit(lambda: encode_tensors(tensors), number=iterations)
316    self.report_benchmark(
317        name='tensor_cache_key_generation',
318        iters=iterations,
319        wall_time=t,
320        metrics=[{
321            'name': 'tensor_cache_key_generation_avg_ms',
322            'value': t / iterations * 1000
323        }])
324
325  def benchmarkTensorSpec(self):
326    shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
327    tensor_specs = []
328    for s in shapes:
329      tensor_specs.append(tensor_spec.TensorSpec(s, dtypes.int32))
330
331    def encode_tensor_specs(tensor_specs):
332      trace_type.from_object(tensor_specs)
333
334    iterations = 100000
335    t = timeit.timeit(
336        lambda: encode_tensor_specs(tensor_specs), number=iterations)
337    self.report_benchmark(
338        name='tensor_spec_cache_key_generation',
339        iters=iterations,
340        wall_time=t,
341        metrics=[{
342            'name': 'tensor_spec_cache_key_generation_avg_ms',
343            'value': t / iterations * 1000
344        }])
345
346  def benchmarkVariable(self):
347    var_list = [
348        variables.Variable(1.0),
349        variables.Variable(1),
350        variables.Variable([1])
351    ]
352
353    def encode_variables(var_list):
354      trace_type.from_object(var_list)
355
356    iterations = 10000
357    t = timeit.timeit(lambda: encode_variables(var_list), number=iterations)
358    self.report_benchmark(
359        name='variable_cache_key_generation',
360        iters=iterations,
361        wall_time=t,
362        metrics=[{
363            'name': 'variable_cache_key_generation_avg_ms',
364            'value': t / iterations * 1000
365        }])
366
367  def benchmarkCacheKeyLookup(self):
368
369    @function.defun
370    def defined(t):
371      return t
372
373    call_arg_list = [
374        1,
375        array_ops.zeros([5, 13]),
376        array_ops.zeros([9, 22, 24]),
377        array_ops.zeros([5, 13, 2])
378    ]
379
380    for c in call_arg_list:
381      defined(c)
382
383    lookup_call_arg = array_ops.zeros([5, 13])
384
385    iterations = 10000
386    t = timeit.timeit(stmt=lambda: defined(lookup_call_arg), number=iterations)
387
388    self.report_benchmark(
389        name='cache_key_lookup',
390        iters=iterations,
391        wall_time=t,
392        metrics=[{
393            'name': 'cache_key_lookup_avg_ms',
394            'value': t / iterations * 1000
395        }])
396
397  def benchmarkNestedStruct(self):
398    struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
399
400    def encode_struct(struct):
401      trace_type.from_object(struct)
402
403    iterations = 100000
404    t = timeit.timeit(lambda: encode_struct(struct), number=iterations)
405    self.report_benchmark(
406        name='nested_struct_cache_key_generation',
407        iters=iterations,
408        wall_time=t,
409        metrics=[{
410            'name': 'nested_struct_cache_key_generation_avg_ms',
411            'value': t / iterations * 1000
412        }])
413
414  def benchmarkFunctionInvocation(self):
415    struct = (variables.Variable(1.0), array_ops.zeros([5, 13]), {
416        'tensor': array_ops.zeros([5, 20]),
417        'variable': variables.Variable(1.0)
418    })
419
420    @function.defun
421    def defined(t):
422      return t
423
424    defined(struct)  # Get it traced and cached.
425
426    iterations = 10000
427    t = timeit.timeit(lambda: defined(struct), number=iterations)
428    self.report_benchmark(
429        name='function_invocation',
430        iters=iterations,
431        wall_time=t,
432        metrics=[{
433            'name': 'function_invocation_time_avg_ms',
434            'value': t / iterations * 1000
435        }])
436
437if __name__ == '__main__':
438  v2_compat.enable_v2_behavior()
439  test.main()
440