1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite.task.vision.segmenter;
17 
18 import android.content.Context;
19 import android.content.res.AssetFileDescriptor;
20 import android.os.ParcelFileDescriptor;
21 import com.google.auto.value.AutoValue;
22 import java.io.File;
23 import java.io.IOException;
24 import java.nio.ByteBuffer;
25 import java.nio.ByteOrder;
26 import java.nio.MappedByteBuffer;
27 import java.util.ArrayList;
28 import java.util.Arrays;
29 import java.util.List;
30 import org.tensorflow.lite.DataType;
31 import org.tensorflow.lite.support.image.TensorImage;
32 import org.tensorflow.lite.task.core.BaseTaskApi;
33 import org.tensorflow.lite.task.core.TaskJniUtils;
34 import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
35 import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
36 
37 /**
38  * Performs segmentation on images.
39  *
40  * <p>The API expects a TFLite model with <a
41  * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
42  *
43  * <p>The API supports models with one image input tensor and one output tensor. To be more
44  * specific, here are the requirements.
45  *
46  * <ul>
47  *   <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
48  *       <ul>
49  *         <li>image input of size {@code [batch x height x width x channels]}.
50  *         <li>batch inference is not supported ({@code batch} is required to be 1).
51  *         <li>only RGB inputs are supported ({@code channels} is required to be 3).
52  *         <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
53  *             to the metadata for input normalization.
54  *       </ul>
55  *   <li>Output image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
56  *       <ul>
57  *         <li>tensor of size {@code [batch x mask_height x mask_width x num_classes]}, where {@code
58  *             batch} is required to be 1, {@code mask_width} and {@code mask_height} are the
59  *             dimensions of the segmentation masks produced by the model, and {@code num_classes}
60  *             is the number of classes supported by the model.
61  *         <li>optional (but recommended) label map(s) can be attached as AssociatedFile-s with type
62  *             TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
63  *             any) is used to fill the class name, i.e. {@link ColoredLabel#getClassName} of the
64  *             results. The display name, i.e. {@link ColoredLabel#getDisplayName}, is filled from
65  *             the AssociatedFile (if any) whose locale matches the `display_names_locale` field of
66  *             the `ImageSegmenterOptions` used at creation time ("en" by default, i.e. English). If
67  *             none of these are available, only the `index` field of the results will be filled.
68  *       </ul>
69  * </ul>
70  *
71  * <p>An example of such model can be found on <a
72  * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>.
73  */
74 public final class ImageSegmenter extends BaseTaskApi {
75 
76   private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
77   private static final int OPTIONAL_FD_LENGTH = -1;
78   private static final int OPTIONAL_FD_OFFSET = -1;
79 
80   private final OutputType outputType;
81 
82   /**
83    * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
84    *
85    * @param modelPath path of the segmentation model with metadata in the assets
86    * @throws IOException if an I/O error occurs when loading the tflite model
87    * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
88    *     code
89    */
createFromFile(Context context, String modelPath)90   public static ImageSegmenter createFromFile(Context context, String modelPath)
91       throws IOException {
92     return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build());
93   }
94 
95   /**
96    * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
97    *
98    * @param modelFile the segmentation model {@link File} instance
99    * @throws IOException if an I/O error occurs when loading the tflite model
100    * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
101    *     code
102    */
createFromFile(File modelFile)103   public static ImageSegmenter createFromFile(File modelFile) throws IOException {
104     return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
105   }
106 
107   /**
108    * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
109    * ImageSegmenterOptions}.
110    *
111    * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
112    *     classification model
113    * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
114    *     code
115    * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
116    *     {@link MappedByteBuffer}
117    */
createFromBuffer(final ByteBuffer modelBuffer)118   public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
119     return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
120   }
121 
122   /**
123    * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
124    *
125    * @param modelPath path of the segmentation model with metadata in the assets
126    * @throws IOException if an I/O error occurs when loading the tflite model
127    * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
128    *     code
129    */
createFromFileAndOptions( Context context, String modelPath, final ImageSegmenterOptions options)130   public static ImageSegmenter createFromFileAndOptions(
131       Context context, String modelPath, final ImageSegmenterOptions options) throws IOException {
132     try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
133       return createFromModelFdAndOptions(
134           /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
135           /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
136           /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
137           options);
138     }
139   }
140 
141   /**
142    * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
143    *
144    * @param modelFile the segmentation model {@link File} instance
145    * @throws IOException if an I/O error occurs when loading the tflite model
146    * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
147    *     code
148    */
createFromFileAndOptions( File modelFile, final ImageSegmenterOptions options)149   public static ImageSegmenter createFromFileAndOptions(
150       File modelFile, final ImageSegmenterOptions options) throws IOException {
151     try (ParcelFileDescriptor descriptor =
152         ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
153       return createFromModelFdAndOptions(
154           /*fileDescriptor=*/ descriptor.getFd(),
155           /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
156           /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
157           options);
158     }
159   }
160 
161   /**
162    * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
163    * ImageSegmenterOptions}.
164    *
165    * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
166    *     classification model
167    * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
168    *     code
169    * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
170    *     {@link MappedByteBuffer}
171    */
createFromBufferAndOptions( final ByteBuffer modelBuffer, final ImageSegmenterOptions options)172   public static ImageSegmenter createFromBufferAndOptions(
173       final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
174     if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
175       throw new IllegalArgumentException(
176           "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
177     }
178     return new ImageSegmenter(
179         TaskJniUtils.createHandleFromLibrary(
180             new EmptyHandleProvider() {
181               @Override
182               public long createHandle() {
183                 return initJniWithByteBuffer(
184                     modelBuffer,
185                     options.getDisplayNamesLocale(),
186                     options.getOutputType().getValue(),
187                     options.getNumThreads());
188               }
189             },
190             IMAGE_SEGMENTER_NATIVE_LIB),
191         options.getOutputType());
192   }
193 
194   /**
195    * Constructor to initialize the JNI with a pointer from C++.
196    *
197    * @param nativeHandle a pointer referencing memory allocated in C++
198    */
199   private ImageSegmenter(long nativeHandle, OutputType outputType) {
200     super(nativeHandle);
201     this.outputType = outputType;
202   }
203 
204   /** Options for setting up an {@link ImageSegmenter}. */
205   @AutoValue
206   public abstract static class ImageSegmenterOptions {
207     private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
208     private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
209     private static final int NUM_THREADS = -1;
210 
211     public abstract String getDisplayNamesLocale();
212 
213     public abstract OutputType getOutputType();
214 
215     public abstract int getNumThreads();
216 
217     public static Builder builder() {
218       return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
219           .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
220           .setOutputType(DEFAULT_OUTPUT_TYPE)
221           .setNumThreads(NUM_THREADS);
222     }
223 
224     /** Builder for {@link ImageSegmenterOptions}. */
225     @AutoValue.Builder
226     public abstract static class Builder {
227 
228       /**
229        * Sets the locale to use for display names specified through the TFLite Model Metadata, if
230        * any.
231        *
232        * <p>Defaults to English({@code "en"}). See the <a
233        * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
234        * Metadata schema file.</a> for the accepted pattern of locale.
235        */
236       public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
237 
238       public abstract Builder setOutputType(OutputType outputType);
239 
240       /**
241        * Sets the number of threads to be used for TFLite ops that support multi-threading when
242        * running inference with CPU. Defaults to -1.
243        *
244        * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
245        * effect to let TFLite runtime set the value.
246        */
247       public abstract Builder setNumThreads(int numThreads);
248 
249       public abstract ImageSegmenterOptions build();
250     }
251   }
252 
253   /**
254    * Performs actual segmentation on the provided image.
255    *
256    * @param image a {@link TensorImage} object that represents an RGB image
257    * @return results of performing image segmentation. Note that at the time, a single {@link
258    *     Segmentation} element is expected to be returned. The result is stored in a {@link List}
259    *     for later extension to e.g. instance segmentation models, which may return one segmentation
260    *     per object.
261    * @throws AssertionError if error occurs when segmenting the image from the native code
262    */
263   public List<Segmentation> segment(TensorImage image) {
264     return segment(image, ImageProcessingOptions.builder().build());
265   }
266 
267   /**
268    * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
269    *
270    * @param image a {@link TensorImage} object that represents an RGB image
271    * @param options {@link ImageSegmenter} only supports image rotation (through {@link
272    *     ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image
273    *     defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}.
274    * @return results of performing image segmentation. Note that at the time, a single {@link
275    *     Segmentation} element is expected to be returned. The result is stored in a {@link List}
276    *     for later extension to e.g. instance segmentation models, which may return one segmentation
277    *     per object.
278    * @throws AssertionError if error occurs when segmenting the image from the native code
279    */
280   public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
281     checkNotClosed();
282 
283     // image_segmenter_jni.cc expects an uint8 image. Convert image of other types into uint8.
284     TensorImage imageUint8 =
285         image.getDataType() == DataType.UINT8
286             ? image
287             : TensorImage.createFrom(image, DataType.UINT8);
288     List<byte[]> maskByteArrays = new ArrayList<>();
289     List<ColoredLabel> coloredLabels = new ArrayList<>();
290     int[] maskShape = new int[2];
291     segmentNative(
292         getNativeHandle(),
293         imageUint8.getBuffer(),
294         imageUint8.getWidth(),
295         imageUint8.getHeight(),
296         maskByteArrays,
297         maskShape,
298         coloredLabels,
299         options.getOrientation().getValue());
300 
301     List<ByteBuffer> maskByteBuffers = new ArrayList<>();
302     for (byte[] bytes : maskByteArrays) {
303       ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
304       // Change the byte order to little_endian, since the buffers were generated in jni.
305       byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
306       maskByteBuffers.add(byteBuffer);
307     }
308 
309     return Arrays.asList(
310         Segmentation.create(
311             outputType,
312             outputType.createMasksFromBuffer(maskByteBuffers, maskShape),
313             coloredLabels));
314   }
315 
316   private static ImageSegmenter createFromModelFdAndOptions(
317       final int fileDescriptor,
318       final long fileDescriptorLength,
319       final long fileDescriptorOffset,
320       final ImageSegmenterOptions options) {
321     long nativeHandle =
322         TaskJniUtils.createHandleFromLibrary(
323             new EmptyHandleProvider() {
324               @Override
325               public long createHandle() {
326                 return initJniWithModelFdAndOptions(
327                     fileDescriptor,
328                     fileDescriptorLength,
329                     fileDescriptorOffset,
330                     options.getDisplayNamesLocale(),
331                     options.getOutputType().getValue(),
332                     options.getNumThreads());
333               }
334             },
335             IMAGE_SEGMENTER_NATIVE_LIB);
336     return new ImageSegmenter(nativeHandle, options.getOutputType());
337   }
338 
339   private static native long initJniWithModelFdAndOptions(
340       int fileDescriptor,
341       long fileDescriptorLength,
342       long fileDescriptorOffset,
343       String displayNamesLocale,
344       int outputType,
345       int numThreads);
346 
347   private static native long initJniWithByteBuffer(
348       ByteBuffer modelBuffer, String displayNamesLocale, int outputType, int numThreads);
349 
350   /**
351    * The native method to segment the image.
352    *
353    * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native
354    * layer.
355    */
356   private static native void segmentNative(
357       long nativeHandle,
358       ByteBuffer image,
359       int width,
360       int height,
361       List<byte[]> maskByteArrays,
362       int[] maskShape,
363       List<ColoredLabel> coloredLabels,
364       int orientation);
365 
366   @Override
367   protected void deinit(long nativeHandle) {
368     deinitJni(nativeHandle);
369   }
370 
371   /**
372    * Native implementation to release memory pointed by the pointer.
373    *
374    * @param nativeHandle pointer to memory allocated
375    */
376   private native void deinitJni(long nativeHandle);
377 }
378