xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/BUILD (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
2load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
3load("//tensorflow:tensorflow.bzl", "cuda_py_test")
4load("//tensorflow/tsl/platform/default:distribute.bzl", "distribute_py_test")
5load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
6
7package(
8    default_visibility = [
9        "//tensorflow:internal",
10        "//third_party/py/keras:__subpackages__",  # TODO(scottzhu): remove this once keras is relying on tf.__internal__.
11    ],
12    licenses = ["notice"],
13)
14
15py_library(
16    name = "distribute_test_lib_pip",
17    srcs_version = "PY3",
18    deps = [
19        ":combinations",
20        ":multi_worker_test_base",
21        ":single_loss_example",
22        ":strategy_combinations",
23        ":strategy_test_lib",
24        ":test_util",
25        ":values_v2",
26        "//tensorflow/python/distribute/v1:all_reduce",
27    ],
28)
29
30py_library(
31    name = "cross_device_ops",
32    srcs = ["cross_device_ops.py"],
33    srcs_version = "PY3",
34    deps = [
35        ":collective_util",
36        ":cross_device_utils",
37        ":device_util",
38        ":distribute_utils",
39        ":ps_values",
40        ":reduce_util",
41        ":tpu_values",
42        ":values",
43        ":values_util",
44        "//tensorflow/python:array_ops",
45        "//tensorflow/python:framework_ops",
46        "//tensorflow/python:kernels",
47        "//tensorflow/python:math_ops",
48        "//tensorflow/python:platform",
49        "//tensorflow/python:resource_variable_ops",
50        "//tensorflow/python:tensor_util",
51        "//tensorflow/python/client:device_lib",
52        "//tensorflow/python/eager:context",
53        "//tensorflow/python/eager:def_function",
54        "//tensorflow/python/framework:indexed_slices",
55        "//tensorflow/python/util",
56        "//tensorflow/python/util:tf_export",
57        "//tensorflow/tools/docs:doc_controls",
58        "@six_archive//:six",
59    ],
60)
61
62pytype_strict_library(
63    name = "cross_device_utils",
64    srcs = ["cross_device_utils.py"],
65    srcs_version = "PY3",
66    deps = [
67        ":collective_util",
68        ":values",
69        "//tensorflow/python:array_ops",
70        "//tensorflow/python:collective_ops",
71        "//tensorflow/python:control_flow_ops",
72        "//tensorflow/python:dtypes",
73        "//tensorflow/python:framework_ops",
74        "//tensorflow/python:math_ops",
75        "//tensorflow/python:nccl_ops",
76        "//tensorflow/python:platform",
77        "//tensorflow/python:resource_variable_ops",
78        "//tensorflow/python/eager:backprop",
79        "//tensorflow/python/eager:context",
80        "//tensorflow/python/framework:indexed_slices",
81        "//tensorflow/python/framework:tensor_spec",
82        "//tensorflow/python/types",
83    ],
84)
85
86py_library(
87    name = "device_util",
88    srcs = ["device_util.py"],
89    srcs_version = "PY3",
90    deps = [
91        "//tensorflow/python:device",
92        "//tensorflow/python:framework_ops",
93        "//tensorflow/python/eager:context",
94    ],
95)
96
97cuda_py_test(
98    name = "device_util_test",
99    srcs = ["device_util_test.py"],
100    python_version = "PY3",
101    deps = [
102        ":combinations",
103        ":device_util",
104        ":multi_worker_test_base",
105        ":multi_worker_util",
106        "//tensorflow/core:protos_all_py",
107        "//tensorflow/python:client_testlib",
108        "//tensorflow/python:extra_py_tests_deps",
109        "//tensorflow/python:framework_ops",
110        "//tensorflow/python:training_server_lib",
111        "//tensorflow/python/eager:context",
112        "@absl_py//absl/testing:parameterized",
113    ],
114)
115
116py_library(
117    name = "distribute",
118    srcs_version = "PY3",
119    deps = [
120        ":cross_device_ops",
121        ":distribute_lib",
122        ":merge_call_interim",
123        ":mirrored_strategy",
124        ":multi_process_runner",
125        ":multi_worker_test_base",
126        ":one_device_strategy",
127        ":parameter_server_strategy_v2",
128        ":sharded_variable",
129        "//tensorflow/python/distribute/coordinator:cluster_coordinator",
130        "//tensorflow/python/distribute/experimental",
131        "//tensorflow/python/distribute/failure_handling:failure_handling_lib",
132    ],
133)
134
135py_library(
136    name = "distribute_lib",
137    srcs = [
138        "distribute_lib.py",
139        "distribution_strategy_context.py",
140    ],
141    srcs_version = "PY3",
142    deps = [
143        ":collective_util",
144        ":device_util",
145        ":numpy_dataset",
146        ":reduce_util",
147        "//tensorflow/python:array_ops",
148        "//tensorflow/python:constant_op",
149        "//tensorflow/python:control_flow_ops",
150        "//tensorflow/python:dtypes",
151        "//tensorflow/python:framework_ops",
152        "//tensorflow/python:platform",
153        "//tensorflow/python:resource_variable_ops",
154        "//tensorflow/python:state_ops",
155        "//tensorflow/python:summary_ops_v2",
156        "//tensorflow/python:util",
157        "//tensorflow/python:variable_scope",
158        "//tensorflow/python/data/ops:dataset_ops",
159        "//tensorflow/python/ops/losses",
160        "//tensorflow/tools/docs:doc_controls",
161    ],
162)
163
164py_test(
165    name = "distribute_lib_test",
166    size = "small",
167    srcs = ["distribute_lib_test.py"],
168    python_version = "PY3",
169    srcs_version = "PY3",
170    deps = [
171        ":combinations",
172        ":distribute_lib",
173        ":input_lib",
174        ":reduce_util",
175        ":values",
176        "//tensorflow/python:client_testlib",
177        "//tensorflow/python:constant_op",
178        "//tensorflow/python:dtypes",
179        "//tensorflow/python:util",
180        "//tensorflow/python:variable_scope",
181        "//tensorflow/python:variables",
182        "//tensorflow/python/autograph/core:test_lib",
183        "//tensorflow/python/data/ops:dataset_ops",
184        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
185    ],
186)
187
188py_library(
189    name = "distribute_config",
190    srcs = [
191        "distribute_config.py",
192    ],
193    srcs_version = "PY3",
194    deps = [],
195)
196
197py_library(
198    name = "distribute_coordinator",
199    srcs = [
200        "distribute_coordinator.py",
201    ],
202    srcs_version = "PY3",
203    deps = [
204        ":distribute_coordinator_context",
205        ":multi_worker_util",
206        "//tensorflow/core:protos_all_py",
207        "//tensorflow/python:platform",
208        "//tensorflow/python:session",
209        "//tensorflow/python:training_server_lib",
210    ],
211)
212
213py_test(
214    name = "distribute_coordinator_test",
215    size = "medium",
216    srcs = ["distribute_coordinator_test.py"],
217    python_version = "PY3",
218    srcs_version = "PY3",
219    tags = [
220        "notsan",  # TODO(b/220133218)
221    ],
222    deps = [
223        ":distribute_coordinator",
224        "//tensorflow/core:protos_all_py",
225        "//tensorflow/python:client_testlib",
226        "//tensorflow/python:control_flow_ops",
227        "//tensorflow/python:distributed_framework_test_lib",
228        "//tensorflow/python:framework_ops",
229        "//tensorflow/python:framework_test_lib",
230        "//tensorflow/python:math_ops",
231        "//tensorflow/python:session",
232        "//tensorflow/python:training",
233        "//tensorflow/python:variable_scope",
234        "//tensorflow/python:variables",
235    ],
236)
237
238py_library(
239    name = "distribute_coordinator_context",
240    srcs = [
241        "distribute_coordinator_context.py",
242    ],
243    srcs_version = "PY3",
244    deps = [],
245)
246
247py_library(
248    name = "mirrored_run",
249    srcs = ["mirrored_run.py"],
250    srcs_version = "PY3",
251    deps = [
252        ":device_util",
253        ":distribute_lib",
254        ":reduce_util",
255        ":shared_variable_creator",
256        ":values",
257        "//tensorflow/python:array_ops",
258        "//tensorflow/python:config",
259        "//tensorflow/python:constant_op",
260        "//tensorflow/python:device",
261        "//tensorflow/python:dtypes",
262        "//tensorflow/python:framework_ops",
263        "//tensorflow/python:platform",
264        "//tensorflow/python:pywrap_tfe",
265        "//tensorflow/python:summary_ops_v2",
266        "//tensorflow/python:tensor_util",
267        "//tensorflow/python:training",
268        "//tensorflow/python:util",
269        "//tensorflow/python:variable_scope",
270        "//tensorflow/python/autograph/core",
271        "//tensorflow/python/autograph/impl",
272        "//tensorflow/python/eager:context",
273        "//tensorflow/python/eager:def_function",
274        "//tensorflow/python/util:tf_export",
275    ],
276)
277
278py_library(
279    name = "distribute_utils",
280    srcs = ["distribute_utils.py"],
281    srcs_version = "PY3",
282    deps = [
283        ":device_util",
284        ":distribute_lib",
285        ":ps_values",
286        ":reduce_util",
287        ":sharded_variable",
288        ":shared_variable_creator",
289        ":tpu_values",
290        ":values",
291        "//tensorflow/python:array_ops",
292        "//tensorflow/python:config",
293        "//tensorflow/python:constant_op",
294        "//tensorflow/python:device",
295        "//tensorflow/python:dtypes",
296        "//tensorflow/python:framework_ops",
297        "//tensorflow/python:platform",
298        "//tensorflow/python:pywrap_tfe",
299        "//tensorflow/python:summary_ops_v2",
300        "//tensorflow/python:tensor_util",
301        "//tensorflow/python:util",
302        "//tensorflow/python:variable_scope",
303        "//tensorflow/python/autograph/core",
304        "//tensorflow/python/autograph/impl",
305        "//tensorflow/python/eager:context",
306        "//tensorflow/python/eager:def_function",
307        "//tensorflow/python/util:tf_export",
308    ],
309)
310
311py_library(
312    name = "tpu_util",
313    srcs = ["tpu_util.py"],
314    deps = [
315        ":packed_distributed_variable",
316        "//tensorflow/python:framework_ops",
317        "//tensorflow/python/eager:context",
318        "//tensorflow/python/tpu:tpu_py",
319    ],
320)
321
322py_library(
323    name = "mirrored_strategy",
324    srcs = ["mirrored_strategy.py"],
325    srcs_version = "PY3",
326    deps = [
327        ":collective_util",
328        ":cross_device_ops",
329        ":device_util",
330        ":distribute_lib",
331        ":distribute_utils",
332        ":input_lib",
333        ":input_util",
334        ":mirrored_run",
335        ":multi_worker_util",
336        ":numpy_dataset",
337        ":reduce_util",
338        ":values",
339        "//tensorflow/python:array_ops",
340        "//tensorflow/python:constant_op",
341        "//tensorflow/python:control_flow_ops",
342        "//tensorflow/python:device",
343        "//tensorflow/python:dtypes",
344        "//tensorflow/python:framework_ops",
345        "//tensorflow/python:util",
346        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
347        "//tensorflow/python/distribute/v1:input_lib",
348        "//tensorflow/python/eager:context",
349        "//tensorflow/python/eager:tape",
350    ],
351)
352
353py_library(
354    name = "parameter_server_strategy",
355    srcs = ["parameter_server_strategy.py"],
356    srcs_version = "PY3",
357    visibility = ["//tensorflow:internal"],
358    deps = [
359        ":cross_device_ops",
360        ":device_util",
361        ":distribute_lib",
362        ":distribute_utils",
363        ":input_lib",
364        ":input_util",
365        ":mirrored_run",
366        ":multi_worker_util",
367        ":numpy_dataset",
368        ":ps_values",
369        ":values",
370        "//tensorflow/python:array_ops",
371        "//tensorflow/python:device",
372        "//tensorflow/python:framework_ops",
373        "//tensorflow/python:platform",
374        "//tensorflow/python:resource_variable_ops",
375        "//tensorflow/python:training",
376        "//tensorflow/python:util",
377        "//tensorflow/python:variable_scope",
378        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
379        "//tensorflow/python/distribute/v1:input_lib",
380        "//tensorflow/python/eager:context",
381        "//tensorflow/python/util:tf_export",
382    ],
383)
384
385py_library(
386    name = "central_storage_strategy",
387    srcs = ["central_storage_strategy.py"],
388    srcs_version = "PY3",
389    visibility = ["//tensorflow:internal"],
390    deps = [
391        ":device_util",
392        ":distribute_lib",
393        ":parameter_server_strategy",
394        "//tensorflow/python:util",
395    ],
396)
397
398py_library(
399    name = "one_device_strategy",
400    srcs = ["one_device_strategy.py"],
401    srcs_version = "PY3",
402    visibility = ["//tensorflow:internal"],
403    deps = [
404        ":distribute_lib",
405        ":input_lib",
406        ":input_util",
407        ":numpy_dataset",
408        ":reduce_util",
409        ":values",
410        "//tensorflow/python:array_ops",
411        "//tensorflow/python:dtypes",
412        "//tensorflow/python:framework_ops",
413        "//tensorflow/python:math_ops",
414        "//tensorflow/python/distribute/v1:input_lib",
415        "//tensorflow/python/eager:context",
416        "@six_archive//:six",
417    ],
418)
419
420py_library(
421    name = "collective_all_reduce_strategy",
422    srcs = ["collective_all_reduce_strategy.py"],
423    srcs_version = "PY3",
424    visibility = ["//tensorflow:internal"],
425    deps = [
426        ":collective_util",
427        ":cross_device_ops",
428        ":cross_device_utils",
429        ":device_util",
430        ":distribute_lib",
431        ":distribute_utils",
432        ":input_lib",
433        ":input_util",
434        ":mirrored_strategy",
435        ":multi_worker_util",
436        ":numpy_dataset",
437        ":reduce_util",
438        ":values",
439        "//tensorflow/core:protos_all_py",
440        "//tensorflow/python:array_ops",
441        "//tensorflow/python:collective_ops",
442        "//tensorflow/python:control_flow_util",
443        "//tensorflow/python:errors",
444        "//tensorflow/python:framework_ops",
445        "//tensorflow/python:platform",
446        "//tensorflow/python:training",
447        "//tensorflow/python:util",
448        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
449        "//tensorflow/python/distribute/v1:input_lib",
450        "//tensorflow/python/eager:context",
451        "//tensorflow/python/framework:device",
452        "//tensorflow/python/tpu:tpu_strategy_util",
453        "//tensorflow/python/util:tf_export",
454    ],
455)
456
457py_library(
458    name = "multi_worker_util",
459    srcs = [
460        "multi_worker_util.py",
461    ],
462    srcs_version = "PY3",
463    deps = [
464        "//tensorflow/core:protos_all_py",
465        "//tensorflow/python:training_server_lib",
466    ],
467)
468
469cuda_py_test(
470    name = "multi_worker_continuous_run_test",
471    srcs = [
472        "multi_worker_continuous_run_test.py",
473    ],
474    python_version = "PY3",
475    tags = [
476        "no_windows",  # TODO(b/184424727): Re-enable this.
477        "noasan",  # TODO(b/180630068)
478        "nomsan",  # TODO(b/180630068)
479        "notsan",  # TODO(b/151841995)
480    ],
481    deps = [
482        ":collective_all_reduce_strategy",
483        ":combinations",
484        ":multi_process_runner",
485        ":multi_worker_test_base",
486        ":reduce_util",
487        ":strategy_combinations",
488        "//tensorflow/python:array_ops",
489        "//tensorflow/python:errors",
490        "//tensorflow/python:framework_ops",
491        "//tensorflow/python:framework_test_lib",
492        "//tensorflow/python:math_ops",
493        "//tensorflow/python/data/ops:dataset_ops",
494        "//tensorflow/python/eager:context",
495        "//tensorflow/python/eager:test",
496        "//third_party/py/numpy",
497        "@absl_py//absl/testing:parameterized",
498    ],
499)
500
501py_library(
502    name = "numpy_dataset",
503    srcs = ["numpy_dataset.py"],
504    srcs_version = "PY3",
505    deps = [
506        "//tensorflow/python:array_ops",
507        "//tensorflow/python:dtypes",
508        "//tensorflow/python:framework_ops",
509        "//tensorflow/python:util",
510        "//tensorflow/python:variable_scope",
511        "//tensorflow/python/data/ops:dataset_ops",
512        "//tensorflow/python/eager:context",
513        "//third_party/py/numpy",
514    ],
515)
516
517py_test(
518    name = "numpy_dataset_test",
519    size = "small",
520    srcs = ["numpy_dataset_test.py"],
521    python_version = "PY3",
522    srcs_version = "PY3",
523    deps = [
524        ":numpy_dataset",
525        "//tensorflow/python:framework_test_lib",
526        "//tensorflow/python:variable_scope",
527        "//tensorflow/python/eager:test",
528        "//third_party/py/numpy",
529    ],
530)
531
532py_library(
533    name = "input_lib",
534    srcs = ["input_lib.py"],
535    srcs_version = "PY3",
536    deps = [
537        ":device_util",
538        ":distribute_lib",
539        ":distribute_utils",
540        ":input_ops",
541        ":reduce_util",
542        ":values",
543        "//tensorflow/python:array_ops",
544        "//tensorflow/python:control_flow_ops",
545        "//tensorflow/python:framework_ops",
546        "//tensorflow/python:math_ops",
547        "//tensorflow/python:sparse_tensor",
548        "//tensorflow/python/data/experimental/ops:batching",
549        "//tensorflow/python/data/experimental/ops:cardinality",
550        "//tensorflow/python/data/experimental/ops:distribute",
551        "//tensorflow/python/data/ops:dataset_ops",
552        "//tensorflow/python/data/ops:iterator_ops",
553        "//tensorflow/python/data/ops:multi_device_iterator_ops",
554        "//tensorflow/python/data/ops:optional_ops",
555        "//tensorflow/python/eager:context",
556        "//tensorflow/python/eager:monitoring",
557        "//tensorflow/python/framework:composite_tensor",
558        "//tensorflow/python/framework:device",
559        "//tensorflow/python/framework:dtypes",
560        "//tensorflow/python/framework:errors",
561        "//tensorflow/python/framework:tensor_shape",
562        "//tensorflow/python/framework:tensor_util",
563        "//tensorflow/python/framework:type_spec",
564        "//tensorflow/python/ops/ragged:ragged_tensor",
565        "//tensorflow/python/platform",
566        "//tensorflow/python/types",
567        "//tensorflow/python/util",
568        "//tensorflow/python/util:tf_export",
569        "//tensorflow/tools/docs:doc_controls",
570        "@six_archive//:six",
571    ],
572)
573
574py_library(
575    name = "input_ops",
576    srcs = ["input_ops.py"],
577    srcs_version = "PY3",
578    deps = [
579        "//tensorflow/python:framework_ops",
580        "//tensorflow/python/data/util:nest",
581    ],
582)
583
584cuda_py_test(
585    name = "input_ops_test",
586    srcs = ["input_ops_test.py"],
587    python_version = "PY3",
588    deps = [
589        ":input_ops",
590        "//tensorflow/python:client_testlib",
591        "//tensorflow/python:errors",
592        "//tensorflow/python:framework_ops",
593        "//tensorflow/python:framework_test_lib",
594        "//tensorflow/python:io_ops",
595        "//tensorflow/python:util",
596        "//tensorflow/python/data/ops:dataset_ops",
597        "//tensorflow/python/data/ops:readers",
598        "//tensorflow/python/data/util:structure",
599    ],
600)
601
602py_test(
603    name = "multi_worker_util_test",
604    srcs = ["multi_worker_util_test.py"],
605    python_version = "PY3",
606    srcs_version = "PY3",
607    deps = [
608        ":multi_worker_util",
609        "//tensorflow/core:protos_all_py",
610        "//tensorflow/python:constant_op",
611        "//tensorflow/python:framework_ops",
612        "//tensorflow/python:framework_test_lib",
613        "//tensorflow/python:math_ops",
614        "//tensorflow/python:training",
615        "//tensorflow/python/eager:test",
616        "//third_party/py/numpy",
617        "@absl_py//absl/testing:parameterized",
618    ],
619)
620
621py_library(
622    name = "tpu_strategy",
623    srcs = ["tpu_strategy.py"],
624    srcs_version = "PY3",
625    visibility = ["//tensorflow:internal"],
626    deps = [
627        ":cross_device_ops",
628        ":device_util",
629        ":distribute_lib",
630        ":distribute_utils",
631        ":input_lib",
632        ":input_util",
633        ":numpy_dataset",
634        ":reduce_util",
635        ":tpu_replicated_variable",
636        ":tpu_util",
637        ":tpu_values",
638        ":values",
639        "//tensorflow/compiler/xla/experimental/xla_sharding",
640        "//tensorflow/python:array_ops",
641        "//tensorflow/python:constant_op",
642        "//tensorflow/python:control_flow_ops",
643        "//tensorflow/python:device_spec",
644        "//tensorflow/python:dtypes",
645        "//tensorflow/python:framework_ops",
646        "//tensorflow/python:math_ops",
647        "//tensorflow/python:resource_variable_ops",
648        "//tensorflow/python:tensor_shape",
649        "//tensorflow/python:tensor_util",
650        "//tensorflow/python:util",
651        "//tensorflow/python:variables",
652        "//tensorflow/python/autograph/core",
653        "//tensorflow/python/autograph/impl",
654        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
655        "//tensorflow/python/distribute/v1:input_lib",
656        "//tensorflow/python/eager:context",
657        "//tensorflow/python/eager:def_function",
658        "//tensorflow/python/eager:function",
659        "//tensorflow/python/framework:sparse_tensor",
660        "//tensorflow/python/ops/ragged:ragged_tensor",
661        "//tensorflow/python/tpu:device_assignment",
662        "//tensorflow/python/tpu:tpu_hardware_feature",
663        "//tensorflow/python/tpu:tpu_lib",
664        "//tensorflow/python/tpu:tpu_py",
665        "//tensorflow/python/tpu/ops",
666        "//tensorflow/python/util:tf_export",
667        "//third_party/py/numpy",
668        "@absl_py//absl/logging",
669    ],
670)
671
672distribute_py_test(
673    name = "random_generator_test",
674    srcs = ["random_generator_test.py"],
675    main = "random_generator_test.py",
676    shard_count = 12,
677    tags = [
678        "multi_and_single_gpu",
679        "no_cuda_asan",  # b/213388775
680        "no_oss",  # b/241013307
681    ],
682    tpu_tags = [
683        "no_oss",
684    ],
685    xla_enable_strict_auto_jit = False,  # PSStrategy doesn't work on _xla tests
686    deps = [
687        "//tensorflow/python:client_testlib",
688        "//tensorflow/python:stateful_random_ops",
689        "//tensorflow/python:util",
690        "//tensorflow/python/compat:v2_compat",
691        "//tensorflow/python/distribute:combinations",
692        "//tensorflow/python/distribute:strategy_combinations",
693    ],
694)
695
696tpu_py_test(
697    name = "tpu_strategy_test",
698    srcs = ["tpu_strategy_test.py"],
699    disable_experimental = True,
700    python_version = "PY3",
701    tags = ["no_oss"],
702    deps = [
703        ":strategy_test_lib",
704        ":tpu_strategy",
705        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
706        "//tensorflow/python/eager:remote",
707        "//tensorflow/python/eager:test",
708        "@absl_py//absl/testing:parameterized",
709    ],
710)
711
712tpu_py_test(
713    name = "tpu_strategy_compilation_test",
714    srcs = ["tpu_strategy_compilation_test.py"],
715    disable_experimental = True,
716    disable_mlir_bridge = False,
717    python_version = "PY3",
718    tags = ["no_oss"],
719    deps = [
720        ":tpu_strategy",
721        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
722        "//tensorflow/python/eager:remote",
723        "//tensorflow/python/eager:test",
724    ],
725)
726
727# Used only by estimator.
728py_library(
729    name = "estimator_training",
730    srcs = [
731        "estimator_training.py",
732    ],
733    srcs_version = "PY3",
734    deps = [
735        ":distribute_coordinator",
736        ":distribute_coordinator_context",
737        "//tensorflow/python:training",
738    ],
739)
740
741py_library(
742    name = "reduce_util",
743    srcs = ["reduce_util.py"],
744    srcs_version = "PY3",
745    deps = [
746        "//tensorflow/python:util",
747        "//tensorflow/python:variable_scope",
748    ],
749)
750
751py_library(
752    name = "collective_util",
753    srcs = ["collective_util.py"],
754    srcs_version = "PY3",
755    deps = [
756        "//tensorflow/python:util",
757        "//tensorflow/python:variable_scope",
758    ],
759)
760
761tf_py_test(
762    name = "collective_util_test",
763    srcs = ["collective_util_test.py"],
764    deps = [
765        ":collective_util",
766        "//tensorflow/python/eager:test",
767    ],
768)
769
770py_library(
771    name = "shared_variable_creator",
772    srcs = ["shared_variable_creator.py"],
773    srcs_version = "PY3",
774)
775
776py_test(
777    name = "shared_variable_creator_test",
778    srcs = ["shared_variable_creator_test.py"],
779    python_version = "PY3",
780    srcs_version = "PY3",
781    deps = [
782        ":shared_variable_creator",
783        "//tensorflow/python:framework_test_lib",
784        "//tensorflow/python:variable_scope",
785        "//tensorflow/python/eager:test",
786    ],
787)
788
789py_library(
790    name = "summary_op_util",
791    srcs = ["summary_op_util.py"],
792    srcs_version = "PY3",
793    deps = [
794        ":distribute_lib",
795        "//tensorflow/python:framework_ops",
796        "//tensorflow/python:tensor_util",
797    ],
798)
799
800py_library(
801    name = "packed_distributed_variable",
802    srcs = ["packed_distributed_variable.py"],
803    srcs_version = "PY3",
804    deps = [
805        ":device_util",
806        "//tensorflow/python:framework_ops",
807        "//tensorflow/python:math_ops",
808        "//tensorflow/python:resource_variable_ops",
809        "//tensorflow/python/eager:context",
810    ],
811)
812
813py_library(
814    name = "values",
815    srcs = ["values.py"],
816    srcs_version = "PY3",
817    deps = [
818        ":device_util",
819        ":distribute_lib",
820        ":packed_distributed_variable",
821        ":reduce_util",
822        ":values_util",
823        "//tensorflow/python:array_ops",
824        "//tensorflow/python:composite_tensor",
825        "//tensorflow/python:control_flow_ops",
826        "//tensorflow/python:framework_ops",
827        "//tensorflow/python:math_ops",
828        "//tensorflow/python:resource_variable_ops",
829        "//tensorflow/python:type_spec",
830        "//tensorflow/python:variable_scope",
831        "//tensorflow/python:variables",
832        "//tensorflow/python/eager:context",
833        "//tensorflow/python/saved_model:save_context",
834        "//tensorflow/python/trackable:base",
835        "//tensorflow/python/training/saving:saveable_object",
836        "//tensorflow/python/training/saving:saveable_object_util",
837        "//tensorflow/python/types",
838        "//tensorflow/python/util:tf_export",
839    ],
840)
841
842py_library(
843    name = "values_v2",
844    srcs = ["values_v2.py"],
845    deps = [
846        ":device_util",
847        ":distribute_lib",
848        ":reduce_util",
849        ":tpu_util",
850        ":values",
851        ":values_util",
852        "//tensorflow/python:array_ops",
853        "//tensorflow/python:control_flow_ops",
854        "//tensorflow/python:framework_ops",
855        "//tensorflow/python:math_ops",
856        "//tensorflow/python:resource_variable_ops",
857        "//tensorflow/python:variables",
858        "//tensorflow/python/eager:context",
859        "//tensorflow/python/types",
860    ],
861)
862
863distribute_py_test(
864    name = "values_v2_test",
865    srcs = ["values_v2_test.py"],
866    tags = [
867        "no_mac",  # b/190644499
868    ],
869    deps = [
870        ":combinations",
871        ":strategy_combinations",
872        ":test_util",
873        ":values_v2",
874        "//tensorflow/python:framework_ops",
875        "//tensorflow/python:variables",
876        "//tensorflow/python/eager:test",
877        "@absl_py//absl/testing:parameterized",
878    ],
879)
880
881py_library(
882    name = "ps_values",
883    srcs = ["ps_values.py"],
884    srcs_version = "PY3",
885    deps = [
886        ":distribute_lib",
887        ":values",
888        ":values_util",
889        "//tensorflow/python:framework_ops",
890        "//tensorflow/python:variable_scope",
891        "//tensorflow/python:variables",
892        "//tensorflow/python/distribute/coordinator:coordinator_context",
893        "//tensorflow/python/keras/saving/saved_model:load_context",
894        "//tensorflow/python/trackable:base",
895        "//tensorflow/python/types",
896    ],
897)
898
899py_library(
900    name = "values_util",
901    srcs = ["values_util.py"],
902    srcs_version = "PY3",
903    deps = [
904        ":distribute_lib",
905        ":reduce_util",
906        "//tensorflow/python:control_flow_ops",
907        "//tensorflow/python:framework_ops",
908        "//tensorflow/python:math_ops",
909        "//tensorflow/python:tensor_util",
910        "//tensorflow/python:variable_scope",
911        "//tensorflow/python/saved_model:save_context",
912        "//tensorflow/python/saved_model:save_options",
913    ],
914)
915
916py_library(
917    name = "tpu_values",
918    srcs = ["tpu_values.py"],
919    srcs_version = "PY3",
920    deps = [
921        ":packed_distributed_variable",
922        ":tpu_replicated_variable",
923        ":tpu_util",
924        ":values",
925        ":values_util",
926        "//tensorflow/python:framework_ops",
927        "//tensorflow/python:math_ops",
928        "//tensorflow/python:resource_variable_ops_gen",
929        "//tensorflow/python:variable_scope",
930        "//tensorflow/python/eager:context",
931        "//tensorflow/python/eager:tape",
932    ],
933)
934
935py_library(
936    name = "tpu_replicated_variable",
937    srcs = ["tpu_replicated_variable.py"],
938    srcs_version = "PY3",
939    deps = [
940        ":tpu_util",
941        "//tensorflow/compiler/xla/experimental/xla_sharding",
942        "//tensorflow/python:control_flow_ops",
943        "//tensorflow/python:resource_variable_ops_gen",
944        "//tensorflow/python:tpu_partition_ops_gen",
945        "//tensorflow/python:variable_scope",
946        "//tensorflow/python:variables",
947        "//tensorflow/python/eager:context",
948        "//tensorflow/python/framework:for_generated_wrappers",
949        "//tensorflow/python/saved_model:save_context",
950        "//tensorflow/python/trackable:base",
951    ],
952)
953
954tpu_py_test(
955    name = "tpu_replicated_variable_test",
956    srcs = ["tpu_replicated_variable_test.py"],
957    python_version = "PY3",
958    srcs_version = "PY3",
959    deps = [
960        ":tpu_replicated_variable",
961        "//tensorflow/python:variables",
962        "//tensorflow/python/eager:test",
963        "//tensorflow/python/framework:combinations",
964        "//tensorflow/python/framework:dtypes",
965        "//third_party/py/numpy",
966        "@absl_py//absl/testing:parameterized",
967    ],
968)
969
970py_library(
971    name = "combinations",
972    srcs = ["combinations.py"],
973    srcs_version = "PY3",
974    visibility = [
975        "//tensorflow:internal",
976        "//tensorflow_models:__subpackages__",
977        "//third_party/py/keras:__subpackages__",
978    ],
979    deps = [
980        ":collective_all_reduce_strategy",
981        ":distribute_lib",
982        ":multi_process_runner",
983        ":multi_worker_test_base",
984        "//tensorflow/python:framework_combinations",
985        "//tensorflow/python:framework_ops",
986        "//tensorflow/python:framework_test_combinations_lib",
987        "//tensorflow/python:framework_test_lib",
988        "//tensorflow/python:platform",
989        "//tensorflow/python:session",
990        "//tensorflow/python:tf_decorator",
991        "//tensorflow/python/eager:context",
992        "//tensorflow/python/eager:def_function",
993        "//tensorflow/python/util:tf_export",
994        "@six_archive//:six",
995    ],
996)
997
998distribute_py_test(
999    name = "combinations_test",
1000    srcs = ["combinations_test.py"],
1001    python_version = "PY3",
1002    tags = ["multi_gpu"],
1003    deps = [
1004        ":combinations",
1005        ":test_util",
1006        "//tensorflow/python:client_testlib",
1007        "//tensorflow/python:framework_combinations",
1008        "//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
1009        "//tensorflow/python/eager:context",
1010        "@absl_py//absl/testing:parameterized",
1011    ],
1012)
1013
1014py_library(
1015    name = "strategy_combinations",
1016    srcs = ["strategy_combinations.py"],
1017    srcs_version = "PY3",
1018    visibility = [
1019        "//tensorflow:internal",
1020        "//tensorflow_models:__subpackages__",
1021        "//third_party/py/keras:__subpackages__",
1022    ],
1023    deps = [
1024        ":central_storage_strategy",
1025        ":collective_all_reduce_strategy",
1026        ":combinations",
1027        ":distribute_lib",
1028        ":mirrored_strategy",
1029        ":multi_process_runner",
1030        ":multi_worker_test_base",
1031        ":one_device_strategy",
1032        ":parameter_server_strategy_v2",
1033        ":sharded_variable",
1034        ":test_util",
1035        ":tpu_strategy",
1036        "//tensorflow/python:platform",
1037        "//tensorflow/python:tf2",
1038        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
1039        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
1040        "//tensorflow/python/eager:context",
1041        "//tensorflow/python/eager:remote",
1042        "//tensorflow/python/framework:test_lib",
1043        "//tensorflow/python/tpu:device_assignment",
1044        "//tensorflow/python/tpu:tpu_lib",
1045        "//tensorflow/python/training:server_lib",
1046        "//tensorflow/python/util:tf_export",
1047    ],
1048)
1049
1050distribute_py_test(
1051    name = "strategy_combinations_test",
1052    srcs = ["strategy_combinations_test.py"],
1053    disable_mlir_bridge = False,
1054    python_version = "PY3",
1055    tags = [
1056        "no_cuda_asan",
1057        "noasan",
1058    ],  # TODO(b/195246941) b/196591124
1059    deps = [
1060        ":central_storage_strategy",
1061        ":collective_all_reduce_strategy",
1062        ":combinations",
1063        ":mirrored_strategy",
1064        ":one_device_strategy",
1065        ":parameter_server_strategy_v2",
1066        ":reduce_util",
1067        ":strategy_combinations",
1068        ":test_util",
1069        ":tpu_strategy",
1070        "//tensorflow/python:array_ops",
1071        "//tensorflow/python:client_testlib",
1072        "//tensorflow/python:constant_op",
1073        "//tensorflow/python:tf2",
1074        "//tensorflow/python/eager:def_function",
1075        "@absl_py//absl/testing:parameterized",
1076    ],
1077)
1078
1079py_library(
1080    name = "multi_worker_test_base",
1081    srcs = ["multi_worker_test_base.py"],
1082    srcs_version = "PY3",
1083    deps = [
1084        ":distribute_coordinator",
1085        ":multi_process_runner",
1086        "//tensorflow/core:protos_all_py",
1087        "//tensorflow/python:client_testlib",
1088        "//tensorflow/python:distributed_framework_test_lib",
1089        "//tensorflow/python:errors",
1090        "//tensorflow/python:framework_ops",
1091        "//tensorflow/python:platform",
1092        "//tensorflow/python:session",
1093        "//tensorflow/python:training_lib",
1094        "//tensorflow/python:util",
1095        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
1096        "//tensorflow/python/eager:context",
1097        "//tensorflow/python/eager:remote",
1098        "//third_party/py/numpy",
1099    ],
1100)
1101
1102tf_py_test(
1103    name = "multi_worker_test_base_test",
1104    srcs = ["multi_worker_test_base_test.py"],
1105    srcs_version = "PY3",
1106    tags = [
1107        "no_oss",  # TODO(b/170834611)
1108    ],
1109    deps = [
1110        ":multi_worker_test_base",
1111    ],
1112)
1113
1114cuda_py_test(
1115    name = "checkpoint_utils_test",
1116    size = "medium",
1117    srcs = ["checkpoint_utils_test.py"],
1118    python_version = "PY3",
1119    tags = [
1120        "multi_and_single_gpu",
1121    ],
1122    deps = [
1123        ":combinations",
1124        ":strategy_combinations",
1125        "//tensorflow/python:client_testlib",
1126        "//tensorflow/python:framework_ops",
1127        "//tensorflow/python:training",
1128        "//tensorflow/python:variable_scope",
1129        "//tensorflow/python:variables",
1130    ],
1131)
1132
1133distribute_py_test(
1134    name = "checkpointing_test",
1135    srcs = ["checkpointing_test.py"],
1136    main = "checkpointing_test.py",
1137    tags = [
1138        "multi_and_single_gpu",
1139    ],
1140    deps = [
1141        ":combinations",
1142        ":strategy_combinations",
1143        "//tensorflow/python/checkpoint",
1144        "//tensorflow/python/eager:test",
1145    ],
1146)
1147
1148distribute_py_test(
1149    name = "input_lib_test",
1150    srcs = ["input_lib_test.py"],
1151    disable_mlir_bridge = False,
1152    main = "input_lib_test.py",
1153    shard_count = 10,
1154    tags = [
1155        "multi_and_single_gpu",
1156        "no_cuda_asan",  # TODO(b/214574707): times out
1157        "notsan",  # TODO(b/177098062): flaky for over a year
1158    ],
1159    deps = [
1160        ":collective_all_reduce_strategy",
1161        ":combinations",
1162        ":input_lib",
1163        ":input_util",
1164        ":mirrored_strategy",
1165        ":multi_worker_test_base",
1166        ":reduce_util",
1167        ":strategy_combinations",
1168        ":test_util",
1169        ":values",
1170        "//tensorflow/python:control_flow_ops",
1171        "//tensorflow/python:errors",
1172        "//tensorflow/python:math_ops",
1173        "//tensorflow/python:sparse_ops",
1174        "//tensorflow/python:sparse_tensor",
1175        "//tensorflow/python/data/ops:dataset_ops",
1176        "//tensorflow/python/distribute/v1:input_lib",
1177        "//tensorflow/python/eager:context",
1178        "//tensorflow/python/eager:test",
1179        "//tensorflow/python/ops/ragged:ragged_tensor",
1180        "//third_party/py/numpy",
1181        "@absl_py//absl/testing:parameterized",
1182    ],
1183)
1184
1185distribute_py_test(
1186    name = "input_lib_type_spec_test",
1187    srcs = ["input_lib_type_spec_test.py"],
1188    main = "input_lib_type_spec_test.py",
1189    shard_count = 10,
1190    tags = [
1191        "multi_and_single_gpu",
1192    ],
1193    deps = [
1194        ":collective_all_reduce_strategy",
1195        ":combinations",
1196        ":input_lib",
1197        ":mirrored_strategy",
1198        ":multi_worker_test_base",
1199        ":reduce_util",
1200        ":strategy_combinations",
1201        ":tpu_strategy",
1202        ":values",
1203        "//tensorflow/python:control_flow_ops",
1204        "//tensorflow/python:errors",
1205        "//tensorflow/python:math_ops",
1206        "//tensorflow/python:sparse_ops",
1207        "//tensorflow/python:sparse_tensor",
1208        "//tensorflow/python/data/ops:dataset_ops",
1209        "//tensorflow/python/eager:context",
1210        "//tensorflow/python/eager:test",
1211        "//tensorflow/python/ops/ragged:ragged_tensor",
1212        "//third_party/py/numpy",
1213        "@absl_py//absl/testing:parameterized",
1214    ],
1215)
1216
1217cuda_py_test(
1218    name = "cross_device_utils_test",
1219    srcs = ["cross_device_utils_test.py"],
1220    python_version = "PY3",
1221    deps = [
1222        ":combinations",
1223        ":cross_device_utils",
1224        ":strategy_combinations",
1225        "//tensorflow/python:array_ops",
1226        "//tensorflow/python:constant_op",
1227        "//tensorflow/python:dtypes",
1228        "//tensorflow/python:framework_ops",
1229        "//tensorflow/python:math_ops",
1230        "//tensorflow/python/eager:context",
1231        "//tensorflow/python/eager:test",
1232        "@absl_py//absl/testing:parameterized",
1233    ],
1234)
1235
1236cuda_py_test(
1237    name = "cross_device_ops_test",
1238    srcs = ["cross_device_ops_test.py"],
1239    python_version = "PY3",
1240    shard_count = 4,
1241    tags = [
1242        "multi_and_single_gpu",
1243        "no_cuda_asan",  # times out
1244    ],
1245    deps = [
1246        ":collective_util",
1247        ":combinations",
1248        ":cross_device_ops",
1249        ":cross_device_utils",
1250        ":multi_process_runner",
1251        ":multi_worker_test_base",
1252        ":reduce_util",
1253        ":test_util",
1254        ":values",
1255        "//tensorflow/python:array_ops",
1256        "//tensorflow/python:collective_ops",
1257        "//tensorflow/python:constant_op",
1258        "//tensorflow/python:dtypes",
1259        "//tensorflow/python:errors",
1260        "//tensorflow/python:framework_ops",
1261        "//tensorflow/python:indexed_slices",
1262        "//tensorflow/python:util",
1263        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
1264        "//tensorflow/python/eager:context",
1265        "//tensorflow/python/eager:def_function",
1266        "//tensorflow/python/eager:test",
1267        "@absl_py//absl/testing:parameterized",
1268    ],
1269)
1270
1271cuda_py_test(
1272    name = "one_device_strategy_test",
1273    srcs = ["one_device_strategy_test.py"],
1274    grpc_enabled = True,
1275    python_version = "PY3",
1276    deps = [
1277        ":combinations",
1278        ":strategy_combinations",
1279        ":strategy_test_lib",
1280        "//tensorflow/python/eager:test",
1281    ],
1282)
1283
1284py_library(
1285    name = "sharded_variable",
1286    srcs = ["sharded_variable.py"],
1287    srcs_version = "PY3",
1288    deps = [
1289        "//tensorflow/python:array_ops",
1290        "//tensorflow/python:composite_tensor",
1291        "//tensorflow/python:constant_op",
1292        "//tensorflow/python:data_flow_ops",
1293        "//tensorflow/python:dtypes",
1294        "//tensorflow/python:embedding_ops",
1295        "//tensorflow/python:framework_ops",
1296        "//tensorflow/python:math_ops",
1297        "//tensorflow/python:partitioned_variables",
1298        "//tensorflow/python:resource_variable_ops",
1299        "//tensorflow/python:tensor_shape",
1300        "//tensorflow/python:type_spec",
1301        "//tensorflow/python:util",
1302        "//tensorflow/python:variables",
1303        "//tensorflow/python/saved_model:revived_types",
1304        "//tensorflow/python/saved_model:save_context",
1305        "//tensorflow/python/trackable:base",
1306        "//tensorflow/python/training/saving:saveable_object_util",
1307        "//tensorflow/python/util:tf_export",
1308        "//third_party/py/numpy",
1309    ],
1310)
1311
1312tf_py_test(
1313    name = "sharded_variable_test",
1314    size = "small",
1315    srcs = ["sharded_variable_test.py"],
1316    deps = [
1317        ":combinations",
1318        ":multi_worker_test_base",
1319        ":parameter_server_strategy_v2",
1320        ":sharded_variable",
1321        ":test_util",
1322        "//tensorflow/python:array_ops",
1323        "//tensorflow/python:client_testlib",
1324        "//tensorflow/python:constant_op",
1325        "//tensorflow/python:dtypes",
1326        "//tensorflow/python:embedding_ops",
1327        "//tensorflow/python:extra_py_tests_deps",
1328        "//tensorflow/python:framework_ops",
1329        "//tensorflow/python:session",
1330        "//tensorflow/python:sparse_tensor",
1331        "//tensorflow/python:tensor_shape",
1332        "//tensorflow/python:tensor_spec",
1333        "//tensorflow/python:util",
1334        "//tensorflow/python:variables",
1335        "//tensorflow/python/checkpoint",
1336        "//tensorflow/python/compat:v2_compat",
1337        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
1338        "//tensorflow/python/distribute/coordinator:cluster_coordinator",
1339        "//tensorflow/python/eager:def_function",
1340        "//tensorflow/python/module",
1341        "//tensorflow/python/saved_model:load",
1342        "//tensorflow/python/saved_model:loader",
1343        "//tensorflow/python/saved_model:save",
1344        "//tensorflow/python/saved_model:signature_constants",
1345        "//tensorflow/python/saved_model:tag_constants",
1346        "//tensorflow/python/trackable:autotrackable",
1347        "//tensorflow/python/training:server_lib",
1348    ],
1349)
1350
1351py_library(
1352    name = "strategy_test_lib",
1353    srcs = ["strategy_test_lib.py"],
1354    srcs_version = "PY3",
1355    deps = [
1356        ":collective_all_reduce_strategy",
1357        ":distribute_lib",
1358        ":distribute_utils",
1359        ":mirrored_strategy",
1360        ":reduce_util",
1361        ":tpu_strategy",
1362        "//tensorflow/core:protos_all_py",
1363        "//tensorflow/python:array_ops",
1364        "//tensorflow/python:dtypes",
1365        "//tensorflow/python:errors",
1366        "//tensorflow/python:framework_ops",
1367        "//tensorflow/python:gradients_impl",
1368        "//tensorflow/python:init_ops",
1369        "//tensorflow/python:init_ops_v2",
1370        "//tensorflow/python:math_ops",
1371        "//tensorflow/python:summary_ops_v2",
1372        "//tensorflow/python:training",
1373        "//tensorflow/python:util",
1374        "//tensorflow/python:variable_scope",
1375        "//tensorflow/python:variables",
1376        "//tensorflow/python/client:session",
1377        "//tensorflow/python/data/ops:dataset_ops",
1378        "//tensorflow/python/eager:backprop",
1379        "//tensorflow/python/eager:context",
1380        "//tensorflow/python/eager:def_function",
1381        "//tensorflow/python/eager:test",
1382        "//tensorflow/python/framework:test_lib",
1383        "//tensorflow/python/lib/io:lib",
1384        "//tensorflow/python/platform",
1385        "//tensorflow/python/training:training_util",
1386        "//third_party/py/numpy",
1387    ],
1388)
1389
1390distribute_py_test(
1391    name = "values_test",
1392    size = "medium",
1393    srcs = ["values_test.py"],
1394    main = "values_test.py",
1395    tags = [
1396        "multi_and_single_gpu",
1397        "no_cuda_asan",  # times out
1398        "no_oss",  # b/178656226
1399        "noasan",  # b/175816710
1400        "notsan",  # b/168645872
1401    ],
1402    tpu_tags = [
1403        "noasan",  # TODO(b/337374867) fails with -fsanitize=null
1404    ],
1405    deps = [
1406        ":combinations",
1407        ":strategy_combinations",
1408        ":test_util",
1409        ":tpu_strategy",
1410        ":tpu_values",
1411        ":values",
1412        "//tensorflow/core:protos_all_py",
1413        "//tensorflow/python:array_ops",
1414        "//tensorflow/python:constant_op",
1415        "//tensorflow/python:framework_ops",
1416        "//tensorflow/python:framework_test_lib",
1417        "//tensorflow/python:math_ops",
1418        "//tensorflow/python:saver",
1419        "//tensorflow/python:sparse_ops",
1420        "//tensorflow/python:sparse_tensor",
1421        "//tensorflow/python:tf2",
1422        "//tensorflow/python:variable_scope",
1423        "//tensorflow/python:variables",
1424        "//tensorflow/python/data/ops:dataset_ops",
1425        "//tensorflow/python/eager:context",
1426        "//tensorflow/python/eager:def_function",
1427        "//tensorflow/python/eager:test",
1428        "@absl_py//absl/testing:parameterized",
1429    ],
1430)
1431
1432distribute_py_test(
1433    name = "distributed_variable_test",
1434    size = "medium",
1435    srcs = ["distributed_variable_test.py"],
1436    disable_v2 = True,  # TODO(b/209058825)
1437    disable_v3 = True,  # TODO(b/209058825)
1438    main = "distributed_variable_test.py",
1439    shard_count = 3,
1440    tags = [
1441        "multi_and_single_gpu",
1442        "no_cuda_asan",  # times out
1443        "no_oss",  # b/178656226
1444        "noasan",  # b/175816710
1445        "notap",  # Flaky
1446        "notsan",  # b/168645872
1447    ],
1448    tpu_tags = [
1449        "noasan",  # TODO(b/337374867) fails with -fsanitize=null
1450    ],
1451    deps = [
1452        ":collective_all_reduce_strategy",
1453        ":combinations",
1454        ":distribute_lib",
1455        ":distribute_utils",
1456        ":packed_distributed_variable",
1457        ":parameter_server_strategy",
1458        ":ps_values",
1459        ":strategy_combinations",
1460        ":test_util",
1461        ":tpu_strategy",
1462        ":values",
1463        "//tensorflow/python:array_ops",
1464        "//tensorflow/python:check_ops",
1465        "//tensorflow/python:constant_op",
1466        "//tensorflow/python:control_flow_ops",
1467        "//tensorflow/python:dtypes",
1468        "//tensorflow/python:framework_ops",
1469        "//tensorflow/python:indexed_slices",
1470        "//tensorflow/python:math_ops",
1471        "//tensorflow/python:tensor_shape",
1472        "//tensorflow/python:training",
1473        "//tensorflow/python:variable_scope",
1474        "//tensorflow/python:variables",
1475        "//tensorflow/python/eager:context",
1476        "//tensorflow/python/eager:def_function",
1477        "//tensorflow/python/eager:test",
1478        "//tensorflow/python/saved_model:save",
1479        "//tensorflow/python/saved_model:save_context",
1480        "//tensorflow/python/saved_model:save_options",
1481        "//tensorflow/python/types",
1482        "@absl_py//absl/testing:parameterized",
1483    ],
1484)
1485
1486distribute_py_test(
1487    name = "mirrored_values_test",
1488    size = "medium",
1489    srcs = ["mirrored_values_test.py"],
1490    main = "mirrored_values_test.py",
1491    tags = [
1492        "multi_and_single_gpu",
1493        "no_cuda_asan",  # times out
1494        "noasan",  # b/175816710
1495        "notsan",  # b/168645872
1496    ],
1497    tpu_tags = [
1498        "noasan",  # TODO(b/337374867) fails with -fsanitize=null
1499    ],
1500    deps = [
1501        ":combinations",
1502        ":distribute_lib",
1503        ":distribute_utils",
1504        ":packed_distributed_variable",
1505        ":strategy_combinations",
1506        ":strategy_test_lib",
1507        ":test_util",
1508        ":tpu_strategy",
1509        ":tpu_values",
1510        ":values",
1511        "//tensorflow/python:math_ops",
1512        "//tensorflow/python:saver",
1513        "//tensorflow/python:training",
1514        "//tensorflow/python:variable_scope",
1515        "//tensorflow/python/eager:context",
1516        "//tensorflow/python/eager:def_function",
1517        "//tensorflow/python/eager:test",
1518        "//tensorflow/python/framework:constant_op",
1519        "//tensorflow/python/framework:ops",
1520        "//tensorflow/python/framework:test_lib",
1521        "//tensorflow/python/types",
1522        "@absl_py//absl/testing:parameterized",
1523    ],
1524)
1525
1526distribute_py_test(
1527    name = "per_replica_test",
1528    size = "medium",
1529    srcs = ["per_replica_test.py"],
1530    main = "per_replica_test.py",
1531    tags = [
1532        "multi_and_single_gpu",
1533        "no_cuda_asan",  # times out
1534        "noasan",  # b/175816710
1535        "notsan",  # b/168645872
1536    ],
1537    tpu_tags = [
1538        "noasan",  # TODO(b/337374867) fails with -fsanitize=null
1539    ],
1540    deps = [
1541        ":collective_all_reduce_strategy",
1542        ":combinations",
1543        ":distribute_lib",
1544        ":distribute_utils",
1545        ":packed_distributed_variable",
1546        ":parameter_server_strategy",
1547        ":ps_values",
1548        ":strategy_combinations",
1549        ":test_util",
1550        ":tpu_strategy",
1551        ":tpu_values",
1552        ":values",
1553        "//tensorflow/python:array_ops",
1554        "//tensorflow/python:constant_op",
1555        "//tensorflow/python:control_flow_ops",
1556        "//tensorflow/python:dtypes",
1557        "//tensorflow/python:framework_ops",
1558        "//tensorflow/python:framework_test_lib",
1559        "//tensorflow/python:tensor_spec",
1560        "//tensorflow/python:tf2",
1561        "//tensorflow/python:training",
1562        "//tensorflow/python:util",
1563        "//tensorflow/python/eager:context",
1564        "//tensorflow/python/eager:def_function",
1565        "//tensorflow/python/eager:test",
1566        "//tensorflow/python/types",
1567        "@absl_py//absl/testing:parameterized",
1568    ],
1569)
1570
1571distribute_py_test(
1572    name = "distribute_utils_test",
1573    srcs = ["distribute_utils_test.py"],
1574    disable_mlir_bridge = False,
1575    main = "distribute_utils_test.py",
1576    tags = [
1577        "multi_and_single_gpu",
1578    ],
1579    deps = [
1580        ":combinations",
1581        ":distribute_utils",
1582        ":strategy_combinations",
1583        ":values",
1584        "//tensorflow/python:array_ops",
1585        "//tensorflow/python:constant_op",
1586        "//tensorflow/python:framework_ops",
1587        "//tensorflow/python:variable_scope",
1588        "//tensorflow/python/eager:context",
1589        "//tensorflow/python/eager:test",
1590        "//tensorflow/python/saved_model/model_utils:mode_keys",
1591        "@absl_py//absl/testing:parameterized",
1592        "@wrapt",
1593    ],
1594)
1595
1596distribute_py_test(
1597    name = "vars_test",
1598    size = "medium",
1599    srcs = ["vars_test.py"],
1600    main = "vars_test.py",
1601    shard_count = 5,
1602    tags = [
1603        "multi_and_single_gpu",
1604        "no_cuda_asan",  # times out
1605    ],
1606    deps = [
1607        ":combinations",
1608        ":distribute_lib",
1609        ":strategy_combinations",
1610        ":strategy_test_lib",
1611        ":test_util",
1612        ":values",
1613        "//tensorflow/python:array_ops",
1614        "//tensorflow/python:constant_op",
1615        "//tensorflow/python:dtypes",
1616        "//tensorflow/python:framework_ops",
1617        "//tensorflow/python:indexed_slices",
1618        "//tensorflow/python:math_ops",
1619        "//tensorflow/python:random_ops",
1620        "//tensorflow/python:training",
1621        "//tensorflow/python:variable_scope",
1622        "//tensorflow/python:variables",
1623        "//tensorflow/python/checkpoint:checkpoint_management",
1624        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
1625        "//tensorflow/python/eager:context",
1626        "//tensorflow/python/eager:def_function",
1627        "//tensorflow/python/eager:test",
1628        "//tensorflow/python/tpu:tpu_lib",
1629        "@absl_py//absl/testing:parameterized",
1630    ],
1631)
1632
1633distribute_py_test(
1634    name = "ps_values_test",
1635    size = "medium",
1636    srcs = ["ps_values_test.py"],
1637    disable_mlir_bridge = False,
1638    main = "ps_values_test.py",
1639    tags = [
1640        "multi_and_single_gpu",
1641    ],
1642    deps = [
1643        ":combinations",
1644        ":ps_values",
1645        ":strategy_combinations",
1646        "//tensorflow/python:variable_scope",
1647        "//tensorflow/python:variables",
1648        "//tensorflow/python/eager:def_function",
1649        "//tensorflow/python/eager:test",
1650        "@absl_py//absl/testing:parameterized",
1651    ],
1652)
1653
1654distribute_py_test(
1655    name = "moving_averages_test",
1656    srcs = ["moving_averages_test.py"],
1657    main = "moving_averages_test.py",
1658    shard_count = 5,
1659    tags = [
1660        "multi_gpu",
1661        "no_windows",  # TODO(b/184424727): Re-enable this.
1662        "notpu",  # TODO(b/210145904)
1663    ],
1664    deps = [
1665        ":collective_all_reduce_strategy",
1666        ":combinations",
1667        ":strategy_combinations",
1668        ":strategy_test_lib",
1669        ":test_util",
1670        "//tensorflow/python:constant_op",
1671        "//tensorflow/python:training",
1672        "//tensorflow/python:variables",
1673        "//tensorflow/python/eager:def_function",
1674        "//tensorflow/python/eager:test",
1675        "@absl_py//absl/testing:parameterized",
1676    ],
1677)
1678
1679distribute_py_test(
1680    name = "custom_training_loop_gradient_test",
1681    srcs = ["custom_training_loop_gradient_test.py"],
1682    disable_mlir_bridge = False,
1683    main = "custom_training_loop_gradient_test.py",
1684    tags = [
1685        "multi_and_single_gpu",
1686    ],
1687    deps = [
1688        ":combinations",
1689        ":strategy_combinations",
1690        "//tensorflow/python:errors",
1691        "//tensorflow/python:variables",
1692        "//tensorflow/python/eager:test",
1693        "@absl_py//absl/testing:parameterized",
1694    ],
1695)
1696
1697distribute_py_test(
1698    name = "custom_training_loop_input_test",
1699    srcs = ["custom_training_loop_input_test.py"],
1700    main = "custom_training_loop_input_test.py",
1701    shard_count = 5,
1702    tags = [
1703        "multi_and_single_gpu",
1704        "no_oss",  # TODO(b/227211015)
1705    ],
1706    tpu_tags = [
1707        "no_oss",  # Target too big to run serially reliably.
1708    ],
1709    deps = [
1710        ":combinations",
1711        ":strategy_combinations",
1712        "//tensorflow/python:errors",
1713        "//tensorflow/python:variables",
1714        "//tensorflow/python/data/ops:dataset_ops",
1715        "//tensorflow/python/eager:test",
1716        "@absl_py//absl/testing:parameterized",
1717    ],
1718)
1719
1720py_library(
1721    name = "single_loss_example",
1722    srcs = ["single_loss_example.py"],
1723    srcs_version = "PY3",
1724    deps = [
1725        ":step_fn",
1726        ":strategy_test_lib",
1727        "//tensorflow/python:array_ops",
1728        "//tensorflow/python:constant_op",
1729        "//tensorflow/python:layers",
1730        "//tensorflow/python:math_ops",
1731        "//tensorflow/python/data/ops:dataset_ops",
1732    ],
1733)
1734
1735py_library(
1736    name = "step_fn",
1737    srcs = ["step_fn.py"],
1738    srcs_version = "PY3",
1739    visibility = ["//tensorflow:internal"],
1740    deps = [
1741        "//tensorflow/python:training",
1742        "//tensorflow/python/eager:backprop",
1743    ],
1744)
1745
1746cuda_py_test(
1747    name = "warm_starting_util_test",
1748    size = "medium",
1749    srcs = ["warm_starting_util_test.py"],
1750    python_version = "PY3",
1751    tags = [
1752        "multi_and_single_gpu",
1753    ],
1754    deps = [
1755        ":combinations",
1756        ":strategy_combinations",
1757        "//tensorflow/python:client_testlib",
1758        "//tensorflow/python:framework_ops",
1759        "//tensorflow/python:training",
1760        "//tensorflow/python:variable_scope",
1761        "//tensorflow/python:variables",
1762    ],
1763)
1764
1765cuda_py_test(
1766    name = "remote_mirrored_strategy_eager_test",
1767    srcs = ["remote_mirrored_strategy_eager_test.py"],
1768    python_version = "PY3",
1769    tags = [
1770        "no_windows",  # TODO(b/197981388): Re-enable this.
1771    ],
1772    deps = [
1773        ":combinations",
1774        ":distribute_lib",
1775        ":mirrored_strategy",
1776        ":multi_worker_test_base",
1777        ":strategy_test_lib",
1778        ":values",
1779        "//tensorflow/python:array_ops",
1780        "//tensorflow/python:constant_op",
1781        "//tensorflow/python:framework_test_lib",
1782        "//tensorflow/python:state_ops",
1783        "//tensorflow/python:tensor_shape",
1784        "//tensorflow/python:tensor_util",
1785        "//tensorflow/python:variable_scope",
1786        "//tensorflow/python/eager:context",
1787        "//tensorflow/python/eager:test",
1788    ],
1789)
1790
1791cuda_py_test(
1792    name = "mirrored_strategy_test",
1793    srcs = ["mirrored_strategy_test.py"],
1794    python_version = "PY3",
1795    shard_count = 5,
1796    tags = [
1797        "multi_and_single_gpu",
1798        "no_windows_gpu",  # TODO(b/130551176)
1799    ],
1800    deps = [
1801        ":combinations",
1802        ":distribute_lib",
1803        ":mirrored_strategy",
1804        ":multi_worker_test_base",
1805        ":strategy_combinations",
1806        ":strategy_test_lib",
1807        ":values",
1808        "//tensorflow/core:protos_all_py",
1809        "//tensorflow/python:array_ops",
1810        "//tensorflow/python:constant_op",
1811        "//tensorflow/python:framework_test_lib",
1812        "//tensorflow/python:state_ops",
1813        "//tensorflow/python:tensor_shape",
1814        "//tensorflow/python:tensor_util",
1815        "//tensorflow/python:variable_scope",
1816        "//tensorflow/python/autograph/core:test_lib",
1817        "//tensorflow/python/eager:context",
1818        "//tensorflow/python/eager:test",
1819    ],
1820)
1821
1822cuda_py_test(
1823    name = "mirrored_variable_test",
1824    srcs = ["mirrored_variable_test.py"],
1825    python_version = "PY3",
1826    tags = [
1827        "guitar",
1828        "multi_and_single_gpu",
1829        "no_windows",  # TODO(b/184424727): Re-enable this.
1830    ],
1831    deps = [
1832        ":collective_all_reduce_strategy",
1833        ":combinations",
1834        ":distribute_lib",
1835        ":distribute_utils",
1836        ":strategy_combinations",
1837        ":values",
1838        "//tensorflow/python:array_ops",
1839        "//tensorflow/python:config",
1840        "//tensorflow/python:constant_op",
1841        "//tensorflow/python:dtypes",
1842        "//tensorflow/python:framework_ops",
1843        "//tensorflow/python:func_graph",
1844        "//tensorflow/python:math_ops",
1845        "//tensorflow/python:rnn_cell",
1846        "//tensorflow/python:state_ops",
1847        "//tensorflow/python:variable_scope",
1848        "//tensorflow/python:variables",
1849        "//tensorflow/python/checkpoint",
1850        "//tensorflow/python/eager:context",
1851        "//tensorflow/python/eager:def_function",
1852        "//tensorflow/python/eager:test",
1853        "//tensorflow/python/saved_model:load",
1854        "//tensorflow/python/saved_model:save",
1855    ],
1856)
1857
1858distribute_py_test(
1859    name = "metrics_v1_test",
1860    srcs = ["metrics_v1_test.py"],
1861    main = "metrics_v1_test.py",
1862    tags = [
1863        "multi_and_single_gpu",
1864    ],
1865    deps = [
1866        ":combinations",
1867        ":strategy_combinations",
1868        ":strategy_test_lib",
1869        "//tensorflow/python:math_ops",
1870        "//tensorflow/python:metrics",
1871        "//tensorflow/python:variables",
1872        "//tensorflow/python/data/ops:dataset_ops",
1873        "//tensorflow/python/eager:test",
1874        "//tensorflow/python/framework:ops",
1875        "@absl_py//absl/testing:parameterized",
1876    ],
1877)
1878
1879distribute_py_test(
1880    name = "zero_batch_test",
1881    srcs = ["zero_batch_test.py"],
1882    disable_mlir_bridge = False,
1883    main = "zero_batch_test.py",
1884    tags = [
1885        "no_oss",  # Keras is not available in OSS test
1886    ],
1887    deps = [
1888        ":combinations",
1889        ":multi_worker_test_base",
1890        ":strategy_combinations",
1891    ],
1892)
1893
1894cuda_py_test(
1895    name = "collective_all_reduce_strategy_test",
1896    srcs = ["collective_all_reduce_strategy_test.py"],
1897    python_version = "PY3",
1898    tags = [
1899        "multi_and_single_gpu",
1900        "nomsan",  # b/154224457: Re-enable when fixed.
1901        "notsan",  # TODO(b/220133218)
1902    ],
1903    # b/155301154 broken with XLA:GPU
1904    xla_enable_strict_auto_jit = True,
1905    deps = [
1906        ":collective_all_reduce_strategy",
1907        ":combinations",
1908        ":cross_device_utils",
1909        ":distribute_lib",
1910        ":distribute_utils",
1911        ":input_lib",
1912        ":multi_worker_test_base",
1913        ":multi_worker_util",
1914        ":reduce_util",
1915        ":strategy_combinations",
1916        ":strategy_test_lib",
1917        ":test_util",
1918        "//tensorflow/python:array_ops",
1919        "//tensorflow/python:client_testlib",
1920        "//tensorflow/python:constant_op",
1921        "//tensorflow/python:device",
1922        "//tensorflow/python:dtypes",
1923        "//tensorflow/python:errors",
1924        "//tensorflow/python:framework_ops",
1925        "//tensorflow/python:gradients",
1926        "//tensorflow/python:init_ops",
1927        "//tensorflow/python:init_ops_v2",
1928        "//tensorflow/python:training_server_lib",
1929        "//tensorflow/python:variable_scope",
1930        "//tensorflow/python:variables",
1931        "//tensorflow/python/data/ops:dataset_ops",
1932        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
1933        "//tensorflow/python/eager:context",
1934        "//tensorflow/python/tpu:tpu_lib",
1935        "//third_party/py/numpy",
1936        "@absl_py//absl/testing:parameterized",
1937    ],
1938)
1939
1940tpu_py_test(
1941    name = "collective_all_reduce_strategy_test_tpu",
1942    srcs = ["collective_all_reduce_strategy_test.py"],
1943    # FIXME(b/227404010): On TFRT TPU, eager CollectiveReduceV2 is broken.
1944    disable_tfrt = True,
1945    main = "collective_all_reduce_strategy_test.py",
1946    python_version = "PY3",
1947    deps = [
1948        ":collective_all_reduce_strategy",
1949        ":combinations",
1950        ":cross_device_utils",
1951        ":distribute_lib",
1952        ":distribute_utils",
1953        ":input_lib",
1954        ":multi_worker_test_base",
1955        ":multi_worker_util",
1956        ":reduce_util",
1957        ":strategy_combinations",
1958        ":strategy_test_lib",
1959        ":test_util",
1960        "//tensorflow/python:array_ops",
1961        "//tensorflow/python:client_testlib",
1962        "//tensorflow/python:constant_op",
1963        "//tensorflow/python:device",
1964        "//tensorflow/python:dtypes",
1965        "//tensorflow/python:errors",
1966        "//tensorflow/python:framework_ops",
1967        "//tensorflow/python:gradients",
1968        "//tensorflow/python:init_ops",
1969        "//tensorflow/python:init_ops_v2",
1970        "//tensorflow/python:training_server_lib",
1971        "//tensorflow/python:variable_scope",
1972        "//tensorflow/python:variables",
1973        "//tensorflow/python/data/ops:dataset_ops",
1974        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
1975        "//tensorflow/python/eager:context",
1976        "//tensorflow/python/tpu:tpu_lib",
1977        "//third_party/py/numpy",
1978        "@absl_py//absl/testing:parameterized",
1979    ],
1980)
1981
1982cuda_py_test(
1983    name = "parameter_server_strategy_test",
1984    srcs = ["parameter_server_strategy_test.py"],
1985    python_version = "PY3",
1986    tags = [
1987        "multi_and_single_gpu",
1988        "notsan",  # TODO(b/220133218)
1989    ],
1990    # b/141096229: Non-atomic AssignAdd
1991    xla_enable_strict_auto_jit = False,
1992    deps = [
1993        ":central_storage_strategy",
1994        ":combinations",
1995        ":device_util",
1996        ":distribute_lib",
1997        ":multi_worker_test_base",
1998        ":multi_worker_util",
1999        ":parameter_server_strategy",
2000        ":ps_values",
2001        ":reduce_util",
2002        ":strategy_test_lib",
2003        ":values",
2004        "//tensorflow/core:protos_all_py",
2005        "//tensorflow/python:array_ops",
2006        "//tensorflow/python:client_testlib",
2007        "//tensorflow/python:constant_op",
2008        "//tensorflow/python:control_flow_ops",
2009        "//tensorflow/python:errors",
2010        "//tensorflow/python:framework_ops",
2011        "//tensorflow/python:gradients",
2012        "//tensorflow/python:math_ops",
2013        "//tensorflow/python:partitioned_variables",
2014        "//tensorflow/python:resource_variable_ops",
2015        "//tensorflow/python:tensor_util",
2016        "//tensorflow/python:training_util",
2017        "//tensorflow/python:variable_scope",
2018        "//tensorflow/python:variables",
2019        "//tensorflow/python/data/ops:dataset_ops",
2020        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
2021        "//tensorflow/python/eager:backprop",
2022        "//tensorflow/python/eager:context",
2023        "//tensorflow/python/estimator:run_config",
2024        "@absl_py//absl/testing:parameterized",
2025    ],
2026)
2027
2028py_library(
2029    name = "multi_process_runner",
2030    srcs = ["multi_process_runner.py"],
2031    srcs_version = "PY3",
2032    tags = [
2033        "no_oss",  # b/241013307 disable flaky under docker
2034    ],
2035    deps = [
2036        ":multi_process_lib",
2037        "//tensorflow/python:client_testlib",
2038        "//tensorflow/python:tf2",
2039        "//tensorflow/python/compat:v2_compat",
2040        "//tensorflow/python/eager:context",
2041        "@absl_py//absl/logging",
2042        "@dill_archive//:dill",
2043        "@six_archive//:six",
2044        "@tblib_archive//:tblib",
2045    ],
2046)
2047
2048py_library(
2049    name = "multi_process_lib",
2050    srcs = ["multi_process_lib.py"],
2051    srcs_version = "PY3",
2052    deps = [
2053        "//tensorflow/python/eager:test",
2054        "@absl_py//absl:app",
2055        "@absl_py//absl/logging",
2056    ],
2057)
2058
2059py_test(
2060    name = "packed_distributed_variable_test",
2061    srcs = ["packed_distributed_variable_test.py"],
2062    tags = [
2063        "nomac",  #TODO(b/145922293): It would cause a Python segfault on macos
2064    ],
2065    deps = [
2066        ":device_util",
2067        ":packed_distributed_variable",
2068        "//tensorflow/python:client_testlib",
2069        "//tensorflow/python:framework_ops",
2070        "//tensorflow/python:math_ops",
2071        "//tensorflow/python:resource_variable_ops",
2072        "//tensorflow/python/eager:context",
2073        "//tensorflow/python/eager:def_function",
2074    ],
2075)
2076
2077cuda_py_test(
2078    name = "multi_process_runner_test",
2079    srcs = ["multi_process_runner_test.py"],
2080    python_version = "PY3",
2081    shard_count = 12,
2082    tags = [
2083        "multi_gpu",
2084        "noasan",
2085        "nomsan",
2086    ],  # b/175904958
2087    deps = [
2088        ":combinations",
2089        ":multi_process_runner",
2090        ":multi_worker_test_base",
2091        "//tensorflow/python/eager:context",
2092        "//tensorflow/python/eager:test",
2093        "@absl_py//absl/logging",
2094        "@absl_py//absl/testing:parameterized",
2095    ],
2096)
2097
2098py_test(
2099    name = "multi_process_runner_no_init_test",
2100    srcs = ["multi_process_runner_no_init_test.py"],
2101    python_version = "PY3",
2102    deps = [
2103        ":multi_process_runner",
2104        ":multi_worker_test_base",
2105        "//tensorflow/python/eager:test",
2106    ],
2107)
2108
2109distribute_py_test(
2110    name = "strategy_common_test",
2111    srcs = ["strategy_common_test.py"],
2112    disable_mlir_bridge = False,
2113    python_version = "PY3",
2114    shard_count = 2,
2115    tags = [
2116        "multi_and_single_gpu",
2117        "notsan",  # TODO(b/160006974)
2118    ],
2119    xla_enable_strict_auto_jit = True,
2120    deps = [
2121        ":collective_all_reduce_strategy",
2122        ":combinations",
2123        ":distribute_lib",
2124        ":multi_worker_test_base",
2125        ":reduce_util",
2126        ":strategy_combinations",
2127        ":strategy_test_lib",
2128        ":test_util",
2129        ":tpu_strategy",
2130        "//tensorflow/python:array_ops",
2131        "//tensorflow/python:client_testlib",
2132        "//tensorflow/python:constant_op",
2133        "//tensorflow/python:dtypes",
2134        "//tensorflow/python:math_ops",
2135        "//tensorflow/python:variables",
2136        "//tensorflow/python/data/ops:dataset_ops",
2137        "//tensorflow/python/eager:def_function",
2138        "//tensorflow/python/framework:ops",
2139        "//tensorflow/python/util",
2140        "@absl_py//absl/testing:parameterized",
2141    ],
2142)
2143
2144distribute_py_test(
2145    name = "strategy_gather_test",
2146    srcs = ["strategy_gather_test.py"],
2147    disable_mlir_bridge = False,
2148    python_version = "PY3",
2149    shard_count = 4,
2150    tags = [
2151        "multi_and_single_gpu",
2152        "no_cuda_asan",  # times out
2153        "notsan",  # TODO(b/160006974)
2154    ],
2155    xla_enable_strict_auto_jit = True,
2156    deps = [
2157        ":collective_all_reduce_strategy",
2158        ":combinations",
2159        ":multi_worker_test_base",
2160        ":reduce_util",
2161        ":strategy_combinations",
2162        ":strategy_test_lib",
2163        ":test_util",
2164        "//tensorflow/python:array_ops",
2165        "//tensorflow/python:client_testlib",
2166        "//tensorflow/python:constant_op",
2167        "//tensorflow/python:dtypes",
2168        "//tensorflow/python:math_ops",
2169        "//tensorflow/python/compat:v2_compat",
2170        "//tensorflow/python/data/ops:dataset_ops",
2171        "//tensorflow/python/eager:def_function",
2172        "@absl_py//absl/testing:parameterized",
2173    ],
2174)
2175
2176distribute_py_test(
2177    name = "tf_function_test",
2178    srcs = ["tf_function_test.py"],
2179    disable_mlir_bridge = False,
2180    main = "tf_function_test.py",
2181    tags = [
2182        "multi_and_single_gpu",
2183    ],
2184    deps = [
2185        ":combinations",
2186        ":device_util",
2187        ":strategy_combinations",
2188        ":values",
2189        "//tensorflow/python:array_ops",
2190        "//tensorflow/python:dtypes",
2191        "//tensorflow/python:framework_ops",
2192        "//tensorflow/python:math_ops",
2193        "//tensorflow/python:variables",
2194        "//tensorflow/python/compat:v2_compat",
2195        "//tensorflow/python/eager:def_function",
2196        "//tensorflow/python/eager:test",
2197        "//tensorflow/python/saved_model:save_context",
2198        "//tensorflow/python/saved_model:save_options",
2199        "@absl_py//absl/testing:parameterized",
2200    ],
2201)
2202
2203py_library(
2204    name = "test_util",
2205    srcs = ["test_util.py"],
2206    srcs_version = "PY3",
2207    deps = [
2208        ":collective_all_reduce_strategy",
2209        ":multi_process_runner",
2210        ":tpu_strategy",
2211        ":values",
2212        "//tensorflow/python:array_ops",
2213        "//tensorflow/python:config",
2214        "//tensorflow/python:framework_ops",
2215        "//tensorflow/python:util",
2216        "//tensorflow/python/compat:v2_compat",
2217        "//tensorflow/python/eager:context",
2218        "@absl_py//absl:app",
2219    ],
2220)
2221
2222distribute_py_test(
2223    name = "test_util_test",
2224    srcs = ["test_util_test.py"],
2225    disable_mlir_bridge = False,
2226    tags = [
2227        "multi_and_single_gpu",
2228    ],
2229    deps = [
2230        ":combinations",
2231        ":strategy_combinations",
2232        ":test_util",
2233        "//tensorflow/python:array_ops",
2234        "//tensorflow/python:dtypes",
2235        "//tensorflow/python/eager:def_function",
2236        "//tensorflow/python/eager:test",
2237        "@absl_py//absl/testing:parameterized",
2238    ],
2239)
2240
2241py_library(
2242    name = "parameter_server_strategy_v2",
2243    srcs = ["parameter_server_strategy_v2.py"],
2244    srcs_version = "PY3",
2245    deps = [
2246        ":distribute_lib",
2247        ":distribute_utils",
2248        ":input_util",
2249        ":parameter_server_strategy",
2250        ":sharded_variable",
2251        "//tensorflow/python:constant_op",
2252        "//tensorflow/python:dtypes",
2253        "//tensorflow/python:framework_ops",
2254        "//tensorflow/python:tensor_shape",
2255        "//tensorflow/python:tf_decorator",
2256        "//tensorflow/python/eager:remote",
2257        "//tensorflow/python/trackable:base",
2258        "//tensorflow/python/training:server_lib",
2259    ],
2260)
2261
2262distribute_py_test(
2263    name = "parameter_server_strategy_v2_test",
2264    srcs = ["parameter_server_strategy_v2_test.py"],
2265    python_version = "PY3",
2266    tags = [
2267        "multi_and_single_gpu",
2268        "nomac",  # TODO(b/201788023): Attempt MultiProcessCluster to fix this.
2269        "notpu",
2270        "notsan",  # Tsan failure doesn't seem to be caused by TF.
2271    ],
2272    deps = [
2273        ":combinations",
2274        ":multi_worker_test_base",
2275        ":parameter_server_strategy_v2",
2276        ":sharded_variable",
2277        "//tensorflow/python:array_ops",
2278        "//tensorflow/python:dtypes",
2279        "//tensorflow/python:extra_py_tests_deps",
2280        "//tensorflow/python:framework_ops",
2281        "//tensorflow/python:init_ops_v2",
2282        "//tensorflow/python:linalg_ops_impl",
2283        "//tensorflow/python:variable_scope",
2284        "//tensorflow/python:variables",
2285        "//tensorflow/python/checkpoint",
2286        "//tensorflow/python/compat:v2_compat",
2287        "//tensorflow/python/data/ops:dataset_ops",
2288        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
2289        "//tensorflow/python/eager:context",
2290        "//tensorflow/python/eager:def_function",
2291        "//tensorflow/python/eager:test",
2292        "//tensorflow/python/trackable:autotrackable",
2293        "//tensorflow/python/training:server_lib",
2294        "@absl_py//absl/testing:parameterized",
2295    ],
2296)
2297
2298distribute_py_test(
2299    name = "distributed_table_test",
2300    srcs = ["distributed_table_test.py"],
2301    python_version = "PY3",
2302    tags = [
2303        "multi_and_single_gpu",
2304        "noasan",  # TODO(b/237407459)
2305        "notpu",
2306        "notsan",  # Tsan failure doesn't seem to be caused by TF.
2307    ],
2308    deps = [
2309        ":combinations",
2310        ":device_util",
2311        ":multi_worker_test_base",
2312        ":parameter_server_strategy_v2",
2313        ":ps_values",
2314        ":test_util",
2315        "//tensorflow/python:dtypes",
2316        "//tensorflow/python:extra_py_tests_deps",
2317        "//tensorflow/python:lookup_ops",
2318        "//tensorflow/python:math_ops",
2319        "//tensorflow/python/compat:v2_compat",
2320        "//tensorflow/python/data/ops:dataset_ops",
2321        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
2322        "//tensorflow/python/distribute/coordinator:cluster_coordinator",
2323        "//tensorflow/python/distribute/coordinator:coordinator_context",
2324        "//tensorflow/python/eager:context",
2325        "//tensorflow/python/eager:def_function",
2326        "//tensorflow/python/eager:test",
2327        "//tensorflow/python/framework:constant_op",
2328        "//tensorflow/python/framework:tensor_spec",
2329        "//tensorflow/python/keras/saving",
2330        "//tensorflow/python/module",
2331        "//tensorflow/python/platform",
2332        "//tensorflow/python/saved_model:load",
2333        "//tensorflow/python/saved_model:save",
2334        "//tensorflow/python/training:server_lib",
2335        "@absl_py//absl/testing:parameterized",
2336    ],
2337)
2338
2339tpu_py_test(
2340    name = "tpu_strategy_model_parallelism_test",
2341    srcs = ["tpu_strategy_model_parallelism_test.py"],
2342    disable_experimental = True,  # b/202779350
2343    disable_mlir_bridge = False,
2344    disable_v3_4chips = False,
2345    python_version = "PY3",
2346    tags = [
2347        "no_oss",
2348        "notsan",  # b/239832964
2349    ],
2350    deps = [
2351        ":distribute_lib",
2352        ":strategy_test_lib",
2353        ":tpu_strategy",
2354        ":tpu_values",
2355        "//tensorflow/python:control_flow_ops",
2356        "//tensorflow/python:math_ops",
2357        "//tensorflow/python:random_ops",
2358        "//tensorflow/python:variables",
2359        "//tensorflow/python/checkpoint",
2360        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
2361        "//tensorflow/python/eager:def_function",
2362        "//tensorflow/python/eager:remote",
2363        "//tensorflow/python/eager:test",
2364        "//tensorflow/python/framework:constant_op",
2365        "//tensorflow/python/framework:dtypes",
2366        "//tensorflow/python/module",
2367        "//tensorflow/python/platform",
2368        "//tensorflow/python/tpu:device_assignment",
2369        "//tensorflow/python/tpu:tpu_lib",
2370        "//tensorflow/python/training:checkpoint_management",
2371    ],
2372)
2373
2374pytype_strict_library(
2375    name = "input_util",
2376    srcs = ["input_util.py"],
2377    srcs_version = "PY3",
2378    deps = [
2379        ":input_lib",
2380        "//tensorflow/python:tf2",
2381        "//tensorflow/python/distribute/v1:input_lib",
2382    ],
2383)
2384
2385py_library(
2386    name = "merge_call_interim",
2387    srcs = [
2388        "merge_call_interim.py",
2389    ],
2390    srcs_version = "PY3",
2391    deps = [
2392        ":distribute_lib",
2393        "//tensorflow/python/util:tf_export",
2394    ],
2395)
2396
2397cuda_py_test(
2398    name = "template_mirrored_strategy_test",
2399    size = "small",
2400    srcs = ["template_mirrored_strategy_test.py"],
2401    deps = [
2402        "//tensorflow/python:client_testlib",
2403        "//tensorflow/python:init_ops",
2404        "//tensorflow/python:template",
2405        "//tensorflow/python:variable_scope",
2406        "//tensorflow/python:variables",
2407        "//tensorflow/python/distribute:distribute_lib",
2408        "//tensorflow/python/distribute:mirrored_strategy",
2409    ],
2410)
2411