# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import importlib import os from typing import Any, Dict, Tuple import torch class EagerModelFactory: """ A factory class for dynamically creating instances of classes implementing EagerModelBase. """ @staticmethod def create_model( module_name, model_class_name, **kwargs ) -> Tuple[torch.nn.Module, Tuple[Any], Dict[str, Any], Any]: """ Create an instance of a model class that implements EagerModelBase and retrieve related data. Args: module_name (str): The name of the module containing the model class. model_class_name (str): The name of the model class to create an instance of. Returns: Tuple[nn.Module, Any]: A tuple containing the eager PyTorch model instance and example inputs, and any dynamic shape information for those inputs. Raises: ValueError: If the provided model class is not found in the module. """ package_prefix = "executorch." if not os.getcwd().endswith("executorch") else "" module = importlib.import_module( f"{package_prefix}examples.models.{module_name}" ) if hasattr(module, model_class_name): model_class = getattr(module, model_class_name) model = model_class(**kwargs) example_kwarg_inputs = None dynamic_shapes = None if hasattr(model, "get_example_kwarg_inputs"): example_kwarg_inputs = model.get_example_kwarg_inputs() if hasattr(model, "get_dynamic_shapes"): dynamic_shapes = model.get_dynamic_shapes() return ( model.get_eager_model(), model.get_example_inputs(), example_kwarg_inputs, dynamic_shapes, ) raise ValueError( f"Model class '{model_class_name}' not found in module '{module_name}'." )