1load("@bazel_skylib//lib:paths.bzl", "paths") 2load("@rules_cc//cc:defs.bzl", "cc_library") 3 4CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"] 5CAPABILITY_COMPILER_FLAGS = { 6 "AVX2": ["-mavx2", "-mfma", "-mf16c"], 7 "DEFAULT": [], 8} 9 10PREFIX = "aten/src/ATen/native/" 11EXTRA_PREFIX = "aten/src/ATen/" 12 13def intern_build_aten_ops(copts, deps, extra_impls): 14 for cpu_capability in CPU_CAPABILITY_NAMES: 15 srcs = [] 16 for impl in native.glob( 17 [ 18 PREFIX + "cpu/*.cpp", 19 PREFIX + "quantized/cpu/kernels/*.cpp", 20 ], 21 ): 22 name = impl.replace(PREFIX, "") 23 out = PREFIX + name + "." + cpu_capability + ".cpp" 24 native.genrule( 25 name = name + "_" + cpu_capability + "_cp", 26 srcs = [impl], 27 outs = [out], 28 cmd = "cp $< $@", 29 ) 30 srcs.append(out) 31 32 for impl in extra_impls: 33 name = impl.replace(EXTRA_PREFIX, "") 34 out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp" 35 native.genrule( 36 name = name + "_" + cpu_capability + "_cp", 37 srcs = [impl], 38 outs = [out], 39 cmd = "cp $< $@", 40 ) 41 srcs.append(out) 42 43 cc_library( 44 name = "ATen_CPU_" + cpu_capability, 45 srcs = srcs, 46 copts = copts + [ 47 "-DCPU_CAPABILITY=" + cpu_capability, 48 "-DCPU_CAPABILITY_" + cpu_capability, 49 ] + CAPABILITY_COMPILER_FLAGS[cpu_capability], 50 deps = deps, 51 linkstatic = 1, 52 ) 53 cc_library( 54 name = "ATen_CPU", 55 deps = [":ATen_CPU_" + cpu_capability for cpu_capability in CPU_CAPABILITY_NAMES], 56 linkstatic = 1, 57 ) 58 59def generate_aten_impl(ctx): 60 # Declare the entire ATen/ops/ directory as an output 61 ops_dir = ctx.actions.declare_directory("aten/src/ATen/ops") 62 outputs = [ops_dir] + ctx.outputs.outs 63 64 install_dir = paths.dirname(ops_dir.path) 65 ctx.actions.run( 66 outputs = outputs, 67 inputs = ctx.files.srcs, 68 executable = ctx.executable.generator, 69 arguments = [ 70 "--source-path", 71 "aten/src/ATen", 72 "--per-operator-headers", 73 "--install_dir", 74 install_dir, 75 ], 76 use_default_shell_env = True, 77 mnemonic = "GenerateAten", 78 ) 79 return [DefaultInfo(files = depset(outputs))] 80 81generate_aten = rule( 82 implementation = generate_aten_impl, 83 attrs = { 84 "generator": attr.label( 85 executable = True, 86 allow_files = True, 87 mandatory = True, 88 cfg = "exec", 89 ), 90 "outs": attr.output_list(), 91 "srcs": attr.label_list(allow_files = True), 92 }, 93) 94