1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import random 8import unittest 9 10import torch 11from torch import nn 12 13# TODO hack tracking here: T143942601 14batch_norm_op = torch.ops.aten.native_batch_norm.default 15if torch._C.DispatchKey.Autograd in batch_norm_op.py_kernels: 16 del batch_norm_op.py_kernels[torch._C.DispatchKey.Autograd] 17batch_norm_op._dispatch_cache.clear() 18 19 20class BatchNormModel(nn.Module): 21 def __init__(self): 22 super().__init__() 23 self.bn = nn.BatchNorm2d(3) 24 25 def forward(self, x): 26 return self.bn(x) 27 28 def get_upper_bound_inputs(self): 29 return (torch.rand(5, 3, 5, 5),) 30 31 def get_random_inputs(self): 32 bs = random.randint(2, 5) 33 return (torch.rand(bs, 3, 5, 5),) 34 35 @staticmethod 36 def verify_graph(testcase: unittest.TestCase, graph_module: torch.fx.GraphModule): 37 bn_node = [ 38 nd 39 for nd in graph_module.graph.nodes 40 if nd.target == torch.ops.aten.native_batch_norm.out 41 ] 42 testcase.assertEqual(1, len(bn_node)) 43 bn_node = bn_node[0] 44 45 speclist = bn_node.meta["spec"] 46 testcase.assertEqual(3, len(speclist)) 47 48 # for infernece, the save_mean and save_var should be empty 49 _, save_mean_spec, save_var_spec = speclist 50 testcase.assertEqual(list(save_mean_spec.shape), [0]) 51 testcase.assertEqual(list(save_var_spec.shape), [0]) 52