# Owner(s): ["oncall: jit"] import torch from torch.cuda.amp import autocast from typing import Optional, Tuple import unittest from test_jit import JitTestCase from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo from torch.testing import FileCheck from jit.test_models import MnistNet TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() @skipIfTorchDynamo("Not a TorchDynamo suitable test") class TestAutocast(JitTestCase): def setUp(self): # common input tensors if TEST_CUDA: self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') self.old_value = torch._C._jit_set_autocast_mode(True) super().setUp() def tearDown(self): torch._C._jit_set_autocast_mode(self.old_value) super().tearDown() @unittest.skipIf(not TEST_CUDA, "No cuda") def test_jit_generic_autocast(self): @torch.jit.script def fn_cuda_autocast(a, b): with autocast(): x = torch.mm(a, b) y = torch.sum(x) return x, y @torch.jit.script def fn_generic_autocast(a, b): with torch.amp.autocast(device_type='cuda'): x = torch.mm(a, b) y = torch.sum(x) return x, y self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32)) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_minimal(self): @torch.jit.script def fn(a, b): with autocast(): x = torch.mm(a, b) y = torch.sum(x) return x, y x, y = fn(self.a_fp32, self.b_fp32) self.assertEqual(x.dtype, torch.float16) self.assertEqual(y.dtype, torch.float32) @unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support") def test_linear_bf16(self): @torch.jit.script def fn(a, b): with autocast(dtype=torch.bfloat16): x = torch.mm(a, b) y = torch.sum(x) return x, y x, y = fn(self.a_fp32, self.b_fp32) self.assertEqual(x.dtype, torch.bfloat16) self.assertEqual(y.dtype, torch.float32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_minimal_cpu(self): @torch.jit.script def fn(a, b): with autocast(): return torch.mm(a, b) result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu')) self.assertEqual(result.dtype, torch.float32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_minimal_off(self): @torch.jit.script def fn(a, b): with autocast(enabled=False): return torch.mm(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_runtime_autocast_state(self): @torch.jit.script def fn(a, b, use_amp: bool): with autocast(enabled=use_amp): return torch.mm(a, b) # runtime values for autocast enable argument are not supported with self.assertRaises(RuntimeError): fn(self.a_fp32, self.b_fp32, True) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_runtime_autocast_state_expr(self): @torch.jit.script def fn(a, b): with autocast(enabled=True if a[0][0] > 0.5 else False): return torch.mm(a, b) # runtime values for autocast enable argument are not supported with self.assertRaises(RuntimeError): fn(self.a_fp32, self.b_fp32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_explicit_casts(self): @torch.jit.script def fn(a, b, c, d): with autocast(): e = torch.mm(a.double(), b.double()).float() f = torch.mm(c, d).double() g = torch.mm(c.double(), f) return e, f, g e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) self.assertEqual(e.dtype, torch.float32) self.assertEqual(f.dtype, torch.float64) self.assertEqual(g.dtype, torch.float64) # multiple uses of the same input value @unittest.skipIf(not TEST_CUDA, "No cuda") def test_duplicate_inputs(self): @torch.jit.script def fn(a, b): with autocast(): e = torch.mm(a, a) f = torch.mm(e, e) return e, f e, f = fn(self.a_fp32, self.b_fp32) self.assertEqual(e.dtype, torch.float16) self.assertEqual(f.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_fp32_policy(self): @torch.jit.script def fn(a): with autocast(enabled=True): return torch.log(a) result = fn(self.a_fp16) self.assertEqual(result.dtype, torch.float32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_fp32_policy_with_fp64(self): @torch.jit.script def fn(a): with autocast(enabled=True): return torch.log(a) # fp32 policy should not narrow fp64 to fp32! result = fn(self.a_fp32.double()) self.assertEqual(result.dtype, torch.float64) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_promote_policy(self): @torch.jit.script def fn(a, b, c, d): with autocast(): e = torch.mm(a, b) f = torch.addcmul(e, c, d, value=0.1) return e, f e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) self.assertEqual(e.dtype, torch.float16) self.assertEqual(f.dtype, torch.float32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_promote_policy_fp64(self): @torch.jit.script def fn(a, b): with autocast(enabled=True): return torch.addcmul(a, a, b, value=0.1) result = fn(self.a_fp32.double(), self.b_fp32.double()) self.assertEqual(result.dtype, torch.float64) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_fp32_set_opt_dtype_policy(self): @torch.jit.script def fn(a, b, c, d, dtype: Optional[int]): with autocast(enabled=True): x = torch.softmax(a, 0) y = torch.softmax(b, 0, None) z = torch.softmax(c, 0, torch.float64) w = torch.softmax(d, 0, dtype) return x, y, z, w x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None) self.assertEqual(x.dtype, torch.float32) self.assertEqual(y.dtype, torch.float32) self.assertEqual(z.dtype, torch.float64) self.assertEqual(w.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_fp32_set_opt_dtype_policy_fp64(self): @torch.jit.script def fn(a, b, c, d, dtype: Optional[int]): with autocast(enabled=True): x = torch.softmax(a, 0) y = torch.softmax(b, 0, None) z = torch.softmax(c, 0, torch.float64) w = torch.softmax(d, 0, dtype) return x, y, z, w x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None) self.assertEqual(x.dtype, torch.float64) self.assertEqual(y.dtype, torch.float64) self.assertEqual(z.dtype, torch.float64) self.assertEqual(w.dtype, torch.float64) @unittest.skipIf(True, "broken due to lack of type propagation") @unittest.skipIf(not TEST_CUDA, "No cuda") def test_control_flow(self): @torch.jit.script def fn(a, b, c, d): with autocast(): if a[0][0] > 0.5: e = torch.mm(a, b) x = 1 else: e = torch.mm(c, d) x = 2 f = torch.mm(d, e) * x return e, f e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) self.assertEqual(e.dtype, torch.float16) self.assertEqual(f.dtype, torch.float16) # this works find in regular Python, but it creates a delicate # situation in TorchScript where the types are not consistent across # the then/else branches @unittest.skipIf(not TEST_CUDA, "No cuda") def test_divergent_types(self): @torch.jit.script def fn(a, b, c, d): with autocast(): if a[0][0] > 0.5: e = torch.mm(a, b) f = torch.mm(a, b).float() else: e = torch.mm(c, d).float() f = torch.mm(a, b) return torch.mm(e.float(), f.float()) result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) self.assertEqual(result.dtype, torch.float32) # another, more complex case of divergent types @unittest.skipIf(not TEST_CUDA, "No cuda") def test_divergent_autocast(self): @torch.jit.script def fn(a, b, c, d): autocast_on = autocast(enabled=True) autocast_off = autocast(enabled=False) if a[0][0] > 0.5: with autocast_on: e = torch.mm(a, b) else: with autocast_off: e = torch.mm(c, d) return torch.mm(e, e) fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_conditional_autocast(self): @torch.jit.script def fn(a, b): autocast_on = autocast(enabled=True) autocast_off = autocast(enabled=False) with autocast_on if a[0][0] > 0.5 else autocast_off: return torch.mm(a, b) # conditional autocast expressions are not supported with self.assertRaises(RuntimeError): fn(self.a_fp32, self.b_fp32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_nested_autocast(self): @torch.jit.script def fn(a, b, c, d): with autocast(enabled=False): e = torch.mm(a, b) with autocast(enabled=True): f = torch.mm(e, c) with autocast(enabled=False): g = torch.mm(e, d) return e, f, g e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) self.assertEqual(e.dtype, torch.float32) self.assertEqual(f.dtype, torch.float16) self.assertEqual(g.dtype, torch.float32) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_implicitly_nested_autocast(self): @torch.jit.script def fn(a, b): with autocast(enabled=False), autocast(enabled=True): return torch.mm(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_reused_autocast(self): @torch.jit.script def fn(a, b, c, d): autocast_instance = autocast(enabled=True) with autocast_instance: e = torch.mm(a, b) with autocast_instance: e = torch.mm(c, d) f = torch.mm(d, e) g = torch.mm(e, f) return e, f, g e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) self.assertEqual(e.dtype, torch.float16) self.assertEqual(f.dtype, torch.float16) self.assertEqual(g.dtype, torch.float16) # TODO: fix and enable this test? # (we could technically fix this, but is it really worth it?) @unittest.skipIf(True, "unsuported autocast syntax") def test_reused_autocast_expr(self): @torch.jit.script def fn(a, b, c, d): with autocast(enabled=True) as autocast_instance: e = torch.mm(a, b) with autocast_instance: e = torch.mm(c, d) f = torch.mm(d, e) g = torch.mm(e, f) return e, f, g e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) self.assertEqual(e.dtype, torch.float16) self.assertEqual(f.dtype, torch.float16) self.assertEqual(g.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_callees(self): def helper(a, b): return torch.mm(a, b) @torch.jit.script def fn(a, b): with autocast(enabled=True): tmp = helper(a, b) tmp = helper(tmp, tmp) tmp = helper(tmp, tmp) tmp = helper(tmp, tmp) return helper(tmp, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_callees_with_autocast_on(self): def helper(a, b): with autocast(enabled=True): return torch.mm(a, b) @torch.jit.script def fn(a, b): with autocast(enabled=False): return helper(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_callees_with_autocast_off(self): def helper(a, b): with autocast(enabled=False): return torch.mm(a, b) @torch.jit.script def fn(a, b): with autocast(enabled=True): return helper(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float32) # scripting inside eager autocast @unittest.skipIf(not TEST_CUDA, "No cuda") def test_eager_and_script(self): @torch.jit.script def fn(a, b): return torch.mm(a, b) for i in range(8): use_autocast = (i % 2 == 0) expected_dtype = torch.float16 if use_autocast else torch.float32 with autocast(enabled=use_autocast): result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, expected_dtype) # traced inside scripting @unittest.skipIf(not TEST_CUDA, "No cuda") def test_script_and_tracing(self): def helper(a, b): return torch.mm(a, b) traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) @torch.jit.script def fn(a, b): with autocast(enabled=True): return traced(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) # traced with autocast inside scripting @unittest.skipIf(True, "autocast(False) is ignored inside traced functions") @unittest.skipIf(not TEST_CUDA, "No cuda") def test_script_and_tracing_with_autocast(self): def helper(a, b): with autocast(enabled=False): return torch.mm(a, b) * 2.0 traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) @torch.jit.script def fn(a, b): with autocast(enabled=True): return traced(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float32) # scripted called from traced @unittest.skipIf(not TEST_CUDA, "No cuda") def test_tracing_and_script(self): @torch.jit.script def fn(a, b): with autocast(): return torch.mm(a, b) def traced(a, b): return fn(a, b) traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) result = traced(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) # scripted called from traced with autocast @unittest.skipIf(True, "scripted called from traced TorchScript is not yet working") @unittest.skipIf(not TEST_CUDA, "No cuda") def test_tracing_with_autocast_and_script(self): @torch.jit.script def fn(a, b): return torch.mm(a, b) def traced(a, b): with autocast(enabled=True): return fn(a, b) traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) result = traced(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_script_module(self): class TestModule(torch.nn.Module): def __init__(self, N, M): super().__init__() self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32)) self.linear = torch.nn.Linear(N, M).float() def forward(self, input): with autocast(enabled=True): output = self.weight.mv(input) output = self.linear(output) return output scripted_module = torch.jit.script(TestModule(2, 3)).cuda() input = torch.rand(3, dtype=torch.float32, device='cuda') result = scripted_module(input) self.assertEqual(result.dtype, torch.float16) @unittest.skipIf(True, "autocast decorators not supported") @unittest.skipIf(not TEST_CUDA, "No cuda") def test_autocast_decorator(self): @torch.jit.script @autocast(enabled=True) def fn(a, b): return torch.mm(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) # this is equivalent to running scripted functions inside autocast) # (see also test_eager_and_script) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_autocast_decorator_outside_jit(self): @autocast(enabled=True) @torch.jit.script def fn(a, b): return torch.mm(a, b) result = fn(self.a_fp32, self.b_fp32) self.assertEqual(result.dtype, torch.float16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_inplace(self): @torch.jit.script def fn(a, b, c): with autocast(enabled=True): x = torch.addmm(a, b, c) y = torch.addmm(a, b, c, out=a) z = a.addmm_(b, c) return x, y, z x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32) self.assertEqual(x.dtype, torch.float16) self.assertEqual(y.dtype, torch.float32) self.assertEqual(z.dtype, torch.float32) def _test_autocast(self, func, cast_op, *args): jit_func = torch.jit.script(func) o = func(*args) jit_o = jit_func(*args) if cast_op is not None: FileCheck().check(cast_op).run(jit_func.graph_for(*args)) for o0, o1 in zip(o, jit_o): self.assertEqual(o0.dtype, o1.dtype) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_autocast_api(self): def t_autocast_cpu(x, y): with torch.autocast("cpu", dtype=torch.bfloat16): return torch.mm(x, y) def t_autocast_cuda(x, y): with torch.autocast("cuda", dtype=torch.half): return torch.mm(x, y) def t_cuda_amp_autocast(x, y): with torch.cuda.amp.autocast(): return torch.mm(x, y) def t_cpu_amp_autocast(x, y): with torch.cpu.amp.autocast(): return torch.mm(x, y) x = torch.randn(5, 5, device="cuda", dtype=torch.float32) y = torch.randn(5, 5, device="cuda", dtype=torch.float32) self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) @unittest.skipIf(True, "we need to provide dtype argument at this moment") @unittest.skipIf(not TEST_CUDA, "No cuda") def test_autocast_api_not_supported(self): def t_autocast_cpu(x, y): # no dtype provided is not currently supported with torch.autocast("cpu"): return torch.mm(x, y) def t_autocast_cuda(x, y): # no dtype provided is not currently supported with torch.autocast("cuda"): return torch.mm(x, y) x = torch.randn(5, 5, device="cuda", dtype=torch.float32) y = torch.randn(5, 5, device="cuda", dtype=torch.float32) self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_autocast_mixed_dtypes(self): def t(cpu0, cpu1, cuda0, cuda1): with torch.autocast("cpu", torch.bfloat16): with torch.autocast("cuda", torch.float16): cpu_o = torch.mm(cpu0, cpu1) cuda_o = torch.mm(cuda0, cuda1) return cpu_o, cuda_o jit_t = torch.jit.script(t) cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_jit_executor_under_autocast(self): def t(cpu0, cpu1, cuda0, cuda1): cpu_o = torch.mm(cpu0, cpu1) cuda_o = torch.mm(cuda0, cuda1) return cpu_o, cuda_o jit_t = torch.jit.script(t) cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) with torch.autocast("cpu", torch.bfloat16): with torch.autocast("cuda", torch.float16): self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) with torch.autocast("cpu", torch.bfloat16): self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) with torch.autocast("cuda", torch.float16): self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) # no cast op should be observed when executing outside autocast context self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_autocast_autodiff(self): def t(t0, t1): o = torch.mm(t0, t1) return o.relu() jit_t = torch.jit.script(t) t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() # run optimization for i in range(5): with torch.autocast("cuda", torch.float16): jit_o = jit_t(t0, t1) jit_o.sum().backward() t0.grad = None t1.grad = None ref_t0 = t0.detach().requires_grad_() ref_t1 = t1.detach().requires_grad_() with torch.autocast("cuda", torch.float16): o = t(ref_t0, ref_t1) jit_o = jit_t(t0, t1) jit_o.sum().backward() o.sum().backward() self.assertEqual(o, jit_o) self.assertEqual(t0.grad, ref_t0.grad) self.assertEqual(t1.grad, ref_t1.grad) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype) self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_jit_call_method_under_autocast(self): @torch.jit.interface class Iface(torch.nn.Module): def forward(self, x, y) -> torch.Tensor: pass class Impl(Iface): def forward(self, x, y): return torch.mm(x, y) class Thing1(torch.nn.Module): impl: Iface def forward(self, x, y): with torch.cuda.amp.autocast(): a = torch.mm(x, y) b = self.impl.forward(a, x) return b scripted_impl = torch.jit.script(Impl()) thing1 = Thing1() thing1.impl = scripted_impl scripted_thing1 = torch.jit.script(thing1) x = torch.rand([2, 2]) y = torch.rand([2, 2]) # make sure this doesn't throw an error with torch.cuda.amp.autocast(): ans = scripted_thing1.forward(x, y) self.assertEqual(torch.mm(torch.mm(x, y), x), ans) # sanity check: this isn't supported currently when global autocasting # isn't enabled self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y)) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_jit_freeze_autocast_basic(self): class TestModule(torch.nn.Module): def forward(self, x, y): with torch.cuda.amp.autocast(): return torch.mm(x, y) x = torch.rand((3, 4), dtype=torch.float).cuda() y = torch.rand((4, 5), dtype=torch.float).cuda() mod = TestModule().eval() # sanity check self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y) frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph) # make sure that the runtime pass doesn't duplicate autocast nodes frozen_mod(x, y) optimized_graph = frozen_mod.graph_for(x, y) FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_jit_freeze_autocast_constants(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.x = torch.rand((3, 4), dtype=torch.float).cuda() def forward(self, y): with torch.cuda.amp.autocast(): return torch.mm(self.x, y) y = torch.rand((4, 5), dtype=torch.float).cuda() mod = TestModule().eval() frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) # freezing should pre-cast the constant self.x to remove one autocast call FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph) # the runtime autocasting pass will re-insert the second autocast call, # but constant propagation will merge it with the constant that it's casting. frozen_mod(y) optimized_graph = frozen_mod.graph_for(y) FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph) @unittest.skipIf(TEST_CUDA, "CPU-only test") def test_jit_autocast_softmax_cpu(self): def fn(x): with torch.cpu.amp.autocast(): return torch.nn.functional.softmax(x, dim=0) fn_s = torch.jit.script(fn) x = torch.rand((2, 2), dtype=torch.bfloat16) fn_s(x) y = fn_s(x) self.assertTrue(y.dtype == torch.bfloat16) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_jit_autocast_softmax_gpu(self): def fn(x): with torch.cuda.amp.autocast(): return torch.nn.functional.softmax(x, dim=0) fn_s = torch.jit.script(fn) x = torch.rand((2, 2), dtype=torch.half).cuda() fn_s(x) y = fn_s(x) self.assertTrue(y.dtype == torch.float) def test_ignore_amp(self): @torch.jit.script def foo(x): return torch.mm(x, x) inp = torch.rand([10, 10], dtype=torch.float) foo._set_ignore_amp(True) with torch.cpu.amp.autocast(): foo(inp) foo(inp) g = torch.jit.last_executed_optimized_graph() FileCheck().check_not("_autocast_to_reduced").run(g) class convbn(torch.nn.Module): def __init__(self, bias_enabled=True): super().__init__() self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled) self.bn = torch.nn.BatchNorm2d(64) def forward(self, x): return self.bn(self.conv(x)) @skipIfTorchDynamo("Not a TorchDynamo suitable test") class TestJitTraceAutocast(JitTestCase): def setUp(self): super().setUp() self.previous_default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float32) self.models = [MnistNet(), convbn(bias_enabled=True), convbn(bias_enabled=False)] self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'), torch.randn(32, 3, 224, 224, device='cpu'), torch.randn(32, 3, 224, 224, device='cpu')] self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False) def tearDown(self): torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass) torch.set_default_dtype(self.previous_default_dtype) super().tearDown() def test_generate_autocast_jit_trace_model(self): def test_generate_autocast_jit_trace_model(model, x): model.eval() with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): traced_model = torch.jit.trace(model, x) traced_model = torch.jit.freeze(traced_model) for i in range(self.models.__len__()): test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i]) def test_nchw_autocast_jit_trace_model(self): def test_nchw_autocast_jit_trace_model(model, x): model.eval() with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): traced_model = torch.jit.trace(model, x) traced_model = torch.jit.freeze(traced_model) with torch.no_grad(): y = traced_model(x.clone()) with torch.cpu.amp.autocast(), torch.no_grad(): y2 = model(x.clone()) torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) for i in range(self.models.__len__()): test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) def test_nhwc_autocast_jit_trace_model(self): def test_nhwc_autocast_jit_trace_model(model, x): model = model.to(memory_format=torch.channels_last) model.eval() with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last)) traced_model = torch.jit.freeze(traced_model) with torch.no_grad(): y = traced_model(x.clone().to(memory_format=torch.channels_last)) with torch.cpu.amp.autocast(), torch.no_grad(): y2 = model(x.clone().to(memory_format=torch.channels_last)) torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) for i in range(self.models.__len__()): if self.inputs[i].size().__len__() == 5: # NHWC 3D case not support yet continue test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i]) def test_cat_promote(self): class TestModel(torch.nn.Module): def forward(self, a, b): return torch.cat([a, b], 0) with torch.jit.fuser("none"): # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs. # To avoid the fusion group from TE, we will disable the fuser here. for jit_freeze_or_not in [False, True]: test_model = TestModel().eval() with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad(): a = torch.rand(24, 128, 128) b = torch.rand(24, 128, 128, dtype=torch.bfloat16) c = test_model(a, b) traced = torch.jit.trace(test_model, (a, b)) if jit_freeze_or_not: traced = torch.jit.freeze(traced) for _ in range(3): c2 = traced(a, b) self.assertTrue(c.dtype, torch.float32) self.assertTrue(c2.dtype, torch.float32) traced_graph = traced.graph_for(a, b) self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes())) def test_script_autocast_cpu(self): def fn(x): if torch.is_autocast_cpu_enabled(): return x.relu() else: return x.sin() fn_s = torch.jit.script(fn) x = torch.rand((4, 4)) - 0.5 with torch.cpu.amp.autocast(): self.assertEqual(fn_s(x), fn(x)) with torch.cpu.amp.autocast(enabled=True): self.assertEqual(fn_s(x), fn(x)) self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes())) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_script_autocast_cuda(self): def fn(x): if torch.is_autocast_enabled(): return x.relu() else: return x.sin() fn_s = torch.jit.script(fn) x = torch.rand((4, 4)) - 0.5 with torch.cpu.amp.autocast(): self.assertEqual(fn_s(x), fn(x)) with torch.cuda.amp.autocast(enabled=True): self.assertEqual(fn_s(x), fn(x)) self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes())) def test_scripted_aliasing(self): # torch.is_autocast_enabled should not be able to move inside of the autocast context. def fn(x): if torch.is_autocast_enabled(): y = True else: y = False with torch.cuda.amp.autocast(enabled=True): z = x.relu() return y, z fn_s = torch.jit.script(fn) graph = fn_s.graph aliasdb = graph.alias_db() is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled") enter_nodes = graph.findAllNodes("prim::Enter") self.assertEqual(len(is_enabled_nodes), 1) self.assertEqual(len(enter_nodes), 1) self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0])) def test_script_autocast_enable_and_check(self): def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]: b1 = torch.is_autocast_cpu_enabled() v1 = torch.mm(x, y) with torch.cpu.amp.autocast(enabled=True): b2 = torch.is_autocast_cpu_enabled() v2 = torch.mm(x, y) with torch.cpu.amp.autocast(enabled=False): b3 = torch.is_autocast_cpu_enabled() v3 = torch.mm(x, y) return (v1, b1, v2, b2, v3, b3) # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float def check_fn_results(arr): [v1, b1, v2, b2, v3, b3] = arr self.assertTrue((v1.dtype == torch.float) != b1) self.assertTrue((v2.dtype == torch.float) != b2) self.assertTrue((v3.dtype == torch.float) != b3) x = torch.rand((2, 2), dtype=torch.float) y = torch.rand((2, 2), dtype=torch.float) fn_s = torch.jit.script(fn) with torch.cpu.amp.autocast(enabled=False): check_fn_results(fn(x, y)) check_fn_results(fn_s(x, y)) with torch.cpu.amp.autocast(enabled=True): check_fn_results(fn(x, y)) check_fn_results(fn_s(x, y)) if __name__ == "__main__": run_tests()