1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package org.pytorch.minibench;
10 
11 import android.app.Activity;
12 import android.content.Intent;
13 import android.os.AsyncTask;
14 import android.os.Bundle;
15 import android.system.ErrnoException;
16 import android.system.Os;
17 import com.google.gson.Gson;
18 import java.io.File;
19 import java.io.FileWriter;
20 import java.io.IOException;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23 import java.util.List;
24 import java.util.stream.Collectors;
25 import org.pytorch.executorch.Module;
26 
27 public class BenchmarkActivity extends Activity {
28   @Override
onCreate(Bundle savedInstanceState)29   protected void onCreate(Bundle savedInstanceState) {
30     super.onCreate(savedInstanceState);
31 
32     try {
33       Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
34     } catch (ErrnoException e) {
35       finish();
36     }
37 
38     Intent intent = getIntent();
39     File modelDir = new File(intent.getStringExtra("model_dir"));
40     File model =
41         Arrays.stream(modelDir.listFiles())
42             .filter(file -> file.getName().endsWith(".pte"))
43             .findFirst()
44             .get();
45 
46     int numIter = intent.getIntExtra("num_iter", 50);
47 
48     // TODO: Format the string with a parsable format
49     Stats stats = new Stats();
50 
51     new AsyncTask<Void, Void, Void>() {
52       @Override
53       protected Void doInBackground(Void... voids) {
54 
55         // Record the time it takes to load the model and the forward method
56         stats.loadStart = System.nanoTime();
57         Module module = Module.load(model.getPath());
58         stats.errorCode = module.loadMethod("forward");
59         stats.loadEnd = System.nanoTime();
60 
61         for (int i = 0; i < numIter; i++) {
62           long start = System.nanoTime();
63           module.forward();
64           double forwardMs = (System.nanoTime() - start) * 1e-6;
65           stats.latency.add(forwardMs);
66         }
67         return null;
68       }
69 
70       @Override
71       protected void onPostExecute(Void aVoid) {
72 
73         final BenchmarkMetric.BenchmarkModel benchmarkModel =
74             BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
75         final List<BenchmarkMetric> results = new ArrayList<>();
76         // The list of metrics we have atm includes:
77         // Avg inference latency after N iterations
78         results.add(
79             new BenchmarkMetric(
80                 benchmarkModel,
81                 "avg_inference_latency(ms)",
82                 stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
83                 0.0f));
84         // Model load time
85         results.add(
86             new BenchmarkMetric(
87                 benchmarkModel,
88                 "model_load_time(ms)",
89                 (stats.loadEnd - stats.loadStart) * 1e-6,
90                 0.0f));
91         // Load status
92         results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
93 
94         try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
95           Gson gson = new Gson();
96           writer.write(gson.toJson(results));
97         } catch (IOException e) {
98           e.printStackTrace();
99         }
100       }
101     }.execute();
102   }
103 }
104 
105 class Stats {
106   long loadStart;
107   long loadEnd;
108   List<Double> latency = new ArrayList<>();
109   int errorCode = 0;
110 
111   @Override
toString()112   public String toString() {
113     return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
114   }
115 }
116