1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport re 10*523fa7a6SAndroid Build Coastguard Workerimport unittest 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Workerimport torch.fx 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.common import extract_out_arguments, get_schema_for_operators 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import add_cursor_to_graph 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerclass TestExirCommon(unittest.TestCase): 20*523fa7a6SAndroid Build Coastguard Worker def test_get_schema_for_operators(self) -> None: 21*523fa7a6SAndroid Build Coastguard Worker op_list = [ 22*523fa7a6SAndroid Build Coastguard Worker "torch.ops._caffe2.RoIAlign.default", 23*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.add.Tensor", 24*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.batch_norm.default", 25*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.cat.default", 26*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.clamp.default", 27*523fa7a6SAndroid Build Coastguard Worker ] 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker schemas = get_schema_for_operators(op_list) 30*523fa7a6SAndroid Build Coastguard Worker pat = re.compile(r"[^\(]+\([^\)]+\) -> ") 31*523fa7a6SAndroid Build Coastguard Worker for _op_name, schema in schemas.items(): 32*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(re.match(pat, schema)) 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker def test_get_out_args(self) -> None: 35*523fa7a6SAndroid Build Coastguard Worker schema1 = torch._C.parse_schema( 36*523fa7a6SAndroid Build Coastguard Worker "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" 37*523fa7a6SAndroid Build Coastguard Worker ) 38*523fa7a6SAndroid Build Coastguard Worker schema2 = torch._C.parse_schema( 39*523fa7a6SAndroid Build Coastguard Worker "split_copy.Tensor_out(Tensor self, int split_size, int dim=0, *, Tensor(a!)[] out) -> ()" 40*523fa7a6SAndroid Build Coastguard Worker ) 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Worker out_args_1 = extract_out_arguments(schema1, {"out": torch.ones(5)}) 43*523fa7a6SAndroid Build Coastguard Worker out_args_2 = extract_out_arguments( 44*523fa7a6SAndroid Build Coastguard Worker schema2, {"out": [torch.ones(5), torch.ones(5)]} 45*523fa7a6SAndroid Build Coastguard Worker ) 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker out_arg_name_1, _ = out_args_1 48*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(out_arg_name_1, "out") 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker out_arg_name_2, _ = out_args_2 51*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(out_arg_name_2, "out") 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker def test_add_cursor(self) -> None: 54*523fa7a6SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 55*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 56*523fa7a6SAndroid Build Coastguard Worker super().__init__() 57*523fa7a6SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 58*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 61*523fa7a6SAndroid Build Coastguard Worker return self.linear(x + self.param).clamp(min=0.0, max=1.0) 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker module = MyModule() 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker from torch.fx import symbolic_trace 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker symbolic_traced = symbolic_trace(module) 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker # Graph we are testing: 70*523fa7a6SAndroid Build Coastguard Worker # graph(): 71*523fa7a6SAndroid Build Coastguard Worker # %x : [#users=1] = placeholder[target=x] 72*523fa7a6SAndroid Build Coastguard Worker # %param : [#users=1] = get_attr[target=param] 73*523fa7a6SAndroid Build Coastguard Worker # %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) 74*523fa7a6SAndroid Build Coastguard Worker # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) 75*523fa7a6SAndroid Build Coastguard Worker # %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) 76*523fa7a6SAndroid Build Coastguard Worker # return clamp 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker actual_str = add_cursor_to_graph( 79*523fa7a6SAndroid Build Coastguard Worker symbolic_traced.graph, list(symbolic_traced.graph.nodes)[3] 80*523fa7a6SAndroid Build Coastguard Worker ) 81*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(actual_str.split("\n")[4].startswith("-->")) 82