1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.demo;
17 
18 import android.content.res.AssetManager;
19 import android.graphics.Bitmap;
20 import android.graphics.RectF;
21 import android.os.Trace;
22 import java.util.ArrayList;
23 import java.util.Comparator;
24 import java.util.List;
25 import java.util.PriorityQueue;
26 import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
27 import org.tensorflow.demo.env.Logger;
28 import org.tensorflow.demo.env.SplitTimer;
29 
30 /** An object detector that uses TF and a YOLO model to detect objects. */
31 public class TensorFlowYoloDetector implements Classifier {
32   private static final Logger LOGGER = new Logger();
33 
34   // Only return this many results with at least this confidence.
35   private static final int MAX_RESULTS = 5;
36 
37   private static final int NUM_CLASSES = 20;
38 
39   private static final int NUM_BOXES_PER_BLOCK = 5;
40 
41   // TODO(andrewharp): allow loading anchors and classes
42   // from files.
43   private static final double[] ANCHORS = {
44     1.08, 1.19,
45     3.42, 4.41,
46     6.63, 11.38,
47     9.42, 5.11,
48     16.62, 10.52
49   };
50 
51   private static final String[] LABELS = {
52     "aeroplane",
53     "bicycle",
54     "bird",
55     "boat",
56     "bottle",
57     "bus",
58     "car",
59     "cat",
60     "chair",
61     "cow",
62     "diningtable",
63     "dog",
64     "horse",
65     "motorbike",
66     "person",
67     "pottedplant",
68     "sheep",
69     "sofa",
70     "train",
71     "tvmonitor"
72   };
73 
74   // Config values.
75   private String inputName;
76   private int inputSize;
77 
78   // Pre-allocated buffers.
79   private int[] intValues;
80   private float[] floatValues;
81   private String[] outputNames;
82 
83   private int blockSize;
84 
85   private boolean logStats = false;
86 
87   private TensorFlowInferenceInterface inferenceInterface;
88 
89   /** Initializes a native TensorFlow session for classifying images. */
create( final AssetManager assetManager, final String modelFilename, final int inputSize, final String inputName, final String outputName, final int blockSize)90   public static Classifier create(
91       final AssetManager assetManager,
92       final String modelFilename,
93       final int inputSize,
94       final String inputName,
95       final String outputName,
96       final int blockSize) {
97     TensorFlowYoloDetector d = new TensorFlowYoloDetector();
98     d.inputName = inputName;
99     d.inputSize = inputSize;
100 
101     // Pre-allocate buffers.
102     d.outputNames = outputName.split(",");
103     d.intValues = new int[inputSize * inputSize];
104     d.floatValues = new float[inputSize * inputSize * 3];
105     d.blockSize = blockSize;
106 
107     d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
108 
109     return d;
110   }
111 
TensorFlowYoloDetector()112   private TensorFlowYoloDetector() {}
113 
expit(final float x)114   private float expit(final float x) {
115     return (float) (1. / (1. + Math.exp(-x)));
116   }
117 
softmax(final float[] vals)118   private void softmax(final float[] vals) {
119     float max = Float.NEGATIVE_INFINITY;
120     for (final float val : vals) {
121       max = Math.max(max, val);
122     }
123     float sum = 0.0f;
124     for (int i = 0; i < vals.length; ++i) {
125       vals[i] = (float) Math.exp(vals[i] - max);
126       sum += vals[i];
127     }
128     for (int i = 0; i < vals.length; ++i) {
129       vals[i] = vals[i] / sum;
130     }
131   }
132 
133   @Override
recognizeImage(final Bitmap bitmap)134   public List<Recognition> recognizeImage(final Bitmap bitmap) {
135     final SplitTimer timer = new SplitTimer("recognizeImage");
136 
137     // Log this method so that it can be analyzed with systrace.
138     Trace.beginSection("recognizeImage");
139 
140     Trace.beginSection("preprocessBitmap");
141     // Preprocess the image data from 0-255 int to normalized float based
142     // on the provided parameters.
143     bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
144 
145     for (int i = 0; i < intValues.length; ++i) {
146       floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f;
147       floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f;
148       floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f;
149     }
150     Trace.endSection(); // preprocessBitmap
151 
152     // Copy the input data into TensorFlow.
153     Trace.beginSection("feed");
154     inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
155     Trace.endSection();
156 
157     timer.endSplit("ready for inference");
158 
159     // Run the inference call.
160     Trace.beginSection("run");
161     inferenceInterface.run(outputNames, logStats);
162     Trace.endSection();
163 
164     timer.endSplit("ran inference");
165 
166     // Copy the output Tensor back into the output array.
167     Trace.beginSection("fetch");
168     final int gridWidth = bitmap.getWidth() / blockSize;
169     final int gridHeight = bitmap.getHeight() / blockSize;
170     final float[] output =
171         new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];
172     inferenceInterface.fetch(outputNames[0], output);
173     Trace.endSection();
174 
175     // Find the best detections.
176     final PriorityQueue<Recognition> pq =
177         new PriorityQueue<Recognition>(
178             1,
179             new Comparator<Recognition>() {
180               @Override
181               public int compare(final Recognition lhs, final Recognition rhs) {
182                 // Intentionally reversed to put high confidence at the head of the queue.
183                 return Float.compare(rhs.getConfidence(), lhs.getConfidence());
184               }
185             });
186 
187     for (int y = 0; y < gridHeight; ++y) {
188       for (int x = 0; x < gridWidth; ++x) {
189         for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
190           final int offset =
191               (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y
192                   + (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x
193                   + (NUM_CLASSES + 5) * b;
194 
195           final float xPos = (x + expit(output[offset + 0])) * blockSize;
196           final float yPos = (y + expit(output[offset + 1])) * blockSize;
197 
198           final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * blockSize;
199           final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * blockSize;
200 
201           final RectF rect =
202               new RectF(
203                   Math.max(0, xPos - w / 2),
204                   Math.max(0, yPos - h / 2),
205                   Math.min(bitmap.getWidth() - 1, xPos + w / 2),
206                   Math.min(bitmap.getHeight() - 1, yPos + h / 2));
207           final float confidence = expit(output[offset + 4]);
208 
209           int detectedClass = -1;
210           float maxClass = 0;
211 
212           final float[] classes = new float[NUM_CLASSES];
213           for (int c = 0; c < NUM_CLASSES; ++c) {
214             classes[c] = output[offset + 5 + c];
215           }
216           softmax(classes);
217 
218           for (int c = 0; c < NUM_CLASSES; ++c) {
219             if (classes[c] > maxClass) {
220               detectedClass = c;
221               maxClass = classes[c];
222             }
223           }
224 
225           final float confidenceInClass = maxClass * confidence;
226           if (confidenceInClass > 0.01) {
227             LOGGER.i(
228                 "%s (%d) %f %s", LABELS[detectedClass], detectedClass, confidenceInClass, rect);
229             pq.add(new Recognition("" + offset, LABELS[detectedClass], confidenceInClass, rect));
230           }
231         }
232       }
233     }
234     timer.endSplit("decoded results");
235 
236     final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
237     for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
238       recognitions.add(pq.poll());
239     }
240     Trace.endSection(); // "recognizeImage"
241 
242     timer.endSplit("processed results");
243 
244     return recognitions;
245   }
246 
247   @Override
enableStatLogging(final boolean logStats)248   public void enableStatLogging(final boolean logStats) {
249     this.logStats = logStats;
250   }
251 
252   @Override
getStatString()253   public String getStatString() {
254     return inferenceInterface.getStatString();
255   }
256 
257   @Override
close()258   public void close() {
259     inferenceInterface.close();
260   }
261 }
262