xref: /aosp_15_r20/external/pytorch/test/_test_bazel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: bazel"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker"""
4*da0073e9SAndroid Build Coastguard WorkerThis test module contains a minimalistic "smoke tests" for the bazel build.
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard WorkerCurrently it doesn't use any testing framework (i.e. pytest)
7*da0073e9SAndroid Build Coastguard WorkerTODO: integrate this into the existing pytorch testing framework.
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerThe name uses underscore `_test_bazel.py` to avoid globbing into other non-bazel configurations.
10*da0073e9SAndroid Build Coastguard Worker"""
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport torch
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef test_sum() -> None:
16*da0073e9SAndroid Build Coastguard Worker    assert torch.eq(
17*da0073e9SAndroid Build Coastguard Worker        torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])
18*da0073e9SAndroid Build Coastguard Worker    ).all()
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerdef test_simple_compile_eager() -> None:
22*da0073e9SAndroid Build Coastguard Worker    def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
23*da0073e9SAndroid Build Coastguard Worker        a = torch.sin(x)
24*da0073e9SAndroid Build Coastguard Worker        b = torch.cos(y)
25*da0073e9SAndroid Build Coastguard Worker        return a + b
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    opt_foo1 = torch.compile(foo, backend="eager")
28*da0073e9SAndroid Build Coastguard Worker    # just check that we can run without raising an Exception
29*da0073e9SAndroid Build Coastguard Worker    assert opt_foo1(torch.randn(10, 10), torch.randn(10, 10)) is not None
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workertest_sum()
33*da0073e9SAndroid Build Coastguard Workertest_simple_compile_eager()
34