1from typing import Any, Dict, List, Optional, Tuple 2 3import numpy as np 4from sklearn.tree import _tree # type: ignore[import-untyped] 5 6 7class DecisionTreeNode: 8 def __init__( 9 self, 10 feature: Optional[str] = None, 11 threshold: Optional[float] = None, 12 left: Optional["DecisionTreeNode"] = None, 13 right: Optional["DecisionTreeNode"] = None, 14 class_probs: Any = None, 15 num_samples: int = 0, 16 node_id: int = 0, 17 ) -> None: 18 self.feature = feature 19 self.threshold = threshold 20 self.left = left 21 self.right = right 22 self.class_probs = class_probs 23 self.num_samples = num_samples 24 self.id = node_id 25 26 def is_leaf(self) -> bool: 27 return self.left is None or self.right is None 28 29 30class DecisionTree: 31 """ 32 Custom decision tree implementation that mimics some of the sklearn API. 33 The purpose of this class it to be able to perform transformations, such as custom pruning, which 34 does not seem to be easy with sklearn. 35 """ 36 37 def __init__(self, sklearn_tree: Any, feature_names: List[str]) -> None: 38 self.feature_names = feature_names 39 self.root = self._convert_sklearn_tree(sklearn_tree.tree_) 40 self.classes_: List[str] = sklearn_tree.classes_ 41 42 def _convert_sklearn_tree( 43 self, sklearn_tree: Any, node_id: int = 0 44 ) -> DecisionTreeNode: 45 class_probs = sklearn_tree.value[node_id][0] 46 num_samples = sklearn_tree.n_node_samples[node_id] 47 if sklearn_tree.feature[node_id] != _tree.TREE_UNDEFINED: 48 feature_index = sklearn_tree.feature[node_id] 49 feature = self.feature_names[feature_index] 50 left = self._convert_sklearn_tree( 51 sklearn_tree, sklearn_tree.children_left[node_id] 52 ) 53 right = self._convert_sklearn_tree( 54 sklearn_tree, sklearn_tree.children_right[node_id] 55 ) 56 return DecisionTreeNode( 57 feature=feature, 58 threshold=sklearn_tree.threshold[node_id], 59 left=left, 60 right=right, 61 class_probs=class_probs, 62 num_samples=num_samples, 63 node_id=node_id, 64 ) 65 else: 66 return DecisionTreeNode( 67 class_probs=class_probs, num_samples=num_samples, node_id=node_id 68 ) 69 70 def prune(self, df: Any, target_col: str, k: int) -> None: 71 self.root = self._prune_tree(self.root, df, target_col, k) 72 73 def _prune_tree( 74 self, node: DecisionTreeNode, df: Any, target_col: str, k: int 75 ) -> DecisionTreeNode: 76 if node.is_leaf(): 77 return node 78 79 left_df = df[df[node.feature] <= node.threshold] 80 right_df = df[df[node.feature] > node.threshold] 81 82 # number of unique classes in the left and right subtrees 83 left_counts = left_df[target_col].nunique() 84 right_counts = right_df[target_col].nunique() 85 86 # for ranking, we want to ensure that we return at least k classes, so if we have less than k classes in the 87 # left or right subtree, we remove the split and make this node a leaf node 88 if left_counts < k or right_counts < k: 89 return DecisionTreeNode(class_probs=node.class_probs) 90 91 assert node.left is not None, "expected left child to exist" 92 node.left = self._prune_tree(node.left, left_df, target_col, k) 93 assert node.right is not None, "expected right child to exist" 94 node.right = self._prune_tree(node.right, right_df, target_col, k) 95 96 return node 97 98 def to_dot(self) -> str: 99 dot = "digraph DecisionTree {\n" 100 dot += ' node [fontname="helvetica"];\n' 101 dot += ' edge [fontname="helvetica"];\n' 102 dot += self._node_to_dot(self.root) 103 dot += "}" 104 return dot 105 106 def _node_to_dot( 107 self, node: DecisionTreeNode, parent_id: int = 0, edge_label: str = "" 108 ) -> str: 109 if node is None: 110 return "" 111 112 node_id = id(node) 113 114 # Format class_probs array with line breaks 115 class_probs_str = self._format_class_probs_array( 116 node.class_probs, node.num_samples 117 ) 118 119 if node.is_leaf(): 120 label = class_probs_str 121 shape = "box" 122 else: 123 feature_name = f"{node.feature}" 124 label = f"{feature_name} <= {node.threshold:.2f}\\n{class_probs_str}" 125 shape = "oval" 126 127 dot = f' {node_id} [label="{label}", shape={shape}];\n' 128 129 if parent_id != 0: 130 dot += f' {parent_id} -> {node_id} [label="{edge_label}"];\n' 131 132 if not node.is_leaf(): 133 assert node.left is not None, "expected left child to exist" 134 dot += self._node_to_dot(node.left, node_id, "<=") 135 assert node.right is not None, "expected right child to exist" 136 dot += self._node_to_dot(node.right, node_id, ">") 137 138 return dot 139 140 def _format_class_prob(self, num: float) -> str: 141 if num == 0: 142 return "0" 143 return f"{num:.2f}" 144 145 def _format_class_probs_array( 146 self, class_probs: Any, num_samples: int, max_per_line: int = 5 147 ) -> str: 148 # add line breaks to avoid very long lines 149 flat_class_probs = class_probs.flatten() 150 formatted = [self._format_class_prob(v) for v in flat_class_probs] 151 lines = [ 152 formatted[i : i + max_per_line] 153 for i in range(0, len(formatted), max_per_line) 154 ] 155 return f"num_samples={num_samples}\\n" + "\\n".join( 156 [", ".join(line) for line in lines] 157 ) 158 159 def predict(self, X: Any) -> Any: 160 predictions = [self._predict_single(x) for _, x in X.iterrows()] 161 return np.array(predictions) 162 163 def predict_proba(self, X: Any) -> Any: 164 return np.array([self._predict_proba_single(x) for _, x in X.iterrows()]) 165 166 def _get_leaf(self, X: Any) -> DecisionTreeNode: 167 node = self.root 168 while not node.is_leaf(): 169 if X[node.feature] <= node.threshold: 170 assert node.left is not None, "expected left child to exist" 171 node = node.left 172 else: 173 assert node.right is not None, "expected right child to exist" 174 node = node.right 175 return node 176 177 def _predict_single(self, x: Any) -> str: 178 node = self._get_leaf(x) 179 # map index to class name 180 return self.classes_[np.argmax(node.class_probs)] 181 182 def _predict_proba_single(self, x: Any) -> Any: 183 node = self._get_leaf(x) 184 return node.class_probs 185 186 def apply(self, X: Any) -> Any: 187 ids = [self._apply_single(x) for _, x in X.iterrows()] 188 return np.array(ids) 189 190 def _apply_single(self, x: Any) -> int: 191 node = self._get_leaf(x) 192 return node.id 193 194 def codegen( 195 self, 196 dummy_col_2_col_val: Dict[str, Tuple[str, Any]], 197 lines: List[str], 198 unsafe_leaves: List[int], 199 ) -> None: 200 # generates python code for the decision tree 201 def codegen_node(node: DecisionTreeNode, depth: int) -> None: 202 indent = " " * (depth + 1) 203 if node.is_leaf(): 204 lines.append(handle_leaf(node, indent, unsafe_leaves)) 205 else: 206 name = node.feature 207 threshold = node.threshold 208 if name in dummy_col_2_col_val: 209 (orig_name, value) = dummy_col_2_col_val[name] 210 predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':" 211 assert ( 212 threshold == 0.5 213 ), f"expected threshold to be 0.5 but is {threshold}" 214 else: 215 predicate = ( 216 f"{indent}if context.get_value('{name}') <= {threshold}:" 217 ) 218 lines.append(predicate) 219 assert node.left is not None, "expected left child to exist" 220 codegen_node(node.left, depth + 1) 221 lines.append(f"{indent}else:") 222 assert node.right is not None, "expected right child to exist" 223 codegen_node(node.right, depth + 1) 224 225 def handle_leaf( 226 node: DecisionTreeNode, indent: str, unsafe_leaves: List[int] 227 ) -> str: 228 """ 229 This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic 230 will return "unsure" (i.e. None). 231 """ 232 if node.id in unsafe_leaves: 233 return f"{indent}return None" 234 class_probas = node.class_probs 235 return f"{indent}return {best_probas_and_indices(class_probas)}" 236 237 def best_probas_and_indices(class_probas: Any) -> str: 238 """ 239 Given a list of tuples (proba, idx), this function returns a string in which the tuples are 240 sorted by proba in descending order. E.g.: 241 Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)] 242 this function returns 243 "[(0.5, 1), (0.3, 0), (0.2, 2)]" 244 """ 245 # we generate a list of tuples (proba, idx) sorted by proba in descending order 246 # idx is the index of a choice 247 # we only generate a tuple if proba > 0 248 probas_indices_sorted = sorted( 249 [ 250 (proba, index) 251 for index, proba in enumerate(class_probas) 252 if proba > 0 253 ], 254 key=lambda x: x[0], 255 reverse=True, 256 ) 257 probas_indices_sorted_str = ", ".join( 258 f"({value:.3f}, {index})" for value, index in probas_indices_sorted 259 ) 260 return f"[{probas_indices_sorted_str}]" 261 262 codegen_node(self.root, 1) 263