/
test_util.py
4299 lines (3510 loc) · 149 KB
/
test_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
"""Test utils for tensorflow."""
import collections
from collections import OrderedDict
from collections.abc import Iterable, Iterator, Callable, Collection, Sequence
import contextlib
import functools
import gc
import itertools
import math
import os
import random
import re
import tempfile
import threading
import time
from typing import Any, cast, Union, Optional, overload, TypeVar
import unittest
from absl.testing import parameterized
import numpy as np
from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import text_format
from tensorflow.core.config import flags
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_sanitizers
from tensorflow.python import tf2
from tensorflow.python.client import device_lib
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.client import session as s
from tensorflow.python.compat import v2_compat
from tensorflow.python.compat.compat import forward_compatibility_horizon
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import _test_metrics_util
from tensorflow.python.framework import config
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import gpu_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import tfrt_utils
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gen_sync_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import _pywrap_stacktrace_handler
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import _pywrap_util_port
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util import traceback_utils
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export
_F = TypeVar("_F", bound=Callable[..., Any])
_T = TypeVar("_T")
_TC = TypeVar("_TC", bound=type["TensorFlowTestCase"])
# If the below import is made available through the BUILD rule, then this
# function is overridden and will instead return True and cause Tensorflow
# graphs to be compiled with XLA.
def is_xla_enabled() -> bool:
return False
# pytype: disable=import-error
try:
from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import
except Exception: # pylint: disable=broad-except
pass
# Uses the same mechanism as above to selectively enable/disable MLIR
# compilation.
def is_mlir_bridge_enabled() -> Optional[bool]:
return None
try:
from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import
except ImportError:
try:
from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import
except ImportError:
pass
# pytype: enable=import-error
def is_asan_enabled() -> bool:
"""Check if ASAN is enabled."""
return pywrap_sanitizers.is_asan_enabled()
def is_msan_enabled() -> bool:
"""Check if MSAN is enabled."""
return pywrap_sanitizers.is_msan_enabled()
def is_tsan_enabled() -> bool:
"""Check if TSAN is enabled."""
return pywrap_sanitizers.is_tsan_enabled()
def is_ubsan_enabled() -> bool:
"""Check if UBSAN is enabled."""
return pywrap_sanitizers.is_ubsan_enabled()
def _get_object_count_by_type(
exclude: Iterable[Any] = (),
) -> collections.Counter[str]:
return (
collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
collections.Counter([type(obj).__name__ for obj in exclude]))
@tf_export("test.gpu_device_name")
def gpu_device_name() -> str:
"""Returns the name of a GPU device if available or a empty string.
This method should only be used in tests written with `tf.test.TestCase`.
>>> class MyTest(tf.test.TestCase):
...
... def test_add_on_gpu(self):
... if not tf.test.is_built_with_gpu_support():
... self.skipTest("test is only applicable on GPU")
...
... with tf.device(tf.test.gpu_device_name()):
... self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
"""
for x in device_lib.list_local_devices():
if x.device_type == "GPU":
return compat.as_str(x.name)
return ""
def assert_ops_in_graph(
expected_ops: dict[str, str], graph: ops.Graph
) -> dict[str, node_def_pb2.NodeDef]:
"""Assert all expected operations are found.
Args:
expected_ops: `dict<string, string>` of op name to op type.
graph: Graph to check.
Returns:
`dict<string, node>` of node name to node.
Raises:
ValueError: If the expected ops are not present in the graph.
"""
actual_ops: dict[str, node_def_pb2.NodeDef] = {}
gd = cast(graph_pb2.GraphDef, graph.as_graph_def())
for node in gd.node:
if node.name in expected_ops:
if expected_ops[node.name] != node.op:
raise ValueError("Expected op for node %s is different. %s vs %s" %
(node.name, expected_ops[node.name], node.op))
actual_ops[node.name] = node
if set(expected_ops.keys()) != set(actual_ops.keys()):
raise ValueError("Not all expected ops are present. Expected %s, found %s" %
(expected_ops.keys(), actual_ops.keys()))
return actual_ops
@tf_export("test.assert_equal_graph_def", v1=[])
def assert_equal_graph_def_v2(
expected: graph_pb2.GraphDef, actual: graph_pb2.GraphDef
) -> None:
"""Asserts that two `GraphDef`s are (mostly) the same.
Compares two `GraphDef` protos for equality, ignoring versions and ordering of
nodes, attrs, and control inputs. Node names are used to match up nodes
between the graphs, so the naming of nodes must be consistent. This function
ignores randomized attribute values that may appear in V2 checkpoints.
Args:
expected: The `GraphDef` we expected.
actual: The `GraphDef` we have.
Raises:
AssertionError: If the `GraphDef`s do not match.
TypeError: If either argument is not a `GraphDef`.
"""
assert_equal_graph_def(actual, expected, checkpoint_v2=True,
hash_table_shared_name=True)
@tf_export(v1=["test.assert_equal_graph_def"])
def assert_equal_graph_def_v1(
actual: graph_pb2.GraphDef,
expected: graph_pb2.GraphDef,
checkpoint_v2: bool = False,
hash_table_shared_name: bool = False
) -> None:
"""Asserts that two `GraphDef`s are (mostly) the same.
Compares two `GraphDef` protos for equality, ignoring versions and ordering of
nodes, attrs, and control inputs. Node names are used to match up nodes
between the graphs, so the naming of nodes must be consistent.
Args:
actual: The `GraphDef` we have.
expected: The `GraphDef` we expected.
checkpoint_v2: boolean determining whether to ignore randomized attribute
values that appear in V2 checkpoints.
hash_table_shared_name: boolean determining whether to ignore randomized
shared_names that appear in HashTableV2 op defs.
Raises:
AssertionError: If the `GraphDef`s do not match.
TypeError: If either argument is not a `GraphDef`.
"""
assert_equal_graph_def(actual, expected, checkpoint_v2,
hash_table_shared_name)
def assert_equal_graph_def(
actual: graph_pb2.GraphDef,
expected: graph_pb2.GraphDef,
checkpoint_v2: bool = False,
hash_table_shared_name: bool = False
)-> None:
if not isinstance(actual, graph_pb2.GraphDef):
raise TypeError("Expected tf.GraphDef for actual, got %s" %
type(actual).__name__)
if not isinstance(expected, graph_pb2.GraphDef):
raise TypeError("Expected tf.GraphDef for expected, got %s" %
type(expected).__name__)
if checkpoint_v2:
_strip_checkpoint_v2_randomized(actual)
_strip_checkpoint_v2_randomized(expected)
if hash_table_shared_name:
_strip_hash_table_shared_name(actual)
_strip_hash_table_shared_name(expected)
diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
expected.SerializeToString())
if diff:
raise AssertionError(compat.as_str(diff))
def assert_meta_graph_protos_equal(
tester: "TensorFlowTestCase",
a: meta_graph_pb2.MetaGraphDef,
b: meta_graph_pb2.MetaGraphDef,
) -> None:
"""Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
# Carefully check the collection_defs
tester.assertEqual(set(a.collection_def), set(b.collection_def))
collection_keys = a.collection_def.keys()
for k in collection_keys:
a_value = a.collection_def[k]
b_value = b.collection_def[k]
proto_type = cast(type[message.Message], ops.get_collection_proto_type(k))
if proto_type:
a_proto = proto_type()
b_proto = proto_type()
# Number of entries in the collections is the same
tester.assertEqual(
len(a_value.bytes_list.value), len(b_value.bytes_list.value))
for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
b_value.bytes_list.value):
a_proto.ParseFromString(a_value_item)
b_proto.ParseFromString(b_value_item)
tester.assertProtoEquals(a_proto, b_proto)
else:
tester.assertEqual(a_value, b_value)
# Compared the fields directly, remove their raw values from the
# proto comparison below.
a.ClearField("collection_def")
b.ClearField("collection_def")
# Check the graph_defs.
assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
# Check graph_def versions (ignored by assert_equal_graph_def).
tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
# Compared the fields directly, remove their raw values from the
# proto comparison below.
a.ClearField("graph_def")
b.ClearField("graph_def")
tester.assertProtoEquals(a, b)
# Matches attributes named via _SHARDED_SUFFIX in
# tensorflow/python/training/saver.py
_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
def _strip_checkpoint_v2_randomized(graph_def: graph_pb2.GraphDef) -> None:
for node in graph_def.node:
delete_keys: list[str] = []
for attr_key in node.attr:
attr_tensor_value = cast(
tensor_pb2.TensorProto, node.attr[attr_key].tensor)
if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
attr_tensor_string_value = attr_tensor_value.string_val[0]
if (attr_tensor_string_value and
re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN),
attr_tensor_string_value)):
delete_keys.append(attr_key)
for attr_key in delete_keys:
del node.attr[attr_key]
_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+"
def _strip_hash_table_shared_name(graph_def: graph_pb2.GraphDef) -> None:
for node in graph_def.node:
delete_keys: list[str] = []
if node.op == "HashTableV2" and "shared_name" in node.attr:
if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN),
node.attr["shared_name"].s):
delete_keys.append("shared_name")
for attr_key in delete_keys:
del node.attr[attr_key]
def IsGoogleCudaEnabled() -> bool:
return _pywrap_util_port.IsGoogleCudaEnabled()
def IsBuiltWithROCm() -> bool:
return _pywrap_util_port.IsBuiltWithROCm()
def IsBuiltWithXLA() -> bool:
return _pywrap_util_port.IsBuiltWithXLA()
def IsBuiltWithNvcc() -> bool:
return _pywrap_util_port.IsBuiltWithNvcc()
def IsCPUTargetAvailable(target):
if target.startswith("arm"):
return _pywrap_util_port.IsAArch32Available()
elif target.startswith("aarch64"):
return _pywrap_util_port.IsAArch64Available()
elif target.startswith("ppc") or target.startswith("powerpc"):
return _pywrap_util_port.IsPowerPCAvailable()
elif target.startswith("x86"):
return _pywrap_util_port.IsX86Available()
elif target.startswith("s390x"):
return _pywrap_util_port.IsSystemZAvailable()
assert False, "Unknown CPU target: " + target
def GpuSupportsHalfMatMulAndConv() -> bool:
return _pywrap_util_port.GpuSupportsHalfMatMulAndConv()
def IsMklEnabled() -> bool:
return _pywrap_util_port.IsMklEnabled()
def InstallStackTraceHandler() -> None:
_pywrap_stacktrace_handler.InstallStacktraceHandler()
def NHWCToNCHW(
input_tensor: Union[tensor_lib.Tensor, list[int]]
) -> Union[tensor_lib.Tensor, list[int]]:
"""Converts the input from the NHWC format to NCHW.
Args:
input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape
Returns:
converted tensor or shape array
"""
# tensor dim -> new axis order
new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
if isinstance(input_tensor, tensor_lib.Tensor):
ndims = input_tensor.shape.ndims
return array_ops.transpose(input_tensor, new_axes[ndims])
else:
ndims = len(input_tensor)
return [input_tensor[a] for a in new_axes[ndims]]
def NHWCToNCHW_VECT_C(
input_shape_or_tensor: Union[tensor_lib.Tensor, list[int]]
)-> Union[tensor_lib.Tensor, list[int]]:
"""Transforms the input from the NHWC layout to NCHW_VECT_C layout.
Note: Does not include quantization or type conversion steps, which should
be applied afterwards.
Args:
input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
Returns:
tensor or shape array transformed into NCHW_VECT_C
Raises:
ValueError: if last dimension of `input_shape_or_tensor` is not evenly
divisible by 4.
"""
permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor)
temp_shape: list[int] = (
input_shape_or_tensor.shape.as_list()
if is_tensor else input_shape_or_tensor)
if temp_shape[-1] % 4 != 0:
raise ValueError(
"Last dimension of input must be evenly divisible by 4 to convert to "
"NCHW_VECT_C.")
temp_shape[-1] //= 4
temp_shape.append(4)
permutation = permutations[len(temp_shape)]
if is_tensor:
t = array_ops.reshape(input_shape_or_tensor, temp_shape)
return array_ops.transpose(t, permutation)
else:
return [temp_shape[a] for a in permutation]
def NCHW_VECT_CToNHWC(
input_shape_or_tensor: Union[tensor_lib.Tensor, list[int]]
) -> Union[tensor_lib.Tensor, list[int]]:
"""Transforms the input from the NCHW_VECT_C layout to NHWC layout.
Note: Does not include de-quantization or type conversion steps, which should
be applied beforehand.
Args:
input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
Returns:
tensor or shape array transformed into NHWC
Raises:
ValueError: if last dimension of `input_shape_or_tensor` is not 4.
"""
permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor)
input_shape: list[int] = (
input_shape_or_tensor.shape.as_list()
if is_tensor else input_shape_or_tensor)
if input_shape[-1] != 4:
raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
permutation = permutations[len(input_shape)]
nhwc_shape = [input_shape[a] for a in permutation[:-1]]
nhwc_shape[-1] *= input_shape[-1]
if is_tensor:
t = array_ops.transpose(input_shape_or_tensor, permutation)
return array_ops.reshape(t, nhwc_shape)
else:
return nhwc_shape
def NCHWToNHWC(
input_tensor: Union[tensor_lib.Tensor, list[int]]
) -> Union[tensor_lib.Tensor, list[int]]:
"""Converts the input from the NCHW format to NHWC.
Args:
input_tensor: a 4- or 5-D tensor, or an array representing shape
Returns:
converted tensor or shape array
"""
# tensor dim -> new axis order
new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
if isinstance(input_tensor, tensor_lib.Tensor):
ndims = input_tensor.shape.ndims
return array_ops.transpose(input_tensor, new_axes[ndims])
else:
ndims = len(input_tensor)
return [input_tensor[a] for a in new_axes[ndims]]
def skip_if(condition: Union[Callable[[], bool], bool]) -> Callable[[_F], _F]:
"""Skips the decorated function if condition is or evaluates to True.
Args:
condition: Either an expression that can be used in "if not condition"
statement, or a callable whose result should be a boolean.
Returns:
The wrapped function
"""
def real_skip_if(fn: _F) -> _F:
def wrapper(*args, **kwargs):
if callable(condition):
skip = condition()
else:
skip = condition
if not skip:
return fn(*args, **kwargs)
return wrapper
return real_skip_if
@contextlib.contextmanager
def skip_if_error(
test_obj: unittest.TestCase,
error_type: type[Exception],
messages: Union[str, list[str], None] = None
) -> Iterator[None]:
"""Context manager to skip cases not considered failures by the tests.
Note that this does not work if used in setUpClass/tearDownClass.
Usage in setUp/tearDown works fine just like regular test methods.
Args:
test_obj: A test object provided as `self` in the test methods; this object
is usually an instance of `unittest.TestCase`'s subclass and should have
`skipTest` method.
error_type: The error type to skip. Note that if `messages` are given, both
`error_type` and `messages` need to match for the test to be skipped.
messages: Optional, a string or list of strings. If `None`, the test will be
skipped if `error_type` matches what is raised; otherwise, the test is
skipped if any of the `messages` is contained in the message of the error
raised, and `error_type` matches the error raised.
Yields:
Nothing.
"""
if messages:
messages = nest.flatten(messages)
try:
yield
except error_type as e:
if not messages or any(message in str(e) for message in messages):
test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e)))
else:
raise
def enable_c_shapes(fn: _F) -> _F:
"""No-op. TODO(b/74620627): Remove this."""
return fn
def with_c_shapes(cls: type[_T]) -> type[_T]:
"""No-op. TODO(b/74620627): Remove this."""
return cls
def enable_control_flow_v2(fn: _F) -> _F:
"""Decorator for enabling CondV2 and WhileV2 on a test.
Note this enables using CondV2 and WhileV2 after running the test class's
setup/teardown methods.
In addition to this, callers must import the while_v2 module in order to set
the _while_v2 module in control_flow_ops.
Args:
fn: the function to be wrapped
Returns:
The wrapped function
"""
def wrapper(*args, **kwargs):
enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
try:
return fn(*args, **kwargs)
finally:
control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
return wrapper
def with_control_flow_v2(cls: _TC) -> _TC:
"""Adds methods that call original methods with WhileV2 and CondV2 enabled.
Note this enables CondV2 and WhileV2 in new methods after running the test
class's setup method.
In addition to this, callers must import the while_v2 module in order to set
the _while_v2 module in control_flow_ops.
If a test function has _disable_control_flow_v2 attr set to True (using the
@disable_control_flow_v2 decorator), the v2 function is not generated for it.
Example:
@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
def testEnabledForV2(self):
...
@test_util.disable_control_flow_v2("b/xyzabc")
def testDisabledForV2(self):
...
Generated class:
class ControlFlowTest(test.TestCase):
def testEnabledForV2(self):
...
def testEnabledForV2WithControlFlowV2(self):
// Enable V2 flags.
testEnabledForV2(self)
// Restore V2 flags.
def testDisabledForV2(self):
...
Args:
cls: class to decorate
Returns:
cls with new test methods added
"""
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
return cls
for name, value in cls.__dict__.copy().items():
if (callable(value) and
name.startswith(unittest.TestLoader.testMethodPrefix) and
not getattr(value, "_disable_control_flow_v2", False)):
setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
return cls
def disable_control_flow_v2(unused_msg: str) -> Callable[[_F], _F]:
"""Decorator for a function in a with_control_flow_v2 enabled test class.
Blocks the function from being run with v2 control flow ops.
Args:
unused_msg: Reason for disabling.
Returns:
The wrapped function with _disable_control_flow_v2 attr set to True.
"""
def wrapper(func: _F) -> _F:
func._disable_control_flow_v2 = True
return func
return wrapper
def enable_output_all_intermediates(fn: _F) -> _F:
"""Force-enable outputing all intermediates from functional control flow ops.
Args:
fn: the function to be wrapped
Returns:
The wrapped function
"""
def wrapper(*args, **kwargs):
output_all_intermediates_old = \
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
try:
return fn(*args, **kwargs)
finally:
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
output_all_intermediates_old
return wrapper
def assert_no_new_pyobjects_executing_eagerly(
warmup_iters: int = 2,
) -> Callable[[Callable[..., Any]], Callable[..., None]]:
"""Decorator for asserting that no new Python objects persist after a test.
Returns a decorator that runs the test multiple times executing eagerly,
first as a warmup and then to let objects accumulate. The warmup helps ignore
caches which do not grow as the test is run repeatedly.
Useful for checking that there are no missing Py_DECREFs in the C exercised by
a bit of Python.
Args:
warmup_iters: The numer of warmup iterations, excluded from measuring.
Returns:
A decorator function which can be applied to the test function.
"""
def wrap_f(f: Callable[..., Any]) -> Callable[..., None]:
def decorator(self: "TensorFlowTestCase", *args, **kwargs) -> None:
"""Warms up, gets object counts, runs the test, checks for new objects."""
with context.eager_mode():
gc.disable()
# Python 3.11 removed "errors" and "skipped" as members of
# unittest.case._Outcome so get them from the test result object
# instead.
test_errors = None
test_skipped = None
if hasattr(self._outcome, "errors"):
test_errors = self._outcome.errors
test_skipped = self._outcome.skipped
else:
test_errors = self._outcome.result.errors
test_skipped = self._outcome.result.skipped
# Run the test 2 times as warmup, in an attempt to fill up caches, which
# should not grow as the test is run repeatedly below.
#
# TODO(b/117156879): Running warmup twice is black magic; we have seen
# tests that fail with 1 warmup run, and pass with 2, on various
# versions of python2.7.x.
for _ in range(warmup_iters):
f(self, *args, **kwargs)
# Since we aren't in the normal test lifecycle, we need to manually run
# cleanups to clear out their object references.
self.doCleanups()
# Some objects are newly created by _get_object_count_by_type(). So
# create and save as a dummy variable to include it as a baseline.
obj_count_by_type = _get_object_count_by_type()
gc.collect()
# Make sure any registered functions are cleaned up in the C++ runtime.
registered_function_names = context.context().list_function_names()
# unittest.doCleanups adds to self._outcome with each unwound call.
# These objects are retained across gc collections so we exclude them
# from the object count calculation.
obj_count_by_type = _get_object_count_by_type(
exclude=gc.get_referents(test_errors, test_skipped))
if ops.has_default_graph():
collection_sizes_before = {
collection: len(ops.get_collection(collection))
for collection in ops.get_default_graph().collections
}
for _ in range(3):
f(self, *args, **kwargs)
# Since we aren't in the normal test lifecycle, we need to manually run
# cleanups to clear out their object references.
self.doCleanups()
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
if ops.has_default_graph():
for collection_key in ops.get_default_graph().collections:
collection = ops.get_collection(collection_key)
size_before = collection_sizes_before.get(collection_key, 0)
if len(collection) > size_before:
raise AssertionError(
("Collection %s increased in size from "
"%d to %d (current items %s).") %
(collection_key, size_before, len(collection), collection))
# Make sure our collection checks don't show up as leaked memory by
# removing references to temporary variables.
del collection
del collection_key
del size_before
del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
obj_count_by_type = (
_get_object_count_by_type(
exclude=gc.get_referents(test_errors, test_skipped)) -
obj_count_by_type)
# There should be no newly registered functions hanging around.
leftover_functions = (
context.context().list_function_names() - registered_function_names)
assert not leftover_functions, (
"The following functions were newly created: %s" %
leftover_functions)
# In some cases (specifically on MacOS), new_count is somehow
# smaller than previous_count.
# Using plain assert because not all classes using this decorator
# have assertLessEqual
assert not obj_count_by_type, (
"The following objects were newly created: %s" %
str(obj_count_by_type))
gc.enable()
return tf_decorator.make_decorator(f, decorator)
return wrap_f
def assert_no_new_tensors(f: _F) -> _F:
"""Decorator for asserting that no new Tensors persist after a test.
Mainly useful for checking that code using the Python C API has correctly
manipulated reference counts.
Clears the caches that it knows about, runs the garbage collector, then checks
that there are no Tensor or Tensor-like objects still around. This includes
Tensors to which something still has a reference (e.g. from missing
Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
of the objects has __del__ defined).
Args:
f: The test case to run.
Returns:
The decorated test case.
"""
def decorator(self: "TensorFlowTestCase", **kwargs):
"""Finds existing Tensors, runs the test, checks for new Tensors."""
def _is_tensorflow_object(obj) -> bool:
try:
return isinstance(obj,
(tensor_lib.Tensor, variables.Variable,
tensor_shape.Dimension, tensor_shape.TensorShape))
except (ReferenceError, AttributeError):
# If the object no longer exists, we don't care about it.
return False
tensors_before = set(
id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
outside_executed_eagerly = cast(bool, context.executing_eagerly())
# Run the test in a new graph so that collections get cleared when it's
# done, but inherit the graph key so optimizers behave.
outside_graph_key = ops.get_default_graph()._graph_key
with ops.Graph().as_default():
ops.get_default_graph()._graph_key = outside_graph_key
if outside_executed_eagerly:
with context.eager_mode():
result = f(self, **kwargs)
else:
result = f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
context.context()._clear_caches() # pylint: disable=protected-access
gc.collect()
tensors_after = [
obj for obj in gc.get_objects()
if _is_tensorflow_object(obj) and id(obj) not in tensors_before
]
if tensors_after:
raise AssertionError(("%d Tensors not deallocated after test: %s" % (
len(tensors_after),
str(tensors_after),
)))
return result
return tf_decorator.make_decorator(f, decorator)
def _find_reference_cycle(objects: Sequence[Any], idx: int) -> bool:
def get_ignore_reason(obj: Any, denylist: Collection[Any]) -> Optional[str]:
"""Tests whether an object should be omitted from the dependency graph."""
if len(denylist) > 100:
return "<depth limit>"
if tf_inspect.isframe(obj):
if "test_util.py" in tf_inspect.getframeinfo(obj)[0]:
return "<test code>"
for b in denylist:
if b is obj:
return "<test code>"
if obj is denylist:
return "<test code>"
return None
# Note: this function is meant to help with diagnostics. Its output is purely
# a human-readable representation, so you may freely modify it to suit your
# needs.
def describe(
obj: Any, denylist: Collection[Any], leaves_only: bool = False,
) -> str:
"""Returns a custom human-readable summary of obj.
Args:
obj: the value to describe.
denylist: same as denylist in get_ignore_reason.
leaves_only: boolean flag used when calling describe recursively. Useful
for summarizing collections.
"""
if get_ignore_reason(obj, denylist):
return "{}{}".format(get_ignore_reason(obj, denylist), type(obj))
if tf_inspect.isframe(obj):
return "frame: {}".format(tf_inspect.getframeinfo(obj))
elif tf_inspect.ismodule(obj):
return "module: {}".format(obj.__name__)
else:
if leaves_only:
return "{}, {}".format(type(obj), id(obj))
elif isinstance(obj, list):
return "list({}): {}".format(
id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
elif isinstance(obj, tuple):
return "tuple({}): {}".format(
id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
elif isinstance(obj, dict):
return "dict({}): {} keys".format(id(obj), len(obj.keys()))
elif tf_inspect.isfunction(obj):
return "function({}) {}; globals ID: {}".format(
id(obj), obj.__name__, id(obj.__globals__))
else:
return "{}, {}".format(type(obj), id(obj))
def build_ref_graph(
obj: Any,
graph: dict[int, list[int]],
reprs: dict[int, str],
denylist: tuple[Any, ...],
) -> None:
"""Builds a reference graph as <referrer> -> <list of referents>.
Args:
obj: The object to start from. The graph will be built by recursively
adding its referrers.
graph: Dict holding the graph to be built. To avoid creating extra
references, the graph holds object IDs rather than actual objects.
reprs: Auxiliary structure that maps object IDs to their human-readable
description.
denylist: List of objects to ignore.
"""
referrers = gc.get_referrers(obj)
denylist = denylist + (referrers,)
obj_id = id(obj)
for r in referrers:
if get_ignore_reason(r, denylist) is None:
r_id = id(r)
if r_id not in graph:
graph[r_id] = []
if obj_id not in graph[r_id]:
graph[r_id].append(obj_id)
build_ref_graph(r, graph, reprs, denylist)
reprs[r_id] = describe(r, denylist)
def find_cycle(
el: int,
graph: dict[int, list[int]],
reprs: dict[int, str],
path: tuple[int, ...],
) -> Optional[bool]:
"""Finds and prints a single cycle in the dependency graph."""
if el not in graph:
return