1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package org.pytorch.minibench; 10 11 import android.app.Activity; 12 import android.content.Intent; 13 import android.os.AsyncTask; 14 import android.os.Bundle; 15 import android.system.ErrnoException; 16 import android.system.Os; 17 import com.google.gson.Gson; 18 import java.io.File; 19 import java.io.FileWriter; 20 import java.io.IOException; 21 import java.util.ArrayList; 22 import java.util.Arrays; 23 import java.util.List; 24 import java.util.stream.Collectors; 25 import org.pytorch.executorch.Module; 26 27 public class BenchmarkActivity extends Activity { 28 @Override onCreate(Bundle savedInstanceState)29 protected void onCreate(Bundle savedInstanceState) { 30 super.onCreate(savedInstanceState); 31 32 try { 33 Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); 34 } catch (ErrnoException e) { 35 finish(); 36 } 37 38 Intent intent = getIntent(); 39 File modelDir = new File(intent.getStringExtra("model_dir")); 40 File model = 41 Arrays.stream(modelDir.listFiles()) 42 .filter(file -> file.getName().endsWith(".pte")) 43 .findFirst() 44 .get(); 45 46 int numIter = intent.getIntExtra("num_iter", 50); 47 48 // TODO: Format the string with a parsable format 49 Stats stats = new Stats(); 50 51 new AsyncTask<Void, Void, Void>() { 52 @Override 53 protected Void doInBackground(Void... voids) { 54 55 // Record the time it takes to load the model and the forward method 56 stats.loadStart = System.nanoTime(); 57 Module module = Module.load(model.getPath()); 58 stats.errorCode = module.loadMethod("forward"); 59 stats.loadEnd = System.nanoTime(); 60 61 for (int i = 0; i < numIter; i++) { 62 long start = System.nanoTime(); 63 module.forward(); 64 double forwardMs = (System.nanoTime() - start) * 1e-6; 65 stats.latency.add(forwardMs); 66 } 67 return null; 68 } 69 70 @Override 71 protected void onPostExecute(Void aVoid) { 72 73 final BenchmarkMetric.BenchmarkModel benchmarkModel = 74 BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); 75 final List<BenchmarkMetric> results = new ArrayList<>(); 76 // The list of metrics we have atm includes: 77 // Avg inference latency after N iterations 78 results.add( 79 new BenchmarkMetric( 80 benchmarkModel, 81 "avg_inference_latency(ms)", 82 stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), 83 0.0f)); 84 // Model load time 85 results.add( 86 new BenchmarkMetric( 87 benchmarkModel, 88 "model_load_time(ms)", 89 (stats.loadEnd - stats.loadStart) * 1e-6, 90 0.0f)); 91 // Load status 92 results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); 93 94 try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { 95 Gson gson = new Gson(); 96 writer.write(gson.toJson(results)); 97 } catch (IOException e) { 98 e.printStackTrace(); 99 } 100 } 101 }.execute(); 102 } 103 } 104 105 class Stats { 106 long loadStart; 107 long loadEnd; 108 List<Double> latency = new ArrayList<>(); 109 int errorCode = 0; 110 111 @Override toString()112 public String toString() { 113 return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining("")); 114 } 115 } 116