1 package software.amazon.awssdk.crt.test; 2 3 import java.util.ArrayList; 4 import java.util.List; 5 import java.util.concurrent.Callable; 6 import java.util.concurrent.CompletableFuture; 7 import java.util.concurrent.ExecutorService; 8 import java.util.concurrent.Executors; 9 10 import org.junit.Assert; 11 import org.junit.Test; 12 13 import software.amazon.awssdk.crt.CRT; 14 import software.amazon.awssdk.crt.CrtResource; 15 import software.amazon.awssdk.crt.Log; 16 17 /** 18 * Checks that the CRT doesn't have any major memory leaks. Probably won't 19 * detect very small leaks but will likely find obvious large ones. 20 */ 21 public class CrtMemoryLeakDetector extends CrtTestFixture { 22 static { CRT()23 new CRT(); // force the CRT to load before doing anything 24 }; 25 26 private final static int DEFAULT_NUM_LEAK_TEST_ITERATIONS = 20; 27 getNativeMemoryInUse()28 private static long getNativeMemoryInUse() { 29 long nativeMemory = CRT.nativeMemory(); 30 Log.log(Log.LogLevel.Trace, Log.LogSubject.JavaCrtGeneral, String.format("Native MemUsage: %d", nativeMemory)); 31 return nativeMemory; 32 } 33 getJvmMemoryInUse()34 private static long getJvmMemoryInUse() { 35 36 Log.log(Log.LogLevel.Trace, Log.LogSubject.JavaCrtGeneral, "Checking JVM Memory Usage"); 37 38 long estimatedMemInUse = Long.MAX_VALUE; 39 40 for (int i = 0; i < 10; i++) { 41 // Force a Java Garbage Collection before measuring to reduce noise in 42 // measurement 43 System.gc(); 44 45 // Take the minimum of several measurements to reduce noise 46 estimatedMemInUse = Long.min(estimatedMemInUse, 47 (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())); 48 } 49 50 Log.log(Log.LogLevel.Trace, Log.LogSubject.JavaCrtGeneral, 51 String.format("JVM MemUsage: %d", estimatedMemInUse)); 52 53 return estimatedMemInUse; 54 } 55 nativeMemoryLeakCheck()56 public static void nativeMemoryLeakCheck() throws Exception { 57 String output = ""; 58 long nativeMemory = getNativeMemoryInUse(); 59 if (nativeMemory > 0) { 60 Log.initLoggingToFile(Log.LogLevel.Trace, "log.txt"); 61 62 output += "Potential Native Memory Leak!\n"; 63 Assert.fail(String.format("%s\nNative Memory remain: %s\n", output, nativeMemory)); 64 CRT.dumpNativeMemory(); 65 } 66 } 67 leakCheck(Callable<Void> fn)68 public static void leakCheck(Callable<Void> fn) throws Exception { 69 leakCheck(DEFAULT_NUM_LEAK_TEST_ITERATIONS, expectedFixedGrowth(), fn); 70 } 71 leakCheck(int numIterations, int maxLeakage, Callable<Void> fn)72 public static void leakCheck(int numIterations, int maxLeakage, Callable<Void> fn) throws Exception { 73 List<Long> jvmSamples = new ArrayList<>(); 74 List<Long> nativeSamples = new ArrayList<>(); 75 numIterations = Math.max(2, numIterations); // There need to be at least 2 iterations to get deltas 76 77 getJvmMemoryInUse(); // force a few GCs to get a good baseline 78 79 for (int i = 0; i < numIterations; i++) { 80 fn.call(); 81 jvmSamples.add(getJvmMemoryInUse()); 82 nativeSamples.add(getNativeMemoryInUse()); 83 } 84 85 // Get the median deltas 86 List<Long> jvmDeltas = getDeltas(jvmSamples); 87 long medianJvmDelta = jvmDeltas.get(jvmDeltas.size() / 2); 88 List<Long> nativeDeltas = getDeltas(nativeSamples); 89 long medianNativeDelta = nativeDeltas.get(nativeDeltas.size() / 2); 90 91 String output = ""; 92 if (medianJvmDelta > maxLeakage) { 93 output += "Potential Java Memory Leak!\n"; 94 } 95 if (medianNativeDelta > maxLeakage) { 96 output += "Potential Native Memory Leak!\n"; 97 CRT.dumpNativeMemory(); 98 } 99 100 final List<String> resources = new ArrayList<>(); 101 CrtResource.collectNativeResources((resource) -> { 102 resources.add(resource); 103 }); 104 if (resources.size() > 0) { 105 output += String.join("\n", resources); 106 } 107 108 if (output.length() > 0) { 109 Assert.fail(String.format( 110 "%s\nJVM Usage Deltas: %s\nJVM Samples: %s\nNative Usage Deltas: %s\nNative Samples: %s\n", output, 111 jvmDeltas.toString(), jvmSamples.toString(), nativeDeltas.toString(), nativeSamples.toString())); 112 } 113 } 114 getDeltas(List<Long> samples)115 private static List<Long> getDeltas(List<Long> samples) { 116 List<Long> memUseDeltas = new ArrayList<>(); 117 for (int i = 0; i < samples.size() - 1; i++) { 118 long prev = samples.get(i); 119 long curr = samples.get(i + 1); 120 long delta = (curr - prev); 121 memUseDeltas.add(delta); 122 } 123 124 // Sort from smallest to largest 125 memUseDeltas.sort(null); 126 return memUseDeltas; 127 } 128 129 @Test testLeakDetectorSerial()130 public void testLeakDetectorSerial() throws Exception { 131 leakCheck(20, 64, () -> { 132 Thread.sleep(1); 133 return null; 134 }); 135 } 136 137 private static int FIXED_EXECUTOR_GROWTH = 0; 138 expectedFixedGrowth()139 public static int expectedFixedGrowth() { 140 if (FIXED_EXECUTOR_GROWTH == 0) { 141 determineBaselineGrowth(); 142 } 143 return FIXED_EXECUTOR_GROWTH; 144 } 145 determineBaselineGrowth()146 private static void determineBaselineGrowth() { 147 148 getJvmMemoryInUse(); // force a few GCs to get a good baseline 149 150 List<Long> jvmSamples = new ArrayList<>(); 151 152 Callable<Void> fn = () -> { 153 final ExecutorService threadPool = Executors.newFixedThreadPool(32); 154 List<CompletableFuture<Void>> futures = new ArrayList<>(); 155 for (int idx = 0; idx < DEFAULT_NUM_LEAK_TEST_ITERATIONS; ++idx) { 156 CompletableFuture<Void> future = new CompletableFuture<>(); 157 futures.add(future); 158 final int thisIdx = idx; 159 threadPool.execute(() -> { 160 try { 161 Thread.sleep(1); 162 } catch (Exception ex) { 163 // no op 164 } finally { 165 future.complete(null); 166 } 167 }); 168 } 169 for (CompletableFuture f : futures) { 170 f.join(); 171 } 172 return null; 173 }; 174 175 for (int i = 0; i < DEFAULT_NUM_LEAK_TEST_ITERATIONS; ++i) { 176 try { 177 fn.call(); 178 jvmSamples.add(getJvmMemoryInUse()); 179 } catch (Exception ex) { 180 } 181 } 182 183 // Get the median deltas 184 List<Long> jvmDeltas = getDeltas(jvmSamples); 185 long medianJvmDelta = jvmDeltas.get(jvmDeltas.size() / 2); 186 FIXED_EXECUTOR_GROWTH = (int) medianJvmDelta; 187 } 188 runViaThreadPool(int numThreads)189 private static void runViaThreadPool(int numThreads) throws Exception { 190 final ExecutorService threadPool = Executors.newFixedThreadPool(numThreads); 191 List<CompletableFuture<Void>> futures = new ArrayList<>(); 192 leakCheck(20, expectedFixedGrowth(), () -> { 193 for (int idx = 0; idx < 100; ++idx) { 194 CompletableFuture<Void> future = new CompletableFuture<>(); 195 futures.add(future); 196 threadPool.execute(() -> { 197 try { 198 Thread.sleep(1); 199 } catch (Exception ex) { 200 // no op 201 } finally { 202 future.complete(null); 203 } 204 }); 205 } 206 207 for (CompletableFuture f : futures) { 208 f.join(); 209 } 210 return null; 211 }); 212 } 213 214 @Test testLeakDetectorParallel_2()215 public void testLeakDetectorParallel_2() throws Exception { 216 runViaThreadPool(2); 217 } 218 219 @Test testLeakDetectorParallel_8()220 public void testLeakDetectorParallel_8() throws Exception { 221 runViaThreadPool(8); 222 } 223 224 @Test testLeakDetectorParallel_32()225 public void testLeakDetectorParallel_32() throws Exception { 226 runViaThreadPool(32); 227 } 228 } 229