xref: /aosp_15_r20/external/pytorch/BUILD.bazel (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerload("@bazel_skylib//lib:paths.bzl", "paths")
2*da0073e9SAndroid Build Coastguard Workerload("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
3*da0073e9SAndroid Build Coastguard Workerload("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
4*da0073e9SAndroid Build Coastguard Workerload("@rules_python//python:defs.bzl", "py_library", "py_test")
5*da0073e9SAndroid Build Coastguard Workerload("@pytorch//third_party:substitution.bzl", "header_template_rule", "template_rule")
6*da0073e9SAndroid Build Coastguard Workerload("@pytorch//:tools/bazel.bzl", "rules")
7*da0073e9SAndroid Build Coastguard Workerload("@pytorch//tools/rules:cu.bzl", "cu_library")
8*da0073e9SAndroid Build Coastguard Workerload("@pytorch//tools/config:defs.bzl", "if_cuda")
9*da0073e9SAndroid Build Coastguard Workerload("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops")
10*da0073e9SAndroid Build Coastguard Workerload(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets")
11*da0073e9SAndroid Build Coastguard Workerload(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources")
12*da0073e9SAndroid Build Coastguard Workerload(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources")
13*da0073e9SAndroid Build Coastguard Workerload("//:tools/bazel.bzl", "rules")
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdefine_targets(rules = rules)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard WorkerCOMMON_COPTS = [
18*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_MALLOC_USABLE_SIZE=1",
19*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_MMAP=1",
20*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_SHM_OPEN=1",
21*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_SHM_UNLINK=1",
22*da0073e9SAndroid Build Coastguard Worker    "-D_FILE_OFFSET_BITS=64",
23*da0073e9SAndroid Build Coastguard Worker    "-DUSE_FBGEMM",
24*da0073e9SAndroid Build Coastguard Worker    "-DUSE_DISTRIBUTED",
25*da0073e9SAndroid Build Coastguard Worker    "-DAT_PER_OPERATOR_HEADERS",
26*da0073e9SAndroid Build Coastguard Worker    "-DATEN_THREADING=NATIVE",
27*da0073e9SAndroid Build Coastguard Worker    "-DNO_CUDNN_DESTROY_HANDLE",
28*da0073e9SAndroid Build Coastguard Worker] + if_cuda([
29*da0073e9SAndroid Build Coastguard Worker    "-DUSE_CUDA",
30*da0073e9SAndroid Build Coastguard Worker    "-DUSE_CUDNN",
31*da0073e9SAndroid Build Coastguard Worker    # TODO: This should be passed only when building for CUDA-11.5 or newer
32*da0073e9SAndroid Build Coastguard Worker    # use cub in a safe manner, see:
33*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/pull/55292
34*da0073e9SAndroid Build Coastguard Worker    "-DCUB_WRAPPED_NAMESPACE=at_cuda_detail",
35*da0073e9SAndroid Build Coastguard Worker])
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workeraten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/src/ATen/native/tags.yaml"] + glob(["aten/src/ATen/templates/**"])
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workergenerated_cpu_cpp = [
40*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterBackendSelect.cpp",
41*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterCPU.cpp",
42*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterFunctionalization_0.cpp",
43*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterFunctionalization_1.cpp",
44*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterFunctionalization_2.cpp",
45*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterFunctionalization_3.cpp",
46*da0073e9SAndroid Build Coastguard Worker    # "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
47*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterMkldnnCPU.cpp",
48*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterNestedTensorCPU.cpp",
49*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterQuantizedCPU.cpp",
50*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterSparseCPU.cpp",
51*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterSparseCsrCPU.cpp",
52*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterZeroTensor.cpp",
53*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
54*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
55*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
56*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
57*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterMeta.cpp",
58*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterSparseMeta.cpp",
59*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterQuantizedMeta.cpp",
60*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterNestedTensorMeta.cpp",
61*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterSchema.cpp",
62*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CPUFunctions.h",
63*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CPUFunctions_inl.h",
64*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeExplicitAutogradFunctions.h",
65*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
66*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
67*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
68*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeImplicitAutogradFunctions.h",
69*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
70*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
71*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
72*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CompositeViewCopyKernels.cpp",
73*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/FunctionalInverses.h",
74*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Functions.h",
75*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Functions.cpp",
76*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RedispatchFunctions.h",
77*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Operators.h",
78*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Operators_0.cpp",
79*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Operators_1.cpp",
80*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Operators_2.cpp",
81*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Operators_3.cpp",
82*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/Operators_4.cpp",
83*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/NativeFunctions.h",
84*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/MetaFunctions.h",
85*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/MetaFunctions_inl.h",
86*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/MethodOperators.h",
87*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/NativeMetaFunctions.h",
88*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegistrationDeclarations.h",
89*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/VmapGeneratedPlumbing.h",
90*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/core/aten_interned_strings.h",
91*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/core/enum_tag.h",
92*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/core/TensorBody.h",
93*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/core/TensorMethods.cpp",
94*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/core/ATenOpList.cpp",
95*da0073e9SAndroid Build Coastguard Worker]
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workergenerated_cuda_cpp = [
98*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CUDAFunctions.h",
99*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/CUDAFunctions_inl.h",
100*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterCUDA.cpp",
101*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterNestedTensorCUDA.cpp",
102*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterQuantizedCUDA.cpp",
103*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterSparseCUDA.cpp",
104*da0073e9SAndroid Build Coastguard Worker    "aten/src/ATen/RegisterSparseCsrCUDA.cpp",
105*da0073e9SAndroid Build Coastguard Worker]
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Workergenerate_aten(
108*da0073e9SAndroid Build Coastguard Worker    name = "generated_aten_cpp",
109*da0073e9SAndroid Build Coastguard Worker    srcs = aten_generation_srcs,
110*da0073e9SAndroid Build Coastguard Worker    outs = (
111*da0073e9SAndroid Build Coastguard Worker        generated_cpu_cpp +
112*da0073e9SAndroid Build Coastguard Worker        generated_cuda_cpp +
113*da0073e9SAndroid Build Coastguard Worker        aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") +
114*da0073e9SAndroid Build Coastguard Worker        aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") +
115*da0073e9SAndroid Build Coastguard Worker        aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [
116*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/Declarations.yaml",
117*da0073e9SAndroid Build Coastguard Worker        ]
118*da0073e9SAndroid Build Coastguard Worker    ),
119*da0073e9SAndroid Build Coastguard Worker    generator = "//torchgen:gen",
120*da0073e9SAndroid Build Coastguard Worker)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Workerfilegroup(
123*da0073e9SAndroid Build Coastguard Worker    name = "cpp_generated_code",
124*da0073e9SAndroid Build Coastguard Worker    srcs = GENERATED_AUTOGRAD_CPP,
125*da0073e9SAndroid Build Coastguard Worker    data = [":generate-code"],
126*da0073e9SAndroid Build Coastguard Worker)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker# ATen
129*da0073e9SAndroid Build Coastguard Workerfilegroup(
130*da0073e9SAndroid Build Coastguard Worker    name = "aten_base_cpp",
131*da0073e9SAndroid Build Coastguard Worker    srcs = glob([
132*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/*.cpp",
133*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/functorch/*.cpp",
134*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/detail/*.cpp",
135*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/cpu/*.cpp",
136*da0073e9SAndroid Build Coastguard Worker    ]),
137*da0073e9SAndroid Build Coastguard Worker)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Workerfilegroup(
140*da0073e9SAndroid Build Coastguard Worker    name = "ATen_CORE_SRCS",
141*da0073e9SAndroid Build Coastguard Worker    srcs = glob(
142*da0073e9SAndroid Build Coastguard Worker        [
143*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/core/**/*.cpp",
144*da0073e9SAndroid Build Coastguard Worker        ],
145*da0073e9SAndroid Build Coastguard Worker        exclude = [
146*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/core/**/*_test.cpp",
147*da0073e9SAndroid Build Coastguard Worker        ],
148*da0073e9SAndroid Build Coastguard Worker    ),
149*da0073e9SAndroid Build Coastguard Worker)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Workerfilegroup(
152*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_cpp",
153*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/native/*.cpp"]),
154*da0073e9SAndroid Build Coastguard Worker)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Workerfilegroup(
157*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_sparse_cpp",
158*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/native/sparse/*.cpp"]),
159*da0073e9SAndroid Build Coastguard Worker)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Workerfilegroup(
162*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_nested_cpp",
163*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/native/nested/*.cpp"]),
164*da0073e9SAndroid Build Coastguard Worker)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Workerfilegroup(
167*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_quantized_cpp",
168*da0073e9SAndroid Build Coastguard Worker    srcs = glob(
169*da0073e9SAndroid Build Coastguard Worker        [
170*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/quantized/*.cpp",
171*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/quantized/cpu/*.cpp",
172*da0073e9SAndroid Build Coastguard Worker        ],
173*da0073e9SAndroid Build Coastguard Worker    ),
174*da0073e9SAndroid Build Coastguard Worker)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Workerfilegroup(
177*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_transformers_cpp",
178*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]),
179*da0073e9SAndroid Build Coastguard Worker)
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Workerfilegroup(
182*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_mkl_cpp",
183*da0073e9SAndroid Build Coastguard Worker    srcs = glob([
184*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/mkl/*.cpp",
185*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/mkl/*.cpp",
186*da0073e9SAndroid Build Coastguard Worker    ]),
187*da0073e9SAndroid Build Coastguard Worker)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Workerfilegroup(
190*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_mkldnn_cpp",
191*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]),
192*da0073e9SAndroid Build Coastguard Worker)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Workerfilegroup(
195*da0073e9SAndroid Build Coastguard Worker    name = "aten_native_xnnpack",
196*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]),
197*da0073e9SAndroid Build Coastguard Worker)
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Workerfilegroup(
200*da0073e9SAndroid Build Coastguard Worker    name = "aten_base_vulkan",
201*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
202*da0073e9SAndroid Build Coastguard Worker)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Workerfilegroup(
205*da0073e9SAndroid Build Coastguard Worker    name = "aten_base_metal",
206*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["aten/src/ATen/metal/*.cpp"]),
207*da0073e9SAndroid Build Coastguard Worker)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Workerfilegroup(
210*da0073e9SAndroid Build Coastguard Worker    name = "ATen_QUANTIZED_SRCS",
211*da0073e9SAndroid Build Coastguard Worker    srcs = glob(
212*da0073e9SAndroid Build Coastguard Worker        [
213*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/quantized/**/*.cpp",
214*da0073e9SAndroid Build Coastguard Worker        ],
215*da0073e9SAndroid Build Coastguard Worker        exclude = [
216*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/quantized/**/*_test.cpp",
217*da0073e9SAndroid Build Coastguard Worker        ],
218*da0073e9SAndroid Build Coastguard Worker    ),
219*da0073e9SAndroid Build Coastguard Worker)
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Workerfilegroup(
222*da0073e9SAndroid Build Coastguard Worker    name = "aten_cuda_cpp_srcs",
223*da0073e9SAndroid Build Coastguard Worker    srcs = glob(
224*da0073e9SAndroid Build Coastguard Worker        [
225*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/cuda/*.cpp",
226*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/cuda/detail/*.cpp",
227*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/cuda/tunable/*.cpp",
228*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/cudnn/*.cpp",
229*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/cuda/*.cpp",
230*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/cuda/linalg/*.cpp",
231*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/cudnn/*.cpp",
232*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/miopen/*.cpp",
233*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/nested/cuda/*.cpp",
234*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/quantized/cuda/*.cpp",
235*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/quantized/cudnn/*.cpp",
236*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/sparse/cuda/*.cpp",
237*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/transformers/cuda/*.cpp",
238*da0073e9SAndroid Build Coastguard Worker        ],
239*da0073e9SAndroid Build Coastguard Worker    ),
240*da0073e9SAndroid Build Coastguard Worker)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Workerfilegroup(
243*da0073e9SAndroid Build Coastguard Worker    name = "aten_cu_srcs",
244*da0073e9SAndroid Build Coastguard Worker    srcs = glob([
245*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/cuda/*.cu",
246*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/cuda/detail/*.cu",
247*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/cuda/*.cu",
248*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/nested/cuda/*.cu",
249*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/quantized/cuda/*.cu",
250*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/sparse/cuda/*.cu",
251*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/transformers/cuda/*.cu",
252*da0073e9SAndroid Build Coastguard Worker    ]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
253*da0073e9SAndroid Build Coastguard Worker    # It's a bit puzzling to me why it's not necessary to declare the
254*da0073e9SAndroid Build Coastguard Worker    # target that generates these sources...
255*da0073e9SAndroid Build Coastguard Worker)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Workerheader_template_rule(
258*da0073e9SAndroid Build Coastguard Worker    name = "aten_src_ATen_config",
259*da0073e9SAndroid Build Coastguard Worker    src = "aten/src/ATen/Config.h.in",
260*da0073e9SAndroid Build Coastguard Worker    out = "aten/src/ATen/Config.h",
261*da0073e9SAndroid Build Coastguard Worker    include = "aten/src",
262*da0073e9SAndroid Build Coastguard Worker    substitutions = {
263*da0073e9SAndroid Build Coastguard Worker        "@AT_MKLDNN_ENABLED@": "1",
264*da0073e9SAndroid Build Coastguard Worker        "@AT_MKLDNN_ACL_ENABLED@": "0",
265*da0073e9SAndroid Build Coastguard Worker        "@AT_MKL_ENABLED@": "1",
266*da0073e9SAndroid Build Coastguard Worker        "@AT_MKL_SEQUENTIAL@": "0",
267*da0073e9SAndroid Build Coastguard Worker        "@AT_POCKETFFT_ENABLED@": "0",
268*da0073e9SAndroid Build Coastguard Worker        "@AT_NNPACK_ENABLED@": "0",
269*da0073e9SAndroid Build Coastguard Worker        "@CAFFE2_STATIC_LINK_CUDA_INT@": "0",
270*da0073e9SAndroid Build Coastguard Worker        "@AT_BUILD_WITH_BLAS@": "1",
271*da0073e9SAndroid Build Coastguard Worker        "@AT_BUILD_WITH_LAPACK@": "1",
272*da0073e9SAndroid Build Coastguard Worker        "@AT_PARALLEL_OPENMP@": "0",
273*da0073e9SAndroid Build Coastguard Worker        "@AT_PARALLEL_NATIVE@": "1",
274*da0073e9SAndroid Build Coastguard Worker        "@AT_BLAS_F2C@": "0",
275*da0073e9SAndroid Build Coastguard Worker        "@AT_BLAS_USE_CBLAS_DOT@": "1",
276*da0073e9SAndroid Build Coastguard Worker    },
277*da0073e9SAndroid Build Coastguard Worker)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Workerheader_template_rule(
280*da0073e9SAndroid Build Coastguard Worker    name = "aten_src_ATen_cuda_config",
281*da0073e9SAndroid Build Coastguard Worker    src = "aten/src/ATen/cuda/CUDAConfig.h.in",
282*da0073e9SAndroid Build Coastguard Worker    out = "aten/src/ATen/cuda/CUDAConfig.h",
283*da0073e9SAndroid Build Coastguard Worker    include = "aten/src",
284*da0073e9SAndroid Build Coastguard Worker    substitutions = {
285*da0073e9SAndroid Build Coastguard Worker        "@AT_CUDNN_ENABLED@": "1",
286*da0073e9SAndroid Build Coastguard Worker        "@AT_CUSPARSELT_ENABLED@": "0",
287*da0073e9SAndroid Build Coastguard Worker        "@AT_ROCM_ENABLED@": "0",
288*da0073e9SAndroid Build Coastguard Worker        "@AT_MAGMA_ENABLED@": "0",
289*da0073e9SAndroid Build Coastguard Worker        "@NVCC_FLAGS_EXTRA@": "",
290*da0073e9SAndroid Build Coastguard Worker    },
291*da0073e9SAndroid Build Coastguard Worker)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Workercc_library(
294*da0073e9SAndroid Build Coastguard Worker    name = "aten_headers",
295*da0073e9SAndroid Build Coastguard Worker    hdrs = [
296*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/Export.h",
297*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/jit/frontend/function_schema_parser.h",
298*da0073e9SAndroid Build Coastguard Worker    ] + glob(
299*da0073e9SAndroid Build Coastguard Worker        [
300*da0073e9SAndroid Build Coastguard Worker            "aten/src/**/*.h",
301*da0073e9SAndroid Build Coastguard Worker            "aten/src/**/*.hpp",
302*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/cuda/**/*.cuh",
303*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/**/*.cuh",
304*da0073e9SAndroid Build Coastguard Worker            "aten/src/THC/*.cuh",
305*da0073e9SAndroid Build Coastguard Worker        ],
306*da0073e9SAndroid Build Coastguard Worker    ) + [
307*da0073e9SAndroid Build Coastguard Worker        ":aten_src_ATen_config",
308*da0073e9SAndroid Build Coastguard Worker        ":generated_aten_cpp",
309*da0073e9SAndroid Build Coastguard Worker    ],
310*da0073e9SAndroid Build Coastguard Worker    includes = [
311*da0073e9SAndroid Build Coastguard Worker        "aten/src",
312*da0073e9SAndroid Build Coastguard Worker    ],
313*da0073e9SAndroid Build Coastguard Worker    deps = [
314*da0073e9SAndroid Build Coastguard Worker        "//c10",
315*da0073e9SAndroid Build Coastguard Worker    ],
316*da0073e9SAndroid Build Coastguard Worker)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard WorkerATEN_COPTS = COMMON_COPTS + [
319*da0073e9SAndroid Build Coastguard Worker    "-DCAFFE2_BUILD_MAIN_LIBS",
320*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_AVX_CPU_DEFINITION",
321*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_AVX2_CPU_DEFINITION",
322*da0073e9SAndroid Build Coastguard Worker    "-fvisibility-inlines-hidden",
323*da0073e9SAndroid Build Coastguard Worker    "-fno-math-errno",
324*da0073e9SAndroid Build Coastguard Worker    "-fno-trapping-math",
325*da0073e9SAndroid Build Coastguard Worker]
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Workerintern_build_aten_ops(
328*da0073e9SAndroid Build Coastguard Worker    copts = ATEN_COPTS,
329*da0073e9SAndroid Build Coastguard Worker    extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"),
330*da0073e9SAndroid Build Coastguard Worker    deps = [
331*da0073e9SAndroid Build Coastguard Worker        ":aten_headers",
332*da0073e9SAndroid Build Coastguard Worker        "@fbgemm",
333*da0073e9SAndroid Build Coastguard Worker        "@mkl",
334*da0073e9SAndroid Build Coastguard Worker        "@sleef",
335*da0073e9SAndroid Build Coastguard Worker        "@mkl_dnn//:mkl-dnn",
336*da0073e9SAndroid Build Coastguard Worker    ],
337*da0073e9SAndroid Build Coastguard Worker)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Workercc_library(
340*da0073e9SAndroid Build Coastguard Worker    name = "aten",
341*da0073e9SAndroid Build Coastguard Worker    srcs = [
342*da0073e9SAndroid Build Coastguard Worker        ":ATen_CORE_SRCS",
343*da0073e9SAndroid Build Coastguard Worker        ":ATen_QUANTIZED_SRCS",
344*da0073e9SAndroid Build Coastguard Worker        ":aten_base_cpp",
345*da0073e9SAndroid Build Coastguard Worker        ":aten_base_metal",
346*da0073e9SAndroid Build Coastguard Worker        ":aten_base_vulkan",
347*da0073e9SAndroid Build Coastguard Worker        ":aten_native_cpp",
348*da0073e9SAndroid Build Coastguard Worker        ":aten_native_mkl_cpp",
349*da0073e9SAndroid Build Coastguard Worker        ":aten_native_mkldnn_cpp",
350*da0073e9SAndroid Build Coastguard Worker        ":aten_native_nested_cpp",
351*da0073e9SAndroid Build Coastguard Worker        ":aten_native_quantized_cpp",
352*da0073e9SAndroid Build Coastguard Worker        ":aten_native_sparse_cpp",
353*da0073e9SAndroid Build Coastguard Worker        ":aten_native_transformers_cpp",
354*da0073e9SAndroid Build Coastguard Worker        ":aten_native_xnnpack",
355*da0073e9SAndroid Build Coastguard Worker        ":aten_src_ATen_config",
356*da0073e9SAndroid Build Coastguard Worker    ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
357*da0073e9SAndroid Build Coastguard Worker    copts = ATEN_COPTS,
358*da0073e9SAndroid Build Coastguard Worker    linkopts = [
359*da0073e9SAndroid Build Coastguard Worker      "-ldl",
360*da0073e9SAndroid Build Coastguard Worker    ],
361*da0073e9SAndroid Build Coastguard Worker    data = if_cuda(
362*da0073e9SAndroid Build Coastguard Worker        [":libcaffe2_nvrtc.so"],
363*da0073e9SAndroid Build Coastguard Worker        [],
364*da0073e9SAndroid Build Coastguard Worker    ),
365*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
366*da0073e9SAndroid Build Coastguard Worker    deps = [
367*da0073e9SAndroid Build Coastguard Worker        ":ATen_CPU",
368*da0073e9SAndroid Build Coastguard Worker        ":aten_headers",
369*da0073e9SAndroid Build Coastguard Worker        ":caffe2_for_aten_headers",
370*da0073e9SAndroid Build Coastguard Worker        ":torch_headers",
371*da0073e9SAndroid Build Coastguard Worker        "@fbgemm",
372*da0073e9SAndroid Build Coastguard Worker        "@ideep",
373*da0073e9SAndroid Build Coastguard Worker    ],
374*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
375*da0073e9SAndroid Build Coastguard Worker)
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Workercc_library(
378*da0073e9SAndroid Build Coastguard Worker    name = "aten_nvrtc",
379*da0073e9SAndroid Build Coastguard Worker    srcs = glob([
380*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/cuda/nvrtc_stub/*.cpp",
381*da0073e9SAndroid Build Coastguard Worker    ]),
382*da0073e9SAndroid Build Coastguard Worker    copts = ATEN_COPTS,
383*da0073e9SAndroid Build Coastguard Worker    linkstatic = True,
384*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
385*da0073e9SAndroid Build Coastguard Worker    deps = [
386*da0073e9SAndroid Build Coastguard Worker        ":aten_headers",
387*da0073e9SAndroid Build Coastguard Worker        "//c10",
388*da0073e9SAndroid Build Coastguard Worker        "@cuda",
389*da0073e9SAndroid Build Coastguard Worker        "@cuda//:cuda_driver",
390*da0073e9SAndroid Build Coastguard Worker        "@cuda//:nvrtc",
391*da0073e9SAndroid Build Coastguard Worker    ],
392*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
393*da0073e9SAndroid Build Coastguard Worker)
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Workercc_binary(
396*da0073e9SAndroid Build Coastguard Worker    name = "libcaffe2_nvrtc.so",
397*da0073e9SAndroid Build Coastguard Worker    linkshared = True,
398*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
399*da0073e9SAndroid Build Coastguard Worker    deps = [
400*da0073e9SAndroid Build Coastguard Worker        ":aten_nvrtc",
401*da0073e9SAndroid Build Coastguard Worker    ],
402*da0073e9SAndroid Build Coastguard Worker)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Workercc_library(
405*da0073e9SAndroid Build Coastguard Worker    name = "aten_cuda_cpp",
406*da0073e9SAndroid Build Coastguard Worker    srcs = [":aten_cuda_cpp_srcs"] + generated_cuda_cpp,
407*da0073e9SAndroid Build Coastguard Worker    hdrs = [":aten_src_ATen_cuda_config"],
408*da0073e9SAndroid Build Coastguard Worker    copts = ATEN_COPTS,
409*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
410*da0073e9SAndroid Build Coastguard Worker    deps = [
411*da0073e9SAndroid Build Coastguard Worker        ":aten",
412*da0073e9SAndroid Build Coastguard Worker        "@cuda",
413*da0073e9SAndroid Build Coastguard Worker        "@cuda//:cusolver",
414*da0073e9SAndroid Build Coastguard Worker        "@cuda//:nvrtc",
415*da0073e9SAndroid Build Coastguard Worker        "@cudnn",
416*da0073e9SAndroid Build Coastguard Worker        "@cudnn_frontend",
417*da0073e9SAndroid Build Coastguard Worker    ],
418*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
419*da0073e9SAndroid Build Coastguard Worker)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Workertorch_cuda_half_options = [
422*da0073e9SAndroid Build Coastguard Worker    "-DCUDA_HAS_FP16=1",
423*da0073e9SAndroid Build Coastguard Worker    "-D__CUDA_NO_HALF_OPERATORS__",
424*da0073e9SAndroid Build Coastguard Worker    "-D__CUDA_NO_HALF_CONVERSIONS__",
425*da0073e9SAndroid Build Coastguard Worker    "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
426*da0073e9SAndroid Build Coastguard Worker    "-D__CUDA_NO_HALF2_OPERATORS__",
427*da0073e9SAndroid Build Coastguard Worker]
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Workercu_library(
430*da0073e9SAndroid Build Coastguard Worker    name = "aten_cuda",
431*da0073e9SAndroid Build Coastguard Worker    srcs = [":aten_cu_srcs"],
432*da0073e9SAndroid Build Coastguard Worker    copts = ATEN_COPTS + torch_cuda_half_options,
433*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
434*da0073e9SAndroid Build Coastguard Worker    deps = [
435*da0073e9SAndroid Build Coastguard Worker        ":aten_cuda_cpp",
436*da0073e9SAndroid Build Coastguard Worker        "//c10/util:bit_cast",
437*da0073e9SAndroid Build Coastguard Worker        "@cuda//:cublas",
438*da0073e9SAndroid Build Coastguard Worker        "@cuda//:cufft",
439*da0073e9SAndroid Build Coastguard Worker        "@cuda//:cusparse",
440*da0073e9SAndroid Build Coastguard Worker        "@cutlass",
441*da0073e9SAndroid Build Coastguard Worker    ],
442*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
443*da0073e9SAndroid Build Coastguard Worker)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker# caffe2
446*da0073e9SAndroid Build Coastguard WorkerCAFFE2_COPTS = COMMON_COPTS + [
447*da0073e9SAndroid Build Coastguard Worker    "-Dcaffe2_EXPORTS",
448*da0073e9SAndroid Build Coastguard Worker    "-DCAFFE2_USE_CUDNN",
449*da0073e9SAndroid Build Coastguard Worker    "-DCAFFE2_BUILD_MAIN_LIB",
450*da0073e9SAndroid Build Coastguard Worker    "-fvisibility-inlines-hidden",
451*da0073e9SAndroid Build Coastguard Worker    "-fno-math-errno",
452*da0073e9SAndroid Build Coastguard Worker    "-fno-trapping-math",
453*da0073e9SAndroid Build Coastguard Worker]
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Workerfilegroup(
456*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_core_srcs",
457*da0073e9SAndroid Build Coastguard Worker    srcs = [
458*da0073e9SAndroid Build Coastguard Worker        "caffe2/core/common.cc",
459*da0073e9SAndroid Build Coastguard Worker    ],
460*da0073e9SAndroid Build Coastguard Worker)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Workerfilegroup(
463*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_perfkernels_srcs",
464*da0073e9SAndroid Build Coastguard Worker    srcs = [
465*da0073e9SAndroid Build Coastguard Worker        "caffe2/perfkernels/embedding_lookup_idx.cc",
466*da0073e9SAndroid Build Coastguard Worker    ],
467*da0073e9SAndroid Build Coastguard Worker)
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Workerfilegroup(
471*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_serialize_srcs",
472*da0073e9SAndroid Build Coastguard Worker    srcs = [
473*da0073e9SAndroid Build Coastguard Worker        "caffe2/serialize/file_adapter.cc",
474*da0073e9SAndroid Build Coastguard Worker        "caffe2/serialize/inline_container.cc",
475*da0073e9SAndroid Build Coastguard Worker        "caffe2/serialize/istream_adapter.cc",
476*da0073e9SAndroid Build Coastguard Worker        "caffe2/serialize/read_adapter_interface.cc",
477*da0073e9SAndroid Build Coastguard Worker    ],
478*da0073e9SAndroid Build Coastguard Worker)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Workerfilegroup(
481*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_utils_srcs",
482*da0073e9SAndroid Build Coastguard Worker    srcs = [
483*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/proto_wrap.cc",
484*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/string_utils.cc",
485*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/threadpool/ThreadPool.cc",
486*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/threadpool/pthreadpool.cc",
487*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/threadpool/pthreadpool_impl.cc",
488*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/threadpool/thread_pool_guard.cpp",
489*da0073e9SAndroid Build Coastguard Worker    ],
490*da0073e9SAndroid Build Coastguard Worker)
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker# To achieve finer granularity and make debug easier, caffe2 is split into three libraries:
493*da0073e9SAndroid Build Coastguard Worker# ATen, caffe2 and caffe2_for_aten_headers. ATen lib group up source codes under
494*da0073e9SAndroid Build Coastguard Worker# aten/ directory and caffe2 contains most files under `caffe2/` directory. Since the
495*da0073e9SAndroid Build Coastguard Worker# ATen lib and the caffe2 lib would depend on each other, `caffe2_for_aten_headers` is splitted
496*da0073e9SAndroid Build Coastguard Worker# out from `caffe2` to avoid dependency cycle.
497*da0073e9SAndroid Build Coastguard Workercc_library(
498*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_for_aten_headers",
499*da0073e9SAndroid Build Coastguard Worker    hdrs = [
500*da0073e9SAndroid Build Coastguard Worker        "caffe2/core/common.h",
501*da0073e9SAndroid Build Coastguard Worker        "caffe2/perfkernels/common.h",
502*da0073e9SAndroid Build Coastguard Worker        "caffe2/perfkernels/embedding_lookup_idx.h",
503*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/fixed_divisor.h",
504*da0073e9SAndroid Build Coastguard Worker    ] + glob([
505*da0073e9SAndroid Build Coastguard Worker        "caffe2/utils/threadpool/*.h",
506*da0073e9SAndroid Build Coastguard Worker    ]),
507*da0073e9SAndroid Build Coastguard Worker    copts = CAFFE2_COPTS,
508*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
509*da0073e9SAndroid Build Coastguard Worker    deps = [
510*da0073e9SAndroid Build Coastguard Worker        ":caffe2_core_macros",
511*da0073e9SAndroid Build Coastguard Worker        "//c10",
512*da0073e9SAndroid Build Coastguard Worker    ],
513*da0073e9SAndroid Build Coastguard Worker)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Workercc_library(
516*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_headers",
517*da0073e9SAndroid Build Coastguard Worker    hdrs = glob(
518*da0073e9SAndroid Build Coastguard Worker        [
519*da0073e9SAndroid Build Coastguard Worker            "caffe2/perfkernels/*.h",
520*da0073e9SAndroid Build Coastguard Worker            "caffe2/serialize/*.h",
521*da0073e9SAndroid Build Coastguard Worker            "caffe2/utils/*.h",
522*da0073e9SAndroid Build Coastguard Worker            "caffe2/utils/threadpool/*.h",
523*da0073e9SAndroid Build Coastguard Worker            "modules/**/*.h",
524*da0073e9SAndroid Build Coastguard Worker        ],
525*da0073e9SAndroid Build Coastguard Worker        exclude = [
526*da0073e9SAndroid Build Coastguard Worker            "caffe2/core/macros.h",
527*da0073e9SAndroid Build Coastguard Worker        ],
528*da0073e9SAndroid Build Coastguard Worker    ) + if_cuda(glob([
529*da0073e9SAndroid Build Coastguard Worker        "caffe2/**/*.cuh",
530*da0073e9SAndroid Build Coastguard Worker    ])),
531*da0073e9SAndroid Build Coastguard Worker    copts = CAFFE2_COPTS,
532*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
533*da0073e9SAndroid Build Coastguard Worker    deps = [
534*da0073e9SAndroid Build Coastguard Worker        ":caffe2_core_macros",
535*da0073e9SAndroid Build Coastguard Worker        ":caffe2_for_aten_headers",
536*da0073e9SAndroid Build Coastguard Worker    ],
537*da0073e9SAndroid Build Coastguard Worker)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Workercc_library(
540*da0073e9SAndroid Build Coastguard Worker    name = "caffe2",
541*da0073e9SAndroid Build Coastguard Worker    srcs = [
542*da0073e9SAndroid Build Coastguard Worker        ":caffe2_core_srcs",
543*da0073e9SAndroid Build Coastguard Worker        ":caffe2_perfkernels_srcs",
544*da0073e9SAndroid Build Coastguard Worker        ":caffe2_serialize_srcs",
545*da0073e9SAndroid Build Coastguard Worker        ":caffe2_utils_srcs",
546*da0073e9SAndroid Build Coastguard Worker    ],
547*da0073e9SAndroid Build Coastguard Worker    copts = CAFFE2_COPTS + ["-mf16c"],
548*da0073e9SAndroid Build Coastguard Worker    linkstatic = 1,
549*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
550*da0073e9SAndroid Build Coastguard Worker    deps = [
551*da0073e9SAndroid Build Coastguard Worker        ":caffe2_core_macros",
552*da0073e9SAndroid Build Coastguard Worker        ":caffe2_headers",
553*da0073e9SAndroid Build Coastguard Worker        ":caffe2_perfkernels_avx",
554*da0073e9SAndroid Build Coastguard Worker        ":caffe2_perfkernels_avx2",
555*da0073e9SAndroid Build Coastguard Worker        "//third_party/miniz-2.1.0:miniz",
556*da0073e9SAndroid Build Coastguard Worker        "@com_google_protobuf//:protobuf",
557*da0073e9SAndroid Build Coastguard Worker        "@eigen",
558*da0073e9SAndroid Build Coastguard Worker        "@fbgemm//:fbgemm_src_headers",
559*da0073e9SAndroid Build Coastguard Worker        "@fmt",
560*da0073e9SAndroid Build Coastguard Worker        "@onnx",
561*da0073e9SAndroid Build Coastguard Worker    ] + if_cuda(
562*da0073e9SAndroid Build Coastguard Worker        [
563*da0073e9SAndroid Build Coastguard Worker            ":aten_cuda",
564*da0073e9SAndroid Build Coastguard Worker            "@tensorpipe//:tensorpipe_cuda",
565*da0073e9SAndroid Build Coastguard Worker        ],
566*da0073e9SAndroid Build Coastguard Worker        [
567*da0073e9SAndroid Build Coastguard Worker            ":aten",
568*da0073e9SAndroid Build Coastguard Worker            "@tensorpipe//:tensorpipe_cpu",
569*da0073e9SAndroid Build Coastguard Worker        ],
570*da0073e9SAndroid Build Coastguard Worker    ),
571*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
572*da0073e9SAndroid Build Coastguard Worker)
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Workercu_library(
575*da0073e9SAndroid Build Coastguard Worker    name = "torch_cuda",
576*da0073e9SAndroid Build Coastguard Worker    srcs = [
577*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/distributed/c10d/intra_node_comm.cu",
578*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/distributed/c10d/NanCheck.cu",
579*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
580*da0073e9SAndroid Build Coastguard Worker    ],
581*da0073e9SAndroid Build Coastguard Worker    copts = torch_cuda_half_options,
582*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
583*da0073e9SAndroid Build Coastguard Worker    deps = [
584*da0073e9SAndroid Build Coastguard Worker        ":aten",
585*da0073e9SAndroid Build Coastguard Worker        "@cuda//:cublas",
586*da0073e9SAndroid Build Coastguard Worker        "@cuda//:curand",
587*da0073e9SAndroid Build Coastguard Worker        "@cudnn",
588*da0073e9SAndroid Build Coastguard Worker        "@eigen",
589*da0073e9SAndroid Build Coastguard Worker        "@tensorpipe//:tensorpipe_cuda",
590*da0073e9SAndroid Build Coastguard Worker    ],
591*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
592*da0073e9SAndroid Build Coastguard Worker)
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard WorkerPERF_COPTS = [
595*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_AVX_CPU_DEFINITION",
596*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_AVX2_CPU_DEFINITION",
597*da0073e9SAndroid Build Coastguard Worker    "-DENABLE_ALIAS=1",
598*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_MALLOC_USABLE_SIZE=1",
599*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_MMAP=1",
600*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_SHM_OPEN=1",
601*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_SHM_UNLINK=1",
602*da0073e9SAndroid Build Coastguard Worker    "-DSLEEF_STATIC_LIBS=1",
603*da0073e9SAndroid Build Coastguard Worker    "-DTH_BALS_MKL",
604*da0073e9SAndroid Build Coastguard Worker    "-D_FILE_OFFSET_BITS=64",
605*da0073e9SAndroid Build Coastguard Worker    "-DUSE_FBGEMM",
606*da0073e9SAndroid Build Coastguard Worker    "-fvisibility-inlines-hidden",
607*da0073e9SAndroid Build Coastguard Worker    "-Wunused-parameter",
608*da0073e9SAndroid Build Coastguard Worker    "-fno-math-errno",
609*da0073e9SAndroid Build Coastguard Worker    "-fno-trapping-math",
610*da0073e9SAndroid Build Coastguard Worker    "-mf16c",
611*da0073e9SAndroid Build Coastguard Worker]
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard WorkerPERF_HEADERS = glob([
614*da0073e9SAndroid Build Coastguard Worker    "caffe2/perfkernels/*.h",
615*da0073e9SAndroid Build Coastguard Worker    "caffe2/core/*.h",
616*da0073e9SAndroid Build Coastguard Worker])
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Workercc_library(
619*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_perfkernels_avx",
620*da0073e9SAndroid Build Coastguard Worker    srcs = glob([
621*da0073e9SAndroid Build Coastguard Worker        "caffe2/perfkernels/*_avx.cc",
622*da0073e9SAndroid Build Coastguard Worker    ]),
623*da0073e9SAndroid Build Coastguard Worker    hdrs = PERF_HEADERS,
624*da0073e9SAndroid Build Coastguard Worker    copts = PERF_COPTS + [
625*da0073e9SAndroid Build Coastguard Worker        "-mavx",
626*da0073e9SAndroid Build Coastguard Worker    ],
627*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
628*da0073e9SAndroid Build Coastguard Worker    deps = [
629*da0073e9SAndroid Build Coastguard Worker        ":caffe2_headers",
630*da0073e9SAndroid Build Coastguard Worker        "//c10",
631*da0073e9SAndroid Build Coastguard Worker    ],
632*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
633*da0073e9SAndroid Build Coastguard Worker)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Workercc_library(
636*da0073e9SAndroid Build Coastguard Worker    name = "caffe2_perfkernels_avx2",
637*da0073e9SAndroid Build Coastguard Worker    srcs = glob([
638*da0073e9SAndroid Build Coastguard Worker        "caffe2/perfkernels/*_avx2.cc",
639*da0073e9SAndroid Build Coastguard Worker    ]),
640*da0073e9SAndroid Build Coastguard Worker    hdrs = PERF_HEADERS,
641*da0073e9SAndroid Build Coastguard Worker    copts = PERF_COPTS + [
642*da0073e9SAndroid Build Coastguard Worker        "-mavx2",
643*da0073e9SAndroid Build Coastguard Worker        "-mfma",
644*da0073e9SAndroid Build Coastguard Worker        "-mavx",
645*da0073e9SAndroid Build Coastguard Worker    ],
646*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
647*da0073e9SAndroid Build Coastguard Worker    deps = [
648*da0073e9SAndroid Build Coastguard Worker        ":caffe2_headers",
649*da0073e9SAndroid Build Coastguard Worker        "//c10",
650*da0073e9SAndroid Build Coastguard Worker    ],
651*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
652*da0073e9SAndroid Build Coastguard Worker)
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker# torch
655*da0073e9SAndroid Build Coastguard Workertorch_cuda_headers = glob(["torch/csrc/cuda/*.h"])
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Workercc_library(
658*da0073e9SAndroid Build Coastguard Worker    name = "torch_headers",
659*da0073e9SAndroid Build Coastguard Worker    hdrs = if_cuda(
660*da0073e9SAndroid Build Coastguard Worker        torch_cuda_headers,
661*da0073e9SAndroid Build Coastguard Worker    ) + glob(
662*da0073e9SAndroid Build Coastguard Worker        [
663*da0073e9SAndroid Build Coastguard Worker            "torch/*.h",
664*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/**/*.h",
665*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/distributed/c10d/**/*.hpp",
666*da0073e9SAndroid Build Coastguard Worker            "torch/lib/libshm/*.h",
667*da0073e9SAndroid Build Coastguard Worker        ],
668*da0073e9SAndroid Build Coastguard Worker        exclude = [
669*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/*/generated/*.h",
670*da0073e9SAndroid Build Coastguard Worker        ] + torch_cuda_headers,
671*da0073e9SAndroid Build Coastguard Worker    ) + GENERATED_AUTOGRAD_CPP + [":version_h"],
672*da0073e9SAndroid Build Coastguard Worker    includes = [
673*da0073e9SAndroid Build Coastguard Worker        "third_party/kineto/libkineto/include",
674*da0073e9SAndroid Build Coastguard Worker        "torch/csrc",
675*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/api/include",
676*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/distributed",
677*da0073e9SAndroid Build Coastguard Worker        "torch/lib",
678*da0073e9SAndroid Build Coastguard Worker        "torch/lib/libshm",
679*da0073e9SAndroid Build Coastguard Worker    ],
680*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
681*da0073e9SAndroid Build Coastguard Worker    deps = [
682*da0073e9SAndroid Build Coastguard Worker        ":aten_headers",
683*da0073e9SAndroid Build Coastguard Worker        ":caffe2_headers",
684*da0073e9SAndroid Build Coastguard Worker        "//c10",
685*da0073e9SAndroid Build Coastguard Worker        "@com_github_google_flatbuffers//:flatbuffers",
686*da0073e9SAndroid Build Coastguard Worker        "@local_config_python//:python_headers",
687*da0073e9SAndroid Build Coastguard Worker        "@onnx",
688*da0073e9SAndroid Build Coastguard Worker    ],
689*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
690*da0073e9SAndroid Build Coastguard Worker)
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard WorkerTORCH_COPTS = COMMON_COPTS + [
693*da0073e9SAndroid Build Coastguard Worker    "-Dtorch_EXPORTS",
694*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_AVX_CPU_DEFINITION",
695*da0073e9SAndroid Build Coastguard Worker    "-DHAVE_AVX2_CPU_DEFINITION",
696*da0073e9SAndroid Build Coastguard Worker    "-DCAFFE2_USE_GLOO",
697*da0073e9SAndroid Build Coastguard Worker    "-fvisibility-inlines-hidden",
698*da0073e9SAndroid Build Coastguard Worker    "-fno-math-errno ",
699*da0073e9SAndroid Build Coastguard Worker    "-fno-trapping-math",
700*da0073e9SAndroid Build Coastguard Worker    "-Wno-error=unused-function",
701*da0073e9SAndroid Build Coastguard Worker]
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Workertorch_sources = {
704*da0073e9SAndroid Build Coastguard Worker    k: ""
705*da0073e9SAndroid Build Coastguard Worker    for k in (
706*da0073e9SAndroid Build Coastguard Worker        libtorch_core_sources +
707*da0073e9SAndroid Build Coastguard Worker        libtorch_distributed_sources +
708*da0073e9SAndroid Build Coastguard Worker        torch_cpp_srcs +
709*da0073e9SAndroid Build Coastguard Worker        libtorch_extra_sources +
710*da0073e9SAndroid Build Coastguard Worker        jit_core_sources +
711*da0073e9SAndroid Build Coastguard Worker        lazy_tensor_ts_sources +
712*da0073e9SAndroid Build Coastguard Worker        GENERATED_AUTOGRAD_CPP
713*da0073e9SAndroid Build Coastguard Worker    )
714*da0073e9SAndroid Build Coastguard Worker}.keys()
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Workercc_library(
717*da0073e9SAndroid Build Coastguard Worker    name = "torch",
718*da0073e9SAndroid Build Coastguard Worker    srcs = if_cuda(glob(
719*da0073e9SAndroid Build Coastguard Worker        libtorch_cuda_sources,
720*da0073e9SAndroid Build Coastguard Worker        exclude = [
721*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/cuda/python_nccl.cpp",
722*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/cuda/nccl.cpp",
723*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/distributed/c10d/intra_node_comm.cu",
724*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
725*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
726*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/distributed/c10d/NanCheck.cu",
727*da0073e9SAndroid Build Coastguard Worker            "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
728*da0073e9SAndroid Build Coastguard Worker        ],
729*da0073e9SAndroid Build Coastguard Worker    )) + torch_sources,
730*da0073e9SAndroid Build Coastguard Worker    copts = TORCH_COPTS,
731*da0073e9SAndroid Build Coastguard Worker    linkopts = [
732*da0073e9SAndroid Build Coastguard Worker      "-lrt",
733*da0073e9SAndroid Build Coastguard Worker    ],
734*da0073e9SAndroid Build Coastguard Worker    defines = [
735*da0073e9SAndroid Build Coastguard Worker        "CAFFE2_NIGHTLY_VERSION=20200115",
736*da0073e9SAndroid Build Coastguard Worker    ],
737*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
738*da0073e9SAndroid Build Coastguard Worker    deps = [
739*da0073e9SAndroid Build Coastguard Worker        ":caffe2",
740*da0073e9SAndroid Build Coastguard Worker        ":torch_headers",
741*da0073e9SAndroid Build Coastguard Worker        "@kineto",
742*da0073e9SAndroid Build Coastguard Worker        "@cpp-httplib",
743*da0073e9SAndroid Build Coastguard Worker        "@nlohmann",
744*da0073e9SAndroid Build Coastguard Worker    ] + if_cuda([
745*da0073e9SAndroid Build Coastguard Worker        "@cuda//:nvToolsExt",
746*da0073e9SAndroid Build Coastguard Worker        "@cutlass",
747*da0073e9SAndroid Build Coastguard Worker        ":torch_cuda",
748*da0073e9SAndroid Build Coastguard Worker    ]),
749*da0073e9SAndroid Build Coastguard Worker    alwayslink = True,
750*da0073e9SAndroid Build Coastguard Worker)
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Workercc_library(
753*da0073e9SAndroid Build Coastguard Worker    name = "shm",
754*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["torch/lib/libshm/*.cpp"]),
755*da0073e9SAndroid Build Coastguard Worker    linkopts = [
756*da0073e9SAndroid Build Coastguard Worker      "-lrt",
757*da0073e9SAndroid Build Coastguard Worker    ],
758*da0073e9SAndroid Build Coastguard Worker    deps = [
759*da0073e9SAndroid Build Coastguard Worker        ":torch",
760*da0073e9SAndroid Build Coastguard Worker    ],
761*da0073e9SAndroid Build Coastguard Worker)
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Workercc_library(
764*da0073e9SAndroid Build Coastguard Worker    name = "libtorch_headers",
765*da0073e9SAndroid Build Coastguard Worker    hdrs = glob([
766*da0073e9SAndroid Build Coastguard Worker        "**/*.h",
767*da0073e9SAndroid Build Coastguard Worker        "**/*.cuh",
768*da0073e9SAndroid Build Coastguard Worker    ]) + [
769*da0073e9SAndroid Build Coastguard Worker        # We need the filegroup here because the raw list causes Bazel
770*da0073e9SAndroid Build Coastguard Worker        # to see duplicate files. It knows how to deduplicate with the
771*da0073e9SAndroid Build Coastguard Worker        # filegroup.
772*da0073e9SAndroid Build Coastguard Worker        ":cpp_generated_code",
773*da0073e9SAndroid Build Coastguard Worker    ],
774*da0073e9SAndroid Build Coastguard Worker    includes = [
775*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/api/include",
776*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/distributed",
777*da0073e9SAndroid Build Coastguard Worker        "torch/lib",
778*da0073e9SAndroid Build Coastguard Worker        "torch/lib/libshm",
779*da0073e9SAndroid Build Coastguard Worker    ],
780*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
781*da0073e9SAndroid Build Coastguard Worker    deps = [
782*da0073e9SAndroid Build Coastguard Worker        ":torch_headers",
783*da0073e9SAndroid Build Coastguard Worker    ],
784*da0073e9SAndroid Build Coastguard Worker)
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Workercc_library(
787*da0073e9SAndroid Build Coastguard Worker    name = "torch_python",
788*da0073e9SAndroid Build Coastguard Worker    srcs = libtorch_python_core_sources
789*da0073e9SAndroid Build Coastguard Worker        + if_cuda(libtorch_python_cuda_sources)
790*da0073e9SAndroid Build Coastguard Worker        + if_cuda(libtorch_python_distributed_sources)
791*da0073e9SAndroid Build Coastguard Worker        + GENERATED_AUTOGRAD_PYTHON,
792*da0073e9SAndroid Build Coastguard Worker    hdrs = glob([
793*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/generic/*.cpp",
794*da0073e9SAndroid Build Coastguard Worker    ]),
795*da0073e9SAndroid Build Coastguard Worker    copts = COMMON_COPTS + if_cuda(["-DUSE_CUDA=1"]),
796*da0073e9SAndroid Build Coastguard Worker    deps = [
797*da0073e9SAndroid Build Coastguard Worker        ":torch",
798*da0073e9SAndroid Build Coastguard Worker        ":shm",
799*da0073e9SAndroid Build Coastguard Worker        "@pybind11",
800*da0073e9SAndroid Build Coastguard Worker    ],
801*da0073e9SAndroid Build Coastguard Worker)
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Workerpybind_extension(
804*da0073e9SAndroid Build Coastguard Worker    name = "torch/_C",
805*da0073e9SAndroid Build Coastguard Worker    srcs = ["torch/csrc/stub.c"],
806*da0073e9SAndroid Build Coastguard Worker    deps = [
807*da0073e9SAndroid Build Coastguard Worker        ":torch_python",
808*da0073e9SAndroid Build Coastguard Worker        ":aten_nvrtc",
809*da0073e9SAndroid Build Coastguard Worker    ],
810*da0073e9SAndroid Build Coastguard Worker)
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Workercc_library(
813*da0073e9SAndroid Build Coastguard Worker    name = "functorch",
814*da0073e9SAndroid Build Coastguard Worker    hdrs = glob([
815*da0073e9SAndroid Build Coastguard Worker        "functorch/csrc/dim/*.h",
816*da0073e9SAndroid Build Coastguard Worker    ]),
817*da0073e9SAndroid Build Coastguard Worker    srcs = glob([
818*da0073e9SAndroid Build Coastguard Worker        "functorch/csrc/dim/*.cpp",
819*da0073e9SAndroid Build Coastguard Worker    ]),
820*da0073e9SAndroid Build Coastguard Worker    deps = [
821*da0073e9SAndroid Build Coastguard Worker        ":aten_nvrtc",
822*da0073e9SAndroid Build Coastguard Worker        ":torch_python",
823*da0073e9SAndroid Build Coastguard Worker        "@pybind11",
824*da0073e9SAndroid Build Coastguard Worker    ],
825*da0073e9SAndroid Build Coastguard Worker)
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Workerpybind_extension(
828*da0073e9SAndroid Build Coastguard Worker    name = "functorch/_C",
829*da0073e9SAndroid Build Coastguard Worker    copts=[
830*da0073e9SAndroid Build Coastguard Worker        "-DTORCH_EXTENSION_NAME=_C"
831*da0073e9SAndroid Build Coastguard Worker    ],
832*da0073e9SAndroid Build Coastguard Worker    srcs = [
833*da0073e9SAndroid Build Coastguard Worker        "functorch/csrc/init_dim_only.cpp",
834*da0073e9SAndroid Build Coastguard Worker    ],
835*da0073e9SAndroid Build Coastguard Worker    deps = [
836*da0073e9SAndroid Build Coastguard Worker        ":functorch",
837*da0073e9SAndroid Build Coastguard Worker        ":torch_python",
838*da0073e9SAndroid Build Coastguard Worker        ":aten_nvrtc",
839*da0073e9SAndroid Build Coastguard Worker    ],
840*da0073e9SAndroid Build Coastguard Worker)
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Workercc_binary(
843*da0073e9SAndroid Build Coastguard Worker    name = "torch/bin/torch_shm_manager",
844*da0073e9SAndroid Build Coastguard Worker    srcs = [
845*da0073e9SAndroid Build Coastguard Worker        "torch/lib/libshm/manager.cpp",
846*da0073e9SAndroid Build Coastguard Worker    ],
847*da0073e9SAndroid Build Coastguard Worker    deps = [
848*da0073e9SAndroid Build Coastguard Worker        ":shm",
849*da0073e9SAndroid Build Coastguard Worker    ],
850*da0073e9SAndroid Build Coastguard Worker    linkstatic = False,
851*da0073e9SAndroid Build Coastguard Worker)
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Workertemplate_rule(
854*da0073e9SAndroid Build Coastguard Worker    name = "gen_version_py",
855*da0073e9SAndroid Build Coastguard Worker    src = ":torch/version.py.tpl",
856*da0073e9SAndroid Build Coastguard Worker    out = "torch/version.py",
857*da0073e9SAndroid Build Coastguard Worker    substitutions = if_cuda({
858*da0073e9SAndroid Build Coastguard Worker        # Set default to 11.2. Otherwise Torchvision complains about incompatibility.
859*da0073e9SAndroid Build Coastguard Worker        "{{CUDA_VERSION}}": "11.2",
860*da0073e9SAndroid Build Coastguard Worker        "{{VERSION}}": "2.0.0",
861*da0073e9SAndroid Build Coastguard Worker    }, {
862*da0073e9SAndroid Build Coastguard Worker        "{{CUDA_VERSION}}": "None",
863*da0073e9SAndroid Build Coastguard Worker        "{{VERSION}}": "2.0.0",
864*da0073e9SAndroid Build Coastguard Worker    }),
865*da0073e9SAndroid Build Coastguard Worker)
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Workerpy_library(
868*da0073e9SAndroid Build Coastguard Worker    name = "pytorch_py",
869*da0073e9SAndroid Build Coastguard Worker    visibility = ["//visibility:public"],
870*da0073e9SAndroid Build Coastguard Worker    srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
871*da0073e9SAndroid Build Coastguard Worker    deps = [
872*da0073e9SAndroid Build Coastguard Worker        rules.requirement("numpy"),
873*da0073e9SAndroid Build Coastguard Worker        rules.requirement("pyyaml"),
874*da0073e9SAndroid Build Coastguard Worker        rules.requirement("requests"),
875*da0073e9SAndroid Build Coastguard Worker        rules.requirement("setuptools"),
876*da0073e9SAndroid Build Coastguard Worker        rules.requirement("sympy"),
877*da0073e9SAndroid Build Coastguard Worker        rules.requirement("typing_extensions"),
878*da0073e9SAndroid Build Coastguard Worker        "//torchgen",
879*da0073e9SAndroid Build Coastguard Worker    ],
880*da0073e9SAndroid Build Coastguard Worker    data = [
881*da0073e9SAndroid Build Coastguard Worker        ":torch/_C.so",
882*da0073e9SAndroid Build Coastguard Worker        ":functorch/_C.so",
883*da0073e9SAndroid Build Coastguard Worker        ":torch/bin/torch_shm_manager",
884*da0073e9SAndroid Build Coastguard Worker    ],
885*da0073e9SAndroid Build Coastguard Worker)
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker# cpp api tests
888*da0073e9SAndroid Build Coastguard Workercc_library(
889*da0073e9SAndroid Build Coastguard Worker    name = "test_support",
890*da0073e9SAndroid Build Coastguard Worker    testonly = True,
891*da0073e9SAndroid Build Coastguard Worker    srcs = [
892*da0073e9SAndroid Build Coastguard Worker        "test/cpp/api/support.cpp",
893*da0073e9SAndroid Build Coastguard Worker    ],
894*da0073e9SAndroid Build Coastguard Worker    hdrs = [
895*da0073e9SAndroid Build Coastguard Worker        "test/cpp/api/init_baseline.h",
896*da0073e9SAndroid Build Coastguard Worker        "test/cpp/api/optim_baseline.h",
897*da0073e9SAndroid Build Coastguard Worker        "test/cpp/api/support.h",
898*da0073e9SAndroid Build Coastguard Worker        "test/cpp/common/support.h",
899*da0073e9SAndroid Build Coastguard Worker    ],
900*da0073e9SAndroid Build Coastguard Worker    deps = [
901*da0073e9SAndroid Build Coastguard Worker        ":torch",
902*da0073e9SAndroid Build Coastguard Worker        "@com_google_googletest//:gtest_main",
903*da0073e9SAndroid Build Coastguard Worker    ],
904*da0073e9SAndroid Build Coastguard Worker)
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker# Torch integration tests rely on a labeled data set from the MNIST database.
907*da0073e9SAndroid Build Coastguard Worker# http://yann.lecun.com/exdb/mnist/
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Workercpp_api_tests = glob(
910*da0073e9SAndroid Build Coastguard Worker    ["test/cpp/api/*.cpp"],
911*da0073e9SAndroid Build Coastguard Worker    exclude = [
912*da0073e9SAndroid Build Coastguard Worker        "test/cpp/api/imethod.cpp",
913*da0073e9SAndroid Build Coastguard Worker        "test/cpp/api/integration.cpp",
914*da0073e9SAndroid Build Coastguard Worker    ],
915*da0073e9SAndroid Build Coastguard Worker)
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Workercc_test(
918*da0073e9SAndroid Build Coastguard Worker    name = "integration_test",
919*da0073e9SAndroid Build Coastguard Worker    size = "medium",
920*da0073e9SAndroid Build Coastguard Worker    srcs = ["test/cpp/api/integration.cpp"],
921*da0073e9SAndroid Build Coastguard Worker    data = [
922*da0073e9SAndroid Build Coastguard Worker        ":download_mnist",
923*da0073e9SAndroid Build Coastguard Worker    ],
924*da0073e9SAndroid Build Coastguard Worker    tags = [
925*da0073e9SAndroid Build Coastguard Worker        "gpu-required",
926*da0073e9SAndroid Build Coastguard Worker    ],
927*da0073e9SAndroid Build Coastguard Worker    deps = [
928*da0073e9SAndroid Build Coastguard Worker        ":test_support",
929*da0073e9SAndroid Build Coastguard Worker        "@com_google_googletest//:gtest_main",
930*da0073e9SAndroid Build Coastguard Worker    ],
931*da0073e9SAndroid Build Coastguard Worker)
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker[
934*da0073e9SAndroid Build Coastguard Worker    cc_test(
935*da0073e9SAndroid Build Coastguard Worker        name = paths.split_extension(paths.basename(filename))[0].replace("-", "_") + "_test",
936*da0073e9SAndroid Build Coastguard Worker        size = "medium",
937*da0073e9SAndroid Build Coastguard Worker        srcs = [filename],
938*da0073e9SAndroid Build Coastguard Worker        deps = [
939*da0073e9SAndroid Build Coastguard Worker            ":test_support",
940*da0073e9SAndroid Build Coastguard Worker            "@com_google_googletest//:gtest_main",
941*da0073e9SAndroid Build Coastguard Worker        ],
942*da0073e9SAndroid Build Coastguard Worker    )
943*da0073e9SAndroid Build Coastguard Worker    for filename in cpp_api_tests
944*da0073e9SAndroid Build Coastguard Worker]
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Workertest_suite(
947*da0073e9SAndroid Build Coastguard Worker    name = "api_tests",
948*da0073e9SAndroid Build Coastguard Worker    tests = [
949*da0073e9SAndroid Build Coastguard Worker        "any_test",
950*da0073e9SAndroid Build Coastguard Worker        "autograd_test",
951*da0073e9SAndroid Build Coastguard Worker        "dataloader_test",
952*da0073e9SAndroid Build Coastguard Worker        "enum_test",
953*da0073e9SAndroid Build Coastguard Worker        "expanding_array_test",
954*da0073e9SAndroid Build Coastguard Worker        "functional_test",
955*da0073e9SAndroid Build Coastguard Worker        "init_test",
956*da0073e9SAndroid Build Coastguard Worker        "integration_test",
957*da0073e9SAndroid Build Coastguard Worker        "jit_test",
958*da0073e9SAndroid Build Coastguard Worker        "memory_test",
959*da0073e9SAndroid Build Coastguard Worker        "misc_test",
960*da0073e9SAndroid Build Coastguard Worker        "module_test",
961*da0073e9SAndroid Build Coastguard Worker        "modulelist_test",
962*da0073e9SAndroid Build Coastguard Worker        "modules_test",
963*da0073e9SAndroid Build Coastguard Worker        "nn_utils_test",
964*da0073e9SAndroid Build Coastguard Worker        "optim_test",
965*da0073e9SAndroid Build Coastguard Worker        "ordered_dict_test",
966*da0073e9SAndroid Build Coastguard Worker        "rnn_test",
967*da0073e9SAndroid Build Coastguard Worker        "sequential_test",
968*da0073e9SAndroid Build Coastguard Worker        "serialize_test",
969*da0073e9SAndroid Build Coastguard Worker        "static_test",
970*da0073e9SAndroid Build Coastguard Worker        "tensor_options_test",
971*da0073e9SAndroid Build Coastguard Worker        "tensor_test",
972*da0073e9SAndroid Build Coastguard Worker        "torch_include_test",
973*da0073e9SAndroid Build Coastguard Worker    ],
974*da0073e9SAndroid Build Coastguard Worker)
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker# dist autograd tests
977*da0073e9SAndroid Build Coastguard Workercc_test(
978*da0073e9SAndroid Build Coastguard Worker    name = "torch_dist_autograd_test",
979*da0073e9SAndroid Build Coastguard Worker    size = "small",
980*da0073e9SAndroid Build Coastguard Worker    srcs = ["test/cpp/dist_autograd/test_dist_autograd.cpp"],
981*da0073e9SAndroid Build Coastguard Worker    tags = [
982*da0073e9SAndroid Build Coastguard Worker        "exclusive",
983*da0073e9SAndroid Build Coastguard Worker        "gpu-required",
984*da0073e9SAndroid Build Coastguard Worker    ],
985*da0073e9SAndroid Build Coastguard Worker    deps = [
986*da0073e9SAndroid Build Coastguard Worker        ":torch",
987*da0073e9SAndroid Build Coastguard Worker        "@com_google_googletest//:gtest_main",
988*da0073e9SAndroid Build Coastguard Worker    ],
989*da0073e9SAndroid Build Coastguard Worker)
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker# jit tests
992*da0073e9SAndroid Build Coastguard Worker# Because these individual unit tests require custom registering,
993*da0073e9SAndroid Build Coastguard Worker# it is easier to mimic the cmake build by globing together a single test.
994*da0073e9SAndroid Build Coastguard Workercc_test(
995*da0073e9SAndroid Build Coastguard Worker    name = "jit_tests",
996*da0073e9SAndroid Build Coastguard Worker    size = "small",
997*da0073e9SAndroid Build Coastguard Worker    srcs = glob(
998*da0073e9SAndroid Build Coastguard Worker        [
999*da0073e9SAndroid Build Coastguard Worker            "test/cpp/jit/*.cpp",
1000*da0073e9SAndroid Build Coastguard Worker            "test/cpp/jit/*.h",
1001*da0073e9SAndroid Build Coastguard Worker            "test/cpp/tensorexpr/*.cpp",
1002*da0073e9SAndroid Build Coastguard Worker            "test/cpp/tensorexpr/*.h",
1003*da0073e9SAndroid Build Coastguard Worker        ],
1004*da0073e9SAndroid Build Coastguard Worker        exclude = [
1005*da0073e9SAndroid Build Coastguard Worker            # skip this since <pybind11/embed.h> is not found in OSS build
1006*da0073e9SAndroid Build Coastguard Worker            "test/cpp/jit/test_exception.cpp",
1007*da0073e9SAndroid Build Coastguard Worker        ],
1008*da0073e9SAndroid Build Coastguard Worker    ),
1009*da0073e9SAndroid Build Coastguard Worker    linkstatic = True,
1010*da0073e9SAndroid Build Coastguard Worker    tags = [
1011*da0073e9SAndroid Build Coastguard Worker        "exclusive",
1012*da0073e9SAndroid Build Coastguard Worker        "gpu-required",
1013*da0073e9SAndroid Build Coastguard Worker    ],
1014*da0073e9SAndroid Build Coastguard Worker    deps = [
1015*da0073e9SAndroid Build Coastguard Worker        ":torch",
1016*da0073e9SAndroid Build Coastguard Worker        "@com_google_googletest//:gtest_main",
1017*da0073e9SAndroid Build Coastguard Worker    ],
1018*da0073e9SAndroid Build Coastguard Worker)
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Workercc_test(
1021*da0073e9SAndroid Build Coastguard Worker    name = "lazy_tests",
1022*da0073e9SAndroid Build Coastguard Worker    size = "small",
1023*da0073e9SAndroid Build Coastguard Worker    srcs = glob(
1024*da0073e9SAndroid Build Coastguard Worker        [
1025*da0073e9SAndroid Build Coastguard Worker            "test/cpp/lazy/*.cpp",
1026*da0073e9SAndroid Build Coastguard Worker            "test/cpp/lazy/*.h",
1027*da0073e9SAndroid Build Coastguard Worker        ],
1028*da0073e9SAndroid Build Coastguard Worker        exclude = [
1029*da0073e9SAndroid Build Coastguard Worker            # skip these since they depend on generated LazyIr.h which isn't available in bazel yet
1030*da0073e9SAndroid Build Coastguard Worker            "test/cpp/lazy/test_ir.cpp",
1031*da0073e9SAndroid Build Coastguard Worker            "test/cpp/lazy/test_lazy_ops.cpp",
1032*da0073e9SAndroid Build Coastguard Worker            "test/cpp/lazy/test_lazy_ops_util.cpp",
1033*da0073e9SAndroid Build Coastguard Worker        ],
1034*da0073e9SAndroid Build Coastguard Worker    ),
1035*da0073e9SAndroid Build Coastguard Worker    linkstatic = True,
1036*da0073e9SAndroid Build Coastguard Worker    tags = [
1037*da0073e9SAndroid Build Coastguard Worker        "exclusive",
1038*da0073e9SAndroid Build Coastguard Worker    ],
1039*da0073e9SAndroid Build Coastguard Worker    deps = [
1040*da0073e9SAndroid Build Coastguard Worker        ":torch",
1041*da0073e9SAndroid Build Coastguard Worker        "@com_google_googletest//:gtest_main",
1042*da0073e9SAndroid Build Coastguard Worker    ],
1043*da0073e9SAndroid Build Coastguard Worker)
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker# python api tests
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Workerpy_test(
1048*da0073e9SAndroid Build Coastguard Worker    name = "test_bazel",
1049*da0073e9SAndroid Build Coastguard Worker    srcs = ["test/_test_bazel.py"],
1050*da0073e9SAndroid Build Coastguard Worker    main = "test/_test_bazel.py",
1051*da0073e9SAndroid Build Coastguard Worker    deps = [":pytorch_py"],
1052*da0073e9SAndroid Build Coastguard Worker)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker# all tests
1055*da0073e9SAndroid Build Coastguard Workertest_suite(
1056*da0073e9SAndroid Build Coastguard Worker    name = "all_tests",
1057*da0073e9SAndroid Build Coastguard Worker    tests = [
1058*da0073e9SAndroid Build Coastguard Worker        "api_tests",
1059*da0073e9SAndroid Build Coastguard Worker        "jit_tests",
1060*da0073e9SAndroid Build Coastguard Worker        "torch_dist_autograd_test",
1061*da0073e9SAndroid Build Coastguard Worker        "//c10/test:tests",
1062*da0073e9SAndroid Build Coastguard Worker    ],
1063*da0073e9SAndroid Build Coastguard Worker)
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker# An internal genrule that we are converging with refers to these file
1066*da0073e9SAndroid Build Coastguard Worker# as if they are from this package, so we alias them for
1067*da0073e9SAndroid Build Coastguard Worker# compatibility.
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker[
1070*da0073e9SAndroid Build Coastguard Worker    alias(
1071*da0073e9SAndroid Build Coastguard Worker        name = paths.basename(path),
1072*da0073e9SAndroid Build Coastguard Worker        actual = path,
1073*da0073e9SAndroid Build Coastguard Worker    )
1074*da0073e9SAndroid Build Coastguard Worker    for path in [
1075*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
1076*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
1077*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/templates/LazyIr.h",
1078*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/templates/LazyNonNativeIr.h",
1079*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/templates/RegisterDispatchKey.cpp",
1080*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
1081*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/native_functions.yaml",
1082*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/tags.yaml",
1083*da0073e9SAndroid Build Coastguard Worker        "aten/src/ATen/native/ts_native_functions.yaml",
1084*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/lazy/core/shape_inference.h",
1085*da0073e9SAndroid Build Coastguard Worker        "torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
1086*da0073e9SAndroid Build Coastguard Worker    ]
1087*da0073e9SAndroid Build Coastguard Worker]
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Workergenrule(
1090*da0073e9SAndroid Build Coastguard Worker    name = "download_mnist",
1091*da0073e9SAndroid Build Coastguard Worker    srcs = ["//:tools/download_mnist.py"],
1092*da0073e9SAndroid Build Coastguard Worker    outs = [
1093*da0073e9SAndroid Build Coastguard Worker        "mnist/train-images-idx3-ubyte",
1094*da0073e9SAndroid Build Coastguard Worker        "mnist/train-labels-idx1-ubyte",
1095*da0073e9SAndroid Build Coastguard Worker        "mnist/t10k-images-idx3-ubyte",
1096*da0073e9SAndroid Build Coastguard Worker        "mnist/t10k-labels-idx1-ubyte",
1097*da0073e9SAndroid Build Coastguard Worker    ],
1098*da0073e9SAndroid Build Coastguard Worker    cmd = "python3 tools/download_mnist.py -d $(RULEDIR)/mnist",
1099*da0073e9SAndroid Build Coastguard Worker)
1100