1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom pathlib import Path 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerdef get_default_model_resource_dir(model_file_path: str) -> Path: 14*523fa7a6SAndroid Build Coastguard Worker """ 15*523fa7a6SAndroid Build Coastguard Worker Get the default path to resouce files (which contain files such as the 16*523fa7a6SAndroid Build Coastguard Worker checkpoint and param files), either: 17*523fa7a6SAndroid Build Coastguard Worker 1. Uses the path from pkg_resources, only works with buck2 18*523fa7a6SAndroid Build Coastguard Worker 2. Uses default path located in examples/models/llama/params 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker Expected to be called from with a `model.py` file located in a 21*523fa7a6SAndroid Build Coastguard Worker `executorch/examples/models/<model_name>` directory. 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker Args: 24*523fa7a6SAndroid Build Coastguard Worker model_file_path: The file path to the eager model definition. 25*523fa7a6SAndroid Build Coastguard Worker For example, `executorch/examples/models/llama/model.py`, 26*523fa7a6SAndroid Build Coastguard Worker where `executorch/examples/models/llama` contains all 27*523fa7a6SAndroid Build Coastguard Worker the llama2-related files. 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker Returns: 30*523fa7a6SAndroid Build Coastguard Worker The path to the resource directory containing checkpoint, params, etc. 31*523fa7a6SAndroid Build Coastguard Worker """ 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker try: 34*523fa7a6SAndroid Build Coastguard Worker import pkg_resources 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker # 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources. 37*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 38*523fa7a6SAndroid Build Coastguard Worker from executorch.examples.models.llama import params # noqa 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker # Get the model name from the cwd, assuming that this module is called from a path such as 41*523fa7a6SAndroid Build Coastguard Worker # examples/models/<model_name>/model.py. 42*523fa7a6SAndroid Build Coastguard Worker model_name = Path(model_file_path).parent.name 43*523fa7a6SAndroid Build Coastguard Worker resource_dir = Path( 44*523fa7a6SAndroid Build Coastguard Worker pkg_resources.resource_filename( 45*523fa7a6SAndroid Build Coastguard Worker f"executorch.examples.models.{model_name}", "params" 46*523fa7a6SAndroid Build Coastguard Worker ) 47*523fa7a6SAndroid Build Coastguard Worker ) 48*523fa7a6SAndroid Build Coastguard Worker except: 49*523fa7a6SAndroid Build Coastguard Worker # 2nd way. 50*523fa7a6SAndroid Build Coastguard Worker resource_dir = Path(model_file_path).absolute().parent / "params" 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker return resource_dir 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Workerdef get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: 56*523fa7a6SAndroid Build Coastguard Worker """ 57*523fa7a6SAndroid Build Coastguard Worker Get the dtype of the checkpoint, returning "None" if the checkpoint is empty. 58*523fa7a6SAndroid Build Coastguard Worker """ 59*523fa7a6SAndroid Build Coastguard Worker dtype = None 60*523fa7a6SAndroid Build Coastguard Worker if len(checkpoint) > 0: 61*523fa7a6SAndroid Build Coastguard Worker first_key = next(iter(checkpoint)) 62*523fa7a6SAndroid Build Coastguard Worker first = checkpoint[first_key] 63*523fa7a6SAndroid Build Coastguard Worker dtype = first.dtype 64*523fa7a6SAndroid Build Coastguard Worker mismatched_dtypes = [ 65*523fa7a6SAndroid Build Coastguard Worker (key, value.dtype) 66*523fa7a6SAndroid Build Coastguard Worker for key, value in checkpoint.items() 67*523fa7a6SAndroid Build Coastguard Worker if value.dtype != dtype 68*523fa7a6SAndroid Build Coastguard Worker ] 69*523fa7a6SAndroid Build Coastguard Worker if len(mismatched_dtypes) > 0: 70*523fa7a6SAndroid Build Coastguard Worker print( 71*523fa7a6SAndroid Build Coastguard Worker f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" 72*523fa7a6SAndroid Build Coastguard Worker ) 73*523fa7a6SAndroid Build Coastguard Worker return dtype 74