xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/targets.bzl (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2load(
3    "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
4    "get_vec_preprocessor_flags",
5    "get_vec_deps",
6)
7load(
8    "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
9    "get_compiler_optimization_flags",
10)
11
12def define_common_targets():
13    """Defines targets that should be shared between fbcode and xplat.
14
15    The directory containing this targets.bzl file should also contain both
16    TARGETS and BUCK files that call this function.
17    """
18    for mkl_dep in ["", "_mkl_noomp"]:
19        runtime.cxx_library(
20            name = "custom_ops" + mkl_dep,
21            srcs = [
22                "op_fallback.cpp",
23                "op_fast_hadamard_transform.cpp",
24                "op_sdpa.cpp",
25                "op_update_quantized_cache.cpp",
26            ],
27            exported_headers = [
28                "op_fallback.h",
29                "op_fast_hadamard_transform.h",
30                "op_sdpa.h",
31                "op_update_quantized_cache.h",
32            ],
33            preprocessor_flags = get_vec_preprocessor_flags(),
34            exported_deps = [
35                "//executorch/runtime/kernel:kernel_includes",
36                "//executorch/kernels/portable/cpu:scalar_utils",
37                "//executorch/kernels/optimized:libblas{}".format(mkl_dep),
38                "//executorch/kernels/optimized:libvec",
39                "//executorch/extension/kernel_util:kernel_util",
40                "//executorch/extension/parallel:thread_parallel",
41                "//executorch/extension/threadpool:threadpool",
42            ],
43            deps = [
44                "//executorch/kernels/portable/cpu/util:reduce_util",
45                "//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
46            ] + get_vec_deps(),
47            compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
48            visibility = [
49                "//executorch/...",
50                "//executorch/extension/llm/custom_ops/...",
51                "@EXECUTORCH_CLIENTS",
52            ],
53            # @lint-ignore BUCKLINT link_whole
54            link_whole = True,
55            force_static = True,
56        )
57
58        runtime.cxx_library(
59            name = "custom_ops_aot_lib" + mkl_dep,
60            srcs = [
61                "op_fast_hadamard_transform_aten.cpp",
62                "op_sdpa_aot.cpp",
63                "op_tile_crop.cpp",
64                "op_tile_crop_aot.cpp",
65            ],
66            headers = ["op_tile_crop.h"],
67            compiler_flags = ["-Wno-global-constructors"],
68            visibility = [
69                "//executorch/...",
70                "@EXECUTORCH_CLIENTS",
71            ],
72            external_deps = [
73                "libtorch",
74            ],
75            deps = [
76                ":custom_ops" + mkl_dep,
77                "//executorch/extension/aten_util:aten_bridge",
78            ],
79        )
80
81    runtime.python_library(
82        name = "custom_ops_aot_py",
83        srcs = [
84            "sdpa_with_kv_cache.py",
85        ],
86        visibility = [
87            "//executorch/...",
88            "@EXECUTORCH_CLIENTS",
89        ],
90        deps = [
91            "//caffe2:torch",
92        ],
93    )
94
95    runtime.cxx_test(
96        name = "op_sdpa_test",
97        srcs = [
98            "op_sdpa_test.cpp",
99        ],
100        visibility = ["//executorch/..."],
101        deps = [
102            "//executorch/runtime/core/exec_aten:lib",
103            "//executorch/runtime/core/exec_aten/testing_util:tensor_util",
104            "//executorch/kernels/test:test_util",
105            ":custom_ops",
106        ],
107    )
108
109    runtime.cxx_test(
110        name = "op_sdpa_with_kv_cache_test",
111        srcs = [
112            "op_sdpa_with_kv_cache_test.cpp",
113        ],
114        visibility = ["//executorch/..."],
115        deps = [
116            "//executorch/runtime/core/exec_aten:lib",
117            "//executorch/runtime/core/exec_aten/testing_util:tensor_util",
118            "//executorch/kernels/test:test_util",
119            ":custom_ops",
120        ],
121    )
122
123    ## For preprocess
124    runtime.python_library(
125        name = "preprocess_custom_ops_py",
126        srcs = [
127            "preprocess_custom_ops.py",
128        ],
129        visibility = [
130            "//executorch/...",
131            "@EXECUTORCH_CLIENTS",
132        ],
133        deps = [
134            "//caffe2:torch",
135        ],
136    )
137
138    runtime.python_library(
139        name = "model_sharding_py",
140        srcs = [
141            "model_sharding.py",
142        ],
143        visibility = [
144            "//executorch/...",
145            "@EXECUTORCH_CLIENTS",
146        ],
147        deps = [
148            "//caffe2:torch",
149        ],
150    )
151
152    runtime.cxx_library(
153        name = "op_tile_crop",
154        srcs = ["op_tile_crop.cpp"],
155        exported_headers = ["op_tile_crop.h"],
156        exported_deps = [
157            "//executorch/runtime/kernel:kernel_includes",
158            "//executorch/extension/kernel_util:kernel_util",
159        ],
160        compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
161        visibility = [
162            "//executorch/...",
163            "@EXECUTORCH_CLIENTS",
164        ],
165        # @lint-ignore BUCKLINT link_whole
166        link_whole = True,
167        force_static = True,
168    )
169
170    runtime.cxx_test(
171        name = "op_tile_crop_test",
172        srcs = [
173            "op_tile_crop_test.cpp",
174        ],
175        visibility = ["//executorch/..."],
176        deps = [
177            "//executorch/runtime/core/exec_aten:lib",
178            "//executorch/runtime/core/exec_aten/testing_util:tensor_util",
179            "//executorch/kernels/test:test_util",
180            ":op_tile_crop",
181        ],
182    )
183