1import functools 2import os 3import shutil 4import sys 5from io import BytesIO 6 7import torch 8from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter 9 10 11_OPERATORS = set() 12_FILENAMES = [] 13_MODELS = [] 14 15 16def save_model(cls): 17 """Save a model and dump all the ops""" 18 19 @functools.wraps(cls) 20 def wrapper_save(): 21 _MODELS.append(cls) 22 model = cls() 23 scripted = torch.jit.script(model) 24 buffer = BytesIO(scripted._save_to_buffer_for_lite_interpreter()) 25 buffer.seek(0) 26 mobile_module = _load_for_lite_interpreter(buffer) 27 ops = _export_operator_list(mobile_module) 28 _OPERATORS.update(ops) 29 path = f"./{cls.__name__}.ptl" 30 _FILENAMES.append(path) 31 scripted._save_for_lite_interpreter(path) 32 33 return wrapper_save 34 35 36@save_model 37class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module): 38 def forward(self, x: int): 39 a = torch.ones( 40 size=[3, x], 41 dtype=torch.int64, 42 layout=torch.strided, 43 device="cpu", 44 pin_memory=False, 45 ) 46 return a 47 48 49@save_model 50class ModelWithTensorOptional(torch.nn.Module): 51 def forward(self, index): 52 a = torch.zeros(2, 2) 53 a[0][1] = 1 54 a[1][0] = 2 55 a[1][1] = 3 56 return a[index] 57 58 59# gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[] 60@save_model 61class ModelWithScalarList(torch.nn.Module): 62 def forward(self, a: int): 63 values = torch.tensor( 64 [4.0, 1.0, 1.0, 16.0], 65 ) 66 if a == 0: 67 return torch.gradient( 68 values, spacing=torch.scalar_tensor(2.0, dtype=torch.float64) 69 ) 70 elif a == 1: 71 return torch.gradient(values, spacing=[torch.tensor(1.0).item()]) 72 73 74# upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor 75@save_model 76class ModelWithFloatList(torch.nn.Upsample): 77 def __init__(self) -> None: 78 super().__init__( 79 scale_factor=(2.0,), 80 mode="linear", 81 align_corners=False, 82 recompute_scale_factor=True, 83 ) 84 85 86# index.Tensor(Tensor self, Tensor?[] indices) -> Tensor 87@save_model 88class ModelWithListOfOptionalTensors(torch.nn.Module): 89 def forward(self, index): 90 values = torch.tensor([[4.0, 1.0, 1.0, 16.0]]) 91 return values[torch.tensor(0), index] 92 93 94# conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, 95# int groups=1) -> Tensor 96@save_model 97class ModelWithArrayOfInt(torch.nn.Conv2d): 98 def __init__(self) -> None: 99 super().__init__(1, 2, (2, 2), stride=(1, 1), padding=(1, 1)) 100 101 102# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 103# ones_like(Tensor self, *, ScalarType?, dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, 104# MemoryFormat? memory_format=None) -> Tensor 105@save_model 106class ModelWithTensors(torch.nn.Module): 107 def forward(self, a): 108 b = torch.ones_like(a) 109 return a + b 110 111 112@save_model 113class ModelWithStringOptional(torch.nn.Module): 114 def forward(self, b): 115 a = torch.tensor(3, dtype=torch.int64) 116 out = torch.empty(size=[1], dtype=torch.float) 117 torch.div(b, a, out=out) 118 return [torch.div(b, a, rounding_mode="trunc"), out] 119 120 121@save_model 122class ModelWithMultipleOps(torch.nn.Module): 123 def __init__(self) -> None: 124 super().__init__() 125 self.ops = torch.nn.Sequential( 126 torch.nn.ReLU(), 127 torch.nn.Flatten(), 128 ) 129 130 def forward(self, x): 131 x[1] = -2 132 return self.ops(x) 133 134 135if __name__ == "__main__": 136 command = sys.argv[1] 137 ops_yaml = sys.argv[2] 138 backup = ops_yaml + ".bak" 139 if command == "setup": 140 tests = [ 141 ModelWithDTypeDeviceLayoutPinMemory(), 142 ModelWithTensorOptional(), 143 ModelWithScalarList(), 144 ModelWithFloatList(), 145 ModelWithListOfOptionalTensors(), 146 ModelWithArrayOfInt(), 147 ModelWithTensors(), 148 ModelWithStringOptional(), 149 ModelWithMultipleOps(), 150 ] 151 shutil.copyfile(ops_yaml, backup) 152 with open(ops_yaml, "a") as f: 153 for op in _OPERATORS: 154 f.write(f"- {op}\n") 155 elif command == "shutdown": 156 for file in _MODELS: 157 if os.path.isfile(file): 158 os.remove(file) 159 shutil.move(backup, ops_yaml) 160