xref: /aosp_15_r20/external/aws-crt-java/src/test/java/software/amazon/awssdk/crt/test/CrtMemoryLeakDetector.java (revision 3c7ae9de214676c52d19f01067dc1a404272dc11)
1 package software.amazon.awssdk.crt.test;
2 
3 import java.util.ArrayList;
4 import java.util.List;
5 import java.util.concurrent.Callable;
6 import java.util.concurrent.CompletableFuture;
7 import java.util.concurrent.ExecutorService;
8 import java.util.concurrent.Executors;
9 
10 import org.junit.Assert;
11 import org.junit.Test;
12 
13 import software.amazon.awssdk.crt.CRT;
14 import software.amazon.awssdk.crt.CrtResource;
15 import software.amazon.awssdk.crt.Log;
16 
17 /**
18  * Checks that the CRT doesn't have any major memory leaks. Probably won't
19  * detect very small leaks but will likely find obvious large ones.
20  */
21 public class CrtMemoryLeakDetector extends CrtTestFixture {
22     static {
CRT()23         new CRT(); // force the CRT to load before doing anything
24     };
25 
26     private final static int DEFAULT_NUM_LEAK_TEST_ITERATIONS = 20;
27 
getNativeMemoryInUse()28     private static long getNativeMemoryInUse() {
29         long nativeMemory = CRT.nativeMemory();
30         Log.log(Log.LogLevel.Trace, Log.LogSubject.JavaCrtGeneral, String.format("Native MemUsage: %d", nativeMemory));
31         return nativeMemory;
32     }
33 
getJvmMemoryInUse()34     private static long getJvmMemoryInUse() {
35 
36         Log.log(Log.LogLevel.Trace, Log.LogSubject.JavaCrtGeneral, "Checking JVM Memory Usage");
37 
38         long estimatedMemInUse = Long.MAX_VALUE;
39 
40         for (int i = 0; i < 10; i++) {
41             // Force a Java Garbage Collection before measuring to reduce noise in
42             // measurement
43             System.gc();
44 
45             // Take the minimum of several measurements to reduce noise
46             estimatedMemInUse = Long.min(estimatedMemInUse,
47                     (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()));
48         }
49 
50         Log.log(Log.LogLevel.Trace, Log.LogSubject.JavaCrtGeneral,
51                 String.format("JVM MemUsage: %d", estimatedMemInUse));
52 
53         return estimatedMemInUse;
54     }
55 
nativeMemoryLeakCheck()56     public static void nativeMemoryLeakCheck() throws Exception {
57         String output = "";
58         long nativeMemory = getNativeMemoryInUse();
59         if (nativeMemory > 0) {
60             Log.initLoggingToFile(Log.LogLevel.Trace, "log.txt");
61 
62             output += "Potential Native Memory Leak!\n";
63             Assert.fail(String.format("%s\nNative Memory remain: %s\n", output, nativeMemory));
64             CRT.dumpNativeMemory();
65         }
66     }
67 
leakCheck(Callable<Void> fn)68     public static void leakCheck(Callable<Void> fn) throws Exception {
69         leakCheck(DEFAULT_NUM_LEAK_TEST_ITERATIONS, expectedFixedGrowth(), fn);
70     }
71 
leakCheck(int numIterations, int maxLeakage, Callable<Void> fn)72     public static void leakCheck(int numIterations, int maxLeakage, Callable<Void> fn) throws Exception {
73         List<Long> jvmSamples = new ArrayList<>();
74         List<Long> nativeSamples = new ArrayList<>();
75         numIterations = Math.max(2, numIterations); // There need to be at least 2 iterations to get deltas
76 
77         getJvmMemoryInUse(); // force a few GCs to get a good baseline
78 
79         for (int i = 0; i < numIterations; i++) {
80             fn.call();
81             jvmSamples.add(getJvmMemoryInUse());
82             nativeSamples.add(getNativeMemoryInUse());
83         }
84 
85         // Get the median deltas
86         List<Long> jvmDeltas = getDeltas(jvmSamples);
87         long medianJvmDelta = jvmDeltas.get(jvmDeltas.size() / 2);
88         List<Long> nativeDeltas = getDeltas(nativeSamples);
89         long medianNativeDelta = nativeDeltas.get(nativeDeltas.size() / 2);
90 
91         String output = "";
92         if (medianJvmDelta > maxLeakage) {
93             output += "Potential Java Memory Leak!\n";
94         }
95         if (medianNativeDelta > maxLeakage) {
96             output += "Potential Native Memory Leak!\n";
97             CRT.dumpNativeMemory();
98         }
99 
100         final List<String> resources = new ArrayList<>();
101         CrtResource.collectNativeResources((resource) -> {
102             resources.add(resource);
103         });
104         if (resources.size() > 0) {
105             output += String.join("\n", resources);
106         }
107 
108         if (output.length() > 0) {
109             Assert.fail(String.format(
110                     "%s\nJVM Usage Deltas: %s\nJVM Samples: %s\nNative Usage Deltas: %s\nNative Samples: %s\n", output,
111                     jvmDeltas.toString(), jvmSamples.toString(), nativeDeltas.toString(), nativeSamples.toString()));
112         }
113     }
114 
getDeltas(List<Long> samples)115     private static List<Long> getDeltas(List<Long> samples) {
116         List<Long> memUseDeltas = new ArrayList<>();
117         for (int i = 0; i < samples.size() - 1; i++) {
118             long prev = samples.get(i);
119             long curr = samples.get(i + 1);
120             long delta = (curr - prev);
121             memUseDeltas.add(delta);
122         }
123 
124         // Sort from smallest to largest
125         memUseDeltas.sort(null);
126         return memUseDeltas;
127     }
128 
129     @Test
testLeakDetectorSerial()130     public void testLeakDetectorSerial() throws Exception {
131         leakCheck(20, 64, () -> {
132             Thread.sleep(1);
133             return null;
134         });
135     }
136 
137     private static int FIXED_EXECUTOR_GROWTH = 0;
138 
expectedFixedGrowth()139     public static int expectedFixedGrowth() {
140         if (FIXED_EXECUTOR_GROWTH == 0) {
141             determineBaselineGrowth();
142         }
143         return FIXED_EXECUTOR_GROWTH;
144     }
145 
determineBaselineGrowth()146     private static void determineBaselineGrowth() {
147 
148         getJvmMemoryInUse(); // force a few GCs to get a good baseline
149 
150         List<Long> jvmSamples = new ArrayList<>();
151 
152         Callable<Void> fn = () -> {
153             final ExecutorService threadPool = Executors.newFixedThreadPool(32);
154             List<CompletableFuture<Void>> futures = new ArrayList<>();
155             for (int idx = 0; idx < DEFAULT_NUM_LEAK_TEST_ITERATIONS; ++idx) {
156                 CompletableFuture<Void> future = new CompletableFuture<>();
157                 futures.add(future);
158                 final int thisIdx = idx;
159                 threadPool.execute(() -> {
160                     try {
161                         Thread.sleep(1);
162                     } catch (Exception ex) {
163                         // no op
164                     } finally {
165                         future.complete(null);
166                     }
167                 });
168             }
169             for (CompletableFuture f : futures) {
170                 f.join();
171             }
172             return null;
173         };
174 
175         for (int i = 0; i < DEFAULT_NUM_LEAK_TEST_ITERATIONS; ++i) {
176             try {
177                 fn.call();
178                 jvmSamples.add(getJvmMemoryInUse());
179             } catch (Exception ex) {
180             }
181         }
182 
183         // Get the median deltas
184         List<Long> jvmDeltas = getDeltas(jvmSamples);
185         long medianJvmDelta = jvmDeltas.get(jvmDeltas.size() / 2);
186         FIXED_EXECUTOR_GROWTH = (int) medianJvmDelta;
187     }
188 
runViaThreadPool(int numThreads)189     private static void runViaThreadPool(int numThreads) throws Exception {
190         final ExecutorService threadPool = Executors.newFixedThreadPool(numThreads);
191         List<CompletableFuture<Void>> futures = new ArrayList<>();
192         leakCheck(20, expectedFixedGrowth(), () -> {
193             for (int idx = 0; idx < 100; ++idx) {
194                 CompletableFuture<Void> future = new CompletableFuture<>();
195                 futures.add(future);
196                 threadPool.execute(() -> {
197                     try {
198                         Thread.sleep(1);
199                     } catch (Exception ex) {
200                         // no op
201                     } finally {
202                         future.complete(null);
203                     }
204                 });
205             }
206 
207             for (CompletableFuture f : futures) {
208                 f.join();
209             }
210             return null;
211         });
212     }
213 
214     @Test
testLeakDetectorParallel_2()215     public void testLeakDetectorParallel_2() throws Exception {
216         runViaThreadPool(2);
217     }
218 
219     @Test
testLeakDetectorParallel_8()220     public void testLeakDetectorParallel_8() throws Exception {
221         runViaThreadPool(8);
222     }
223 
224     @Test
testLeakDetectorParallel_32()225     public void testLeakDetectorParallel_32() throws Exception {
226         runViaThreadPool(32);
227     }
228 }
229