1# mypy: ignore-errors 2import os 3import sys 4 5import pandas as pd # type: ignore[import-untyped] 6 7 8sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 10from train_decision import AHTrainDecisionTree 11 12from torch._inductor.autoheuristic.autoheuristic_utils import mm_operations 13 14 15class AHTrainDecisionTreeMM(AHTrainDecisionTree): 16 def __init__(self): 17 super().__init__() 18 19 def add_new_features(self, results): 20 ops = mm_operations() 21 added_categorical_features = [] 22 for op in ops: 23 results[op.name] = results.apply(op.func, axis=1) 24 if op.is_categorical: 25 added_categorical_features.append(op.name) 26 return (results, added_categorical_features) 27 28 def get_default_config(self, row): 29 return "extern_mm" 30 31 def get_allowed_wrong_prediction_pct(self): 32 return 1.0 33 34 def get_test_and_val_size(self): 35 return (0.01, 0.19) 36 37 def get_grid_search_values(self): 38 return {"max_depth": [5], "min_samples_leaf": [0.01], "criterion": ["entropy"]} 39 40 def add_training_data(self, df_train, datasets): 41 # add each dataset to the training data 3 times 42 # we really want to make sure that the heuristic performs well on these datasets 43 df_timm_train = datasets["train_timm"] 44 df_timm_train = df_timm_train.loc[df_timm_train.index.repeat(3)].reset_index( 45 drop=True 46 ) 47 df_hf_train = datasets["train_hf"] 48 df_hf_train = df_hf_train.loc[df_hf_train.index.repeat(3)].reset_index( 49 drop=True 50 ) 51 df_train = datasets["train"] 52 df_train = pd.concat( 53 [df_train, df_timm_train, df_hf_train], 54 ignore_index=True, 55 ) 56 return df_train 57 58 def ranking_always_included_choices(self): 59 return ["extern_mm"] 60 61 62if __name__ == "__main__": 63 train = AHTrainDecisionTreeMM() 64 train.generate_heuristic() 65