xref: /aosp_15_r20/external/executorch/shim/xplat/executorch/kernels/test/util.bzl (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "is_xplat", "runtime")
2
3def op_test(name, deps = [], kernel_name = "portable", use_kernel_prefix = False):
4    """Defines a cxx_test() for an "op_*_test.cpp" file.
5
6    Args:
7        name: "op_<operator-group-name>_test"; e.g., "op_add_test". Must match
8            the non-extension part of the test source file (e.g.,
9            "op_add_test.cpp"). This name must also agree with the target names
10            under //kernels/<kernel>/...; e.g., "op_add_test" will depend on
11            "//kernels/portable/cpu:op_add".
12        deps: Optional extra deps to add to the cxx_test().
13        kernel_name: The name string as in //executorch/kernels/<kernel_name>.
14        use_kernel_prefix: If True, the target name is
15            <kernel>_op_<operator-group-name>_test. Used by common kernel testing.
16    """
17    if not (name.startswith("op_") and name.endswith("_test")):
18        fail("'{}' must match the pattern 'op_*_test'")
19    op_root = name[:-len("_test")]  # E.g., "op_add" if name is "op_add_test".
20
21    if kernel_name == "aten":
22        generated_lib_and_op_deps = [
23            "//executorch/kernels/aten:generated_lib",
24            #TODO(T187390274): consolidate all aten ops into one target
25            "//executorch/kernels/aten/cpu:op__to_dim_order_copy_aten",
26            "//executorch/kernels/aten:generated_lib_headers",
27            "//executorch/kernels/test:supported_features_aten",
28        ]
29    else:
30        generated_lib_and_op_deps = [
31            "//executorch/kernels/{}/cpu:{}".format(kernel_name, op_root),
32            "//executorch/kernels/{}:generated_lib_headers".format(kernel_name),
33            "//executorch/kernels/{}/test:supported_features".format(kernel_name),
34        ]
35
36    name_prefix = ""
37    aten_suffix = ""
38    if kernel_name == "aten":
39        # For aten kernel, we need to use aten specific utils and types
40        name_prefix = "aten_"
41        aten_suffix = "_aten"
42    elif use_kernel_prefix:
43        name_prefix = kernel_name + "_"
44    runtime.cxx_test(
45        name = name_prefix + name,
46        srcs = [
47            "{}.cpp".format(name),
48        ],
49        visibility = ["//executorch/kernels/..."],
50        deps = [
51            "//executorch/runtime/core/exec_aten:lib" + aten_suffix,
52            "//executorch/runtime/core/exec_aten/testing_util:tensor_util" + aten_suffix,
53            "//executorch/runtime/kernel:kernel_includes" + aten_suffix,
54            "//executorch/kernels/test:test_util" + aten_suffix,
55        ] + generated_lib_and_op_deps + deps,
56    )
57
58def generated_op_test(name, op_impl_target, generated_lib_headers_target, supported_features_target, function_header_wrapper_target, deps = []):
59    """
60    Build rule for testing an aten compliant op from an external kernel
61    (outside of executorch/) and re-use test cases here, so we can compare
62    between the external kernel and portable.
63
64    Args:
65        name: "op_<operator-group-name>_test"; e.g., "op_add_test".
66        mandatory dependency targets:
67              - op_impl_target (e.g. executorch/kernels/portable/cpu:op_add)
68                required for testing the kernel impl
69              - generated_lib_headers_target (e.g. executorch/kernels/portable:generated_lib_headers)
70                required for dispatching op to the specific kernel
71              - supported_features_target (e.g. executorch/kernels/portable/test:supported_features)
72                required so we know which features that kernel support, and bypass unsupported tests
73              - function_header_wrapper_target (e.g. executorch/kernels/portable/test:function_header_wrapper_portable)
74                required so we can include a header wrapper for Functions.h. Use codegen_function_header_wrapper() to generate.
75        deps: additional deps
76    """
77    runtime.cxx_test(
78        name = name,
79        srcs = [
80            "fbsource//xplat/executorch/kernels/test:test_srcs_gen[{}.cpp]".format(name),
81        ] if is_xplat() else [
82            "//executorch/kernels/test:test_srcs_gen[{}.cpp]".format(name),
83        ],
84        deps = [
85            "//executorch/runtime/core/exec_aten:lib",
86            "//executorch/runtime/core/exec_aten/testing_util:tensor_util",
87            "//executorch/runtime/kernel:kernel_includes",
88            "//executorch/kernels/test:test_util",
89            op_impl_target,
90            generated_lib_headers_target,
91            supported_features_target,
92            function_header_wrapper_target,
93        ] + deps,
94    )
95
96def define_supported_features_lib():
97    runtime.genrule(
98        name = "supported_feature_gen",
99        cmd = "$(exe //executorch/kernels/test:gen_supported_features) ${SRCS} > $OUT/supported_features.cpp",
100        srcs = ["supported_features_def.yaml"],
101        outs = {"supported_features.cpp": ["supported_features.cpp"]},
102        default_outs = ["."],
103    )
104
105    runtime.cxx_library(
106        name = "supported_features",
107        srcs = [":supported_feature_gen[supported_features.cpp]"],
108        visibility = [
109            "//executorch/kernels/...",
110        ],
111        exported_deps = [
112            "//executorch/kernels/test:supported_features_header",
113        ],
114    )
115
116def codegen_function_header_wrapper(kernel_path, kernel_name):
117    """Produces a file (FunctionHeaderWrapper.h) which simply includes the real
118    Functions.h for the specified kernel.
119
120    Generate the wrapper for each kernel (except aten where we can use portable).
121    Use target "function_header_wrapper_<kernel_name>" in tests.
122
123    For ATen kernel, use portable as we use its functions.yaml
124    """
125    header = "\"#include <{}/Functions.h>\"".format(kernel_path)
126
127    runtime.genrule(
128        name = "gen_function_header_wrapper_{}".format(kernel_name),
129        cmd = "echo " + header + " > $OUT/FunctionHeaderWrapper.h",
130        outs = {"FunctionHeaderWrapper.h": ["FunctionHeaderWrapper.h"]},
131        default_outs = ["."],
132    )
133
134    runtime.cxx_library(
135        name = "function_header_wrapper_{}".format(kernel_name),
136        exported_headers = {
137            "FunctionHeaderWrapper.h": ":gen_function_header_wrapper_{}[FunctionHeaderWrapper.h]".format(kernel_name),
138        },
139        # TODO(T149423767): So far we have to expose this to users. Ideally this part can also be codegen.
140        _is_external_target = True,
141        visibility = ["//executorch/...", "//pye/..."],
142    )
143