xref: /aosp_15_r20/external/pytorch/aten.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1load("@bazel_skylib//lib:paths.bzl", "paths")
2load("@rules_cc//cc:defs.bzl", "cc_library")
3
4CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
5CAPABILITY_COMPILER_FLAGS = {
6    "AVX2": ["-mavx2", "-mfma", "-mf16c"],
7    "DEFAULT": [],
8}
9
10PREFIX = "aten/src/ATen/native/"
11EXTRA_PREFIX = "aten/src/ATen/"
12
13def intern_build_aten_ops(copts, deps, extra_impls):
14    for cpu_capability in CPU_CAPABILITY_NAMES:
15        srcs = []
16        for impl in native.glob(
17            [
18                PREFIX + "cpu/*.cpp",
19                PREFIX + "quantized/cpu/kernels/*.cpp",
20            ],
21        ):
22            name = impl.replace(PREFIX, "")
23            out = PREFIX + name + "." + cpu_capability + ".cpp"
24            native.genrule(
25                name = name + "_" + cpu_capability + "_cp",
26                srcs = [impl],
27                outs = [out],
28                cmd = "cp $< $@",
29            )
30            srcs.append(out)
31
32        for impl in extra_impls:
33            name = impl.replace(EXTRA_PREFIX, "")
34            out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp"
35            native.genrule(
36                name = name + "_" + cpu_capability + "_cp",
37                srcs = [impl],
38                outs = [out],
39                cmd = "cp $< $@",
40            )
41            srcs.append(out)
42
43        cc_library(
44            name = "ATen_CPU_" + cpu_capability,
45            srcs = srcs,
46            copts = copts + [
47                "-DCPU_CAPABILITY=" + cpu_capability,
48                "-DCPU_CAPABILITY_" + cpu_capability,
49            ] + CAPABILITY_COMPILER_FLAGS[cpu_capability],
50            deps = deps,
51            linkstatic = 1,
52        )
53    cc_library(
54        name = "ATen_CPU",
55        deps = [":ATen_CPU_" + cpu_capability for cpu_capability in CPU_CAPABILITY_NAMES],
56        linkstatic = 1,
57    )
58
59def generate_aten_impl(ctx):
60    # Declare the entire ATen/ops/ directory as an output
61    ops_dir = ctx.actions.declare_directory("aten/src/ATen/ops")
62    outputs = [ops_dir] + ctx.outputs.outs
63
64    install_dir = paths.dirname(ops_dir.path)
65    ctx.actions.run(
66        outputs = outputs,
67        inputs = ctx.files.srcs,
68        executable = ctx.executable.generator,
69        arguments = [
70            "--source-path",
71            "aten/src/ATen",
72            "--per-operator-headers",
73            "--install_dir",
74            install_dir,
75        ],
76        use_default_shell_env = True,
77        mnemonic = "GenerateAten",
78    )
79    return [DefaultInfo(files = depset(outputs))]
80
81generate_aten = rule(
82    implementation = generate_aten_impl,
83    attrs = {
84        "generator": attr.label(
85            executable = True,
86            allow_files = True,
87            mandatory = True,
88            cfg = "exec",
89        ),
90        "outs": attr.output_list(),
91        "srcs": attr.label_list(allow_files = True),
92    },
93)
94