1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.demo; 17 18 import android.content.res.AssetManager; 19 import android.graphics.Bitmap; 20 import android.graphics.RectF; 21 import android.os.Trace; 22 import java.util.ArrayList; 23 import java.util.Comparator; 24 import java.util.List; 25 import java.util.PriorityQueue; 26 import org.tensorflow.contrib.android.TensorFlowInferenceInterface; 27 import org.tensorflow.demo.env.Logger; 28 import org.tensorflow.demo.env.SplitTimer; 29 30 /** An object detector that uses TF and a YOLO model to detect objects. */ 31 public class TensorFlowYoloDetector implements Classifier { 32 private static final Logger LOGGER = new Logger(); 33 34 // Only return this many results with at least this confidence. 35 private static final int MAX_RESULTS = 5; 36 37 private static final int NUM_CLASSES = 20; 38 39 private static final int NUM_BOXES_PER_BLOCK = 5; 40 41 // TODO(andrewharp): allow loading anchors and classes 42 // from files. 43 private static final double[] ANCHORS = { 44 1.08, 1.19, 45 3.42, 4.41, 46 6.63, 11.38, 47 9.42, 5.11, 48 16.62, 10.52 49 }; 50 51 private static final String[] LABELS = { 52 "aeroplane", 53 "bicycle", 54 "bird", 55 "boat", 56 "bottle", 57 "bus", 58 "car", 59 "cat", 60 "chair", 61 "cow", 62 "diningtable", 63 "dog", 64 "horse", 65 "motorbike", 66 "person", 67 "pottedplant", 68 "sheep", 69 "sofa", 70 "train", 71 "tvmonitor" 72 }; 73 74 // Config values. 75 private String inputName; 76 private int inputSize; 77 78 // Pre-allocated buffers. 79 private int[] intValues; 80 private float[] floatValues; 81 private String[] outputNames; 82 83 private int blockSize; 84 85 private boolean logStats = false; 86 87 private TensorFlowInferenceInterface inferenceInterface; 88 89 /** Initializes a native TensorFlow session for classifying images. */ create( final AssetManager assetManager, final String modelFilename, final int inputSize, final String inputName, final String outputName, final int blockSize)90 public static Classifier create( 91 final AssetManager assetManager, 92 final String modelFilename, 93 final int inputSize, 94 final String inputName, 95 final String outputName, 96 final int blockSize) { 97 TensorFlowYoloDetector d = new TensorFlowYoloDetector(); 98 d.inputName = inputName; 99 d.inputSize = inputSize; 100 101 // Pre-allocate buffers. 102 d.outputNames = outputName.split(","); 103 d.intValues = new int[inputSize * inputSize]; 104 d.floatValues = new float[inputSize * inputSize * 3]; 105 d.blockSize = blockSize; 106 107 d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); 108 109 return d; 110 } 111 TensorFlowYoloDetector()112 private TensorFlowYoloDetector() {} 113 expit(final float x)114 private float expit(final float x) { 115 return (float) (1. / (1. + Math.exp(-x))); 116 } 117 softmax(final float[] vals)118 private void softmax(final float[] vals) { 119 float max = Float.NEGATIVE_INFINITY; 120 for (final float val : vals) { 121 max = Math.max(max, val); 122 } 123 float sum = 0.0f; 124 for (int i = 0; i < vals.length; ++i) { 125 vals[i] = (float) Math.exp(vals[i] - max); 126 sum += vals[i]; 127 } 128 for (int i = 0; i < vals.length; ++i) { 129 vals[i] = vals[i] / sum; 130 } 131 } 132 133 @Override recognizeImage(final Bitmap bitmap)134 public List<Recognition> recognizeImage(final Bitmap bitmap) { 135 final SplitTimer timer = new SplitTimer("recognizeImage"); 136 137 // Log this method so that it can be analyzed with systrace. 138 Trace.beginSection("recognizeImage"); 139 140 Trace.beginSection("preprocessBitmap"); 141 // Preprocess the image data from 0-255 int to normalized float based 142 // on the provided parameters. 143 bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 144 145 for (int i = 0; i < intValues.length; ++i) { 146 floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f; 147 floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f; 148 floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f; 149 } 150 Trace.endSection(); // preprocessBitmap 151 152 // Copy the input data into TensorFlow. 153 Trace.beginSection("feed"); 154 inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); 155 Trace.endSection(); 156 157 timer.endSplit("ready for inference"); 158 159 // Run the inference call. 160 Trace.beginSection("run"); 161 inferenceInterface.run(outputNames, logStats); 162 Trace.endSection(); 163 164 timer.endSplit("ran inference"); 165 166 // Copy the output Tensor back into the output array. 167 Trace.beginSection("fetch"); 168 final int gridWidth = bitmap.getWidth() / blockSize; 169 final int gridHeight = bitmap.getHeight() / blockSize; 170 final float[] output = 171 new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK]; 172 inferenceInterface.fetch(outputNames[0], output); 173 Trace.endSection(); 174 175 // Find the best detections. 176 final PriorityQueue<Recognition> pq = 177 new PriorityQueue<Recognition>( 178 1, 179 new Comparator<Recognition>() { 180 @Override 181 public int compare(final Recognition lhs, final Recognition rhs) { 182 // Intentionally reversed to put high confidence at the head of the queue. 183 return Float.compare(rhs.getConfidence(), lhs.getConfidence()); 184 } 185 }); 186 187 for (int y = 0; y < gridHeight; ++y) { 188 for (int x = 0; x < gridWidth; ++x) { 189 for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) { 190 final int offset = 191 (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y 192 + (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x 193 + (NUM_CLASSES + 5) * b; 194 195 final float xPos = (x + expit(output[offset + 0])) * blockSize; 196 final float yPos = (y + expit(output[offset + 1])) * blockSize; 197 198 final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * blockSize; 199 final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * blockSize; 200 201 final RectF rect = 202 new RectF( 203 Math.max(0, xPos - w / 2), 204 Math.max(0, yPos - h / 2), 205 Math.min(bitmap.getWidth() - 1, xPos + w / 2), 206 Math.min(bitmap.getHeight() - 1, yPos + h / 2)); 207 final float confidence = expit(output[offset + 4]); 208 209 int detectedClass = -1; 210 float maxClass = 0; 211 212 final float[] classes = new float[NUM_CLASSES]; 213 for (int c = 0; c < NUM_CLASSES; ++c) { 214 classes[c] = output[offset + 5 + c]; 215 } 216 softmax(classes); 217 218 for (int c = 0; c < NUM_CLASSES; ++c) { 219 if (classes[c] > maxClass) { 220 detectedClass = c; 221 maxClass = classes[c]; 222 } 223 } 224 225 final float confidenceInClass = maxClass * confidence; 226 if (confidenceInClass > 0.01) { 227 LOGGER.i( 228 "%s (%d) %f %s", LABELS[detectedClass], detectedClass, confidenceInClass, rect); 229 pq.add(new Recognition("" + offset, LABELS[detectedClass], confidenceInClass, rect)); 230 } 231 } 232 } 233 } 234 timer.endSplit("decoded results"); 235 236 final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); 237 for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { 238 recognitions.add(pq.poll()); 239 } 240 Trace.endSection(); // "recognizeImage" 241 242 timer.endSplit("processed results"); 243 244 return recognitions; 245 } 246 247 @Override enableStatLogging(final boolean logStats)248 public void enableStatLogging(final boolean logStats) { 249 this.logStats = logStats; 250 } 251 252 @Override getStatString()253 public String getStatString() { 254 return inferenceInterface.getStatString(); 255 } 256 257 @Override close()258 public void close() { 259 inferenceInterface.close(); 260 } 261 } 262