1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite.task.vision.segmenter; 17 18 import android.content.Context; 19 import android.content.res.AssetFileDescriptor; 20 import android.os.ParcelFileDescriptor; 21 import com.google.auto.value.AutoValue; 22 import java.io.File; 23 import java.io.IOException; 24 import java.nio.ByteBuffer; 25 import java.nio.ByteOrder; 26 import java.nio.MappedByteBuffer; 27 import java.util.ArrayList; 28 import java.util.Arrays; 29 import java.util.List; 30 import org.tensorflow.lite.DataType; 31 import org.tensorflow.lite.support.image.TensorImage; 32 import org.tensorflow.lite.task.core.BaseTaskApi; 33 import org.tensorflow.lite.task.core.TaskJniUtils; 34 import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; 35 import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; 36 37 /** 38 * Performs segmentation on images. 39 * 40 * <p>The API expects a TFLite model with <a 41 * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. 42 * 43 * <p>The API supports models with one image input tensor and one output tensor. To be more 44 * specific, here are the requirements. 45 * 46 * <ul> 47 * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) 48 * <ul> 49 * <li>image input of size {@code [batch x height x width x channels]}. 50 * <li>batch inference is not supported ({@code batch} is required to be 1). 51 * <li>only RGB inputs are supported ({@code channels} is required to be 3). 52 * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached 53 * to the metadata for input normalization. 54 * </ul> 55 * <li>Output image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) 56 * <ul> 57 * <li>tensor of size {@code [batch x mask_height x mask_width x num_classes]}, where {@code 58 * batch} is required to be 1, {@code mask_width} and {@code mask_height} are the 59 * dimensions of the segmentation masks produced by the model, and {@code num_classes} 60 * is the number of classes supported by the model. 61 * <li>optional (but recommended) label map(s) can be attached as AssociatedFile-s with type 62 * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if 63 * any) is used to fill the class name, i.e. {@link ColoredLabel#getClassName} of the 64 * results. The display name, i.e. {@link ColoredLabel#getDisplayName}, is filled from 65 * the AssociatedFile (if any) whose locale matches the `display_names_locale` field of 66 * the `ImageSegmenterOptions` used at creation time ("en" by default, i.e. English). If 67 * none of these are available, only the `index` field of the results will be filled. 68 * </ul> 69 * </ul> 70 * 71 * <p>An example of such model can be found on <a 72 * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>. 73 */ 74 public final class ImageSegmenter extends BaseTaskApi { 75 76 private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni"; 77 private static final int OPTIONAL_FD_LENGTH = -1; 78 private static final int OPTIONAL_FD_OFFSET = -1; 79 80 private final OutputType outputType; 81 82 /** 83 * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. 84 * 85 * @param modelPath path of the segmentation model with metadata in the assets 86 * @throws IOException if an I/O error occurs when loading the tflite model 87 * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native 88 * code 89 */ createFromFile(Context context, String modelPath)90 public static ImageSegmenter createFromFile(Context context, String modelPath) 91 throws IOException { 92 return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build()); 93 } 94 95 /** 96 * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. 97 * 98 * @param modelFile the segmentation model {@link File} instance 99 * @throws IOException if an I/O error occurs when loading the tflite model 100 * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native 101 * code 102 */ createFromFile(File modelFile)103 public static ImageSegmenter createFromFile(File modelFile) throws IOException { 104 return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build()); 105 } 106 107 /** 108 * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link 109 * ImageSegmenterOptions}. 110 * 111 * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the 112 * classification model 113 * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native 114 * code 115 * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a 116 * {@link MappedByteBuffer} 117 */ createFromBuffer(final ByteBuffer modelBuffer)118 public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) { 119 return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build()); 120 } 121 122 /** 123 * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. 124 * 125 * @param modelPath path of the segmentation model with metadata in the assets 126 * @throws IOException if an I/O error occurs when loading the tflite model 127 * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native 128 * code 129 */ createFromFileAndOptions( Context context, String modelPath, final ImageSegmenterOptions options)130 public static ImageSegmenter createFromFileAndOptions( 131 Context context, String modelPath, final ImageSegmenterOptions options) throws IOException { 132 try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { 133 return createFromModelFdAndOptions( 134 /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), 135 /*fileDescriptorLength=*/ assetFileDescriptor.getLength(), 136 /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), 137 options); 138 } 139 } 140 141 /** 142 * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. 143 * 144 * @param modelFile the segmentation model {@link File} instance 145 * @throws IOException if an I/O error occurs when loading the tflite model 146 * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native 147 * code 148 */ createFromFileAndOptions( File modelFile, final ImageSegmenterOptions options)149 public static ImageSegmenter createFromFileAndOptions( 150 File modelFile, final ImageSegmenterOptions options) throws IOException { 151 try (ParcelFileDescriptor descriptor = 152 ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { 153 return createFromModelFdAndOptions( 154 /*fileDescriptor=*/ descriptor.getFd(), 155 /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, 156 /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, 157 options); 158 } 159 } 160 161 /** 162 * Creates an {@link ImageSegmenter} instance with a model buffer and {@link 163 * ImageSegmenterOptions}. 164 * 165 * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the 166 * classification model 167 * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native 168 * code 169 * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a 170 * {@link MappedByteBuffer} 171 */ createFromBufferAndOptions( final ByteBuffer modelBuffer, final ImageSegmenterOptions options)172 public static ImageSegmenter createFromBufferAndOptions( 173 final ByteBuffer modelBuffer, final ImageSegmenterOptions options) { 174 if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { 175 throw new IllegalArgumentException( 176 "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); 177 } 178 return new ImageSegmenter( 179 TaskJniUtils.createHandleFromLibrary( 180 new EmptyHandleProvider() { 181 @Override 182 public long createHandle() { 183 return initJniWithByteBuffer( 184 modelBuffer, 185 options.getDisplayNamesLocale(), 186 options.getOutputType().getValue(), 187 options.getNumThreads()); 188 } 189 }, 190 IMAGE_SEGMENTER_NATIVE_LIB), 191 options.getOutputType()); 192 } 193 194 /** 195 * Constructor to initialize the JNI with a pointer from C++. 196 * 197 * @param nativeHandle a pointer referencing memory allocated in C++ 198 */ 199 private ImageSegmenter(long nativeHandle, OutputType outputType) { 200 super(nativeHandle); 201 this.outputType = outputType; 202 } 203 204 /** Options for setting up an {@link ImageSegmenter}. */ 205 @AutoValue 206 public abstract static class ImageSegmenterOptions { 207 private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en"; 208 private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK; 209 private static final int NUM_THREADS = -1; 210 211 public abstract String getDisplayNamesLocale(); 212 213 public abstract OutputType getOutputType(); 214 215 public abstract int getNumThreads(); 216 217 public static Builder builder() { 218 return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() 219 .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE) 220 .setOutputType(DEFAULT_OUTPUT_TYPE) 221 .setNumThreads(NUM_THREADS); 222 } 223 224 /** Builder for {@link ImageSegmenterOptions}. */ 225 @AutoValue.Builder 226 public abstract static class Builder { 227 228 /** 229 * Sets the locale to use for display names specified through the TFLite Model Metadata, if 230 * any. 231 * 232 * <p>Defaults to English({@code "en"}). See the <a 233 * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite 234 * Metadata schema file.</a> for the accepted pattern of locale. 235 */ 236 public abstract Builder setDisplayNamesLocale(String displayNamesLocale); 237 238 public abstract Builder setOutputType(OutputType outputType); 239 240 /** 241 * Sets the number of threads to be used for TFLite ops that support multi-threading when 242 * running inference with CPU. Defaults to -1. 243 * 244 * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the 245 * effect to let TFLite runtime set the value. 246 */ 247 public abstract Builder setNumThreads(int numThreads); 248 249 public abstract ImageSegmenterOptions build(); 250 } 251 } 252 253 /** 254 * Performs actual segmentation on the provided image. 255 * 256 * @param image a {@link TensorImage} object that represents an RGB image 257 * @return results of performing image segmentation. Note that at the time, a single {@link 258 * Segmentation} element is expected to be returned. The result is stored in a {@link List} 259 * for later extension to e.g. instance segmentation models, which may return one segmentation 260 * per object. 261 * @throws AssertionError if error occurs when segmenting the image from the native code 262 */ 263 public List<Segmentation> segment(TensorImage image) { 264 return segment(image, ImageProcessingOptions.builder().build()); 265 } 266 267 /** 268 * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}. 269 * 270 * @param image a {@link TensorImage} object that represents an RGB image 271 * @param options {@link ImageSegmenter} only supports image rotation (through {@link 272 * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image 273 * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. 274 * @return results of performing image segmentation. Note that at the time, a single {@link 275 * Segmentation} element is expected to be returned. The result is stored in a {@link List} 276 * for later extension to e.g. instance segmentation models, which may return one segmentation 277 * per object. 278 * @throws AssertionError if error occurs when segmenting the image from the native code 279 */ 280 public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) { 281 checkNotClosed(); 282 283 // image_segmenter_jni.cc expects an uint8 image. Convert image of other types into uint8. 284 TensorImage imageUint8 = 285 image.getDataType() == DataType.UINT8 286 ? image 287 : TensorImage.createFrom(image, DataType.UINT8); 288 List<byte[]> maskByteArrays = new ArrayList<>(); 289 List<ColoredLabel> coloredLabels = new ArrayList<>(); 290 int[] maskShape = new int[2]; 291 segmentNative( 292 getNativeHandle(), 293 imageUint8.getBuffer(), 294 imageUint8.getWidth(), 295 imageUint8.getHeight(), 296 maskByteArrays, 297 maskShape, 298 coloredLabels, 299 options.getOrientation().getValue()); 300 301 List<ByteBuffer> maskByteBuffers = new ArrayList<>(); 302 for (byte[] bytes : maskByteArrays) { 303 ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); 304 // Change the byte order to little_endian, since the buffers were generated in jni. 305 byteBuffer.order(ByteOrder.LITTLE_ENDIAN); 306 maskByteBuffers.add(byteBuffer); 307 } 308 309 return Arrays.asList( 310 Segmentation.create( 311 outputType, 312 outputType.createMasksFromBuffer(maskByteBuffers, maskShape), 313 coloredLabels)); 314 } 315 316 private static ImageSegmenter createFromModelFdAndOptions( 317 final int fileDescriptor, 318 final long fileDescriptorLength, 319 final long fileDescriptorOffset, 320 final ImageSegmenterOptions options) { 321 long nativeHandle = 322 TaskJniUtils.createHandleFromLibrary( 323 new EmptyHandleProvider() { 324 @Override 325 public long createHandle() { 326 return initJniWithModelFdAndOptions( 327 fileDescriptor, 328 fileDescriptorLength, 329 fileDescriptorOffset, 330 options.getDisplayNamesLocale(), 331 options.getOutputType().getValue(), 332 options.getNumThreads()); 333 } 334 }, 335 IMAGE_SEGMENTER_NATIVE_LIB); 336 return new ImageSegmenter(nativeHandle, options.getOutputType()); 337 } 338 339 private static native long initJniWithModelFdAndOptions( 340 int fileDescriptor, 341 long fileDescriptorLength, 342 long fileDescriptorOffset, 343 String displayNamesLocale, 344 int outputType, 345 int numThreads); 346 347 private static native long initJniWithByteBuffer( 348 ByteBuffer modelBuffer, String displayNamesLocale, int outputType, int numThreads); 349 350 /** 351 * The native method to segment the image. 352 * 353 * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native 354 * layer. 355 */ 356 private static native void segmentNative( 357 long nativeHandle, 358 ByteBuffer image, 359 int width, 360 int height, 361 List<byte[]> maskByteArrays, 362 int[] maskShape, 363 List<ColoredLabel> coloredLabels, 364 int orientation); 365 366 @Override 367 protected void deinit(long nativeHandle) { 368 deinitJni(nativeHandle); 369 } 370 371 /** 372 * Native implementation to release memory pointed by the pointer. 373 * 374 * @param nativeHandle pointer to memory allocated 375 */ 376 private native void deinitJni(long nativeHandle); 377 } 378