1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import importlib 8import os 9from typing import Any, Dict, Tuple 10 11import torch 12 13 14class EagerModelFactory: 15 """ 16 A factory class for dynamically creating instances of classes implementing EagerModelBase. 17 """ 18 19 @staticmethod 20 def create_model( 21 module_name, model_class_name, **kwargs 22 ) -> Tuple[torch.nn.Module, Tuple[Any], Dict[str, Any], Any]: 23 """ 24 Create an instance of a model class that implements EagerModelBase and retrieve related data. 25 26 Args: 27 module_name (str): The name of the module containing the model class. 28 model_class_name (str): The name of the model class to create an instance of. 29 30 Returns: 31 Tuple[nn.Module, Any]: A tuple containing the eager PyTorch model instance and example inputs, 32 and any dynamic shape information for those inputs. 33 34 Raises: 35 ValueError: If the provided model class is not found in the module. 36 """ 37 package_prefix = "executorch." if not os.getcwd().endswith("executorch") else "" 38 module = importlib.import_module( 39 f"{package_prefix}examples.models.{module_name}" 40 ) 41 42 if hasattr(module, model_class_name): 43 model_class = getattr(module, model_class_name) 44 model = model_class(**kwargs) 45 example_kwarg_inputs = None 46 dynamic_shapes = None 47 if hasattr(model, "get_example_kwarg_inputs"): 48 example_kwarg_inputs = model.get_example_kwarg_inputs() 49 if hasattr(model, "get_dynamic_shapes"): 50 dynamic_shapes = model.get_dynamic_shapes() 51 return ( 52 model.get_eager_model(), 53 model.get_example_inputs(), 54 example_kwarg_inputs, 55 dynamic_shapes, 56 ) 57 58 raise ValueError( 59 f"Model class '{model_class_name}' not found in module '{module_name}'." 60 ) 61