xref: /aosp_15_r20/external/executorch/examples/models/checkpoint.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom pathlib import Path
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerdef get_default_model_resource_dir(model_file_path: str) -> Path:
14*523fa7a6SAndroid Build Coastguard Worker    """
15*523fa7a6SAndroid Build Coastguard Worker    Get the default path to resouce files (which contain files such as the
16*523fa7a6SAndroid Build Coastguard Worker    checkpoint and param files), either:
17*523fa7a6SAndroid Build Coastguard Worker    1. Uses the path from pkg_resources, only works with buck2
18*523fa7a6SAndroid Build Coastguard Worker    2. Uses default path located in examples/models/llama/params
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker    Expected to be called from with a `model.py` file located in a
21*523fa7a6SAndroid Build Coastguard Worker    `executorch/examples/models/<model_name>` directory.
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker    Args:
24*523fa7a6SAndroid Build Coastguard Worker        model_file_path: The file path to the eager model definition.
25*523fa7a6SAndroid Build Coastguard Worker            For example, `executorch/examples/models/llama/model.py`,
26*523fa7a6SAndroid Build Coastguard Worker            where `executorch/examples/models/llama` contains all
27*523fa7a6SAndroid Build Coastguard Worker            the llama2-related files.
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker    Returns:
30*523fa7a6SAndroid Build Coastguard Worker        The path to the resource directory containing checkpoint, params, etc.
31*523fa7a6SAndroid Build Coastguard Worker    """
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker    try:
34*523fa7a6SAndroid Build Coastguard Worker        import pkg_resources
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker        # 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
37*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
38*523fa7a6SAndroid Build Coastguard Worker        from executorch.examples.models.llama import params  # noqa
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker        # Get the model name from the cwd, assuming that this module is called from a path such as
41*523fa7a6SAndroid Build Coastguard Worker        # examples/models/<model_name>/model.py.
42*523fa7a6SAndroid Build Coastguard Worker        model_name = Path(model_file_path).parent.name
43*523fa7a6SAndroid Build Coastguard Worker        resource_dir = Path(
44*523fa7a6SAndroid Build Coastguard Worker            pkg_resources.resource_filename(
45*523fa7a6SAndroid Build Coastguard Worker                f"executorch.examples.models.{model_name}", "params"
46*523fa7a6SAndroid Build Coastguard Worker            )
47*523fa7a6SAndroid Build Coastguard Worker        )
48*523fa7a6SAndroid Build Coastguard Worker    except:
49*523fa7a6SAndroid Build Coastguard Worker        # 2nd way.
50*523fa7a6SAndroid Build Coastguard Worker        resource_dir = Path(model_file_path).absolute().parent / "params"
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker    return resource_dir
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Workerdef get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
56*523fa7a6SAndroid Build Coastguard Worker    """
57*523fa7a6SAndroid Build Coastguard Worker    Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
58*523fa7a6SAndroid Build Coastguard Worker    """
59*523fa7a6SAndroid Build Coastguard Worker    dtype = None
60*523fa7a6SAndroid Build Coastguard Worker    if len(checkpoint) > 0:
61*523fa7a6SAndroid Build Coastguard Worker        first_key = next(iter(checkpoint))
62*523fa7a6SAndroid Build Coastguard Worker        first = checkpoint[first_key]
63*523fa7a6SAndroid Build Coastguard Worker        dtype = first.dtype
64*523fa7a6SAndroid Build Coastguard Worker        mismatched_dtypes = [
65*523fa7a6SAndroid Build Coastguard Worker            (key, value.dtype)
66*523fa7a6SAndroid Build Coastguard Worker            for key, value in checkpoint.items()
67*523fa7a6SAndroid Build Coastguard Worker            if value.dtype != dtype
68*523fa7a6SAndroid Build Coastguard Worker        ]
69*523fa7a6SAndroid Build Coastguard Worker        if len(mismatched_dtypes) > 0:
70*523fa7a6SAndroid Build Coastguard Worker            print(
71*523fa7a6SAndroid Build Coastguard Worker                f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
72*523fa7a6SAndroid Build Coastguard Worker            )
73*523fa7a6SAndroid Build Coastguard Worker    return dtype
74