1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: intel"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport tempfile 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch.xpu._gpu_trace as gpu_trace 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.autocast_test_lists import AutocastTestLists 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 12*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 13*da0073e9SAndroid Build Coastguard Worker onlyXPU, 14*da0073e9SAndroid Build Coastguard Worker OpDTypes, 15*da0073e9SAndroid Build Coastguard Worker ops, 16*da0073e9SAndroid Build Coastguard Worker) 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ops_and_refs 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 19*da0073e9SAndroid Build Coastguard Worker NoTest, 20*da0073e9SAndroid Build Coastguard Worker run_tests, 21*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 22*da0073e9SAndroid Build Coastguard Worker TEST_WITH_UBSAN, 23*da0073e9SAndroid Build Coastguard Worker TEST_XPU, 24*da0073e9SAndroid Build Coastguard Worker TestCase, 25*da0073e9SAndroid Build Coastguard Worker) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerif not TEST_XPU: 28*da0073e9SAndroid Build Coastguard Worker print("XPU not available, skipping tests", file=sys.stderr) 29*da0073e9SAndroid Build Coastguard Worker TestCase = NoTest # noqa: F811 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard WorkerTEST_MULTIXPU = torch.xpu.device_count() > 1 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workercpu_device = torch.device("cpu") 34*da0073e9SAndroid Build Coastguard Workerxpu_device = torch.device("xpu") 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerany_common_cpu_xpu_one = OpDTypes.any_common_cpu_cuda_one 37*da0073e9SAndroid Build Coastguard Worker_xpu_computation_op_list = [ 38*da0073e9SAndroid Build Coastguard Worker "fill", 39*da0073e9SAndroid Build Coastguard Worker "zeros", 40*da0073e9SAndroid Build Coastguard Worker "zeros_like", 41*da0073e9SAndroid Build Coastguard Worker "clone", 42*da0073e9SAndroid Build Coastguard Worker "view_as_real", 43*da0073e9SAndroid Build Coastguard Worker "view_as_complex", 44*da0073e9SAndroid Build Coastguard Worker "view", 45*da0073e9SAndroid Build Coastguard Worker "resize_", 46*da0073e9SAndroid Build Coastguard Worker "resize_as_", 47*da0073e9SAndroid Build Coastguard Worker "add", 48*da0073e9SAndroid Build Coastguard Worker "sub", 49*da0073e9SAndroid Build Coastguard Worker "mul", 50*da0073e9SAndroid Build Coastguard Worker "div", 51*da0073e9SAndroid Build Coastguard Worker "abs", 52*da0073e9SAndroid Build Coastguard Worker] 53*da0073e9SAndroid Build Coastguard Worker_xpu_tensor_factory_op_list = [ 54*da0073e9SAndroid Build Coastguard Worker "as_strided", 55*da0073e9SAndroid Build Coastguard Worker "empty", 56*da0073e9SAndroid Build Coastguard Worker "empty_strided", 57*da0073e9SAndroid Build Coastguard Worker] 58*da0073e9SAndroid Build Coastguard Worker_xpu_not_test_dtype_op_list = [ 59*da0073e9SAndroid Build Coastguard Worker "resize_", # Skipped by CPU 60*da0073e9SAndroid Build Coastguard Worker "resize_as_", # Skipped by CPU 61*da0073e9SAndroid Build Coastguard Worker "abs", # Not aligned dtype 62*da0073e9SAndroid Build Coastguard Worker] 63*da0073e9SAndroid Build Coastguard Worker_xpu_all_op_list = _xpu_computation_op_list + _xpu_tensor_factory_op_list 64*da0073e9SAndroid Build Coastguard Worker_xpu_all_ops = [op for op in ops_and_refs if op.name in _xpu_all_op_list] 65*da0073e9SAndroid Build Coastguard Worker_xpu_computation_ops = [ 66*da0073e9SAndroid Build Coastguard Worker op for op in ops_and_refs if op.name in _xpu_computation_op_list 67*da0073e9SAndroid Build Coastguard Worker] 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerclass TestXpu(TestCase): 71*da0073e9SAndroid Build Coastguard Worker def test_device_behavior(self): 72*da0073e9SAndroid Build Coastguard Worker current_device = torch.xpu.current_device() 73*da0073e9SAndroid Build Coastguard Worker torch.xpu.set_device(current_device) 74*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_device, torch.xpu.current_device()) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected") 77*da0073e9SAndroid Build Coastguard Worker def test_multi_device_behavior(self): 78*da0073e9SAndroid Build Coastguard Worker current_device = torch.xpu.current_device() 79*da0073e9SAndroid Build Coastguard Worker target_device = (current_device + 1) % torch.xpu.device_count() 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker with torch.xpu.device(target_device): 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target_device, torch.xpu.current_device()) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_device, torch.xpu.current_device()) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker with torch.xpu._DeviceGuard(target_device): 86*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target_device, torch.xpu.current_device()) 87*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_device, torch.xpu.current_device()) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def test_get_device_properties(self): 90*da0073e9SAndroid Build Coastguard Worker current_device = torch.xpu.current_device() 91*da0073e9SAndroid Build Coastguard Worker device_properties = torch.xpu.get_device_properties(current_device) 92*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_properties, torch.xpu.get_device_properties(None)) 93*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_properties, torch.xpu.get_device_properties()) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker device_name = torch.xpu.get_device_name(current_device) 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_name, torch.xpu.get_device_name(None)) 97*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_name, torch.xpu.get_device_name()) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker device_capability = torch.xpu.get_device_capability(current_device) 100*da0073e9SAndroid Build Coastguard Worker self.assertTrue(device_capability["max_work_group_size"] > 0) 101*da0073e9SAndroid Build Coastguard Worker self.assertTrue(device_capability["max_num_sub_groups"] > 0) 102*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 103*da0073e9SAndroid Build Coastguard Worker device_properties.driver_version, device_capability["driver_version"] 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_properties.has_fp16, device_capability["has_fp16"]) 106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_properties.has_fp64, device_capability["has_fp64"]) 107*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 108*da0073e9SAndroid Build Coastguard Worker device_properties.has_atomic64, device_capability["has_atomic64"] 109*da0073e9SAndroid Build Coastguard Worker ) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker def test_wrong_xpu_fork(self): 112*da0073e9SAndroid Build Coastguard Worker stderr = TestCase.runWithPytorchAPIUsageStderr( 113*da0073e9SAndroid Build Coastguard Worker """\ 114*da0073e9SAndroid Build Coastguard Workerimport torch 115*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import Process 116*da0073e9SAndroid Build Coastguard Workerdef run(rank): 117*da0073e9SAndroid Build Coastguard Worker torch.xpu.set_device(rank) 118*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 119*da0073e9SAndroid Build Coastguard Worker size = 2 120*da0073e9SAndroid Build Coastguard Worker processes = [] 121*da0073e9SAndroid Build Coastguard Worker for rank in range(size): 122*da0073e9SAndroid Build Coastguard Worker # it would work fine without the line below 123*da0073e9SAndroid Build Coastguard Worker torch.xpu.set_device(0) 124*da0073e9SAndroid Build Coastguard Worker p = Process(target=run, args=(rank,)) 125*da0073e9SAndroid Build Coastguard Worker p.start() 126*da0073e9SAndroid Build Coastguard Worker processes.append(p) 127*da0073e9SAndroid Build Coastguard Worker for p in processes: 128*da0073e9SAndroid Build Coastguard Worker p.join() 129*da0073e9SAndroid Build Coastguard Worker""" 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.") 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def test_streams(self): 134*da0073e9SAndroid Build Coastguard Worker s0 = torch.xpu.Stream() 135*da0073e9SAndroid Build Coastguard Worker torch.xpu.set_stream(s0) 136*da0073e9SAndroid Build Coastguard Worker s1 = torch.xpu.current_stream() 137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0, s1) 138*da0073e9SAndroid Build Coastguard Worker s2 = torch.xpu.Stream() 139*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s0 == s2) 140*da0073e9SAndroid Build Coastguard Worker torch.xpu.set_stream(s2) 141*da0073e9SAndroid Build Coastguard Worker with torch.xpu.stream(s0): 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0, torch.xpu.current_stream()) 143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s2, torch.xpu.current_stream()) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def test_stream_priority(self): 146*da0073e9SAndroid Build Coastguard Worker low, high = torch.xpu.Stream.priority_range() 147*da0073e9SAndroid Build Coastguard Worker s0 = torch.xpu.Stream(device=0, priority=low) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(low, s0.priority) 150*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.device("xpu:0"), s0.device) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker s1 = torch.xpu.Stream(device=0, priority=high) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(high, s1.priority) 155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.device("xpu:0"), s1.device) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def test_stream_event_repr(self): 158*da0073e9SAndroid Build Coastguard Worker s = torch.xpu.current_stream() 159*da0073e9SAndroid Build Coastguard Worker self.assertTrue("torch.xpu.Stream" in str(s)) 160*da0073e9SAndroid Build Coastguard Worker e = torch.xpu.Event() 161*da0073e9SAndroid Build Coastguard Worker self.assertTrue("torch.xpu.Event(uninitialized)" in str(e)) 162*da0073e9SAndroid Build Coastguard Worker s.record_event(e) 163*da0073e9SAndroid Build Coastguard Worker self.assertTrue("torch.xpu.Event" in str(e)) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def test_events(self): 166*da0073e9SAndroid Build Coastguard Worker stream = torch.xpu.current_stream() 167*da0073e9SAndroid Build Coastguard Worker event = torch.xpu.Event() 168*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event.query()) 169*da0073e9SAndroid Build Coastguard Worker stream.record_event(event) 170*da0073e9SAndroid Build Coastguard Worker event.synchronize() 171*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event.query()) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker def test_generic_stream_event(self): 174*da0073e9SAndroid Build Coastguard Worker stream = torch.Stream("xpu") 175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.device_index, torch.xpu.current_device()) 176*da0073e9SAndroid Build Coastguard Worker xpu_stream = torch.xpu.Stream( 177*da0073e9SAndroid Build Coastguard Worker stream_id=stream.stream_id, 178*da0073e9SAndroid Build Coastguard Worker device_index=stream.device_index, 179*da0073e9SAndroid Build Coastguard Worker device_type=stream.device_type, 180*da0073e9SAndroid Build Coastguard Worker ) 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.stream_id, xpu_stream.stream_id) 182*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker event1 = torch.Event("xpu") 185*da0073e9SAndroid Build Coastguard Worker event2 = torch.Event("xpu") 186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(event1.event_id, 0) 187*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1000) 188*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1000) 189*da0073e9SAndroid Build Coastguard Worker with torch.xpu.stream(xpu_stream): 190*da0073e9SAndroid Build Coastguard Worker a_xpu = a.to("xpu", non_blocking=True) 191*da0073e9SAndroid Build Coastguard Worker b_xpu = b.to("xpu", non_blocking=True) 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.stream_id, torch.xpu.current_stream().stream_id) 193*da0073e9SAndroid Build Coastguard Worker event1.record(stream) 194*da0073e9SAndroid Build Coastguard Worker event1.synchronize() 195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event1.query()) 196*da0073e9SAndroid Build Coastguard Worker c_xpu = a_xpu + b_xpu 197*da0073e9SAndroid Build Coastguard Worker event2.record() 198*da0073e9SAndroid Build Coastguard Worker event2.synchronize() 199*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event2.query()) 200*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(event1.event_id, event2.event_id) 201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_xpu.cpu(), a + b) 202*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 203*da0073e9SAndroid Build Coastguard Worker NotImplementedError, "elapsedTime is not supported by XPU backend." 204*da0073e9SAndroid Build Coastguard Worker ): 205*da0073e9SAndroid Build Coastguard Worker event1.elapsed_time(event2) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def test_generator(self): 208*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2024) 209*da0073e9SAndroid Build Coastguard Worker g_state0 = torch.xpu.get_rng_state() 210*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1234) 211*da0073e9SAndroid Build Coastguard Worker g_state1 = torch.xpu.get_rng_state() 212*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(g_state0, g_state1) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker torch.xpu.manual_seed(2024) 215*da0073e9SAndroid Build Coastguard Worker g_state2 = torch.xpu.get_rng_state() 216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g_state0, g_state2) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker torch.xpu.set_rng_state(g_state1) 219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g_state1, torch.xpu.get_rng_state()) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1234) 222*da0073e9SAndroid Build Coastguard Worker torch.xpu.set_rng_state(g_state0) 223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2024, torch.xpu.initial_seed()) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker @onlyXPU 226*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 227*da0073e9SAndroid Build Coastguard Worker @ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one) 228*da0073e9SAndroid Build Coastguard Worker def test_compare_cpu(self, device, dtype, op): 229*da0073e9SAndroid Build Coastguard Worker def to_cpu(arg): 230*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor): 231*da0073e9SAndroid Build Coastguard Worker return arg.to(device="cpu") 232*da0073e9SAndroid Build Coastguard Worker return arg 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker samples = op.reference_inputs(device, dtype) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker for sample in samples: 237*da0073e9SAndroid Build Coastguard Worker cpu_sample = sample.transform(to_cpu) 238*da0073e9SAndroid Build Coastguard Worker xpu_results = op(sample.input, *sample.args, **sample.kwargs) 239*da0073e9SAndroid Build Coastguard Worker cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker xpu_results = sample.output_process_fn_grad(xpu_results) 242*da0073e9SAndroid Build Coastguard Worker cpu_results = cpu_sample.output_process_fn_grad(cpu_results) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker # Lower tolerance because we are running this as a `@slowTest` 245*da0073e9SAndroid Build Coastguard Worker # Don't want the periodic tests to fail frequently 246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker @onlyXPU 249*da0073e9SAndroid Build Coastguard Worker @ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,)) 250*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior") 251*da0073e9SAndroid Build Coastguard Worker def test_non_standard_bool_values(self, device, dtype, op): 252*da0073e9SAndroid Build Coastguard Worker # Test boolean values other than 0x00 and 0x01 (gh-54789) 253*da0073e9SAndroid Build Coastguard Worker def convert_boolean_tensors(x): 254*da0073e9SAndroid Build Coastguard Worker if not isinstance(x, torch.Tensor) or x.dtype != torch.bool: 255*da0073e9SAndroid Build Coastguard Worker return x 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker # Map False -> 0 and True -> Random value in [2, 255] 258*da0073e9SAndroid Build Coastguard Worker true_vals = torch.randint( 259*da0073e9SAndroid Build Coastguard Worker 2, 255, x.shape, dtype=torch.uint8, device=x.device 260*da0073e9SAndroid Build Coastguard Worker ) 261*da0073e9SAndroid Build Coastguard Worker false_vals = torch.zeros((), dtype=torch.uint8, device=x.device) 262*da0073e9SAndroid Build Coastguard Worker x_int = torch.where(x, true_vals, false_vals) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker ret = x_int.view(torch.bool) 265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret, x) 266*da0073e9SAndroid Build Coastguard Worker return ret 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device, dtype): 269*da0073e9SAndroid Build Coastguard Worker expect = op(sample.input, *sample.args, **sample.kwargs) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker transformed = sample.transform(convert_boolean_tensors) 272*da0073e9SAndroid Build Coastguard Worker actual = op(transformed.input, *transformed.args, **transformed.kwargs) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def test_serialization_array_with_storage(self): 277*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5).xpu() 278*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(2, 5, dtype=torch.int, device="xpu") 279*da0073e9SAndroid Build Coastguard Worker q = [x, y, x, y.storage()] 280*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 281*da0073e9SAndroid Build Coastguard Worker torch.save(q, f) 282*da0073e9SAndroid Build Coastguard Worker f.seek(0) 283*da0073e9SAndroid Build Coastguard Worker q_copy = torch.load(f) 284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy, q, atol=0, rtol=0) 285*da0073e9SAndroid Build Coastguard Worker q_copy[0].fill_(5) 286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0) 287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy[0].dtype, torch.float) 288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy[1].dtype, torch.int) 289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy[2].dtype, torch.float) 290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) 291*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) 292*da0073e9SAndroid Build Coastguard Worker q_copy[1].fill_(10) 293*da0073e9SAndroid Build Coastguard Worker y.fill_(10) 294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy[3], y.storage()) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def test_serialization_array_with_empty(self): 297*da0073e9SAndroid Build Coastguard Worker x = [ 298*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4).xpu(), 299*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=torch.float, device=torch.device("xpu")), 300*da0073e9SAndroid Build Coastguard Worker ] 301*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 302*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 303*da0073e9SAndroid Build Coastguard Worker f.seek(0) 304*da0073e9SAndroid Build Coastguard Worker x_copy = torch.load(f) 305*da0073e9SAndroid Build Coastguard Worker for original, copy in zip(x, x_copy): 306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy, original) 307*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(copy), type(original)) 308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy.get_device(), original.get_device()) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestXpu, globals(), only_for="xpu") 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Workerclass TestXpuAutocast(TestCase): 315*da0073e9SAndroid Build Coastguard Worker # These operators are not implemented on XPU backend and we can NOT fall back 316*da0073e9SAndroid Build Coastguard Worker # them to CPU. So we have to skip them at this moment. 317*da0073e9SAndroid Build Coastguard Worker # TODO: remove these operators from skip list when they are implemented on XPU backend. 318*da0073e9SAndroid Build Coastguard Worker skip_list = ["gru_cell"] 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker def setUp(self): 321*da0073e9SAndroid Build Coastguard Worker super().setUp() 322*da0073e9SAndroid Build Coastguard Worker self.autocast_lists = AutocastTestLists(torch.device("xpu")) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 325*da0073e9SAndroid Build Coastguard Worker del self.autocast_lists 326*da0073e9SAndroid Build Coastguard Worker super().tearDown() 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker def _run_autocast_outofplace( 329*da0073e9SAndroid Build Coastguard Worker self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None 330*da0073e9SAndroid Build Coastguard Worker ): 331*da0073e9SAndroid Build Coastguard Worker # helper to cast args 332*da0073e9SAndroid Build Coastguard Worker def cast(val, to_type): 333*da0073e9SAndroid Build Coastguard Worker if isinstance(val, torch.Tensor): 334*da0073e9SAndroid Build Coastguard Worker return val.to(to_type) if val.is_floating_point() else val 335*da0073e9SAndroid Build Coastguard Worker elif isinstance(val, collections.abc.Iterable): 336*da0073e9SAndroid Build Coastguard Worker return type(val)(cast(v, to_type) for v in val) 337*da0073e9SAndroid Build Coastguard Worker else: 338*da0073e9SAndroid Build Coastguard Worker return val 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker if add_kwargs is None: 341*da0073e9SAndroid Build Coastguard Worker add_kwargs = {} 342*da0073e9SAndroid Build Coastguard Worker fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 343*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_autocast_enabled("xpu")) 344*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast("xpu", dtype=fast_dtype): 345*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_autocast_enabled("xpu")) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker out_type = out_type if out_type is not None else run_as_type 348*da0073e9SAndroid Build Coastguard Worker output = output_method = None 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker # Try module.* variant, if requested: 351*da0073e9SAndroid Build Coastguard Worker if module is not None and hasattr(module, op): 352*da0073e9SAndroid Build Coastguard Worker output = getattr(module, op)(*args, **add_kwargs) 353*da0073e9SAndroid Build Coastguard Worker if isinstance(output, torch.Tensor): 354*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 355*da0073e9SAndroid Build Coastguard Worker out_type == output.dtype, 356*da0073e9SAndroid Build Coastguard Worker f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", 357*da0073e9SAndroid Build Coastguard Worker ) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker # Try Tensor.* variant: 360*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.Tensor, op): 361*da0073e9SAndroid Build Coastguard Worker output_method = getattr(args[0], op)(*args[1:], **add_kwargs) 362*da0073e9SAndroid Build Coastguard Worker if isinstance(output_method, torch.Tensor): 363*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 364*da0073e9SAndroid Build Coastguard Worker out_type == output_method.dtype, 365*da0073e9SAndroid Build Coastguard Worker f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", 366*da0073e9SAndroid Build Coastguard Worker ) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 369*da0073e9SAndroid Build Coastguard Worker (output is not None) or (output_method is not None), 370*da0073e9SAndroid Build Coastguard Worker f"{op} not found as an attribute on either Tensor or the requested module {module}", 371*da0073e9SAndroid Build Coastguard Worker ) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker # Accounts for ops that return Tensors, iterables, and other non-Tensors. 374*da0073e9SAndroid Build Coastguard Worker # For example, lstm_cell returns a tuple and equal returns bool. 375*da0073e9SAndroid Build Coastguard Worker def compare(first, second): 376*da0073e9SAndroid Build Coastguard Worker if isinstance(first, torch.Tensor): 377*da0073e9SAndroid Build Coastguard Worker return torch.equal(first, second) 378*da0073e9SAndroid Build Coastguard Worker elif isinstance(first, collections.abc.Iterable): 379*da0073e9SAndroid Build Coastguard Worker return all(compare(f, s) for f, s in zip(first, second)) 380*da0073e9SAndroid Build Coastguard Worker else: 381*da0073e9SAndroid Build Coastguard Worker return first == second 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker # If both torch.* and Tensor.* variants were found, check outputs are identical 384*da0073e9SAndroid Build Coastguard Worker if (output is not None) and (output_method is not None): 385*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(output) == type(output_method)) 386*da0073e9SAndroid Build Coastguard Worker comparison = compare(output, output_method) 387*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 388*da0073e9SAndroid Build Coastguard Worker comparison, f"torch.{op} result did not match Tensor.{op} result" 389*da0073e9SAndroid Build Coastguard Worker ) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker # Compare numerics to Python-side "autocasting" that (we expect) does the same thing 392*da0073e9SAndroid Build Coastguard Worker # as the C++-side autocasting, and should be bitwise accurate. 393*da0073e9SAndroid Build Coastguard Worker output_to_compare = output if output is not None else output_method 394*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast("xpu", enabled=False): 395*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_autocast_enabled("xpu")) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker if module is not None and hasattr(module, op): 398*da0073e9SAndroid Build Coastguard Worker control = getattr(module, op)( 399*da0073e9SAndroid Build Coastguard Worker *cast(args, run_as_type), **add_kwargs 400*da0073e9SAndroid Build Coastguard Worker ) 401*da0073e9SAndroid Build Coastguard Worker else: 402*da0073e9SAndroid Build Coastguard Worker control = getattr(args[0].to(run_as_type), op)( 403*da0073e9SAndroid Build Coastguard Worker *cast(args[1:], run_as_type), **add_kwargs 404*da0073e9SAndroid Build Coastguard Worker ) 405*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(output_to_compare) == type(control)) 406*da0073e9SAndroid Build Coastguard Worker comparison = compare(output_to_compare, control) 407*da0073e9SAndroid Build Coastguard Worker self.assertTrue(comparison, f"torch.{op} result did not match control") 408*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_autocast_enabled("xpu")) 409*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_autocast_enabled("xpu")) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_fp16(self): 412*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_fp16: 413*da0073e9SAndroid Build Coastguard Worker skip_test = False 414*da0073e9SAndroid Build Coastguard Worker op, args = op_with_args[0], op_with_args[1] 415*da0073e9SAndroid Build Coastguard Worker if op in self.skip_list: 416*da0073e9SAndroid Build Coastguard Worker skip_test = True # skip unimplemented op 417*da0073e9SAndroid Build Coastguard Worker if len(op_with_args) == 3: 418*da0073e9SAndroid Build Coastguard Worker skip_test = True # skip cudnn op 419*da0073e9SAndroid Build Coastguard Worker if not skip_test: 420*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace(op, args, torch.float16) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_bf16(self): 423*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_fp16: 424*da0073e9SAndroid Build Coastguard Worker skip_test = False 425*da0073e9SAndroid Build Coastguard Worker op, args = op_with_args[0], op_with_args[1] 426*da0073e9SAndroid Build Coastguard Worker if op in self.skip_list: 427*da0073e9SAndroid Build Coastguard Worker skip_test = True # skip unimplemented op 428*da0073e9SAndroid Build Coastguard Worker if len(op_with_args) == 3: 429*da0073e9SAndroid Build Coastguard Worker skip_test = True # skip cudnn op 430*da0073e9SAndroid Build Coastguard Worker if not skip_test: 431*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace(op, args, torch.bfloat16) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_need_autocast_promote(self): 434*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.torch_need_autocast_promote: 435*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace(op, args, torch.float32) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_expect_builtin_promote(self): 438*da0073e9SAndroid Build Coastguard Worker for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: 439*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def test_xpu_autocast_dtype(self): 442*da0073e9SAndroid Build Coastguard Worker dtype = torch.get_autocast_dtype("xpu") 443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dtype, torch.float16) 444*da0073e9SAndroid Build Coastguard Worker mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu") 445*da0073e9SAndroid Build Coastguard Worker mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu") 446*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast("xpu"): 447*da0073e9SAndroid Build Coastguard Worker result = torch.mm(mat0_fp32, mat1_fp32) 448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.float16) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Workerclass TestXpuTrace(TestCase): 452*da0073e9SAndroid Build Coastguard Worker def setUp(self): 453*da0073e9SAndroid Build Coastguard Worker torch._C._activate_gpu_trace() 454*da0073e9SAndroid Build Coastguard Worker self.mock = unittest.mock.MagicMock() 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker def test_event_creation_callback(self): 457*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_creation(self.mock) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker event = torch.xpu.Event() 460*da0073e9SAndroid Build Coastguard Worker event.record() 461*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(event._as_parameter_.value) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker def test_event_deletion_callback(self): 464*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_deletion(self.mock) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker event = torch.xpu.Event() 467*da0073e9SAndroid Build Coastguard Worker event.record() 468*da0073e9SAndroid Build Coastguard Worker event_id = event._as_parameter_.value 469*da0073e9SAndroid Build Coastguard Worker del event 470*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(event_id) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def test_event_record_callback(self): 473*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_record(self.mock) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker event = torch.xpu.Event() 476*da0073e9SAndroid Build Coastguard Worker event.record() 477*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with( 478*da0073e9SAndroid Build Coastguard Worker event._as_parameter_.value, torch.xpu.current_stream().sycl_queue 479*da0073e9SAndroid Build Coastguard Worker ) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker def test_event_wait_callback(self): 482*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_wait(self.mock) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker event = torch.xpu.Event() 485*da0073e9SAndroid Build Coastguard Worker event.record() 486*da0073e9SAndroid Build Coastguard Worker event.wait() 487*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with( 488*da0073e9SAndroid Build Coastguard Worker event._as_parameter_.value, torch.xpu.current_stream().sycl_queue 489*da0073e9SAndroid Build Coastguard Worker ) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker def test_device_synchronization_callback(self): 492*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_device_synchronization(self.mock) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker torch.xpu.synchronize() 495*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called() 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker def test_stream_synchronization_callback(self): 498*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_stream_synchronization(self.mock) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker stream = torch.xpu.Stream() 501*da0073e9SAndroid Build Coastguard Worker stream.synchronize() 502*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(stream.sycl_queue) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker def test_event_synchronization_callback(self): 505*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_synchronization(self.mock) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker event = torch.xpu.Event() 508*da0073e9SAndroid Build Coastguard Worker event.record() 509*da0073e9SAndroid Build Coastguard Worker event.synchronize() 510*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(event._as_parameter_.value) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 514*da0073e9SAndroid Build Coastguard Worker run_tests() 515