xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/g3doc/examples/style_transfer/overview.ipynb (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1{
2  "cells": [
3    {
4      "cell_type": "markdown",
5      "metadata": {
6        "id": "g_nWetWWd_ns"
7      },
8      "source": [
9        "##### Copyright 2019 The TensorFlow Authors."
10      ]
11    },
12    {
13      "cell_type": "code",
14      "execution_count": null,
15      "metadata": {
16        "cellView": "form",
17        "id": "2pHVBk_seED1"
18      },
19      "outputs": [],
20      "source": [
21        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22        "# you may not use this file except in compliance with the License.\n",
23        "# You may obtain a copy of the License at\n",
24        "#\n",
25        "# https://www.apache.org/licenses/LICENSE-2.0\n",
26        "#\n",
27        "# Unless required by applicable law or agreed to in writing, software\n",
28        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30        "# See the License for the specific language governing permissions and\n",
31        "# limitations under the License."
32      ]
33    },
34    {
35      "cell_type": "markdown",
36      "metadata": {
37        "id": "M7vSdG6sAIQn"
38      },
39      "source": [
40        "# Artistic Style Transfer with TensorFlow Lite"
41      ]
42    },
43    {
44      "cell_type": "markdown",
45      "metadata": {
46        "id": "fwc5GKHBASdc"
47      },
48      "source": [
49        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
50        "  \u003ctd\u003e\n",
51        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/style_transfer/overview\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
52        "  \u003c/td\u003e\n",
53        "  \u003ctd\u003e\n",
54        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/style_transfer/overview.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
55        "  \u003c/td\u003e\n",
56        "  \u003ctd\u003e\n",
57        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/style_transfer/overview.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
58        "  \u003c/td\u003e\n",
59        "  \u003ctd\u003e\n",
60        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/style_transfer/overview.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
61        "  \u003c/td\u003e\n",
62        "  \u003ctd\u003e\n",
63        "    \u003ca href=\"https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
64        "  \u003c/td\u003e\n",
65        "\u003c/table\u003e"
66      ]
67    },
68    {
69      "cell_type": "markdown",
70      "metadata": {
71        "id": "31O0iaROAw8z"
72      },
73      "source": [
74        "One of the most exciting developments in deep learning to come out recently is [artistic style transfer](https://arxiv.org/abs/1508.06576), or the ability to create a new image, known as a [pastiche](https://en.wikipedia.org/wiki/Pastiche), based on two input images: one representing the artistic style and one representing the content.\n",
75        "\n",
76        "![Style transfer example](https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/formula.png)\n",
77        "\n",
78        "Using this technique, we can generate beautiful new artworks in a range of styles.\n",
79        "\n",
80        "![Style transfer example](https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/table.png)\n",
81        "\n",
82        "If you are new to TensorFlow Lite and are working with Android, we\n",
83        "recommend exploring the following example applications that can help you get\n",
84        "started.\n",
85        "\n",
86        "\u003ca class=\"button button-primary\" href=\"https://github.com/tensorflow/examples/tree/master/lite/examples/style_transfer/android\"\u003eAndroid\n",
87        "example\u003c/a\u003e \u003ca class=\"button button-primary\" href=\"https://github.com/tensorflow/examples/tree/master/lite/examples/style_transfer/ios\"\u003eiOS\n",
88        "example\u003c/a\u003e\n",
89        "\n",
90        "If you are using a platform other than Android or iOS, or you are already\n",
91        "familiar with the\n",
92        "\u003ca href=\"https://www.tensorflow.org/api_docs/python/tf/lite\"\u003eTensorFlow Lite\n",
93        "APIs\u003c/a\u003e, you can follow this tutorial to learn how to apply style transfer on any pair of content and style image with a pre-trained TensorFlow Lite model. You can use the model to add style transfer to your own mobile applications.\n",
94        "\n",
95        "The model is open-sourced on [GitHub](https://github.com/tensorflow/magenta/tree/master/magenta/models/arbitrary_image_stylization#train-a-model-on-a-large-dataset-with-data-augmentation-to-run-on-mobile). You can retrain the model with different parameters (e.g. increase content layers' weights to make the output image look more like the content image)."
96      ]
97    },
98    {
99      "cell_type": "markdown",
100      "metadata": {
101        "id": "ak0S4gkOCSxs"
102      },
103      "source": [
104        "## Understand the model architecture"
105      ]
106    },
107    {
108      "cell_type": "markdown",
109      "metadata": {
110        "id": "oee6G_bBCgAM"
111      },
112      "source": [
113        "![Model Architecture](https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/architecture.png)\n",
114        "\n",
115        "This Artistic Style Transfer model consists of two submodels:\n",
116        "1. **Style Prediciton Model**: A MobilenetV2-based neural network that takes an input style image to a 100-dimension style bottleneck vector.\n",
117        "1. **Style Transform Model**: A neural network that takes apply a style bottleneck vector to a content image and creates a stylized image.\n",
118        "\n",
119        "If your app only needs to support a fixed set of style images, you can compute their style bottleneck vectors in advance, and exclude the Style Prediction Model from your app's binary."
120      ]
121    },
122    {
123      "cell_type": "markdown",
124      "metadata": {
125        "id": "a7ZETsRVNMo7"
126      },
127      "source": [
128        "## Setup"
129      ]
130    },
131    {
132      "cell_type": "markdown",
133      "metadata": {
134        "id": "3n8oObKZN4c8"
135      },
136      "source": [
137        "Import dependencies."
138      ]
139    },
140    {
141      "cell_type": "code",
142      "execution_count": null,
143      "metadata": {
144        "id": "xz62Lb1oNm97"
145      },
146      "outputs": [],
147      "source": [
148        "import tensorflow as tf\n",
149        "print(tf.__version__)"
150      ]
151    },
152    {
153      "cell_type": "code",
154      "execution_count": null,
155      "metadata": {
156        "id": "1Ua5FpcJNrIj"
157      },
158      "outputs": [],
159      "source": [
160        "import IPython.display as display\n",
161        "\n",
162        "import matplotlib.pyplot as plt\n",
163        "import matplotlib as mpl\n",
164        "mpl.rcParams['figure.figsize'] = (12,12)\n",
165        "mpl.rcParams['axes.grid'] = False\n",
166        "\n",
167        "import numpy as np\n",
168        "import time\n",
169        "import functools"
170      ]
171    },
172    {
173      "cell_type": "markdown",
174      "metadata": {
175        "id": "1b988wrrQnVF"
176      },
177      "source": [
178        "Download the content and style images, and the pre-trained TensorFlow Lite models."
179      ]
180    },
181    {
182      "cell_type": "code",
183      "execution_count": null,
184      "metadata": {
185        "id": "16g57cIMQnen"
186      },
187      "outputs": [],
188      "source": [
189        "content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')\n",
190        "style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')\n",
191        "\n",
192        "style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')\n",
193        "style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')"
194      ]
195    },
196    {
197      "cell_type": "markdown",
198      "metadata": {
199        "id": "MQZXL7kON-gM"
200      },
201      "source": [
202        "## Pre-process the inputs\n",
203        "\n",
204        "* The content image and the style image must be RGB images with pixel values being float32 numbers between [0..1].\n",
205        "* The style image size must be (1, 256, 256, 3). We central crop the image and resize it.\n",
206        "* The content image must be (1, 384, 384, 3). We central crop the image and resize it."
207      ]
208    },
209    {
210      "cell_type": "code",
211      "execution_count": null,
212      "metadata": {
213        "id": "Cg0Vi-rXRUFl"
214      },
215      "outputs": [],
216      "source": [
217        "# Function to load an image from a file, and add a batch dimension.\n",
218        "def load_img(path_to_img):\n",
219        "  img = tf.io.read_file(path_to_img)\n",
220        "  img = tf.io.decode_image(img, channels=3)\n",
221        "  img = tf.image.convert_image_dtype(img, tf.float32)\n",
222        "  img = img[tf.newaxis, :]\n",
223        "\n",
224        "  return img\n",
225        "\n",
226        "# Function to pre-process by resizing an central cropping it.\n",
227        "def preprocess_image(image, target_dim):\n",
228        "  # Resize the image so that the shorter dimension becomes 256px.\n",
229        "  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)\n",
230        "  short_dim = min(shape)\n",
231        "  scale = target_dim / short_dim\n",
232        "  new_shape = tf.cast(shape * scale, tf.int32)\n",
233        "  image = tf.image.resize(image, new_shape)\n",
234        "\n",
235        "  # Central crop the image.\n",
236        "  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)\n",
237        "\n",
238        "  return image\n",
239        "\n",
240        "# Load the input images.\n",
241        "content_image = load_img(content_path)\n",
242        "style_image = load_img(style_path)\n",
243        "\n",
244        "# Preprocess the input images.\n",
245        "preprocessed_content_image = preprocess_image(content_image, 384)\n",
246        "preprocessed_style_image = preprocess_image(style_image, 256)\n",
247        "\n",
248        "print('Style Image Shape:', preprocessed_style_image.shape)\n",
249        "print('Content Image Shape:', preprocessed_content_image.shape)"
250      ]
251    },
252    {
253      "cell_type": "markdown",
254      "metadata": {
255        "id": "xE4Yt8nArTeR"
256      },
257      "source": [
258        "## Visualize the inputs"
259      ]
260    },
261    {
262      "cell_type": "code",
263      "execution_count": null,
264      "metadata": {
265        "id": "ncPA4esJRcEu"
266      },
267      "outputs": [],
268      "source": [
269        "def imshow(image, title=None):\n",
270        "  if len(image.shape) \u003e 3:\n",
271        "    image = tf.squeeze(image, axis=0)\n",
272        "\n",
273        "  plt.imshow(image)\n",
274        "  if title:\n",
275        "    plt.title(title)\n",
276        "\n",
277        "plt.subplot(1, 2, 1)\n",
278        "imshow(preprocessed_content_image, 'Content Image')\n",
279        "\n",
280        "plt.subplot(1, 2, 2)\n",
281        "imshow(preprocessed_style_image, 'Style Image')"
282      ]
283    },
284    {
285      "cell_type": "markdown",
286      "metadata": {
287        "id": "CJ7R-CHbjC3s"
288      },
289      "source": [
290        "## Run style transfer with TensorFlow Lite"
291      ]
292    },
293    {
294      "cell_type": "markdown",
295      "metadata": {
296        "id": "euu00ldHjKwD"
297      },
298      "source": [
299        "### Style prediction"
300      ]
301    },
302    {
303      "cell_type": "code",
304      "execution_count": null,
305      "metadata": {
306        "id": "o3zd9cTFRiS_"
307      },
308      "outputs": [],
309      "source": [
310        "# Function to run style prediction on preprocessed style image.\n",
311        "def run_style_predict(preprocessed_style_image):\n",
312        "  # Load the model.\n",
313        "  interpreter = tf.lite.Interpreter(model_path=style_predict_path)\n",
314        "\n",
315        "  # Set model input.\n",
316        "  interpreter.allocate_tensors()\n",
317        "  input_details = interpreter.get_input_details()\n",
318        "  interpreter.set_tensor(input_details[0][\"index\"], preprocessed_style_image)\n",
319        "\n",
320        "  # Calculate style bottleneck.\n",
321        "  interpreter.invoke()\n",
322        "  style_bottleneck = interpreter.tensor(\n",
323        "      interpreter.get_output_details()[0][\"index\"]\n",
324        "      )()\n",
325        "\n",
326        "  return style_bottleneck\n",
327        "\n",
328        "# Calculate style bottleneck for the preprocessed style image.\n",
329        "style_bottleneck = run_style_predict(preprocessed_style_image)\n",
330        "print('Style Bottleneck Shape:', style_bottleneck.shape)"
331      ]
332    },
333    {
334      "cell_type": "markdown",
335      "metadata": {
336        "id": "00t8S2PekIyW"
337      },
338      "source": [
339        "### Style transform"
340      ]
341    },
342    {
343      "cell_type": "code",
344      "execution_count": null,
345      "metadata": {
346        "id": "cZp5bCj8SX1w"
347      },
348      "outputs": [],
349      "source": [
350        "# Run style transform on preprocessed style image\n",
351        "def run_style_transform(style_bottleneck, preprocessed_content_image):\n",
352        "  # Load the model.\n",
353        "  interpreter = tf.lite.Interpreter(model_path=style_transform_path)\n",
354        "\n",
355        "  # Set model input.\n",
356        "  input_details = interpreter.get_input_details()\n",
357        "  interpreter.allocate_tensors()\n",
358        "\n",
359        "  # Set model inputs.\n",
360        "  interpreter.set_tensor(input_details[0][\"index\"], preprocessed_content_image)\n",
361        "  interpreter.set_tensor(input_details[1][\"index\"], style_bottleneck)\n",
362        "  interpreter.invoke()\n",
363        "\n",
364        "  # Transform content image.\n",
365        "  stylized_image = interpreter.tensor(\n",
366        "      interpreter.get_output_details()[0][\"index\"]\n",
367        "      )()\n",
368        "\n",
369        "  return stylized_image\n",
370        "\n",
371        "# Stylize the content image using the style bottleneck.\n",
372        "stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)\n",
373        "\n",
374        "# Visualize the output.\n",
375        "imshow(stylized_image, 'Stylized Image')"
376      ]
377    },
378    {
379      "cell_type": "markdown",
380      "metadata": {
381        "id": "vv_71Td-QtrW"
382      },
383      "source": [
384        "### Style blending\n",
385        "\n",
386        "We can blend the style of content image into the stylized output, which in turn making the output look more like the content image."
387      ]
388    },
389    {
390      "cell_type": "code",
391      "execution_count": null,
392      "metadata": {
393        "id": "eJcAURXQQtJ7"
394      },
395      "outputs": [],
396      "source": [
397        "# Calculate style bottleneck of the content image.\n",
398        "style_bottleneck_content = run_style_predict(\n",
399        "    preprocess_image(content_image, 256)\n",
400        "    )"
401      ]
402    },
403    {
404      "cell_type": "code",
405      "execution_count": null,
406      "metadata": {
407        "id": "4S3yg2MgkmRD"
408      },
409      "outputs": [],
410      "source": [
411        "# Define content blending ratio between [0..1].\n",
412        "# 0.0: 0% style extracts from content image.\n",
413        "# 1.0: 100% style extracted from content image.\n",
414        "content_blending_ratio = 0.5 #@param {type:\"slider\", min:0, max:1, step:0.01}\n",
415        "\n",
416        "# Blend the style bottleneck of style image and content image\n",
417        "style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \\\n",
418        "                           + (1 - content_blending_ratio) * style_bottleneck\n",
419        "\n",
420        "# Stylize the content image using the style bottleneck.\n",
421        "stylized_image_blended = run_style_transform(style_bottleneck_blended,\n",
422        "                                             preprocessed_content_image)\n",
423        "\n",
424        "# Visualize the output.\n",
425        "imshow(stylized_image_blended, 'Blended Stylized Image')"
426      ]
427    },
428    {
429      "cell_type": "markdown",
430      "metadata": {
431        "id": "9k9jGIep8p1c"
432      },
433      "source": [
434        "## Performance Benchmarks\n",
435        "\n",
436        "Performance benchmark numbers are generated with the tool [described here](https://www.tensorflow.org/lite/performance/benchmarks).\n",
437        "\u003ctable \u003e\u003cthead\u003e\u003ctr\u003e\u003cth\u003eModel name\u003c/th\u003e \u003cth\u003eModel size\u003c/th\u003e  \u003cth\u003eDevice \u003c/th\u003e \u003cth\u003eNNAPI\u003c/th\u003e \u003cth\u003eCPU\u003c/th\u003e \u003cth\u003eGPU\u003c/th\u003e\u003c/tr\u003e \u003c/thead\u003e \n",
438        "\u003ctr\u003e \u003ctd rowspan = 3\u003e \u003ca href=\"https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite\"\u003eStyle prediction model (int8)\u003c/a\u003e \u003c/td\u003e \n",
439        "\u003ctd rowspan = 3\u003e2.8 Mb\u003c/td\u003e\n",
440        "\u003ctd\u003ePixel 3 (Android 10) \u003c/td\u003e \u003ctd\u003e142ms\u003c/td\u003e\u003ctd\u003e14ms*\u003c/td\u003e\u003ctd\u003e\u003c/td\u003e\u003c/tr\u003e\n",
441        "\u003ctr\u003e\u003ctd\u003ePixel 4 (Android 10) \u003c/td\u003e \u003ctd\u003e5.2ms\u003c/td\u003e\u003ctd\u003e6.7ms*\u003c/td\u003e\u003ctd\u003e\u003c/td\u003e\u003c/tr\u003e\n",
442        "\u003ctr\u003e\u003ctd\u003eiPhone XS (iOS 12.4.1) \u003c/td\u003e \u003ctd\u003e\u003c/td\u003e\u003ctd\u003e10.7ms**\u003c/td\u003e\u003ctd\u003e\u003c/td\u003e\u003c/tr\u003e\n",
443        "\u003ctr\u003e \u003ctd rowspan = 3\u003e \u003ca href=\"https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite\"\u003eStyle transform model (int8)\u003c/a\u003e \u003c/td\u003e \n",
444        "\u003ctd rowspan = 3\u003e0.2 Mb\u003c/td\u003e\n",
445        "\u003ctd\u003ePixel 3 (Android 10) \u003c/td\u003e \u003ctd\u003e\u003c/td\u003e\u003ctd\u003e540ms*\u003c/td\u003e\u003ctd\u003e\u003c/td\u003e\u003c/tr\u003e\n",
446        "\u003ctr\u003e\u003ctd\u003ePixel 4 (Android 10) \u003c/td\u003e \u003ctd\u003e\u003c/td\u003e\u003ctd\u003e405ms*\u003c/td\u003e\u003ctd\u003e\u003c/td\u003e\u003c/tr\u003e\n",
447        "\u003ctr\u003e\u003ctd\u003eiPhone XS (iOS 12.4.1) \u003c/td\u003e \u003ctd\u003e\u003c/td\u003e\u003ctd\u003e251ms**\u003c/td\u003e\u003ctd\u003e\u003c/td\u003e\u003c/tr\u003e\n",
448        "\n",
449        "\u003ctr\u003e \u003ctd rowspan = 2\u003e \u003ca href=\"https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/fp16/prediction/1?lite-format=tflite\"\u003eStyle prediction model (float16)\u003c/a\u003e \u003c/td\u003e \n",
450        "\u003ctd rowspan = 2\u003e4.7 Mb\u003c/td\u003e\n",
451        "\u003ctd\u003ePixel 3 (Android 10) \u003c/td\u003e \u003ctd\u003e86ms\u003c/td\u003e\u003ctd\u003e28ms*\u003c/td\u003e\u003ctd\u003e9.1ms\u003c/td\u003e\u003c/tr\u003e\n",
452        "\u003ctr\u003e\u003ctd\u003ePixel 4 (Android 10) \u003c/td\u003e\u003ctd\u003e32ms\u003c/td\u003e\u003ctd\u003e12ms*\u003c/td\u003e\u003ctd\u003e10ms\u003c/td\u003e\u003c/tr\u003e\n",
453        "\n",
454        "\u003ctr\u003e \u003ctd rowspan = 2\u003e \u003ca href=\"https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/fp16/transfer/1?lite-format=tflite\"\u003eStyle transfer model (float16)\u003c/a\u003e \u003c/td\u003e \n",
455        "\u003ctd rowspan = 2\u003e0.4 Mb\u003c/td\u003e\n",
456        "\u003ctd\u003ePixel 3 (Android 10) \u003c/td\u003e \u003ctd\u003e1095ms\u003c/td\u003e\u003ctd\u003e545ms*\u003c/td\u003e\u003ctd\u003e42ms\u003c/td\u003e\u003c/tr\u003e\n",
457        "\u003ctr\u003e\u003ctd\u003ePixel 4 (Android 10) \u003c/td\u003e\u003ctd\u003e603ms\u003c/td\u003e\u003ctd\u003e377ms*\u003c/td\u003e\u003ctd\u003e42ms\u003c/td\u003e\u003c/tr\u003e\n",
458        "\n",
459        "\u003c/table\u003e\n",
460        "\n",
461        "*\u0026ast; 4 threads used. \u003cbr/\u003e*\n",
462        "*\u0026ast;\u0026ast; 2 threads on iPhone for the best performance.*\n"
463      ]
464    }
465  ],
466  "metadata": {
467    "colab": {
468      "collapsed_sections": [],
469      "name": "overview.ipynb",
470      "provenance": [],
471      "toc_visible": true
472    },
473    "kernelspec": {
474      "display_name": "Python 3",
475      "name": "python3"
476    }
477  },
478  "nbformat": 4,
479  "nbformat_minor": 0
480}
481