xref: /aosp_15_r20/external/executorch/examples/models/model_factory.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import importlib
8import os
9from typing import Any, Dict, Tuple
10
11import torch
12
13
14class EagerModelFactory:
15    """
16    A factory class for dynamically creating instances of classes implementing EagerModelBase.
17    """
18
19    @staticmethod
20    def create_model(
21        module_name, model_class_name, **kwargs
22    ) -> Tuple[torch.nn.Module, Tuple[Any], Dict[str, Any], Any]:
23        """
24        Create an instance of a model class that implements EagerModelBase and retrieve related data.
25
26        Args:
27            module_name (str): The name of the module containing the model class.
28            model_class_name (str): The name of the model class to create an instance of.
29
30        Returns:
31            Tuple[nn.Module, Any]: A tuple containing the eager PyTorch model instance and example inputs,
32              and any dynamic shape information for those inputs.
33
34        Raises:
35            ValueError: If the provided model class is not found in the module.
36        """
37        package_prefix = "executorch." if not os.getcwd().endswith("executorch") else ""
38        module = importlib.import_module(
39            f"{package_prefix}examples.models.{module_name}"
40        )
41
42        if hasattr(module, model_class_name):
43            model_class = getattr(module, model_class_name)
44            model = model_class(**kwargs)
45            example_kwarg_inputs = None
46            dynamic_shapes = None
47            if hasattr(model, "get_example_kwarg_inputs"):
48                example_kwarg_inputs = model.get_example_kwarg_inputs()
49            if hasattr(model, "get_dynamic_shapes"):
50                dynamic_shapes = model.get_dynamic_shapes()
51            return (
52                model.get_eager_model(),
53                model.get_example_inputs(),
54                example_kwarg_inputs,
55                dynamic_shapes,
56            )
57
58        raise ValueError(
59            f"Model class '{model_class_name}' not found in module '{module_name}'."
60        )
61