xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/BUILD (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Description: Operations defined for Cloud TPUs
2load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
3load("//tensorflow:tensorflow.bzl", "pytype_library")  # buildifier: disable=same-origin-load
4load("//tensorflow:tensorflow.bzl", "tf_py_test")  # buildifier: disable=same-origin-load
5load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
6load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
7
8# Do not add anymore paths here. You do not need to be in the visibility list
9# to use TPU symbols. They are accessible from tf.contrib.tpu in TF 1.x and
10# tf.tpu and tf.compat.v1.tpu in TF 2.x.
11package(
12    default_visibility = [
13        "//learning/brain:__subpackages__",
14        "//learning/deepmind:__subpackages__",
15        "//learning/serving:__subpackages__",
16        "//research/graph:__subpackages__",
17        "//tensorflow:__subpackages__",
18        "//waymo/ml/deploy/sync_test/tools:__subpackages__",
19    ],
20    licenses = ["notice"],
21)
22
23exports_files(["tpu_test_wrapper.py"])
24
25py_test(
26    name = "tpu_test_wrapper_test",
27    srcs = [
28        "tpu_test_wrapper.py",
29        "tpu_test_wrapper_test.py",
30    ],
31    main = "tpu_test_wrapper_test.py",
32    python_version = "PY3",
33    srcs_version = "PY3",
34    tags = [
35        "no_oss_py35",
36        "no_pip",
37    ],
38    deps = [
39        "//tensorflow/python:client_testlib",
40        "//tensorflow/python:platform",
41        "@absl_py//absl/testing:flagsaver",
42    ],
43)
44
45alias(
46    name = "tpu_ops",
47    actual = "//tensorflow/python/tpu/ops",
48)
49
50pytype_library(
51    name = "async_checkpoint",
52    srcs = ["async_checkpoint.py"],
53    srcs_version = "PY3",
54    deps = [
55        "//tensorflow/python:array_ops",
56        "//tensorflow/python:control_flow_ops",
57        "//tensorflow/python:framework_for_generated_wrappers",
58        "//tensorflow/python:init_ops",
59        "//tensorflow/python:math_ops",
60        "//tensorflow/python:platform",
61        "//tensorflow/python:state_ops",
62        "//tensorflow/python:summary_ops_v2",
63        "//tensorflow/python:training",
64        "//tensorflow/python:variable_scope",
65        "//tensorflow/python:variables",
66        "//tensorflow/python/estimator:estimator_py",
67        "//tensorflow/python/summary",
68    ],
69)
70
71tpu_py_test(
72    name = "async_checkpoint_test",
73    size = "medium",
74    srcs = ["async_checkpoint_test.py"],
75    disable_experimental = True,
76    disable_mlir_bridge = False,
77    deps = [
78        ":async_checkpoint",
79        ":tpu_estimator",
80        ":tpu_lib",
81        "//tensorflow/python:lib",
82        "//tensorflow/python:platform",
83        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
84        "//third_party/py/numpy",
85    ],
86)
87
88pytype_library(
89    name = "device_assignment",
90    srcs = ["device_assignment.py"],
91    visibility = [
92        "//tensorflow:internal",
93    ],
94    deps = [
95        ":topology",
96        "//tensorflow/python:platform",
97        "//tensorflow/python/util:tf_export",
98    ],
99)
100
101pytype_library(
102    name = "preempted_hook_py",
103    srcs = ["preempted_hook.py"],
104    srcs_version = "PY3",
105    deps = [
106        "//tensorflow/python:errors",
107        "//tensorflow/python:platform",
108        "//tensorflow/python:session_run_hook",
109        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
110    ],
111)
112
113py_library(
114    name = "tpu_estimator",
115    srcs = [
116        "_tpu_estimator_embedding.py",
117        "error_handling.py",
118        "tpu_config.py",
119        "tpu_context.py",
120        "tpu_estimator.py",
121        "util.py",
122    ],
123    srcs_version = "PY3",
124    deps = [
125        ":async_checkpoint",
126        ":feature_column",
127        ":feature_column_v2",
128        ":functional",
129        ":preempted_hook_py",
130        ":tpu_embedding",
131        ":tpu_lib",
132        "//tensorflow/core:protos_all_py",
133        "//tensorflow/python:array_ops",
134        "//tensorflow/python:control_flow_ops",
135        "//tensorflow/python:framework_for_generated_wrappers",
136        "//tensorflow/python:function",
137        "//tensorflow/python:init_ops",
138        "//tensorflow/python:math_ops",
139        "//tensorflow/python:platform",
140        "//tensorflow/python:session",
141        "//tensorflow/python:state_ops",
142        "//tensorflow/python:summary_ops_v2",
143        "//tensorflow/python:training",
144        "//tensorflow/python:variable_scope",
145        "//tensorflow/python:variables",
146        "//tensorflow/python/estimator:estimator_py",
147        "//tensorflow/python/estimator:util",
148        "//tensorflow/python/summary",
149    ],
150)
151
152py_library(
153    name = "functional",
154    srcs = ["functional.py"],
155    srcs_version = "PY3",
156    visibility = [
157        "//visibility:public",
158    ],
159    deps = [
160        "//tensorflow/python:tpu_ops_gen",
161    ],
162)
163
164pytype_library(
165    name = "topology",
166    srcs = ["topology.py"],
167    srcs_version = "PY3",
168    deps = [
169        "//tensorflow/core/protobuf/tpu:topology_proto_py",
170        "//tensorflow/python/util:tf_export",
171        "//third_party/py/numpy",
172    ],
173)
174
175py_library(
176    name = "tpu",
177    srcs = [
178        "__init__.py",
179    ],
180    srcs_version = "PY3",
181    deps = [
182        ":feature_column",
183        ":feature_column_v2",
184        ":tpu_embedding",
185        ":tpu_estimator",
186        ":tpu_lib",
187    ],
188)
189
190py_library(
191    name = "tpu_noestimator",
192    srcs = [
193        "__init__.py",
194        "api.py",
195    ],
196    srcs_version = "PY3",
197    deps = [
198        ":feature_column",
199        ":feature_column_v2",
200        ":preempted_hook_py",
201        ":tpu_embedding",
202        ":tpu_embedding_base",
203        ":tpu_embedding_for_serving",
204        ":tpu_embedding_v1",
205        ":tpu_embedding_v2",
206        ":tpu_embedding_v2_utils",
207        ":tpu_hardware_feature",
208        ":tpu_lib",
209    ],
210)
211
212pytype_library(
213    name = "tpu_lib",
214    srcs = [
215        "__init__.py",
216        "bfloat16.py",
217        "session_support.py",
218        "tensor_tracer.py",
219        "tensor_tracer_flags.py",
220        "tensor_tracer_report.py",
221        "tpu_optimizer.py",
222        "tpu_strategy_util.py",
223        "training_loop.py",
224    ],
225    srcs_version = "PY3",
226    deps = [
227        ":datasets",
228        ":device_assignment",
229        ":functional",
230        ":topology",
231        ":tpu_feed",
232        ":tpu_function",
233        ":tpu_sharding",
234        "//tensorflow/compiler/xla/experimental/xla_sharding",
235        "//tensorflow/compiler/xla/python_api:xla_shape",
236        "//tensorflow/core:protos_all_py",
237        "//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
238        "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
239        "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
240        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
241        "//tensorflow/python:array_ops",
242        "//tensorflow/python:batch_ops",
243        "//tensorflow/python:control_flow_ops",
244        "//tensorflow/python:control_flow_util",
245        "//tensorflow/python:dtypes",
246        "//tensorflow/python:framework",
247        "//tensorflow/python:framework_ops",
248        "//tensorflow/python:platform",
249        "//tensorflow/python:platform_analytics",
250        "//tensorflow/python:tensor_shape",
251        "//tensorflow/python:tf2",
252        "//tensorflow/python:tpu_ops_gen",
253        "//tensorflow/python:training",
254        "//tensorflow/python:util",
255        "//tensorflow/python:variable_scope",
256        "//tensorflow/python/compiler/xla",
257        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
258        "//tensorflow/python/ops/losses",
259        "//tensorflow/python/tpu:tensor_tracer_proto_py",
260        "//tensorflow/python/tpu/ops",
261        "//tensorflow/python/tpu/profiler",
262    ],
263)
264
265pytype_library(
266    name = "tpu_py",
267    srcs = ["tpu.py"],
268    deps = [
269        ":device_assignment",
270        ":tpu_feed",
271        ":tpu_function",
272        ":tpu_name_util",
273        "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
274        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
275        "//tensorflow/python:array_ops",
276        "//tensorflow/python:auto_control_deps",
277        "//tensorflow/python:c_api_util",
278        "//tensorflow/python:composite_tensor",
279        "//tensorflow/python:config",
280        "//tensorflow/python:control_flow_ops",
281        "//tensorflow/python:device",
282        "//tensorflow/python:dtypes",
283        "//tensorflow/python:errors",
284        "//tensorflow/python:framework_ops",
285        "//tensorflow/python:func_graph",
286        "//tensorflow/python:function",
287        "//tensorflow/python:math_ops",
288        "//tensorflow/python:tensor_shape",
289        "//tensorflow/python:util",
290        "//tensorflow/python:variable_scope",
291        "//tensorflow/python/compiler/xla",
292        "//tensorflow/python/distribute:device_util",
293        "//tensorflow/python/distribute:distribute_lib",
294        "//tensorflow/python/tpu/ops",
295        "//tensorflow/python/util:tf_export",
296        "//third_party/py/numpy",
297        "@absl_py//absl/logging",
298    ],
299)
300
301pytype_library(
302    name = "tpu_feed",
303    srcs = ["tpu_feed.py"],
304    deps = [
305        ":tpu_name_util",
306        ":tpu_sharding",
307        "//tensorflow/compiler/xla/experimental/xla_sharding",
308        "//tensorflow/python:array_ops",
309        "//tensorflow/python:dtypes",
310        "//tensorflow/python:tensor_shape",
311        "//tensorflow/python/tpu/ops",
312        "//tensorflow/python/user_ops:ops",
313    ],
314)
315
316pytype_library(
317    name = "tpu_function",
318    srcs = ["tpu_function.py"],
319)
320
321pytype_library(
322    name = "tpu_sharding",
323    srcs = ["tpu_sharding.py"],
324    deps = [
325        "//tensorflow/python:tensor_shape",
326    ],
327)
328
329pytype_library(
330    name = "tpu_system_metadata",
331    srcs = ["tpu_system_metadata.py"],
332    deps = [
333        ":tpu_py",
334        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
335        "//tensorflow/python:config",
336        "//tensorflow/python:framework",
337        "//tensorflow/python:framework_ops",
338        "//tensorflow/python:platform",
339        "//tensorflow/python:session",
340        "//tensorflow/python:util",
341        "//tensorflow/python/distribute:device_util",
342    ],
343)
344
345pytype_library(
346    name = "datasets",
347    srcs = [
348        "datasets.py",
349    ],
350    srcs_version = "PY3",
351    deps = [
352        "//tensorflow/python:dtypes",
353        "//tensorflow/python:function",
354        "//tensorflow/python:functional_ops",
355        "//tensorflow/python/data/experimental/ops:interleave_ops",
356        "//tensorflow/python/data/ops:dataset_ops",
357        "//tensorflow/python/data/ops:iterator_ops",
358        "//tensorflow/python/data/ops:readers",
359    ],
360)
361
362tf_py_test(
363    name = "datasets_test",
364    size = "medium",
365    srcs = ["datasets_test.py"],
366    grpc_enabled = True,
367    shard_count = 4,
368    tags = ["no_oss"],
369    deps = [
370        ":datasets",
371        "//tensorflow/python:client_testlib",
372    ],
373)
374
375tf_py_test(
376    name = "tpu_test",
377    size = "small",
378    srcs = ["tpu_test.py"],
379    tags = [
380        "no_oss",  # TODO(b/131157871): Reenable in OSS when fixed
381        "no_windows",  # TODO: needs investigation on Windows
382    ],
383    deps = [
384        ":tpu",
385        "//tensorflow/python:client_testlib",
386        "//tensorflow/python:dtypes",
387        "//tensorflow/python:framework",
388        "//tensorflow/python:layers",
389    ],
390)
391
392tf_py_test(
393    name = "tpu_sharding_test",
394    size = "small",
395    srcs = ["tpu_sharding_test.py"],
396    deps = [
397        ":tpu_sharding",
398        "//tensorflow/python:client_testlib",
399        "//tensorflow/python:framework",
400    ],
401)
402
403tf_py_test(
404    name = "bfloat16_test",
405    size = "small",
406    srcs = ["bfloat16_test.py"],
407    deps = [
408        ":tpu",
409        "//tensorflow/python:client_testlib",
410        "//tensorflow/python:framework",
411    ],
412)
413
414tf_py_test(
415    name = "tpu_infeed_test",
416    size = "small",
417    srcs = ["tpu_infeed_test.py"],
418    deps = [
419        ":tpu",
420        "//tensorflow/python:framework",
421        "//tensorflow/python:framework_test_lib",
422    ],
423)
424
425tf_py_test(
426    name = "topology_test",
427    size = "medium",
428    srcs = ["topology_test.py"],
429    deps = [
430        ":topology",
431        "//tensorflow/python:framework_test_lib",
432    ],
433)
434
435pytype_library(
436    name = "tpu_embedding",
437    srcs = [
438        "tpu_embedding.py",
439        "tpu_embedding_gradient.py",
440    ],
441    srcs_version = "PY3",
442    deps = [
443        ":tpu_lib",
444        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
445        "//tensorflow/python:array_ops",
446        "//tensorflow/python:framework_for_generated_wrappers",
447        "//tensorflow/python:init_ops",
448        "//tensorflow/python:math_ops",
449        "//tensorflow/python:partitioned_variables",
450        "//tensorflow/python:tpu_ops_gen",
451        "//tensorflow/python:variable_scope",
452        "//tensorflow/python:variables",
453    ],
454)
455
456pytype_library(
457    name = "tpu_strategy_util",
458    srcs = ["tpu_strategy_util.py"],
459    visibility = [
460        "//learning/brain:__subpackages__",
461        "//learning/deepmind:__subpackages__",
462        "//learning/serving:__subpackages__",
463        "//research/graph:__subpackages__",
464        "//tensorflow:__subpackages__",
465        "//third_party/py/tensorflow_numerics/extensions:__pkg__",
466    ],
467    deps = [
468        ":tpu_lib",
469        "//tensorflow/python:dtypes",
470        "//tensorflow/python:framework_ops",
471        "//tensorflow/python:util",
472        "//tensorflow/python/distribute:device_util",
473        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
474        "//tensorflow/python/eager:context",
475        "//tensorflow/python/eager:tape",
476    ],
477)
478
479pytype_library(
480    name = "tpu_hardware_feature",
481    srcs = ["tpu_hardware_feature.py"],
482    deps = [
483        "//tensorflow/core/protobuf/tpu:topology_proto_py",
484        "//tensorflow/python/util:tf_export",
485    ],
486)
487
488py_library(
489    name = "tpu_name_util",
490    srcs = ["tpu_name_util.py"],
491    srcs_version = "PY3",
492    deps = [
493        "//tensorflow/python/util:tf_export",
494    ],
495)
496
497pytype_library(
498    name = "feature_column",
499    srcs = ["feature_column.py"],
500    deps = [
501        ":tpu_lib",
502        "//tensorflow/python:framework_ops",
503        "//tensorflow/python:init_ops",
504        "//tensorflow/python:variable_scope",
505        "//tensorflow/python/feature_column",
506        "//tensorflow/python/feature_column:feature_column_py",
507    ],
508)
509
510pytype_library(
511    name = "feature_column_v2",
512    srcs = ["feature_column_v2.py"],
513    deps = [
514        ":feature_column",
515        ":tpu_lib",
516        "//tensorflow/python:framework_ops",
517        "//tensorflow/python:init_ops",
518        "//tensorflow/python:variable_scope",
519        "//tensorflow/python/feature_column",
520        "//tensorflow/python/feature_column:feature_column_py",
521    ],
522)
523
524tf_py_test(
525    name = "feature_column_test",
526    srcs = [
527        "feature_column_test.py",
528    ],
529    main = "feature_column_test.py",
530    deps = [
531        ":feature_column",
532        "//tensorflow/python:client_testlib",
533        "//tensorflow/python:dtypes",
534        "//tensorflow/python:framework_ops",
535        "//tensorflow/python:lookup_ops",
536        "//tensorflow/python:parsing_ops",
537        "//tensorflow/python:session",
538        "//tensorflow/python:sparse_tensor",
539        "//tensorflow/python:variables",
540        "//tensorflow/python/feature_column",
541        "//tensorflow/python/feature_column:feature_column_py",
542        "//third_party/py/numpy",
543    ],
544)
545
546tf_py_test(
547    name = "feature_column_v2_test",
548    srcs = [
549        "feature_column_v2_test.py",
550    ],
551    main = "feature_column_v2_test.py",
552    tags = ["no_oss"],  # Due to the usage of keras component.
553    deps = [
554        ":feature_column_v2",
555        "//tensorflow/python:client_testlib",
556        "//tensorflow/python:dtypes",
557        "//tensorflow/python:framework_ops",
558        "//tensorflow/python:lookup_ops",
559        "//tensorflow/python:parsing_ops",
560        "//tensorflow/python:session",
561        "//tensorflow/python:sparse_tensor",
562        "//tensorflow/python:variables",
563        "//tensorflow/python/feature_column",
564        "//tensorflow/python/feature_column:feature_column_py",
565        "//third_party/py/numpy",
566    ],
567)
568
569pytype_library(
570    name = "tpu_embedding_v2_utils",
571    srcs = ["tpu_embedding_v2_utils.py"],
572    srcs_version = "PY3",
573    deps = [
574        "//tensorflow/python:init_ops_v2",
575        "//tensorflow/python:variable_scope",
576        "//tensorflow/python/distribute:device_util",
577        "//tensorflow/python/distribute:sharded_variable",
578        "//tensorflow/python/tpu:tpu_lib",
579        "//tensorflow/python/tpu/ops",
580    ],
581)
582
583pytype_library(
584    name = "tpu_embedding_v2",
585    srcs = ["tpu_embedding_v2.py"],
586    srcs_version = "PY3",
587    deps = [
588        ":tpu_embedding_v2_utils",
589        ":tpu_lib",
590        "//tensorflow/python:variable_scope",
591        "//tensorflow/python/distribute:device_util",
592        "//tensorflow/python/distribute:distribute_utils",
593        "//tensorflow/python/distribute:sharded_variable",
594        "//tensorflow/python/distribute:tpu_strategy",
595        "//tensorflow/python/saved_model/registration",
596        "//tensorflow/python/tpu/ops",
597    ],
598)
599
600pytype_strict_library(
601    name = "tpu_embedding_base",
602    srcs = ["tpu_embedding_base.py"],
603    srcs_version = "PY3",
604    deps = [
605        ":tpu_embedding_v2_utils",
606        "//tensorflow/python:variables",
607        "//tensorflow/python/framework:dtypes",
608        "//tensorflow/python/trackable:autotrackable",
609        "//tensorflow/python/util",
610    ],
611)
612
613pytype_strict_library(
614    name = "tpu_embedding_for_serving",
615    srcs = ["tpu_embedding_for_serving.py"],
616    srcs_version = "PY3",
617    deps = [
618        ":tpu_embedding_base",
619        ":tpu_embedding_v2_utils",
620        "//tensorflow/python:array_ops",
621        "//tensorflow/python:embedding_ops",
622        "//tensorflow/python:math_ops",
623        "//tensorflow/python:sparse_ops",
624        "//tensorflow/python:variables",
625        "//tensorflow/python/distribute:distribute_lib",
626        "//tensorflow/python/distribute:tpu_strategy",
627        "//tensorflow/python/framework:dtypes",
628        "//tensorflow/python/framework:ops",
629        "//tensorflow/python/framework:sparse_tensor",
630        "//tensorflow/python/ops/ragged:ragged_tensor",
631        "//tensorflow/python/types",
632        "//tensorflow/python/util",
633        "//tensorflow/python/util:tf_export",
634    ],
635)
636
637tf_py_test(
638    name = "tpu_embedding_for_serving_test",
639    srcs = [
640        "tpu_embedding_for_serving_test.py",
641    ],
642    python_version = "PY3",
643    srcs_version = "PY3",
644    deps = [
645        ":tpu_embedding_for_serving",
646        "//tensorflow/python:init_ops_v2",
647        "//tensorflow/python/compat:v2_compat",
648        "//tensorflow/python/ops/ragged:ragged_tensor",
649        "//third_party/py/numpy",
650    ],
651)
652
653pytype_strict_library(
654    name = "tpu_embedding_v1",
655    srcs = ["tpu_embedding_v1.py"],
656    srcs_version = "PY3",
657    deps = [
658        ":tpu_embedding_base",
659        ":tpu_embedding_v2_utils",
660        ":tpu_py",
661        "//tensorflow/python:array_ops",
662        "//tensorflow/python:embedding_ops",
663        "//tensorflow/python:math_ops",
664        "//tensorflow/python:sparse_ops",
665        "//tensorflow/python:variables",
666        "//tensorflow/python/distribute:distribute_lib",
667        "//tensorflow/python/distribute:tpu_strategy",
668        "//tensorflow/python/framework:dtypes",
669        "//tensorflow/python/framework:ops",
670        "//tensorflow/python/framework:sparse_tensor",
671        "//tensorflow/python/ops/ragged:ragged_tensor",
672        "//tensorflow/python/util",
673        "//tensorflow/python/util:tf_export",
674    ],
675)
676
677tf_py_test(
678    name = "tpu_embedding_v2_utils_test",
679    srcs = [
680        "tpu_embedding_v2_utils_test.py",
681    ],
682    python_version = "PY3",
683    srcs_version = "PY3",
684    deps = [
685        ":tpu_embedding_v2",
686        "//tensorflow/python/compat:v2_compat",
687    ],
688)
689
690tpu_py_test(
691    name = "tpu_outside_compilation_test",
692    srcs = [
693        "tpu_outside_compilation_test.py",
694    ],
695    disable_experimental = True,
696    disable_mlir_bridge = False,
697    python_version = "PY3",
698    tags = ["no_oss"],
699    deps = [
700        ":tpu_lib",
701        "//tensorflow/python:variables",
702        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
703        "//tensorflow/python/eager:remote",
704        "//tensorflow/python/eager:test",
705    ],
706)
707
708# NOTE this target should only be depended on by the tpu_test_wrapper macro.
709py_library(
710    name = "tpu_test_deps",
711    srcs_version = "PY3",
712    visibility = ["//visibility:public"],
713    deps = ["//tensorflow/python:client_testlib"],
714)
715
716tf_proto_library(
717    name = "tensor_tracer_proto",
718    srcs = ["tensor_tracer.proto"],
719    cc_api_version = 2,
720    protodeps = [
721        "//tensorflow/core:protos_all",
722    ],
723    visibility = ["//visibility:public"],
724)
725
726# copybara:uncomment_begin(google-only)
727# py_proto_library(
728#     name = "tensor_tracer_py_pb2",
729#     api_version = 2,
730#     visibility = ["//visibility:public"],
731#     deps = [":tensor_tracer_proto"],
732# )
733# copybara:uncomment_end
734