1from abc import ABC, abstractmethod 2 3import torch 4 5 6class BaseModelChunk(ABC, torch.nn.Module): 7 def __init__( 8 self, 9 config, 10 num_blocks, 11 chunk_idx, 12 dtype=torch.float32, 13 include_tail=False, 14 return_attn=False, 15 jit_trace=False, 16 ): 17 torch.nn.Module.__init__(self) 18 torch.set_default_dtype(dtype) 19 self.dtype = dtype 20 self.config = config 21 self.num_blocks = num_blocks 22 self.chunk_idx = chunk_idx 23 self.include_tail = include_tail 24 self.return_attn = return_attn 25 self.jit_trace = jit_trace 26 self.device_list = [] 27 28 @abstractmethod 29 def forward(self): 30 pass 31 32 @abstractmethod 33 def load_weights(self, state_dict, state_dict_start_idx, verbose): 34 pass 35 36 @abstractmethod 37 def get_example_inputs(self, num_token, cache_size, get_dym_shape): 38 pass 39