xref: /aosp_15_r20/external/executorch/examples/models/model_base.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom abc import ABC, abstractmethod
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerclass EagerModelBase(ABC):
13*523fa7a6SAndroid Build Coastguard Worker    """
14*523fa7a6SAndroid Build Coastguard Worker    Abstract base class for eager mode models.
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker    This abstract class defines the interface that eager mode model classes should adhere to.
17*523fa7a6SAndroid Build Coastguard Worker    Eager mode models inherit from this class to ensure consistent behavior and structure.
18*523fa7a6SAndroid Build Coastguard Worker    """
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker    @abstractmethod
21*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
22*523fa7a6SAndroid Build Coastguard Worker        """
23*523fa7a6SAndroid Build Coastguard Worker        Constructor for EagerModelBase.
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker        This initializer may be overridden in derived classes to provide additional setup if needed.
26*523fa7a6SAndroid Build Coastguard Worker        """
27*523fa7a6SAndroid Build Coastguard Worker        pass
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker    @abstractmethod
30*523fa7a6SAndroid Build Coastguard Worker    def get_eager_model(self) -> torch.nn.Module:
31*523fa7a6SAndroid Build Coastguard Worker        """
32*523fa7a6SAndroid Build Coastguard Worker        Abstract method to return an eager PyTorch model instance.
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker        Returns:
35*523fa7a6SAndroid Build Coastguard Worker            nn.Module: An instance of a PyTorch model, suitable for eager execution.
36*523fa7a6SAndroid Build Coastguard Worker        """
37*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError("get_eager_model")
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    @abstractmethod
40*523fa7a6SAndroid Build Coastguard Worker    def get_example_inputs(self):
41*523fa7a6SAndroid Build Coastguard Worker        """
42*523fa7a6SAndroid Build Coastguard Worker        Abstract method to provide example inputs for the model.
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker        Returns:
45*523fa7a6SAndroid Build Coastguard Worker            Any: Example inputs that can be used for testing and tracing.
46*523fa7a6SAndroid Build Coastguard Worker        """
47*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError("get_example_inputs")
48