1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 10*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 16*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 17*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 18*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 19*da0073e9SAndroid Build Coastguard Worker "instead." 20*da0073e9SAndroid Build Coastguard Worker ) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerclass TestLogging(JitTestCase): 24*da0073e9SAndroid Build Coastguard Worker def test_bump_numeric_counter(self): 25*da0073e9SAndroid Build Coastguard Worker class ModuleThatLogs(torch.jit.ScriptModule): 26*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 27*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 28*da0073e9SAndroid Build Coastguard Worker for i in range(x.size(0)): 29*da0073e9SAndroid Build Coastguard Worker x += 1.0 30*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.add_stat_value("foo", 1) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker if bool(x.sum() > 0.0): 33*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.add_stat_value("positive", 1) 34*da0073e9SAndroid Build Coastguard Worker else: 35*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.add_stat_value("negative", 1) 36*da0073e9SAndroid Build Coastguard Worker return x 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker logger = torch.jit._logging.LockingLogger() 39*da0073e9SAndroid Build Coastguard Worker old_logger = torch.jit._logging.set_logger(logger) 40*da0073e9SAndroid Build Coastguard Worker try: 41*da0073e9SAndroid Build Coastguard Worker mtl = ModuleThatLogs() 42*da0073e9SAndroid Build Coastguard Worker for i in range(5): 43*da0073e9SAndroid Build Coastguard Worker mtl(torch.rand(3, 4, 5)) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logger.get_counter_val("foo"), 15) 46*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logger.get_counter_val("positive"), 5) 47*da0073e9SAndroid Build Coastguard Worker finally: 48*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.set_logger(old_logger) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker def test_trace_numeric_counter(self): 51*da0073e9SAndroid Build Coastguard Worker def foo(x): 52*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.add_stat_value("foo", 1) 53*da0073e9SAndroid Build Coastguard Worker return x + 1.0 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, torch.rand(3, 4)) 56*da0073e9SAndroid Build Coastguard Worker logger = torch.jit._logging.LockingLogger() 57*da0073e9SAndroid Build Coastguard Worker old_logger = torch.jit._logging.set_logger(logger) 58*da0073e9SAndroid Build Coastguard Worker try: 59*da0073e9SAndroid Build Coastguard Worker traced(torch.rand(3, 4)) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logger.get_counter_val("foo"), 1) 62*da0073e9SAndroid Build Coastguard Worker finally: 63*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.set_logger(old_logger) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_time_measurement_counter(self): 66*da0073e9SAndroid Build Coastguard Worker class ModuleThatTimes(torch.jit.ScriptModule): 67*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 68*da0073e9SAndroid Build Coastguard Worker tp_start = torch.jit._logging.time_point() 69*da0073e9SAndroid Build Coastguard Worker for i in range(30): 70*da0073e9SAndroid Build Coastguard Worker x += 1.0 71*da0073e9SAndroid Build Coastguard Worker tp_end = torch.jit._logging.time_point() 72*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) 73*da0073e9SAndroid Build Coastguard Worker return x 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker mtm = ModuleThatTimes() 76*da0073e9SAndroid Build Coastguard Worker logger = torch.jit._logging.LockingLogger() 77*da0073e9SAndroid Build Coastguard Worker old_logger = torch.jit._logging.set_logger(logger) 78*da0073e9SAndroid Build Coastguard Worker try: 79*da0073e9SAndroid Build Coastguard Worker mtm(torch.rand(3, 4)) 80*da0073e9SAndroid Build Coastguard Worker self.assertGreater(logger.get_counter_val("mytimer"), 0) 81*da0073e9SAndroid Build Coastguard Worker finally: 82*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.set_logger(old_logger) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def test_time_measurement_counter_script(self): 85*da0073e9SAndroid Build Coastguard Worker class ModuleThatTimes(torch.jit.ScriptModule): 86*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 87*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 88*da0073e9SAndroid Build Coastguard Worker tp_start = torch.jit._logging.time_point() 89*da0073e9SAndroid Build Coastguard Worker for i in range(30): 90*da0073e9SAndroid Build Coastguard Worker x += 1.0 91*da0073e9SAndroid Build Coastguard Worker tp_end = torch.jit._logging.time_point() 92*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) 93*da0073e9SAndroid Build Coastguard Worker return x 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker mtm = ModuleThatTimes() 96*da0073e9SAndroid Build Coastguard Worker logger = torch.jit._logging.LockingLogger() 97*da0073e9SAndroid Build Coastguard Worker old_logger = torch.jit._logging.set_logger(logger) 98*da0073e9SAndroid Build Coastguard Worker try: 99*da0073e9SAndroid Build Coastguard Worker mtm(torch.rand(3, 4)) 100*da0073e9SAndroid Build Coastguard Worker self.assertGreater(logger.get_counter_val("mytimer"), 0) 101*da0073e9SAndroid Build Coastguard Worker finally: 102*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.set_logger(old_logger) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker def test_counter_aggregation(self): 105*da0073e9SAndroid Build Coastguard Worker def foo(x): 106*da0073e9SAndroid Build Coastguard Worker for i in range(3): 107*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.add_stat_value("foo", 1) 108*da0073e9SAndroid Build Coastguard Worker return x + 1.0 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, torch.rand(3, 4)) 111*da0073e9SAndroid Build Coastguard Worker logger = torch.jit._logging.LockingLogger() 112*da0073e9SAndroid Build Coastguard Worker logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG) 113*da0073e9SAndroid Build Coastguard Worker old_logger = torch.jit._logging.set_logger(logger) 114*da0073e9SAndroid Build Coastguard Worker try: 115*da0073e9SAndroid Build Coastguard Worker traced(torch.rand(3, 4)) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logger.get_counter_val("foo"), 1) 118*da0073e9SAndroid Build Coastguard Worker finally: 119*da0073e9SAndroid Build Coastguard Worker torch.jit._logging.set_logger(old_logger) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def test_logging_levels_set(self): 122*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_logging_option("foo") 123*da0073e9SAndroid Build Coastguard Worker self.assertEqual("foo", torch._C._jit_get_logging_option()) 124