1# Owner(s): ["module: functorch"] 2 3import inspect 4import random 5import unittest 6from typing import Callable 7 8import torch 9import torch.fx as fx 10import torch.nn as nn 11from functorch import make_fx 12from functorch.compile import memory_efficient_fusion 13from torch._functorch.compile_utils import fx_graph_cse 14from torch.nn import functional as F 15from torch.testing._internal.common_utils import run_tests, TestCase 16 17 18HAS_CUDA = torch.cuda.is_available() 19 20 21def _num_args(fn: Callable): 22 return len(inspect.signature(fn).parameters) 23 24 25def gelu_bias(bias, y): 26 x = bias + y 27 return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 28 29 30def swish(x): 31 return x * torch.sigmoid(x) 32 33 34def mish(x): 35 return x.mul(torch.tanh(F.softplus(x))) 36 37 38def hard_sigmoid(x): 39 return (x + 3.0).clamp(min=0.0, max=6.0).div(6.0) 40 41 42def hard_swish(x): 43 return x * (x + 3.0).clamp(min=0.0, max=6.0).div(6.0) 44 45 46def hard_mish(x): 47 return 0.5 * x * (x + 2.0).clamp(min=0.0, max=2.0) 48 49 50# todo: convert these into tests 51# def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): 52# B, C, H, W = x.shape 53# x_dtype = x.dtype 54# if flatten: 55# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues 56# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) 57# else: 58# x = x.reshape(B, groups, C // groups, H, W) 59# std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) 60# return std.expand(x.shape).reshape(B, C, H, W) 61 62# class EvoNorm2dS0(nn.Module): 63# def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): 64# super().__init__() 65# self.apply_act = apply_act # apply activation (non-linearity) 66# if group_size: 67# assert num_features % group_size == 0 68# self.groups = num_features // group_size 69# else: 70# self.groups = groups 71# self.eps = eps 72# self.weight = nn.Parameter(torch.ones(num_features)) 73# self.bias = nn.Parameter(torch.zeros(num_features)) 74# self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None 75# self.reset_parameters() 76 77# def reset_parameters(self): 78# nn.init.ones_(self.weight) 79# nn.init.zeros_(self.bias) 80# if self.v is not None: 81# nn.init.ones_(self.v) 82 83# def forward(self, x): 84# x_dtype = x.dtype 85# v_shape = (1, -1, 1, 1) 86# if self.v is not None: 87# v = self.v.view(v_shape).to(dtype=x_dtype) 88# x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) 89# return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 90 91 92# device = "cuda" 93# dtype = torch.float 94 95# evo_norm = EvoNorm2dS0(2048) 96# evo_norm_inp = [(128, 2048, 8, 8)] 97 98 99def run_and_compare_activation(self, fn, inps): 100 with torch.jit.fuser("fuser1"): 101 device = "cuda" 102 dtype = torch.float 103 if isinstance(fn, nn.Module): 104 fn = fn.to(device=device, dtype=dtype) 105 106 ref_args = [ 107 torch.randn(shape, device=device, dtype=dtype, requires_grad=True) 108 for shape in inps 109 ] 110 res_args = [i.clone().detach().requires_grad_(True) for i in ref_args] 111 112 ref = fn(*ref_args) 113 ref.sum().backward() 114 115 mem_optimized_fn = memory_efficient_fusion(fn) 116 for _ in range(5): 117 for i in res_args: 118 i.grad = None 119 res = mem_optimized_fn(*res_args) 120 res.sum().backward() 121 122 self.assertEqual(ref, res) 123 for ref_arg, res_arg in zip(ref_args, res_args): 124 self.assertEqual(ref_arg.grad, res_arg.grad) 125 126 127@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 128class TestMemoryEfficientOpAuthoring(TestCase): 129 def test_gelu_bias(self): 130 run_and_compare_activation(self, gelu_bias, [(1024,), (1024,)]) 131 132 def test_mish(self): 133 run_and_compare_activation(self, mish, [(1024,)]) 134 135 def test_swish(self): 136 run_and_compare_activation(self, swish, [(1024,)]) 137 138 def test_hard_sigmoid(self): 139 run_and_compare_activation(self, hard_sigmoid, [(1024,)]) 140 141 def test_hard_swish(self): 142 run_and_compare_activation(self, hard_swish, [(1024,)]) 143 144 def test_layer_norm(self): 145 def layer_norm(x, weight, bias): 146 dim = -1 147 eps = 1e-5 148 mean = torch.mean(x, dim, keepdim=True) 149 centered = x - mean 150 var = torch.sum(centered * centered, dim, keepdim=True) / x.size(-1) 151 rvar = 1.0 / torch.sqrt(var + eps) 152 normed = (x - mean) * rvar 153 return normed * weight + bias 154 155 bs = 10 156 ln_size = 16 157 layer_norm_inps = [(bs, ln_size), (ln_size,), (ln_size,)] 158 run_and_compare_activation(self, layer_norm, layer_norm_inps) 159 160 def test_rmsnorm(self): 161 class T5LayerNorm(nn.Module): 162 def __init__(self, hidden_size, eps=1e-6): 163 """ 164 Construct a layernorm module in the T5 style No bias and no subtraction of mean. 165 """ 166 super().__init__() 167 self.weight = nn.Parameter(torch.ones(hidden_size)) 168 self.variance_epsilon = eps 169 170 def forward(self, hidden_states): 171 # layer norm should always be calculated in float32 172 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 173 hidden_states = hidden_states * torch.rsqrt( 174 variance + self.variance_epsilon 175 ) 176 177 # convert into half-precision if necessary 178 if self.weight.dtype in [torch.float16, torch.bfloat16]: 179 hidden_states = hidden_states.to(self.weight.dtype) 180 181 return self.weight * hidden_states 182 183 bs = 256 184 seq = 256 185 hidden = 1024 186 t5_norm = T5LayerNorm(hidden) 187 t5_norm_inputs = [(bs, seq, hidden)] 188 run_and_compare_activation(self, t5_norm, t5_norm_inputs) 189 190 # TODO - Assertion failure 191 # def test_hard_mish(self): 192 # for compiler in compilers: 193 # run_and_compare_activation(hard_mish, 1024) 194 195 196# check if the CSE modified graph of f has delta less nodes, and do not reduce the number of nodes further on a second pass. 197# delta is an integer >= -1. If delta = -1, only check if the new graph 198# has less or equal number of nodes 199def check(f, t, delta, check_val=True, graph_input=False): 200 if graph_input: 201 fx_g = f 202 else: 203 fx_g = make_fx(f)(t) 204 new_graph = fx_graph_cse(fx_g.graph) 205 new_g = fx.GraphModule(fx_g, new_graph) 206 207 # the number of nodes decrease/ or stay the same 208 old_num_nodes = len(fx_g.graph.nodes) 209 new_num_nodes = len(new_graph.nodes) 210 if delta == -1: 211 assert ( 212 old_num_nodes >= new_num_nodes 213 ), f"number of nodes increased {old_num_nodes}, {new_num_nodes}" 214 else: 215 assert ( 216 old_num_nodes == new_num_nodes + delta 217 ), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" 218 219 # a second pass should not reduce more nodes 220 pass_2_graph = fx_graph_cse(new_graph) 221 pass_2_num_nodes = len(pass_2_graph.nodes) 222 assert ( 223 pass_2_num_nodes == new_num_nodes 224 ), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" 225 226 # check correctness 227 if check_val: 228 true_result = fx_g(t) 229 our_result = new_g(t) 230 if true_result is None: # both return None 231 assert ( 232 our_result is None 233 ), f"true result is None, CSE result is {our_result}" 234 else: # results returned are the same 235 assert torch.all( 236 true_result == our_result 237 ), f"results are different {true_result}, {our_result}" # check results are the same 238 239 240class NoChangeTestCase(TestCase): 241 def test_nochange(self): 242 def f(x): 243 a = x + 1 244 b = x + a 245 a = x 246 d = x + a 247 return b + d 248 249 t = torch.randn(2, 2) 250 check(f, t, 0) 251 252 def test_empty(self): 253 def f(x): 254 pass 255 256 t = torch.randn(2, 2) 257 check(f, t, 0) 258 259 def test_rand_like(self): 260 def f(x): 261 a = torch.rand_like(x) 262 b = torch.rand_like(x) 263 return a + b 264 265 t = torch.randn(2, 2) 266 check(f, t, 0, check_val=False) 267 268 def test_rand_n(self): 269 def f(x): 270 a = torch.randn(4) 271 b = torch.randn(4) 272 return a + b 273 274 t = torch.randn(2, 2) 275 check(f, t, 0, check_val=False) 276 277 def test_hash_with_numbers(self): 278 # Test to repro issue with fx_graph_cse when 279 # hash((primals_2, 1.0)) == hash((primals_2, 1)) 280 281 if torch._dynamo.is_compiling(): 282 self.skipTest("Unsupported if test run is compiled") 283 284 def f(inpt, osize): 285 size = inpt.shape[-1] 286 s1 = size - 1 287 s2 = size - 1.0 288 scale = s2 / (osize - 1.0) 289 inpt = torch.clamp(inpt, 0, s1) 290 return scale * inpt 291 292 # Fetch dynamic graph 293 gms = [] 294 295 def toy_backend(gm, _): 296 gms.append(gm) 297 return gm.forward 298 299 torch._dynamo.reset() 300 fn = torch.compile(backend=toy_backend, dynamic=True)(f) 301 302 t = torch.rand(3, 100) 303 _ = fn(t, 50) 304 assert len(gms) == 1, gms 305 fx_g = gms[0] 306 check(fx_g, None, 0, check_val=False, graph_input=True) 307 308 309class ReduceTestCase(TestCase): 310 def test_immutable_list_type(self): 311 def f(x): 312 a = x.sum(dim=1) 313 b = x.sum(dim=1) 314 c = x.sum() 315 d = x.sum() 316 return a + b + c + d 317 318 t = torch.randn(2, 2) 319 check(f, t, 2) 320 321 def test_immutable_list_multiple_entries(self): 322 def f(x): 323 a = x.sum(dim=[0, 1]) 324 b = x.sum(dim=[0, 1]) 325 c = x.sum(dim=1) 326 d = x.sum(dim=1) 327 return a + b + c + d 328 329 t = torch.randn(2, 2) 330 check(f, t, 2) 331 332 def test_simple(self): 333 def f(x): 334 a = x.cos() 335 b = x.cos() 336 c = a + a 337 d = b + b 338 return c + d 339 340 t = torch.randn(2, 2) 341 check(f, t, 2) 342 343 def test_simple_2(self): 344 def f(x): 345 a = x.cos().sin() 346 b = x.cos().sin() 347 c = a + a 348 d = b + b 349 return c + d 350 351 t = torch.randn(1) 352 check(f, t, 3) 353 354 def test_two_args_default(self): 355 def f(x): 356 a = x.sum(dim=1) 357 b = x.sum(dim=1, keepdim=False) 358 c = x.sum(dim=1, keepdim=False) 359 d = x.sum(dim=1) 360 return a + b + c + d 361 362 t = torch.randn(2, 2) 363 check(f, t, 3) 364 365 def test_two_args(self): 366 def f(x): 367 a = x.sum(dim=1) 368 b = x.sum(dim=1, keepdim=True) 369 c = x.sum(dim=1, keepdim=True) 370 d = x.sum(dim=1) 371 return a + b + c + d 372 373 t = torch.randn(2, 2) 374 check(f, t, 2) 375 376 def test_simple_multiple_same_ops(self): 377 def f(x): 378 a = x.sum() 379 b = x.sum() 380 c = x.sum() 381 d = x.sum() 382 return a + b + c + d 383 384 t = torch.randn(2, 2) 385 check(f, t, 3) 386 387 def test_nested_immutable_list_type(self): 388 def f(x): 389 a = torch.cat((x, x)) 390 b = torch.cat((x, x)) 391 return a + b 392 393 t = torch.randn(2, 2) 394 check(f, t, 1) 395 396 def test_kwarg(self): 397 def f(x): 398 a = torch.ones_like(x) 399 b = torch.ones_like(x) 400 return a + b 401 402 t = torch.randn(2, 2) 403 check(f, t, 1) 404 405 406class RandomOpTestCase(TestCase): 407 def test_random(self): 408 def f(x): 409 vals = [x] 410 ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu] 411 for _ in range(100): 412 new_val = random.choice(ops)(random.choice(vals)) 413 vals.append(new_val) 414 return vals[-1] 415 416 fx_g = fx.symbolic_trace(f) 417 fx_g.graph.eliminate_dead_code() 418 fx_g.recompile() 419 t = torch.randn(2, 2) 420 421 for _ in range(30): 422 check(fx_g, t, -1, graph_input=True) 423 424 425if __name__ == "__main__": 426 run_tests() 427