""" This script uses linear programming to analyze outputs of triton mm config tuning. To generate output that can be fed into this script set the env varTORCHINDUCTOR_MM_LOGGING_FILE. That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates. """ import json import click import pulp def parse_log_file(file_path): with open(file_path) as f: logs = json.load(f) occurrence_count = {} benchmark_logs = {} # Parse the logs for entry in logs: if "invoke" in entry: shape = entry["invoke"] if shape not in occurrence_count: occurrence_count[shape] = 0 occurrence_count[shape] += 1 else: for shape, timings in entry.items(): if shape not in benchmark_logs: benchmark_logs[shape] = [] benchmark_logs[shape].extend(timings) return occurrence_count, benchmark_logs def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False): # Set of all possible Triton templates keyed by their attributes triton_templates = set() for timings in benchmark_logs.values(): for timing in timings: if timing["type"] == "triton": triton_templates.add( ( timing["BLOCK_M"], timing["BLOCK_N"], timing["BLOCK_K"], timing["num_stages"], timing["num_warps"], ) ) # Print the initial data if verbose: print("Occurrence Count:", occurrence_count) print("Triton Templates:", triton_templates) # Create a dictionary to store template selection variables template_vars = { template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary) for template in triton_templates } # Variables to select specific timing option for each shape selection_vars = { (shape, "cublas"): pulp.LpVariable( f"Select_{shape}_cublas", 0, 1, pulp.LpBinary ) for shape in occurrence_count } for shape in occurrence_count: for template in triton_templates: selection_vars[(shape, template)] = pulp.LpVariable( f"Select_{shape}_{template}", 0, 1, pulp.LpBinary ) # Variables for the total time for each shape min_time_vars = pulp.LpVariable.dicts( "MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous ) # Define the problem prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize) # Objective: Minimize the weighted total time prob += pulp.lpSum( [occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count] ) # Constraints to select exactly N templates prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N # Store triton options per shape for debugging triton_options_per_shape = {} # Constraints for the total time for each shape for shape in occurrence_count: # Get cuBLAS time cublas_times = [ timing["time"] for timing in benchmark_logs[shape] if timing["type"] == "cublas" ] min_cublas_time = min(cublas_times) # Collect Triton options triton_options = [] for template in triton_templates: triton_times = [ timing["time"] for timing in benchmark_logs[shape] if timing["type"] == "triton" and ( timing["BLOCK_M"], timing["BLOCK_N"], timing["BLOCK_K"], timing["num_stages"], timing["num_warps"], ) == template ] if triton_times: min_triton_time = min(triton_times) triton_options.append((min_triton_time, template)) # Save triton options for debugging triton_options_per_shape[shape] = triton_options # Ensure exactly one timing option is selected for each shape prob += ( pulp.lpSum( [selection_vars[(shape, "cublas")]] + [ selection_vars[(shape, template)] for triton_time, template in triton_options ] ) == 1 ) # Ensure min_time_vars[shape] matches the selected timing option prob += min_time_vars[shape] == ( selection_vars[(shape, "cublas")] * min_cublas_time + pulp.lpSum( [ selection_vars[(shape, template)] * triton_time for triton_time, template in triton_options ] ) ) # Ensure Triton templates can only be selected if they are included in the N allowed templates for triton_time, template in triton_options: prob += selection_vars[(shape, template)] <= template_vars[template] # Print the constraints if verbose: print("Constraints:") for constraint in prob.constraints.values(): print(constraint) # Solve the problem with suppressed output prob.solve(pulp.PULP_CBC_CMD(msg=False)) # Output the selected templates and their configurations selected_templates = [ template for template in triton_templates if pulp.value(template_vars[template]) == 1 ] total_time = sum( pulp.value(min_time_vars[shape]) * occurrence_count[shape] for shape in occurrence_count ) # Print the values of the decision variables after solving if verbose: print("Decision Variable Values:") for var in prob.variables(): print(f"{var.name} = {var.varValue}") # # Debugging information if verbose: for shape in occurrence_count: print(f"Shape: {shape}") print(f" Min Time: {pulp.value(min_time_vars[shape])}") print(f" Occurrences: {occurrence_count[shape]}") print( f" Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}" ) for triton_time, template in triton_options_per_shape[shape]: print( f" Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}" ) return selected_templates, total_time # Main code to parse the log file and optimize templates @click.command() @click.argument("filename") @click.option("--min-templates", default=0, help="Minimum number of templates.") @click.option("--max-templates", default=10, help="Maximum number of templates.") @click.option("--verbose", is_flag=True, help="Enable verbose output.") def main(filename, min_templates, max_templates, verbose): occurrence_count, benchmark_logs = parse_log_file(filename) times = [] for N in range(min_templates, max_templates + 1): selected_templates, total_time = optimize_templates( N, occurrence_count, benchmark_logs, verbose ) print(f"N = {N}") print(f"Selected Templates: {selected_templates}") print(f"Total Weighted Time: {total_time}") times.append(total_time) print(times) if __name__ == "__main__": main()