xref: /aosp_15_r20/external/pytorch/torchgen/_autoheuristic/train_decision.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import itertools
4import json
5import logging
6import math
7import warnings
8
9
10warnings.filterwarnings(
11    "ignore",
12    message="The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
13)
14
15from dataclasses import dataclass
16
17import numpy as np
18import pandas as pd  # type: ignore[import-untyped]
19from ah_tree import DecisionTree
20from scipy.stats import gmean
21from sklearn.model_selection import train_test_split
22from sklearn.tree import DecisionTreeClassifier
23from train import AHTrain
24
25
26log = logging.getLogger(__name__)
27DEBUG = True
28if DEBUG:
29    ch = logging.StreamHandler()
30    ch.setLevel(logging.DEBUG)
31    formatter = logging.Formatter(
32        "%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
33    )
34    ch.setFormatter(formatter)
35    log.addHandler(ch)
36
37
38class AHTrainDecisionTree(AHTrain):
39    def __init__(self):
40        super().__init__()
41
42    def debug_time(self, row, top_k_choices):
43        choices_feedback = json.loads(row["choice2time"])
44        timings = sorted(choices_feedback.items(), key=lambda x: x[1])
45        for choice, time in timings:
46            result = f"{choice} {time}"
47            if choice in top_k_choices:
48                result += " TOPK"
49            print(result)
50
51    def is_unsafe_leaf(self, row, predicted_config, choice2time):
52        """
53        Can be overridden by subclasses to define their own logic for deciding when a leaf is unsafe. Returns a sample
54        that landed in the leaf, the choice predicted by the tree, and a dictionary that maps each choice to the
55        execution time. One can for example decide to mark a leaf as unsafe if the predicted choice is 2x slower
56        than the fastest choice.
57        If a leaf is unsafe, the learned heuristic will always return 'unsure' if an input lands in that leaf.
58        """
59
60        return False
61
62    def get_unsafe_leaves(self, model, df, feature_columns):
63        """
64        Given a trained decision tree, and a dataframe containing the training data, returns a list of unsafe leaves.
65        """
66        X = df[feature_columns]
67        y = df["winner"]
68        leaf_ids = model.apply(X)
69        unique_leaves = np.unique(leaf_ids)
70
71        unsafe_leaves = []
72        # Iterate over each leaf
73        for leaf in unique_leaves:
74            leaf_mask = leaf_ids == leaf
75            # Get samples that land in this leaf
76            leaf_X = X[leaf_mask]
77
78            predicted_config = model.predict(leaf_X.iloc[[0]])[0]
79
80            # For each sample, check if we should mark the leaf as unsafe
81            for idx, row in leaf_X.iterrows():
82                choice2time = json.loads(df.loc[idx, "choice2time"])
83                if self.is_unsafe_leaf(row, predicted_config, choice2time):
84                    unsafe_leaves.append(leaf)
85                    break
86        return unsafe_leaves
87
88    def get_allowed_wrong_prediction_pct(self):
89        """
90        This is used to determine a threshold for when a learned heuristic returns 'unsure'.
91        If this function returns 0.01, we will set the probability required for the decision tree to return a decision
92        such that at most 1% of the predictions will be wrong on the validation set.
93        """
94        return 0.01
95
96    def get_grid_search_values(self):
97        """
98        Standard values for grid search. Can be overriden.
99        """
100        return {
101            "max_depth": [5, 6, 7],
102            "min_samples_leaf": [1, 5, 10, 0.01, 0.05, 0.02],
103            "criterion": ["gini", "entropy"],
104        }
105
106    def predict(self, model, df, feature_columns):
107        """
108        Returns the predictions, probabilities, and leaf ids for a given dataframe.
109        """
110        predictions = model.predict(df[feature_columns])
111        proba = model.predict_proba(df[feature_columns])
112        leaf_ids = model.apply(df[feature_columns])
113        return predictions, proba, leaf_ids
114
115    def ranking_num_choices(self):
116        # if the heuristic is used for ranking, this function returns the number
117        # of choices that the heuristic will return
118        if self.args.ranking is None:
119            return 5
120        return self.args.ranking
121
122    def train_and_evaluate_models(
123        self,
124        datasets,
125        max_depths,
126        min_samples_leafs,
127        criterion_list,
128        feature_columns,
129        ranking=False,
130    ):
131        """
132        Does a grid search over max_depths, min_samples_leafs, and criterion_list and returns the best model.
133        """
134
135        results = []
136        best_model = None
137        best_model_safe_proba = 0
138        best_model_num_correct = 0
139        best_model_num_wrong = 0
140        best_model_unsafe_leaves = []
141        columns = ["set", "crit", "max_depth", "min_samples_leaf"]
142        metrics_columns = []
143        for max_depth, min_samples_leaf, criterion in itertools.product(
144            max_depths, min_samples_leafs, criterion_list
145        ):
146            print(
147                f"max_depth={max_depth} min_samples_leaf={min_samples_leaf} criterion={criterion}"
148            )
149            model = DecisionTreeClassifier(
150                max_depth=max_depth,
151                min_samples_leaf=min_samples_leaf,
152                criterion=criterion,
153                random_state=42,
154            )
155            df_train = datasets["train"]
156            df_val = datasets["val"]
157            if ranking:
158                model.fit(
159                    df_train[feature_columns],
160                    df_train["winner"],
161                    sample_weight=df_train["relative_performance"],
162                )
163            else:
164                model.fit(df_train[feature_columns], df_train["winner"])
165
166            model = DecisionTree(model, feature_columns)
167
168            if ranking:
169                model.prune(df_train, "winner", k=self.ranking_num_choices())
170
171            unsafe_leaves = self.get_unsafe_leaves(model, df_train, feature_columns)
172            predictions, proba, leaf_ids = self.predict(model, df_val, feature_columns)
173
174            wrong_pct = self.get_allowed_wrong_prediction_pct()
175            evaluator = DecisionEvaluator(
176                self,
177                model,
178                predictions,
179                df_val,
180                proba,
181                wrong_pct=wrong_pct,
182                unsafe_leaves=unsafe_leaves,
183                leaf_ids=leaf_ids,
184                k=self.ranking_num_choices(),
185                ranking=ranking,
186            )
187            safe_proba = evaluator.get_safe_proba()
188            print(f"safe_proba={safe_proba}")
189
190            def eval(name, df):
191                if ranking:
192                    # when ranking is enabled, we duplicate each input for each choice that
193                    # is almost as good as the best choice
194                    # we do not want to evaluate the same input multiple times, so we remove duplicates here
195                    df = df[df["winner"] == df["actual_winner"]]
196                predictions, proba, leaf_ids = self.predict(model, df, feature_columns)
197                evaluator = DecisionEvaluator(
198                    self,
199                    model,
200                    predictions,
201                    df,
202                    proba,
203                    wrong_pct=wrong_pct,
204                    threshold=safe_proba,
205                    unsafe_leaves=unsafe_leaves,
206                    leaf_ids=leaf_ids,
207                    k=self.ranking_num_choices(),
208                    ranking=ranking,
209                )
210                return evaluator.get_results()
211
212            for dataset_name, dataset in datasets.items():
213                eval_result: EvalResults = eval(dataset_name, dataset)
214                eval_result_metrics = eval_result.to_map()
215                if dataset_name == "val":
216                    num_correct = eval_result.accuracy.num_correct
217                    num_wrong = eval_result.accuracy.num_wrong
218                    num_total = eval_result.accuracy.total
219                    if num_wrong <= num_total * wrong_pct:
220                        if num_correct > best_model_num_correct:
221                            print(
222                                f"new best model with {num_correct} correct and {num_wrong} wrong"
223                            )
224                            best_model = model
225                            best_model_num_correct = num_correct
226                            best_model_num_wrong = num_wrong
227                            best_model_safe_proba = safe_proba
228                            best_model_unsafe_leaves = unsafe_leaves
229
230                result = (dataset_name, criterion, max_depth, min_samples_leaf)
231                result += tuple(eval_result_metrics.values())
232                results.append(result)
233                if len(metrics_columns) == 0:
234                    metrics_columns = list(eval_result_metrics.keys())
235                    columns += metrics_columns
236
237        return (
238            pd.DataFrame(results, columns=columns),
239            best_model,
240            best_model_safe_proba,
241            best_model_unsafe_leaves,
242        )
243
244    def get_test_and_val_size(self):
245        """
246        Returns the size of the test and validation sets.
247        """
248        return (0.15, 0.15)
249
250    def prepare_datasets(self, df, other_datasets, cat_feature2cats, ranking=False):
251        """
252        Splits the dataframe into train, val, and test sets.
253        Also adds other datasets, specified by the user, to the train set.
254        """
255        test_size, val_size = self.get_test_and_val_size()
256        # Split into train+val and test
257        df_train_val, df_test = train_test_split(
258            df, test_size=test_size, random_state=42
259        )
260
261        # Split train+val inputs into train and val
262        train_val_size = 1 - test_size
263        df_train, df_val = train_test_split(
264            df_train_val, test_size=val_size / train_val_size, random_state=42
265        )
266        datasets = {"train": df_train, "val": df_val, "test": df_test}
267        self.add_real_datasets(datasets, other_datasets, cat_feature2cats, ranking)
268        return datasets
269
270    def export_to_dot(self, best_model, df, feature_columns):
271        """
272        Export a learned decision tree to a dot file.
273        """
274        dot_str = best_model.to_dot()
275        with open("best_model.dot", "w") as f:
276            f.write(dot_str)
277
278    def get_feature_columns(self, df):
279        """
280        The dataframe contains columns that are not features, such as 'winner', 'speedup' that are only used for
281        debugging purposes. This function returns the columns that are actually features.
282        """
283        exclude_columns = [
284            "speedup",
285            "winner",
286            "target",
287            "avail_choices",
288            "choice2time",
289            "index",
290            "actual_winner",
291            "relative_performance",
292        ]
293        feature_columns = [col for col in df.columns if col not in exclude_columns]
294        return feature_columns
295
296    def add_training_data(self, df_train, datasets):
297        return datasets["train"]
298
299    def main(
300        self,
301        log_path,
302        other_datasets,
303        nrows,
304        heuristic_name,
305        save_dot=False,
306        ranking=False,
307    ):
308        """
309        Main function that trains a decision tree and generates a heuristic.
310        """
311        # TODO: Enable apply_filters
312        (df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df(
313            log_path, nrows=nrows, apply_filters=False, add_near_best=ranking
314        )
315        self.dummy_col_2_col_val = dummy_col_2_col_val
316        datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats, ranking)
317        df_train = self.add_training_data(datasets["train"], datasets)
318        datasets["train"] = df_train
319        print(datasets["train"]["winner"].value_counts().to_string())
320
321        feature_columns = self.get_feature_columns(df)
322        grid_search_values = self.get_grid_search_values()
323        max_depths = grid_search_values["max_depth"]
324        min_samples_leafs = grid_search_values["min_samples_leaf"]
325        criterion_list = grid_search_values["criterion"]
326        (
327            results_df,
328            best_model,
329            best_model_safe_proba,
330            unsafe_leaves,
331        ) = self.train_and_evaluate_models(
332            datasets,
333            max_depths,
334            min_samples_leafs,
335            criterion_list,
336            feature_columns,
337            ranking=ranking,
338        )
339
340        if ranking:
341            columns_to_keep = [
342                "set",
343                "crit",
344                "max_depth",
345                "min_samples_leaf",
346                "total",
347                "top_k_correct",
348                "top_k_wrong",
349                "top_k_unsure",
350                "wrong_max_speedup_k",
351                "wrong_gmean_speedup_k",
352            ]
353            results_df = results_df[columns_to_keep]
354        # prints results for all models and datasets
355        print(results_df.to_string())
356
357        sort_metric = "top_k_correct" if ranking else "correct"
358        # prints results grouped by dataset
359        for set_name in results_df["set"].unique():
360            dataset_results = results_df[results_df["set"] == set_name]
361            dataset_results = dataset_results.sort_values(by=sort_metric)
362            print(dataset_results.to_string() + "\n")
363
364        if best_model is not None:
365            if save_dot:
366                self.export_to_dot(best_model, df, feature_columns)
367            self.codegen(
368                best_model,
369                metadata,
370                heuristic_name,
371                best_model_safe_proba,
372                dummy_col_2_col_val,
373                unsafe_leaves,
374            )
375        else:
376            print(
377                "All learned models have too many wrong predictions, so no heuristic was generated"
378            )
379
380    def get_df(
381        self,
382        log_path,
383        cat_feature2cats=None,
384        nrows=None,
385        apply_filters=False,
386        add_near_best=False,
387    ):
388        """
389        Parses the log file and processes the data into a dataframe that can be used for training.
390        """
391        (df, metadata, features, categorical_features, choices) = self.parse_log(
392            log_path, nrows
393        )
394
395        def calculate_stats(group):
396            count = len(group)
397            has_inf = np.isinf(group["feedback"]).any()
398            if has_inf:
399                relative_std = np.inf
400                median = np.inf
401            else:
402                mean = group["feedback"].mean()
403                std = group["feedback"].std()
404                relative_std = (std / mean) * 100 if mean != 0 else np.inf
405                median = group["feedback"].median()
406            if relative_std > 5:
407                times = group["feedback"].tolist()
408                times_str = ", ".join([f"{t:.3f}" for t in sorted(times)])
409                log.debug("High relative std: %f. times=%s", relative_std, times_str)
410            return pd.Series(
411                {
412                    "count": count,
413                    "relative_std": relative_std,
414                    "median_execution_time": median,
415                }
416            )
417
418        feature_columns = features
419        stats = (
420            df.groupby(feature_columns + ["choice"], as_index=False)
421            .apply(calculate_stats, include_groups=False)
422            .reset_index()
423        )
424
425        # TODO: We have to be careful with removing certain choices, because if we e.g. remove the winner, the
426        # heuristic will end up learning wrong things. But, execution times with high variance are also bad
427        if apply_filters:
428            # Filter out inputs with less than 3 measurements or high relative std
429            valid_stats = stats[(stats["count"] >= 3) & (stats["relative_std"] <= 5)]
430            # Group by input features and count how many valid choices we have for each input
431            valid_inputs = valid_stats.groupby(feature_columns).filter(
432                lambda x: len(x) >= 2
433            )
434        else:
435            valid_inputs = stats
436
437        # Compute the winner and speedup for each valid input
438        def get_winner_and_speedup(group):
439            assert len(group) >= 2, "Need at least 2 choices"
440
441            sorted_group = group.sort_values("median_execution_time")
442            winner = sorted_group.iloc[0]["choice"]
443            winning_time = sorted_group.iloc[0]["median_execution_time"]
444            second_best_time = sorted_group.iloc[1]["median_execution_time"]
445            speedup = second_best_time / winning_time
446            unique_choices = group["choice"].unique()
447
448            choice2time = {}
449            for row in group.itertuples():
450                choice2time[row.choice] = row.median_execution_time
451
452            assert len(unique_choices) == len(
453                group
454            ), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}"
455
456            return pd.Series(
457                {
458                    "winner": winner,
459                    "speedup": speedup,
460                    "avail_choices": unique_choices,
461                    "choice2time": json.dumps(choice2time),
462                }
463            )
464
465        results = (
466            valid_inputs.groupby(feature_columns, as_index=False)
467            .filter(lambda x: len(x) >= 2)
468            .groupby(feature_columns, as_index=False)
469            .apply(get_winner_and_speedup, include_groups=False)
470            .reset_index()
471        )
472
473        def add_near_best_configs(df):
474            new_rows = []
475
476            for index, row in df.iterrows():
477                dictionary = json.loads(row["choice2time"])
478                min_value = min(dictionary.values())
479
480                for key, value in dictionary.items():
481                    new_row = row.copy()
482                    relative_performance = min_value / value
483                    new_row["relative_performance"] = relative_performance
484                    if relative_performance is None or relative_performance is np.inf:
485                        breakpoint()
486                    new_row["actual_winner"] = row["winner"]
487                    new_row["winner"] = key
488                    if relative_performance >= 0.98:
489                        new_rows.append(new_row)
490
491            return pd.DataFrame(new_rows).reset_index(drop=True)
492
493        if add_near_best:
494            results = add_near_best_configs(results)
495        (results, added_categorical_features) = self.add_new_features(results)
496        categorical_features += added_categorical_features
497
498        (
499            results,
500            cat_feature2cats,
501            dummy_col_2_col_val,
502        ) = self.handle_categorical_features(
503            cat_feature2cats, categorical_features, results
504        )
505        return (results, choices, cat_feature2cats, dummy_col_2_col_val, metadata)
506
507    def ranking_always_included_choices(self):
508        return []
509
510    def gen_classes(self, classes, num_spaces):
511        """
512        If classes=['choice1', 'choice2', 'choice3'], then this function returns
513        the following string:
514        self.choices.append('choice1')
515        self.choices.append('choice2')
516        self.choices.append('choice3')
517        Used in the generated heuristic to map the index of a choice to its name.
518        """
519        indent = " " * num_spaces
520        return "\n".join([f"{indent}self.choices.append('{c}')" for c in classes])
521
522    def get_default_config(self, row):
523        """
524        Returns the default config for a given sample. The default config could for example be the config that is
525        the chosen by a current handwritten heuristic. This can for example be used in get_unsafe_leaf to
526        compare the predicted config with the default config.
527        """
528        return None
529
530    def gen_predict_fn_def(self):
531        """
532        Generates the definition of the predict function.
533        """
534        return "def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:"
535
536    def codegen_boilerplate(
537        self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
538    ):
539        """
540        Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
541        etc.
542        """
543
544        boiler_plate = f"""# flake8: noqa: B950
545# fmt: off
546# This file was generated by AutoHeuristic. Do not modify it manually!
547# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/{opt_name}/
548from typing import List, Optional, Tuple
549
550from torch._inductor.autoheuristic.autoheuristic_utils import (
551    AHContext,
552    AHMetadata,
553    Choice,
554)
555from torch._inductor.autoheuristic.learnedheuristic_interface import (
556    LearnedHeuristicDecision,
557)
558
559
560class {heuristic_name}(LearnedHeuristicDecision):
561
562    def __init__(self) -> None:
563        self.choices: List[Choice] = []
564        self.fill_choices()
565
566{self.gen_precondition(opt_name, shared_memory, device_capa)}
567
568    def get_confidence_threshold(self) -> float:
569        return {threshold}
570
571    def get_choice(self, idx: int) -> Optional[str]:
572        if idx < len(self.choices):
573            return self.choices[idx]
574        return None
575
576    def fill_choices(self) -> None:
577{self.gen_classes(classes, num_spaces=8)}
578
579    def get_name(self) -> str:
580        return '{opt_name}'"""
581        return boiler_plate
582
583    def add_real_datasets(
584        self, datasets, other_datasets, cat_feature2cats, ranking=False
585    ):
586        """
587        Adds datasets specified by the user to the datasets dictionary.
588        """
589        if other_datasets:
590            for name, path in other_datasets:
591                (df_other, choices, _, _, _) = self.get_df(
592                    path,
593                    cat_feature2cats=cat_feature2cats,
594                    apply_filters=False,
595                    add_near_best=ranking,
596                )
597                datasets[name] = df_other
598
599    def codegen(
600        self,
601        tree,
602        metadata,
603        heuristic_name,
604        threshold,
605        dummy_col_2_col_val,
606        unsafe_leaves,
607    ):
608        lines = []
609        device_capa = metadata["device_capa"]
610        device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
611        opt_name = metadata["name"]
612        lines.append(
613            self.codegen_boilerplate(
614                heuristic_name,
615                opt_name,
616                threshold,
617                metadata["shared_memory"],
618                device_capa_str,
619                tree.classes_,
620            )
621        )
622        fn_def = f"\n    {self.gen_predict_fn_def()}"
623        lines.append(fn_def)
624        tree.codegen(dummy_col_2_col_val, lines, unsafe_leaves)
625        self.write_heuristic_to_file(lines, heuristic_name)
626
627
628@dataclass
629class AccuracyMetrics:
630    # Number of correct predictions
631    num_correct: int
632    # Number of wrong predictions
633    num_wrong: int
634    # Number of predictions where model is unsure
635    num_unsure: int
636    # Total number of predictions
637    total: int
638
639    def to_map(self):
640        return {
641            "correct": self.num_correct,
642            "wrong": self.num_wrong,
643            "unsure": self.num_unsure,
644            "total": self.total,
645        }
646
647
648@dataclass
649class WrongSpeedupMetrics:
650    # If the model predicted the wrong choice, this is the maximum speedup of the best choice over the predicted choice
651    max_speedup: float
652    # For all wrong predictions, this is the geometric mean of the speedups of the best choices over the predicted choices
653    gmean_speedup: float
654
655    def to_map(self):
656        return {
657            "wrong_max_speedup": self.max_speedup,
658            "wrong_gmean_speedup": self.gmean_speedup,
659        }
660
661
662@dataclass
663class RankingMetrics:
664    # Number of predictions where best choice is in top k choices
665    num_correct: int
666    # Number of predictions where best choice is not in top k choices
667    num_wrong: int
668    # Maximum speedup of best choice over best choice in top k (this tells us how much better the best choice, which
669    # is not in top k, is over the best choice in top k)
670    max_speedup: float
671    # Geometric mean of speedups of best choice over best choice in top k
672    gmean_speedup: float
673    # Number of predictions where model is unsure
674    unsure: int
675
676    def to_map(self):
677        return {
678            "top_k_correct": self.num_correct,
679            "top_k_wrong": self.num_wrong,
680            "wrong_max_speedup_k": self.max_speedup,
681            "wrong_gmean_speedup_k": self.gmean_speedup,
682            "top_k_unsure": self.unsure,
683        }
684
685
686@dataclass
687class DefaultComparisonMetrics:
688    # Maximum speedup of predicted choice over default choice
689    max_speedup: float
690    # Geometric mean of speedups of predicted choices over default choices
691    gmean_speedup: float
692    # Maximum speedup of default choice over predicted choice
693    max_slowdown: float
694    # Number of predictions where the predicted choice is not the default choice
695    non_default_predictions: int
696    # Number of predictions where the default choice is better than the predicted choice
697    default_better: bool
698
699    def to_map(self):
700        return {
701            "max_speedup_over_default": self.max_speedup,
702            "gmean_speedup_over_default": self.gmean_speedup,
703            "max_speedup_default_over_heuristic": self.max_slowdown,
704            "non_default_predictions": self.non_default_predictions,
705            "default_better": self.default_better,
706        }
707
708
709@dataclass
710class EvalResults:
711    accuracy: AccuracyMetrics
712    speedup: WrongSpeedupMetrics
713    ranking: RankingMetrics
714    default_comparison: DefaultComparisonMetrics
715
716    def to_map(self):
717        return {
718            **self.accuracy.to_map(),
719            **self.speedup.to_map(),
720            **self.ranking.to_map(),
721            **self.default_comparison.to_map(),
722        }
723
724
725class DecisionEvaluator:
726    def __init__(
727        self,
728        train,
729        model,
730        predictions,
731        df,
732        probas,
733        wrong_pct=0.01,
734        threshold=0.0,
735        k=10,
736        unsafe_leaves=None,
737        leaf_ids=None,
738        ranking=False,
739    ) -> None:
740        self.train = train
741        self.model = model
742        self.predictions = predictions
743        self.df = df
744        self.probas = probas
745        self.wrong_pct = wrong_pct
746        self.threshold = threshold
747        self.k = k
748        self.unsafe_leaves = unsafe_leaves
749        self.leaf_ids = leaf_ids
750        self.ranking = ranking
751
752        self.num_correct = 0
753        self.num_wrong = 0
754        self.num_unsure = 0
755        self.wrong_probas = []
756        self.speedups_wrong = []
757        self.num_correct_top_k = 0
758        self.num_wrong_top_k = 0
759        self.wrong_speedups_top_k = []
760        self.top_k_unsure = 0
761        self.num_non_default_predictions = 0
762        self.speedups_over_default = []
763        self.num_default_better = 0
764
765    def compute_speedup_over_default(self, default_config, pred, i, predicted_time):
766        if default_config is not None:
767            if pred != default_config:
768                self.num_non_default_predictions += 1
769            default_time = self.get_time(self.df.iloc[i], default_config)
770            # TODO: We should keep track of how often this happens
771            if default_time is not None and not math.isinf(default_time):
772                speedup_over_default = default_time / predicted_time
773                if speedup_over_default < 1:
774                    self.num_default_better += 1
775                self.speedups_over_default.append(speedup_over_default)
776            else:
777                log.debug(
778                    "cannot compute speedup over default because default_time=%d",
779                    default_time,
780                )
781
782    def get_time(self, row, choice):
783        choices_feedback = json.loads(row["choice2time"])
784        return choices_feedback.get(choice, None)
785
786    def top_k_classes(self, model, probas, k, avail_choices):
787        # Get classes and their corresponding probabilities
788        classes = model.classes_
789        class_proba_pairs = list(zip(classes, probas))
790
791        # Sort by probability (descending) and filter out zero probabilities
792        sorted_classes = [
793            c
794            for c, p in sorted(zip(classes, probas), key=lambda x: x[1], reverse=True)
795            if p > 0 and c in avail_choices
796        ]
797
798        # Return top k choices
799        top_k_choices = sorted_classes[:k]
800        top_k_choices += self.train.ranking_always_included_choices()
801        top_k_choices = list(dict.fromkeys(top_k_choices))
802        return top_k_choices
803
804    def eval_prediction(
805        self, avail_choices, leaf_id, pred, true, prob, threshold, default_config, i
806    ):
807        predicted_time = self.get_time(self.df.iloc[i], pred)
808        max_prob = max(prob)
809        if (
810            leaf_id in self.unsafe_leaves
811            or pred not in avail_choices
812            or (max_prob != 1.0 and max_prob <= threshold)
813        ):
814            self.num_unsure += 1
815            self.speedups_over_default.append(1.0)
816        elif pred == true:
817            self.compute_speedup_over_default(default_config, pred, i, predicted_time)
818            self.num_correct += 1
819        else:
820            self.compute_speedup_over_default(default_config, pred, i, predicted_time)
821            self.num_wrong += 1
822            self.wrong_probas.append(max_prob)
823            best_time = self.get_time(self.df.iloc[i], true)
824            wrong_speedup = predicted_time / best_time
825            self.speedups_wrong.append(wrong_speedup)
826
827    def eval_ranking_prediction(self, true, top_k_choices, i):
828        if true in top_k_choices:
829            self.num_correct_top_k += 1
830        else:
831            top_k_choices_times = []
832            for choice in top_k_choices:
833                time = self.get_time(self.df.iloc[i], choice)
834                if time is not None:
835                    top_k_choices_times.append(time)
836            best_time = self.get_time(self.df.iloc[i], true)
837            min_time = min(top_k_choices_times, default=None)
838            if min_time is not None:
839                speedup = min_time / best_time
840                self.wrong_speedups_top_k.append(speedup)
841                self.num_wrong_top_k += 1
842            else:
843                self.top_k_unsure += 1
844                # TODO (AlnisM): print more info (input and choices)
845                log.debug(
846                    "All top k choices have no time which means all top k are unavailable"
847                )
848
849    def get_safe_proba(self):
850        return self.get_results(return_safe_proba=True)
851
852    def compute_safe_proba(self, num_predictions, wrong_probas, wrong_pct):
853        wrong_probas.sort()
854        num_wrong = len(wrong_probas)
855        allowed_wrong = int(num_predictions * wrong_pct)
856        if allowed_wrong >= num_wrong:
857            return 0.0
858        too_many_wrong = num_wrong - allowed_wrong
859        idx = min(too_many_wrong, len(wrong_probas) - 1)
860        return wrong_probas[idx]
861
862    def get_results(self, return_safe_proba=False) -> EvalResults:
863        """
864        Custom evaluation function that evaluates a learned decision tree.
865        """
866
867        y_true = self.df["actual_winner"] if self.ranking else self.df["winner"]
868        i = 0
869        for pred, true, prob, leaf_id in zip(
870            self.predictions, y_true, self.probas, self.leaf_ids
871        ):
872            avail_choices = self.df["avail_choices"].iloc[i]
873            top_k_choices = self.top_k_classes(
874                self.model, prob, k=self.k, avail_choices=avail_choices
875            )
876            assert (
877                true in avail_choices
878            ), f"Best choice {true} not in available choices {avail_choices}"
879            default_config = self.train.get_default_config(self.df.iloc[i])
880            self.eval_prediction(
881                avail_choices,
882                leaf_id,
883                pred,
884                true,
885                prob,
886                self.threshold,
887                default_config,
888                i,
889            )
890            self.eval_ranking_prediction(true, top_k_choices, i)
891            i += 1
892
893        total = len(self.predictions)
894        if return_safe_proba:
895            return self.compute_safe_proba(total, self.wrong_probas, self.wrong_pct)
896
897        def safe_gmean(x):
898            return gmean(x) if x else 0
899
900        max_speedup = max(self.speedups_wrong, default=0)
901        gmean_speedup = safe_gmean(self.speedups_wrong)
902        max_speedup_top_k = max(self.wrong_speedups_top_k, default=0)
903        gmean_speedup_top_k = safe_gmean(self.wrong_speedups_top_k)
904        max_speedup_over_default = max(self.speedups_over_default, default=0)
905        gmean_speedup_over_default = safe_gmean(self.speedups_over_default)
906        max_slowdown_over_default = min(self.speedups_over_default, default=0)
907
908        accuracyMetrics = AccuracyMetrics(
909            self.num_correct, self.num_wrong, self.num_unsure, total
910        )
911        wrongSpeedupMetrics = WrongSpeedupMetrics(max_speedup, gmean_speedup)
912        rankingMetrics = RankingMetrics(
913            self.num_correct_top_k,
914            self.num_wrong_top_k,
915            max_speedup_top_k,
916            gmean_speedup_top_k,
917            self.top_k_unsure,
918        )
919        defaultComparisonMetrics = DefaultComparisonMetrics(
920            max_speedup_over_default,
921            gmean_speedup_over_default,
922            max_slowdown_over_default,
923            self.num_non_default_predictions,
924            self.num_default_better,
925        )
926        return EvalResults(
927            accuracyMetrics,
928            wrongSpeedupMetrics,
929            rankingMetrics,
930            defaultComparisonMetrics,
931        )
932
933
934if __name__ == "__main__":
935    train = AHTrainDecisionTree()
936    train.generate_heuristic()
937