xref: /aosp_15_r20/external/pytorch/test/test_xpu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: intel"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport tempfile
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch.xpu._gpu_trace as gpu_trace
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.autocast_test_lists import AutocastTestLists
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
12*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
13*da0073e9SAndroid Build Coastguard Worker    onlyXPU,
14*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
15*da0073e9SAndroid Build Coastguard Worker    ops,
16*da0073e9SAndroid Build Coastguard Worker)
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ops_and_refs
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
19*da0073e9SAndroid Build Coastguard Worker    NoTest,
20*da0073e9SAndroid Build Coastguard Worker    run_tests,
21*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
22*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_UBSAN,
23*da0073e9SAndroid Build Coastguard Worker    TEST_XPU,
24*da0073e9SAndroid Build Coastguard Worker    TestCase,
25*da0073e9SAndroid Build Coastguard Worker)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerif not TEST_XPU:
28*da0073e9SAndroid Build Coastguard Worker    print("XPU not available, skipping tests", file=sys.stderr)
29*da0073e9SAndroid Build Coastguard Worker    TestCase = NoTest  # noqa: F811
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard WorkerTEST_MULTIXPU = torch.xpu.device_count() > 1
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workercpu_device = torch.device("cpu")
34*da0073e9SAndroid Build Coastguard Workerxpu_device = torch.device("xpu")
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerany_common_cpu_xpu_one = OpDTypes.any_common_cpu_cuda_one
37*da0073e9SAndroid Build Coastguard Worker_xpu_computation_op_list = [
38*da0073e9SAndroid Build Coastguard Worker    "fill",
39*da0073e9SAndroid Build Coastguard Worker    "zeros",
40*da0073e9SAndroid Build Coastguard Worker    "zeros_like",
41*da0073e9SAndroid Build Coastguard Worker    "clone",
42*da0073e9SAndroid Build Coastguard Worker    "view_as_real",
43*da0073e9SAndroid Build Coastguard Worker    "view_as_complex",
44*da0073e9SAndroid Build Coastguard Worker    "view",
45*da0073e9SAndroid Build Coastguard Worker    "resize_",
46*da0073e9SAndroid Build Coastguard Worker    "resize_as_",
47*da0073e9SAndroid Build Coastguard Worker    "add",
48*da0073e9SAndroid Build Coastguard Worker    "sub",
49*da0073e9SAndroid Build Coastguard Worker    "mul",
50*da0073e9SAndroid Build Coastguard Worker    "div",
51*da0073e9SAndroid Build Coastguard Worker    "abs",
52*da0073e9SAndroid Build Coastguard Worker]
53*da0073e9SAndroid Build Coastguard Worker_xpu_tensor_factory_op_list = [
54*da0073e9SAndroid Build Coastguard Worker    "as_strided",
55*da0073e9SAndroid Build Coastguard Worker    "empty",
56*da0073e9SAndroid Build Coastguard Worker    "empty_strided",
57*da0073e9SAndroid Build Coastguard Worker]
58*da0073e9SAndroid Build Coastguard Worker_xpu_not_test_dtype_op_list = [
59*da0073e9SAndroid Build Coastguard Worker    "resize_",  # Skipped by CPU
60*da0073e9SAndroid Build Coastguard Worker    "resize_as_",  # Skipped by CPU
61*da0073e9SAndroid Build Coastguard Worker    "abs",  # Not aligned dtype
62*da0073e9SAndroid Build Coastguard Worker]
63*da0073e9SAndroid Build Coastguard Worker_xpu_all_op_list = _xpu_computation_op_list + _xpu_tensor_factory_op_list
64*da0073e9SAndroid Build Coastguard Worker_xpu_all_ops = [op for op in ops_and_refs if op.name in _xpu_all_op_list]
65*da0073e9SAndroid Build Coastguard Worker_xpu_computation_ops = [
66*da0073e9SAndroid Build Coastguard Worker    op for op in ops_and_refs if op.name in _xpu_computation_op_list
67*da0073e9SAndroid Build Coastguard Worker]
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerclass TestXpu(TestCase):
71*da0073e9SAndroid Build Coastguard Worker    def test_device_behavior(self):
72*da0073e9SAndroid Build Coastguard Worker        current_device = torch.xpu.current_device()
73*da0073e9SAndroid Build Coastguard Worker        torch.xpu.set_device(current_device)
74*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_device, torch.xpu.current_device())
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
77*da0073e9SAndroid Build Coastguard Worker    def test_multi_device_behavior(self):
78*da0073e9SAndroid Build Coastguard Worker        current_device = torch.xpu.current_device()
79*da0073e9SAndroid Build Coastguard Worker        target_device = (current_device + 1) % torch.xpu.device_count()
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        with torch.xpu.device(target_device):
82*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(target_device, torch.xpu.current_device())
83*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_device, torch.xpu.current_device())
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        with torch.xpu._DeviceGuard(target_device):
86*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(target_device, torch.xpu.current_device())
87*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_device, torch.xpu.current_device())
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    def test_get_device_properties(self):
90*da0073e9SAndroid Build Coastguard Worker        current_device = torch.xpu.current_device()
91*da0073e9SAndroid Build Coastguard Worker        device_properties = torch.xpu.get_device_properties(current_device)
92*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_properties, torch.xpu.get_device_properties(None))
93*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_properties, torch.xpu.get_device_properties())
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        device_name = torch.xpu.get_device_name(current_device)
96*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_name, torch.xpu.get_device_name(None))
97*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_name, torch.xpu.get_device_name())
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        device_capability = torch.xpu.get_device_capability(current_device)
100*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(device_capability["max_work_group_size"] > 0)
101*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(device_capability["max_num_sub_groups"] > 0)
102*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
103*da0073e9SAndroid Build Coastguard Worker            device_properties.driver_version, device_capability["driver_version"]
104*da0073e9SAndroid Build Coastguard Worker        )
105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_properties.has_fp16, device_capability["has_fp16"])
106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_properties.has_fp64, device_capability["has_fp64"])
107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
108*da0073e9SAndroid Build Coastguard Worker            device_properties.has_atomic64, device_capability["has_atomic64"]
109*da0073e9SAndroid Build Coastguard Worker        )
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    def test_wrong_xpu_fork(self):
112*da0073e9SAndroid Build Coastguard Worker        stderr = TestCase.runWithPytorchAPIUsageStderr(
113*da0073e9SAndroid Build Coastguard Worker            """\
114*da0073e9SAndroid Build Coastguard Workerimport torch
115*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import Process
116*da0073e9SAndroid Build Coastguard Workerdef run(rank):
117*da0073e9SAndroid Build Coastguard Worker    torch.xpu.set_device(rank)
118*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
119*da0073e9SAndroid Build Coastguard Worker    size = 2
120*da0073e9SAndroid Build Coastguard Worker    processes = []
121*da0073e9SAndroid Build Coastguard Worker    for rank in range(size):
122*da0073e9SAndroid Build Coastguard Worker        # it would work fine without the line below
123*da0073e9SAndroid Build Coastguard Worker        torch.xpu.set_device(0)
124*da0073e9SAndroid Build Coastguard Worker        p = Process(target=run, args=(rank,))
125*da0073e9SAndroid Build Coastguard Worker        p.start()
126*da0073e9SAndroid Build Coastguard Worker        processes.append(p)
127*da0073e9SAndroid Build Coastguard Worker    for p in processes:
128*da0073e9SAndroid Build Coastguard Worker        p.join()
129*da0073e9SAndroid Build Coastguard Worker"""
130*da0073e9SAndroid Build Coastguard Worker        )
131*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def test_streams(self):
134*da0073e9SAndroid Build Coastguard Worker        s0 = torch.xpu.Stream()
135*da0073e9SAndroid Build Coastguard Worker        torch.xpu.set_stream(s0)
136*da0073e9SAndroid Build Coastguard Worker        s1 = torch.xpu.current_stream()
137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0, s1)
138*da0073e9SAndroid Build Coastguard Worker        s2 = torch.xpu.Stream()
139*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(s0 == s2)
140*da0073e9SAndroid Build Coastguard Worker        torch.xpu.set_stream(s2)
141*da0073e9SAndroid Build Coastguard Worker        with torch.xpu.stream(s0):
142*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s0, torch.xpu.current_stream())
143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s2, torch.xpu.current_stream())
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def test_stream_priority(self):
146*da0073e9SAndroid Build Coastguard Worker        low, high = torch.xpu.Stream.priority_range()
147*da0073e9SAndroid Build Coastguard Worker        s0 = torch.xpu.Stream(device=0, priority=low)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(low, s0.priority)
150*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.device("xpu:0"), s0.device)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        s1 = torch.xpu.Stream(device=0, priority=high)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(high, s1.priority)
155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.device("xpu:0"), s1.device)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    def test_stream_event_repr(self):
158*da0073e9SAndroid Build Coastguard Worker        s = torch.xpu.current_stream()
159*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("torch.xpu.Stream" in str(s))
160*da0073e9SAndroid Build Coastguard Worker        e = torch.xpu.Event()
161*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("torch.xpu.Event(uninitialized)" in str(e))
162*da0073e9SAndroid Build Coastguard Worker        s.record_event(e)
163*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("torch.xpu.Event" in str(e))
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    def test_events(self):
166*da0073e9SAndroid Build Coastguard Worker        stream = torch.xpu.current_stream()
167*da0073e9SAndroid Build Coastguard Worker        event = torch.xpu.Event()
168*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event.query())
169*da0073e9SAndroid Build Coastguard Worker        stream.record_event(event)
170*da0073e9SAndroid Build Coastguard Worker        event.synchronize()
171*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event.query())
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    def test_generic_stream_event(self):
174*da0073e9SAndroid Build Coastguard Worker        stream = torch.Stream("xpu")
175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stream.device_index, torch.xpu.current_device())
176*da0073e9SAndroid Build Coastguard Worker        xpu_stream = torch.xpu.Stream(
177*da0073e9SAndroid Build Coastguard Worker            stream_id=stream.stream_id,
178*da0073e9SAndroid Build Coastguard Worker            device_index=stream.device_index,
179*da0073e9SAndroid Build Coastguard Worker            device_type=stream.device_type,
180*da0073e9SAndroid Build Coastguard Worker        )
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stream.stream_id, xpu_stream.stream_id)
182*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        event1 = torch.Event("xpu")
185*da0073e9SAndroid Build Coastguard Worker        event2 = torch.Event("xpu")
186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(event1.event_id, 0)
187*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1000)
188*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1000)
189*da0073e9SAndroid Build Coastguard Worker        with torch.xpu.stream(xpu_stream):
190*da0073e9SAndroid Build Coastguard Worker            a_xpu = a.to("xpu", non_blocking=True)
191*da0073e9SAndroid Build Coastguard Worker            b_xpu = b.to("xpu", non_blocking=True)
192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
193*da0073e9SAndroid Build Coastguard Worker        event1.record(stream)
194*da0073e9SAndroid Build Coastguard Worker        event1.synchronize()
195*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event1.query())
196*da0073e9SAndroid Build Coastguard Worker        c_xpu = a_xpu + b_xpu
197*da0073e9SAndroid Build Coastguard Worker        event2.record()
198*da0073e9SAndroid Build Coastguard Worker        event2.synchronize()
199*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event2.query())
200*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(event1.event_id, event2.event_id)
201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_xpu.cpu(), a + b)
202*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
203*da0073e9SAndroid Build Coastguard Worker            NotImplementedError, "elapsedTime is not supported by XPU backend."
204*da0073e9SAndroid Build Coastguard Worker        ):
205*da0073e9SAndroid Build Coastguard Worker            event1.elapsed_time(event2)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker    def test_generator(self):
208*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(2024)
209*da0073e9SAndroid Build Coastguard Worker        g_state0 = torch.xpu.get_rng_state()
210*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1234)
211*da0073e9SAndroid Build Coastguard Worker        g_state1 = torch.xpu.get_rng_state()
212*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(g_state0, g_state1)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        torch.xpu.manual_seed(2024)
215*da0073e9SAndroid Build Coastguard Worker        g_state2 = torch.xpu.get_rng_state()
216*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g_state0, g_state2)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        torch.xpu.set_rng_state(g_state1)
219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g_state1, torch.xpu.get_rng_state())
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1234)
222*da0073e9SAndroid Build Coastguard Worker        torch.xpu.set_rng_state(g_state0)
223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2024, torch.xpu.initial_seed())
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    @onlyXPU
226*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
227*da0073e9SAndroid Build Coastguard Worker    @ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
228*da0073e9SAndroid Build Coastguard Worker    def test_compare_cpu(self, device, dtype, op):
229*da0073e9SAndroid Build Coastguard Worker        def to_cpu(arg):
230*da0073e9SAndroid Build Coastguard Worker            if isinstance(arg, torch.Tensor):
231*da0073e9SAndroid Build Coastguard Worker                return arg.to(device="cpu")
232*da0073e9SAndroid Build Coastguard Worker            return arg
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        samples = op.reference_inputs(device, dtype)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
237*da0073e9SAndroid Build Coastguard Worker            cpu_sample = sample.transform(to_cpu)
238*da0073e9SAndroid Build Coastguard Worker            xpu_results = op(sample.input, *sample.args, **sample.kwargs)
239*da0073e9SAndroid Build Coastguard Worker            cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker            xpu_results = sample.output_process_fn_grad(xpu_results)
242*da0073e9SAndroid Build Coastguard Worker            cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker            # Lower tolerance because we are running this as a `@slowTest`
245*da0073e9SAndroid Build Coastguard Worker            # Don't want the periodic tests to fail frequently
246*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker    @onlyXPU
249*da0073e9SAndroid Build Coastguard Worker    @ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
250*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
251*da0073e9SAndroid Build Coastguard Worker    def test_non_standard_bool_values(self, device, dtype, op):
252*da0073e9SAndroid Build Coastguard Worker        # Test boolean values other than 0x00 and 0x01 (gh-54789)
253*da0073e9SAndroid Build Coastguard Worker        def convert_boolean_tensors(x):
254*da0073e9SAndroid Build Coastguard Worker            if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
255*da0073e9SAndroid Build Coastguard Worker                return x
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker            # Map False -> 0 and True -> Random value in [2, 255]
258*da0073e9SAndroid Build Coastguard Worker            true_vals = torch.randint(
259*da0073e9SAndroid Build Coastguard Worker                2, 255, x.shape, dtype=torch.uint8, device=x.device
260*da0073e9SAndroid Build Coastguard Worker            )
261*da0073e9SAndroid Build Coastguard Worker            false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
262*da0073e9SAndroid Build Coastguard Worker            x_int = torch.where(x, true_vals, false_vals)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker            ret = x_int.view(torch.bool)
265*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ret, x)
266*da0073e9SAndroid Build Coastguard Worker            return ret
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device, dtype):
269*da0073e9SAndroid Build Coastguard Worker            expect = op(sample.input, *sample.args, **sample.kwargs)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker            transformed = sample.transform(convert_boolean_tensors)
272*da0073e9SAndroid Build Coastguard Worker            actual = op(transformed.input, *transformed.args, **transformed.kwargs)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    def test_serialization_array_with_storage(self):
277*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5).xpu()
278*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
279*da0073e9SAndroid Build Coastguard Worker        q = [x, y, x, y.storage()]
280*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
281*da0073e9SAndroid Build Coastguard Worker            torch.save(q, f)
282*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
283*da0073e9SAndroid Build Coastguard Worker            q_copy = torch.load(f)
284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy, q, atol=0, rtol=0)
285*da0073e9SAndroid Build Coastguard Worker        q_copy[0].fill_(5)
286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
287*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy[0].dtype, torch.float)
288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy[1].dtype, torch.int)
289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy[2].dtype, torch.float)
290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
291*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
292*da0073e9SAndroid Build Coastguard Worker        q_copy[1].fill_(10)
293*da0073e9SAndroid Build Coastguard Worker        y.fill_(10)
294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy[3], y.storage())
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    def test_serialization_array_with_empty(self):
297*da0073e9SAndroid Build Coastguard Worker        x = [
298*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4).xpu(),
299*da0073e9SAndroid Build Coastguard Worker            torch.tensor([], dtype=torch.float, device=torch.device("xpu")),
300*da0073e9SAndroid Build Coastguard Worker        ]
301*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
302*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
303*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
304*da0073e9SAndroid Build Coastguard Worker            x_copy = torch.load(f)
305*da0073e9SAndroid Build Coastguard Worker        for original, copy in zip(x, x_copy):
306*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy, original)
307*da0073e9SAndroid Build Coastguard Worker            self.assertIs(type(copy), type(original))
308*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy.get_device(), original.get_device())
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestXpu, globals(), only_for="xpu")
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Workerclass TestXpuAutocast(TestCase):
315*da0073e9SAndroid Build Coastguard Worker    # These operators are not implemented on XPU backend and we can NOT fall back
316*da0073e9SAndroid Build Coastguard Worker    # them to CPU. So we have to skip them at this moment.
317*da0073e9SAndroid Build Coastguard Worker    # TODO: remove these operators from skip list when they are implemented on XPU backend.
318*da0073e9SAndroid Build Coastguard Worker    skip_list = ["gru_cell"]
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
321*da0073e9SAndroid Build Coastguard Worker        super().setUp()
322*da0073e9SAndroid Build Coastguard Worker        self.autocast_lists = AutocastTestLists(torch.device("xpu"))
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
325*da0073e9SAndroid Build Coastguard Worker        del self.autocast_lists
326*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker    def _run_autocast_outofplace(
329*da0073e9SAndroid Build Coastguard Worker        self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
330*da0073e9SAndroid Build Coastguard Worker    ):
331*da0073e9SAndroid Build Coastguard Worker        # helper to cast args
332*da0073e9SAndroid Build Coastguard Worker        def cast(val, to_type):
333*da0073e9SAndroid Build Coastguard Worker            if isinstance(val, torch.Tensor):
334*da0073e9SAndroid Build Coastguard Worker                return val.to(to_type) if val.is_floating_point() else val
335*da0073e9SAndroid Build Coastguard Worker            elif isinstance(val, collections.abc.Iterable):
336*da0073e9SAndroid Build Coastguard Worker                return type(val)(cast(v, to_type) for v in val)
337*da0073e9SAndroid Build Coastguard Worker            else:
338*da0073e9SAndroid Build Coastguard Worker                return val
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        if add_kwargs is None:
341*da0073e9SAndroid Build Coastguard Worker            add_kwargs = {}
342*da0073e9SAndroid Build Coastguard Worker        fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
343*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_autocast_enabled("xpu"))
344*da0073e9SAndroid Build Coastguard Worker        with torch.amp.autocast("xpu", dtype=fast_dtype):
345*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_autocast_enabled("xpu"))
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker            out_type = out_type if out_type is not None else run_as_type
348*da0073e9SAndroid Build Coastguard Worker            output = output_method = None
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker            # Try module.* variant, if requested:
351*da0073e9SAndroid Build Coastguard Worker            if module is not None and hasattr(module, op):
352*da0073e9SAndroid Build Coastguard Worker                output = getattr(module, op)(*args, **add_kwargs)
353*da0073e9SAndroid Build Coastguard Worker                if isinstance(output, torch.Tensor):
354*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
355*da0073e9SAndroid Build Coastguard Worker                        out_type == output.dtype,
356*da0073e9SAndroid Build Coastguard Worker                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
357*da0073e9SAndroid Build Coastguard Worker                    )
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker            # Try Tensor.* variant:
360*da0073e9SAndroid Build Coastguard Worker            if hasattr(torch.Tensor, op):
361*da0073e9SAndroid Build Coastguard Worker                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
362*da0073e9SAndroid Build Coastguard Worker                if isinstance(output_method, torch.Tensor):
363*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
364*da0073e9SAndroid Build Coastguard Worker                        out_type == output_method.dtype,
365*da0073e9SAndroid Build Coastguard Worker                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
366*da0073e9SAndroid Build Coastguard Worker                    )
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
369*da0073e9SAndroid Build Coastguard Worker                (output is not None) or (output_method is not None),
370*da0073e9SAndroid Build Coastguard Worker                f"{op} not found as an attribute on either Tensor or the requested module {module}",
371*da0073e9SAndroid Build Coastguard Worker            )
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
374*da0073e9SAndroid Build Coastguard Worker            # For example, lstm_cell returns a tuple and equal returns bool.
375*da0073e9SAndroid Build Coastguard Worker            def compare(first, second):
376*da0073e9SAndroid Build Coastguard Worker                if isinstance(first, torch.Tensor):
377*da0073e9SAndroid Build Coastguard Worker                    return torch.equal(first, second)
378*da0073e9SAndroid Build Coastguard Worker                elif isinstance(first, collections.abc.Iterable):
379*da0073e9SAndroid Build Coastguard Worker                    return all(compare(f, s) for f, s in zip(first, second))
380*da0073e9SAndroid Build Coastguard Worker                else:
381*da0073e9SAndroid Build Coastguard Worker                    return first == second
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker            # If both torch.* and Tensor.* variants were found, check outputs are identical
384*da0073e9SAndroid Build Coastguard Worker            if (output is not None) and (output_method is not None):
385*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(type(output) == type(output_method))
386*da0073e9SAndroid Build Coastguard Worker                comparison = compare(output, output_method)
387*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
388*da0073e9SAndroid Build Coastguard Worker                    comparison, f"torch.{op} result did not match Tensor.{op} result"
389*da0073e9SAndroid Build Coastguard Worker                )
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
392*da0073e9SAndroid Build Coastguard Worker            # as the C++-side autocasting, and should be bitwise accurate.
393*da0073e9SAndroid Build Coastguard Worker            output_to_compare = output if output is not None else output_method
394*da0073e9SAndroid Build Coastguard Worker            with torch.amp.autocast("xpu", enabled=False):
395*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_autocast_enabled("xpu"))
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker                if module is not None and hasattr(module, op):
398*da0073e9SAndroid Build Coastguard Worker                    control = getattr(module, op)(
399*da0073e9SAndroid Build Coastguard Worker                        *cast(args, run_as_type), **add_kwargs
400*da0073e9SAndroid Build Coastguard Worker                    )
401*da0073e9SAndroid Build Coastguard Worker                else:
402*da0073e9SAndroid Build Coastguard Worker                    control = getattr(args[0].to(run_as_type), op)(
403*da0073e9SAndroid Build Coastguard Worker                        *cast(args[1:], run_as_type), **add_kwargs
404*da0073e9SAndroid Build Coastguard Worker                    )
405*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(type(output_to_compare) == type(control))
406*da0073e9SAndroid Build Coastguard Worker                comparison = compare(output_to_compare, control)
407*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(comparison, f"torch.{op} result did not match control")
408*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_autocast_enabled("xpu"))
409*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_autocast_enabled("xpu"))
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_fp16(self):
412*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.torch_fp16:
413*da0073e9SAndroid Build Coastguard Worker            skip_test = False
414*da0073e9SAndroid Build Coastguard Worker            op, args = op_with_args[0], op_with_args[1]
415*da0073e9SAndroid Build Coastguard Worker            if op in self.skip_list:
416*da0073e9SAndroid Build Coastguard Worker                skip_test = True  # skip unimplemented op
417*da0073e9SAndroid Build Coastguard Worker            if len(op_with_args) == 3:
418*da0073e9SAndroid Build Coastguard Worker                skip_test = True  # skip cudnn op
419*da0073e9SAndroid Build Coastguard Worker            if not skip_test:
420*da0073e9SAndroid Build Coastguard Worker                self._run_autocast_outofplace(op, args, torch.float16)
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_bf16(self):
423*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.torch_fp16:
424*da0073e9SAndroid Build Coastguard Worker            skip_test = False
425*da0073e9SAndroid Build Coastguard Worker            op, args = op_with_args[0], op_with_args[1]
426*da0073e9SAndroid Build Coastguard Worker            if op in self.skip_list:
427*da0073e9SAndroid Build Coastguard Worker                skip_test = True  # skip unimplemented op
428*da0073e9SAndroid Build Coastguard Worker            if len(op_with_args) == 3:
429*da0073e9SAndroid Build Coastguard Worker                skip_test = True  # skip cudnn op
430*da0073e9SAndroid Build Coastguard Worker            if not skip_test:
431*da0073e9SAndroid Build Coastguard Worker                self._run_autocast_outofplace(op, args, torch.bfloat16)
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_need_autocast_promote(self):
434*da0073e9SAndroid Build Coastguard Worker        for op, args in self.autocast_lists.torch_need_autocast_promote:
435*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(op, args, torch.float32)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_expect_builtin_promote(self):
438*da0073e9SAndroid Build Coastguard Worker        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
439*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    def test_xpu_autocast_dtype(self):
442*da0073e9SAndroid Build Coastguard Worker        dtype = torch.get_autocast_dtype("xpu")
443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dtype, torch.float16)
444*da0073e9SAndroid Build Coastguard Worker        mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
445*da0073e9SAndroid Build Coastguard Worker        mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
446*da0073e9SAndroid Build Coastguard Worker        with torch.amp.autocast("xpu"):
447*da0073e9SAndroid Build Coastguard Worker            result = torch.mm(mat0_fp32, mat1_fp32)
448*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, torch.float16)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Workerclass TestXpuTrace(TestCase):
452*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
453*da0073e9SAndroid Build Coastguard Worker        torch._C._activate_gpu_trace()
454*da0073e9SAndroid Build Coastguard Worker        self.mock = unittest.mock.MagicMock()
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker    def test_event_creation_callback(self):
457*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_creation(self.mock)
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        event = torch.xpu.Event()
460*da0073e9SAndroid Build Coastguard Worker        event.record()
461*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(event._as_parameter_.value)
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    def test_event_deletion_callback(self):
464*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_deletion(self.mock)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        event = torch.xpu.Event()
467*da0073e9SAndroid Build Coastguard Worker        event.record()
468*da0073e9SAndroid Build Coastguard Worker        event_id = event._as_parameter_.value
469*da0073e9SAndroid Build Coastguard Worker        del event
470*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(event_id)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    def test_event_record_callback(self):
473*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_record(self.mock)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        event = torch.xpu.Event()
476*da0073e9SAndroid Build Coastguard Worker        event.record()
477*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(
478*da0073e9SAndroid Build Coastguard Worker            event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
479*da0073e9SAndroid Build Coastguard Worker        )
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    def test_event_wait_callback(self):
482*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_wait(self.mock)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker        event = torch.xpu.Event()
485*da0073e9SAndroid Build Coastguard Worker        event.record()
486*da0073e9SAndroid Build Coastguard Worker        event.wait()
487*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(
488*da0073e9SAndroid Build Coastguard Worker            event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
489*da0073e9SAndroid Build Coastguard Worker        )
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    def test_device_synchronization_callback(self):
492*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_device_synchronization(self.mock)
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker        torch.xpu.synchronize()
495*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called()
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker    def test_stream_synchronization_callback(self):
498*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_stream_synchronization(self.mock)
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        stream = torch.xpu.Stream()
501*da0073e9SAndroid Build Coastguard Worker        stream.synchronize()
502*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(stream.sycl_queue)
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker    def test_event_synchronization_callback(self):
505*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_synchronization(self.mock)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        event = torch.xpu.Event()
508*da0073e9SAndroid Build Coastguard Worker        event.record()
509*da0073e9SAndroid Build Coastguard Worker        event.synchronize()
510*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(event._as_parameter_.value)
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
514*da0073e9SAndroid Build Coastguard Worker    run_tests()
515