xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/BUILD (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# buildifier: disable=same-origin-load
2load("//tensorflow:tensorflow.bzl", "cuda_py_test")
3
4# buildifier: disable=same-origin-load
5load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
6
7# buildifier: disable=same-origin-load
8load("//tensorflow:tensorflow.bzl", "tf_py_test")
9
10# buildifier: disable=same-origin-load
11load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
12
13# buildifier: disable=same-origin-load
14load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
15load("//tensorflow/core/platform:build_config.bzl", "tf_protos_grappler")
16load("//tensorflow:tensorflow.bzl", "if_not_windows")
17
18package(
19    default_visibility = ["//tensorflow:internal"],
20    licenses = ["notice"],
21)
22
23# TODO(gunan): Investigate making this action hermetic so we do not need
24# to run it locally.
25cc_library(
26    name = "cost_analyzer_lib",
27    srcs = ["cost_analyzer.cc"],
28    hdrs = ["cost_analyzer.h"],
29    compatible_with = get_compatible_with_cloud(),
30    deps = [
31        "//tensorflow/core:lib",
32        "//tensorflow/core/grappler/costs:analytical_cost_estimator",
33        "//tensorflow/core/grappler/costs:measuring_cost_estimator",
34        "//tensorflow/core:protos_all_cc",
35        "//tensorflow/core/grappler:grappler_item",
36        "//tensorflow/core/grappler/clusters:cluster",
37        "//tensorflow/core/grappler/costs:cost_estimator",
38        "//tensorflow/core/grappler/costs:utils",
39    ] + tf_protos_grappler(),
40    alwayslink = 1,
41)
42
43# Necessary for the pywrap inclusion below. Combining targets does not work
44# properly.
45tf_pybind_cc_library_wrapper(
46    name = "cost_analyzer_headers",
47    deps = [
48        ":cost_analyzer_lib",
49    ],
50)
51
52tf_python_pybind_extension(
53    name = "_pywrap_cost_analyzer",
54    srcs = ["cost_analyzer_wrapper.cc"],
55    hdrs = [
56        "cost_analyzer.h",
57        "//tensorflow/cc:pywrap_required_hdrs",
58        "//tensorflow/core/grappler:pywrap_required_hdrs",
59        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
60        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
61        "//tensorflow/core/public:session.h",
62        "//tensorflow/core/public:session_options.h",
63    ],
64    deps = [
65        ":cost_analyzer_headers",
66        "//tensorflow/core:framework_headers_lib",
67        "//tensorflow/core:lib_headers_for_pybind",
68        "//tensorflow/core:protos_all_cc",
69        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
70        "//tensorflow/core/common_runtime/gpu:gpu_id",
71        "//tensorflow/python/lib/core:pybind11_status",
72        "@pybind11",
73    ],
74)
75
76cc_library(
77    name = "model_analyzer_lib",
78    srcs = ["model_analyzer.cc"],
79    hdrs = ["model_analyzer.h"],
80    deps = [
81        "//tensorflow/core:framework",
82        "//tensorflow/core:lib",
83        "//tensorflow/core:protos_all_cc",
84        "//tensorflow/core/grappler:grappler_item",
85        "//tensorflow/core/grappler/costs:graph_properties",
86    ],
87)
88
89tf_python_pybind_extension(
90    name = "_pywrap_model_analyzer",
91    srcs = ["model_analyzer_wrapper.cc"],
92    hdrs = [
93        "model_analyzer.h",
94        "//tensorflow/core/grappler:pywrap_required_hdrs",
95    ],
96    deps = [
97        "//tensorflow/core:framework_headers_lib",
98        "//tensorflow/core:lib_headers_for_pybind",
99        "//tensorflow/core:protos_all_cc",
100        "//tensorflow/python/lib/core:pybind11_status",
101        "@pybind11",
102    ],
103)
104
105py_library(
106    name = "tf_item",
107    srcs = [
108        "item.py",
109    ],
110    srcs_version = "PY3",
111    visibility = ["//visibility:public"],
112    deps = [
113        ":_pywrap_tf_item",
114        "//tensorflow/core/grappler/costs:op_performance_data_py",
115    ],
116)
117
118tf_python_pybind_extension(
119    name = "_pywrap_tf_item",
120    srcs = ["item_wrapper.cc"],
121    hdrs = [
122        "//tensorflow/cc:pywrap_required_hdrs",
123        "//tensorflow/core/grappler:pywrap_required_hdrs",
124        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
125        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
126        "//tensorflow/core/grappler/utils:pywrap_required_hdrs",
127    ],
128    deps = [
129        "//tensorflow/python/lib/core:pybind11_status",
130        "@pybind11",
131        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
132        "//tensorflow/core:framework_headers_lib",
133        "//tensorflow/core/common_runtime/gpu:gpu_id",
134        "//tensorflow/core:protos_all_cc",
135    ] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]),  # b/148556093,
136)
137
138tf_py_test(
139    name = "item_test",
140    size = "small",
141    srcs = [
142        "item_test.py",
143    ],
144    python_version = "PY3",
145    tags = [
146        "grappler",
147        "no_pip",  # tf_optimizer is not available in pip.
148    ],
149    deps = [
150        ":tf_item",
151        "//tensorflow/core:protos_all_py",
152        "//tensorflow/python:client_testlib",
153        "//tensorflow/python:math_ops",
154        "//tensorflow/python/framework:for_generated_wrappers",
155    ],
156)
157
158tf_py_test(
159    name = "datasets_test",
160    size = "small",
161    srcs = [
162        "datasets_test.py",
163    ],
164    python_version = "PY3",
165    tags = [
166        "grappler",
167        "no_pip",  # tf_optimizer is not available in pip.
168    ],
169    deps = [
170        ":tf_item",
171        "//tensorflow/core:protos_all_py",
172        "//tensorflow/python:array_ops",
173        "//tensorflow/python:client_testlib",
174        "//tensorflow/python/data",
175        "//tensorflow/python/framework:combinations",
176        "//tensorflow/python/framework:for_generated_wrappers",
177    ],
178)
179
180py_library(
181    name = "tf_cluster",
182    srcs = [
183        "cluster.py",
184    ],
185    srcs_version = "PY3",
186    visibility = ["//visibility:public"],
187    deps = [
188        ":_pywrap_tf_cluster",
189        "//tensorflow/core/grappler/costs:op_performance_data_py",
190    ],
191)
192
193tf_python_pybind_extension(
194    name = "_pywrap_tf_cluster",
195    srcs = ["cluster_wrapper.cc"],
196    hdrs = [
197        "//tensorflow/cc:pywrap_required_hdrs",
198        "//tensorflow/core/grappler:pywrap_required_hdrs",
199        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
200        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
201        "//tensorflow/core/grappler/utils:pywrap_required_hdrs",
202    ],
203    deps = [
204        "//tensorflow/core:framework_headers_lib",
205        "//tensorflow/core:lib_headers_for_pybind",
206        "//tensorflow/core:protos_all_cc",
207        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
208        "//tensorflow/core/common_runtime/gpu:gpu_id",
209        "//tensorflow/python/lib/core:pybind11_status",
210        "@com_google_absl//absl/types:span",
211        "@pybind11",
212    ],
213)
214
215cuda_py_test(
216    name = "cluster_test",
217    size = "small",
218    srcs = [
219        "cluster_test.py",
220    ],
221    python_version = "PY3",
222    shard_count = 10,
223    tags = [
224        "grappler",
225        "no_pip",  # tf_optimizer is not available in pip.
226        "no_windows",  # b/173520599
227        "notap",  # TODO(b/135924227): Re-enable after fixing flakiness.
228    ],
229    # This test will not run on XLA because it primarily tests the TF Classic flow.
230    xla_enable_strict_auto_jit = False,
231    deps = [
232        ":tf_cluster",
233        ":tf_item",
234        "//tensorflow/core:protos_all_py",
235        "//tensorflow/python:client_testlib",
236        "//tensorflow/python/framework:for_generated_wrappers",
237    ],
238)
239
240py_library(
241    name = "tf_optimizer",
242    srcs = [
243        "tf_optimizer.py",
244    ],
245    srcs_version = "PY3",
246    visibility = ["//visibility:public"],
247    deps = [
248        ":_pywrap_tf_optimizer",
249        ":tf_cluster",
250    ],
251)
252
253tf_python_pybind_extension(
254    name = "_pywrap_tf_optimizer",
255    srcs = ["tf_optimizer_wrapper.cc"],
256    hdrs = [
257        "//tensorflow/cc:pywrap_required_hdrs",
258        "//tensorflow/core/grappler:pywrap_required_hdrs",
259        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
260        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
261        "//tensorflow/core/grappler/optimizers:pywrap_required_hdrs",
262        "//tensorflow/core/grappler/verifiers:pywrap_required_hdrs",
263    ],
264    deps = [
265        "//tensorflow/core:framework_headers_lib",
266        "//tensorflow/core:lib_headers_for_pybind",
267        "//tensorflow/core:protos_all_cc",
268        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
269        "//tensorflow/core/common_runtime/gpu:gpu_id",
270        "//tensorflow/python/lib/core:pybind11_status",
271        "@pybind11",
272    ],
273)
274
275tf_py_test(
276    name = "tf_optimizer_test",
277    size = "small",
278    srcs = [
279        "tf_optimizer_test.py",
280    ],
281    python_version = "PY3",
282    tags = [
283        "grappler",
284        "no_pip",  # tf_optimizer is not available in pip.
285    ],
286    deps = [
287        ":tf_item",
288        ":tf_optimizer",
289        "//tensorflow/core:protos_all_py",
290        "//tensorflow/python:client_testlib",
291        "//tensorflow/python:math_ops",
292        "//tensorflow/python/framework:for_generated_wrappers",
293        "//third_party/py/numpy",
294    ],
295)
296
297tf_py_test(
298    name = "memory_optimizer_test",
299    size = "medium",
300    srcs = [
301        "memory_optimizer_test.py",
302    ],
303    python_version = "PY3",
304    tags = [
305        "grappler",
306    ],
307    deps = [
308        ":tf_optimizer",
309        "//tensorflow/core:protos_all_py",
310        "//tensorflow/python:client_testlib",
311        "//tensorflow/python:math_ops",
312        "//tensorflow/python:nn",
313        "//tensorflow/python:session",
314        "//tensorflow/python:training",
315        "//tensorflow/python:variable_scope",
316        "//tensorflow/python:variables",
317        "//tensorflow/python/framework:for_generated_wrappers",
318        "//tensorflow/python/framework:random_seed",
319        "//third_party/py/numpy",
320    ],
321)
322
323cuda_py_test(
324    name = "constant_folding_test",
325    size = "medium",
326    srcs = [
327        "constant_folding_test.py",
328    ],
329    python_version = "PY3",
330    tags = [
331        "grappler",
332    ],
333    deps = [
334        "//tensorflow/core:protos_all_py",
335        "//tensorflow/python:array_ops",
336        "//tensorflow/python:client_testlib",
337        "//tensorflow/python:control_flow_ops",
338        "//tensorflow/python:functional_ops",
339        "//tensorflow/python:math_ops",
340        "//tensorflow/python:ops",
341        "//tensorflow/python/framework:dtypes",
342        "//tensorflow/python/framework:for_generated_wrappers",
343        "//third_party/py/numpy",
344    ],
345)
346
347cuda_py_test(
348    name = "arithmetic_optimizer_test",
349    size = "small",
350    srcs = [
351        "arithmetic_optimizer_test.py",
352    ],
353    python_version = "PY3",
354    tags = [
355        "grappler",
356    ],
357    xla_enable_strict_auto_jit = False,
358    deps = [
359        "//tensorflow/core:protos_all_py",
360        "//tensorflow/python:array_ops",
361        "//tensorflow/python:client_testlib",
362        "//tensorflow/python:math_ops",
363        "//tensorflow/python/framework:for_generated_wrappers",
364        "//third_party/py/numpy",
365    ],
366)
367
368# TODO(b/131764887) Remove once LayoutOptimizer is swapped out with GenericLayoutOptimizer.
369#
370# cuda_py_test(
371#     name = "layout_optimizer_test",
372#     size = "medium",
373#     srcs = [
374#         "layout_optimizer_test.py",
375#     ],
376#     deps = [
377#         "//tensorflow/python:client_testlib",
378#         "//tensorflow/python/framework:for_generated_wrappers",
379#         "//tensorflow/python:array_ops",
380#         "//tensorflow/python:functional_ops",
381#         "//tensorflow/python:math_ops",
382#         "//tensorflow/python:nn",
383#         "//tensorflow/python:ops",
384#         "//tensorflow/python:random_ops",
385#         "//tensorflow/python:state_ops",
386#         ":tf_cluster",
387#         ":tf_optimizer",
388#         "//tensorflow/python:training",
389#         "//third_party/py/numpy",
390#         "//tensorflow/core:protos_all_py",
391#         "//tensorflow/python/framework:constant_op",
392#         "//tensorflow/python/framework:dtypes",
393#     ],
394#     shard_count = 10,
395#     tags = [
396#         "grappler",
397#     ],
398#     # This test will not run on XLA because it primarily tests the TF Classic flow.
399#     xla_enable_strict_auto_jit = False,
400# )
401
402py_library(
403    name = "cost_analyzer",
404    srcs = [
405        "cost_analyzer.py",
406    ],
407    srcs_version = "PY3",
408    deps = [
409        ":_pywrap_cost_analyzer",
410        ":tf_cluster",
411        ":tf_item",
412    ],
413)
414
415py_binary(
416    name = "cost_analyzer_tool",
417    srcs = [
418        "cost_analyzer_tool.py",
419    ],
420    python_version = "PY3",
421    srcs_version = "PY3",
422    deps = [
423        ":cost_analyzer",
424        ":tf_optimizer",
425        "//tensorflow/core:protos_all_py",
426        "//tensorflow/python/framework:for_generated_wrappers",
427    ],
428)
429
430tf_py_test(
431    name = "cost_analyzer_test",
432    size = "small",
433    srcs = ["cost_analyzer_test.py"],
434    python_version = "PY3",
435    tags = [
436        "grappler",
437        "no_cuda_on_cpu_tap",
438        "no_mac",
439        "no_pip",
440        "no_windows",  # TODO(b/151942037)
441    ],
442    deps = [
443        ":cost_analyzer",
444        "//tensorflow/core:protos_all_py",
445        "//tensorflow/python:array_ops",
446        "//tensorflow/python:client_testlib",
447        "//tensorflow/python:math_ops",
448        "//tensorflow/python:nn",
449        "//tensorflow/python:nn_grad",
450        "//tensorflow/python:random_ops",
451        "//tensorflow/python:state_ops",
452        "//tensorflow/python:training",
453        "//tensorflow/python:variables",
454        "//tensorflow/python/framework:for_generated_wrappers",
455        "//third_party/py/numpy",
456    ],
457)
458
459py_library(
460    name = "model_analyzer",
461    srcs = [
462        "model_analyzer.py",
463    ],
464    srcs_version = "PY3",
465    deps = [":_pywrap_model_analyzer"],
466)
467
468tf_py_test(
469    name = "model_analyzer_test",
470    size = "small",
471    srcs = ["model_analyzer_test.py"],
472    tags = [
473        "grappler",
474        "no_pip",
475    ],
476    deps = [
477        ":model_analyzer",
478        "//tensorflow/core:protos_all_py",
479        "//tensorflow/python:array_ops",
480        "//tensorflow/python:client_testlib",
481        "//tensorflow/python:math_ops",
482        "//tensorflow/python:state_ops",
483        "//tensorflow/python/framework:for_generated_wrappers",
484        "//third_party/py/numpy",
485    ],
486)
487
488cuda_py_test(
489    name = "auto_mixed_precision_test",
490    size = "medium",
491    srcs = [
492        "auto_mixed_precision_test.py",
493    ],
494    python_version = "PY3",
495    tags = ["grappler"],
496    # This test analyzes the graph, but XLA changes the names of nodes.
497    xla_enable_strict_auto_jit = False,
498    deps = [
499        "//tensorflow/core:protos_all_py",
500        "//tensorflow/python:array_ops",
501        "//tensorflow/python:client_testlib",
502        "//tensorflow/python:control_flow_ops",
503        "//tensorflow/python:math_ops",
504        "//tensorflow/python:nn",
505        "//tensorflow/python:ops",
506        "//tensorflow/python:random_ops",
507        "//tensorflow/python:training",
508        "//tensorflow/python/framework:constant_op",
509        "//tensorflow/python/framework:dtypes",
510        "//tensorflow/python/framework:for_generated_wrappers",
511        "//third_party/py/numpy",
512    ],
513)
514
515cuda_py_test(
516    name = "remapper_test",
517    size = "medium",
518    srcs = [
519        "remapper_test.py",
520    ],
521    python_version = "PY3",
522    tags = ["grappler"],
523    # This test analyzes the graph, but XLA changes the names of nodes.
524    xla_enable_strict_auto_jit = False,
525    deps = [
526        "//tensorflow/core:protos_all_py",
527        "//tensorflow/python:array_ops",
528        "//tensorflow/python:client_testlib",
529        "//tensorflow/python:control_flow_ops",
530        "//tensorflow/python:math_ops",
531        "//tensorflow/python:nn",
532        "//tensorflow/python:ops",
533        "//tensorflow/python:random_ops",
534        "//tensorflow/python:training",
535        "//tensorflow/python/framework:constant_op",
536        "//tensorflow/python/framework:dtypes",
537        "//tensorflow/python/framework:for_generated_wrappers",
538        "//third_party/py/numpy",
539    ],
540)
541
542tf_python_pybind_extension(
543    name = "_pywrap_graph_analyzer",
544    srcs = ["graph_analyzer_tool_wrapper.cc"],
545    deps = [
546        "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
547        "@pybind11",
548    ],
549)
550
551py_binary(
552    name = "graph_analyzer",
553    srcs = [
554        "graph_analyzer.py",
555    ],
556    python_version = "PY3",
557    srcs_version = "PY3",
558    deps = [
559        ":_pywrap_graph_analyzer",
560        "//tensorflow/python/framework:for_generated_wrappers",
561    ],
562)
563