1{ 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "metadata": { 6 "id": "Mq-riZs-TJGt" 7 }, 8 "source": [ 9 "##### Copyright 2021 The TensorFlow Authors." 10 ] 11 }, 12 { 13 "cell_type": "code", 14 "execution_count": null, 15 "metadata": { 16 "id": "LEvnopDoTC4M" 17 }, 18 "outputs": [], 19 "source": [ 20 "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", 21 "# you may not use this file except in compliance with the License.\n", 22 "# You may obtain a copy of the License at\n", 23 "#\n", 24 "# https://www.apache.org/licenses/LICENSE-2.0\n", 25 "#\n", 26 "# Unless required by applicable law or agreed to in writing, software\n", 27 "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 28 "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 29 "# See the License for the specific language governing permissions and\n", 30 "# limitations under the License." 31 ] 32 }, 33 { 34 "cell_type": "markdown", 35 "metadata": { 36 "id": "QSRG6qmtTRSk" 37 }, 38 "source": [ 39 "# TensorFlow Lite Metadata Writer API\n", 40 "\n" 41 ] 42 }, 43 { 44 "cell_type": "markdown", 45 "metadata": { 46 "id": "JlzjEt4Txr0x" 47 }, 48 "source": [ 49 "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", 50 " \u003ctd\u003e\n", 51 " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/models/convert/metadata_writer_tutorial\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", 52 " \u003c/td\u003e\n", 53 " \u003ctd\u003e\n", 54 " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models/convert/metadata_writer_tutorial.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", 55 " \u003c/td\u003e\n", 56 " \u003ctd\u003e\n", 57 " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models/convert/metadata_writer_tutorial.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", 58 " \u003c/td\u003e\n", 59 " \u003ctd\u003e\n", 60 " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/models/convert/metadata_writer_tutorial.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", 61 " \u003c/td\u003e\n", 62 "\n", 63 "\u003c/table\u003e" 64 ] 65 }, 66 { 67 "cell_type": "markdown", 68 "metadata": { 69 "id": "b0gwEhfRYat6" 70 }, 71 "source": [ 72 "[TensorFlow Lite Model Metadata](https://www.tensorflow.org/lite/models/convert/metadata) is a standard model description format. It contains rich semantics for general model information, inputs/outputs, and associated files, which makes the model self-descriptive and exchangeable.\n", 73 "\n", 74 "Model Metadata is currently used in the following two primary use cases:\n", 75 "1. **Enable easy model inference using TensorFlow Lite [Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview) and [codegen tools](https://www.tensorflow.org/lite/inference_with_metadata/codegen)**. Model Metadata contains the mandatory information required during inference, such as label files in image classification, sampling rate of the audio input in audio classification, and tokenizer type to process input string in Natural Language models.\n", 76 "\n", 77 "2. **Enable model creators to include documentation**, such as description of model inputs/outputs or how to use the model. Model users can view these documentation via visualization tools such as [Netron](https://netron.app/).\n", 78 "\n", 79 "TensorFlow Lite Metadata Writer API provides an easy-to-use API to create Model Metadata for popular ML tasks supported by the TFLite Task Library. This notebook shows examples on how the metadata should be populated for the following tasks below:\n", 80 "\n", 81 "* [Image classifiers](#image_classifiers)\n", 82 "* [Object detectors](#object_detectors)\n", 83 "* [Image segmenters](#image_segmenters)\n", 84 "* [Natural language classifiers](#nl_classifiers)\n", 85 "* [Audio classifiers](#audio_classifiers)\n", 86 "\n", 87 "Metadata writers for BERT natural language classifiers and BERT question answerers are coming soon.\n", 88 "\n", 89 "If you want to add metadata for use cases that are not supported, please use the [Flatbuffers Python API](https://www.tensorflow.org/lite/models/convert/metadata#adding_metadata). See the tutorials [here](https://www.tensorflow.org/lite/models/convert/metadata#adding_metadata).\n" 90 ] 91 }, 92 { 93 "cell_type": "markdown", 94 "metadata": { 95 "id": "GVRIGdA4T6tO" 96 }, 97 "source": [ 98 "## Prerequisites" 99 ] 100 }, 101 { 102 "cell_type": "markdown", 103 "metadata": { 104 "id": "bVTD2KSyotBK" 105 }, 106 "source": [ 107 "Install the TensorFlow Lite Support Pypi package." 108 ] 109 }, 110 { 111 "cell_type": "code", 112 "execution_count": null, 113 "metadata": { 114 "id": "m-8xSrSvUg-6" 115 }, 116 "outputs": [], 117 "source": [ 118 "!pip install tflite-support-nightly" 119 ] 120 }, 121 { 122 "cell_type": "markdown", 123 "metadata": { 124 "id": "hyYS87Odpxef" 125 }, 126 "source": [ 127 "## Create Model Metadata for Task Library and Codegen" 128 ] 129 }, 130 { 131 "cell_type": "markdown", 132 "metadata": { 133 "id": "uLxv541TqTim" 134 }, 135 "source": [ 136 "\u003ca name=image_classifiers\u003e\u003c/a\u003e\n", 137 "### Image classifiers" 138 ] 139 }, 140 { 141 "cell_type": "markdown", 142 "metadata": { 143 "id": "s41TjCGlsyEF" 144 }, 145 "source": [ 146 "See the [image classifier model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements) for more details about the supported model format." 147 ] 148 }, 149 { 150 "cell_type": "markdown", 151 "metadata": { 152 "id": "_KsPKmg8T9-8" 153 }, 154 "source": [ 155 "Step 1: Import the required packages." 156 ] 157 }, 158 { 159 "cell_type": "code", 160 "execution_count": null, 161 "metadata": { 162 "id": "hhgNqEtWrwB3" 163 }, 164 "outputs": [], 165 "source": [ 166 "from tflite_support.metadata_writers import image_classifier\n", 167 "from tflite_support.metadata_writers import writer_utils" 168 ] 169 }, 170 { 171 "cell_type": "markdown", 172 "metadata": { 173 "id": "o9WBgiFdsiIQ" 174 }, 175 "source": [ 176 "Step 2: Download the example image classifier, [mobilenet_v2_1.0_224.tflite](https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite), and the [label file](https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt)." 177 ] 178 }, 179 { 180 "cell_type": "code", 181 "execution_count": null, 182 "metadata": { 183 "id": "6WgSBbNet-Tt" 184 }, 185 "outputs": [], 186 "source": [ 187 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite\n", 188 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt" 189 ] 190 }, 191 { 192 "cell_type": "markdown", 193 "metadata": { 194 "id": "ALtlz7woweHe" 195 }, 196 "source": [ 197 "Step 3: Create metadata writer and populate." 198 ] 199 }, 200 { 201 "cell_type": "code", 202 "execution_count": null, 203 "metadata": { 204 "id": "_SMEBBt2r-W6" 205 }, 206 "outputs": [], 207 "source": [ 208 "ImageClassifierWriter = image_classifier.MetadataWriter\n", 209 "_MODEL_PATH = \"mobilenet_v2_1.0_224.tflite\"\n", 210 "# Task Library expects label files that are in the same format as the one below.\n", 211 "_LABEL_FILE = \"mobilenet_labels.txt\"\n", 212 "_SAVE_TO_PATH = \"mobilenet_v2_1.0_224_metadata.tflite\"\n", 213 "# Normalization parameters is required when reprocessing the image. It is\n", 214 "# optional if the image pixel values are in range of [0, 255] and the input\n", 215 "# tensor is quantized to uint8. See the introduction for normalization and\n", 216 "# quantization parameters below for more details.\n", 217 "# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)\n", 218 "_INPUT_NORM_MEAN = 127.5\n", 219 "_INPUT_NORM_STD = 127.5\n", 220 "\n", 221 "# Create the metadata writer.\n", 222 "writer = ImageClassifierWriter.create_for_inference(\n", 223 " writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],\n", 224 " [_LABEL_FILE])\n", 225 "\n", 226 "# Verify the metadata generated by metadata writer.\n", 227 "print(writer.get_metadata_json())\n", 228 "\n", 229 "# Populate the metadata into the model.\n", 230 "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)" 231 ] 232 }, 233 { 234 "cell_type": "markdown", 235 "metadata": { 236 "id": "GhhTDkr-uf0n" 237 }, 238 "source": [ 239 "\u003ca name=object_detectors\u003e\u003c/a\u003e\n", 240 "### Object detectors" 241 ] 242 }, 243 { 244 "cell_type": "markdown", 245 "metadata": { 246 "id": "EL9GssnTuf0n" 247 }, 248 "source": [ 249 "See the [object detector model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector#model_compatibility_requirements) for more details about the supported model format." 250 ] 251 }, 252 { 253 "cell_type": "markdown", 254 "metadata": { 255 "id": "r-HUTEtHuf0n" 256 }, 257 "source": [ 258 "Step 1: Import the required packages." 259 ] 260 }, 261 { 262 "cell_type": "code", 263 "execution_count": null, 264 "metadata": { 265 "id": "2_NIROeouf0o" 266 }, 267 "outputs": [], 268 "source": [ 269 "from tflite_support.metadata_writers import object_detector\n", 270 "from tflite_support.metadata_writers import writer_utils" 271 ] 272 }, 273 { 274 "cell_type": "markdown", 275 "metadata": { 276 "id": "UM6jijiUuf0o" 277 }, 278 "source": [ 279 "Step 2: Download the example object detector, [ssd_mobilenet_v1.tflite](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite), and the [label file](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt)." 280 ] 281 }, 282 { 283 "cell_type": "code", 284 "execution_count": null, 285 "metadata": { 286 "id": "4i_BBfGzuf0o" 287 }, 288 "outputs": [], 289 "source": [ 290 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite -o ssd_mobilenet_v1.tflite\n", 291 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt -o ssd_mobilenet_labels.txt" 292 ] 293 }, 294 { 295 "cell_type": "markdown", 296 "metadata": { 297 "id": "DG9T3eSDwsnd" 298 }, 299 "source": [ 300 "Step 3: Create metadata writer and populate." 301 ] 302 }, 303 { 304 "cell_type": "code", 305 "execution_count": null, 306 "metadata": { 307 "id": "vMGGeJfCuf0p" 308 }, 309 "outputs": [], 310 "source": [ 311 "ObjectDetectorWriter = object_detector.MetadataWriter\n", 312 "_MODEL_PATH = \"ssd_mobilenet_v1.tflite\"\n", 313 "# Task Library expects label files that are in the same format as the one below.\n", 314 "_LABEL_FILE = \"ssd_mobilenet_labels.txt\"\n", 315 "_SAVE_TO_PATH = \"ssd_mobilenet_v1_metadata.tflite\"\n", 316 "# Normalization parameters is required when reprocessing the image. It is\n", 317 "# optional if the image pixel values are in range of [0, 255] and the input\n", 318 "# tensor is quantized to uint8. See the introduction for normalization and\n", 319 "# quantization parameters below for more details.\n", 320 "# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)\n", 321 "_INPUT_NORM_MEAN = 127.5\n", 322 "_INPUT_NORM_STD = 127.5\n", 323 "\n", 324 "# Create the metadata writer.\n", 325 "writer = ObjectDetectorWriter.create_for_inference(\n", 326 " writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],\n", 327 " [_LABEL_FILE])\n", 328 "\n", 329 "# Verify the metadata generated by metadata writer.\n", 330 "print(writer.get_metadata_json())\n", 331 "\n", 332 "# Populate the metadata into the model.\n", 333 "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)" 334 ] 335 }, 336 { 337 "cell_type": "markdown", 338 "metadata": { 339 "id": "QT0Oa0SU6uGS" 340 }, 341 "source": [ 342 "\u003ca name=image_segmenters\u003e\u003c/a\u003e\n", 343 "### Image segmenters" 344 ] 345 }, 346 { 347 "cell_type": "markdown", 348 "metadata": { 349 "id": "XaFQmg-S6uGW" 350 }, 351 "source": [ 352 "See the [image segmenter model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_segmenter#model_compatibility_requirements) for more details about the supported model format." 353 ] 354 }, 355 { 356 "cell_type": "markdown", 357 "metadata": { 358 "id": "DiktANhj6uGX" 359 }, 360 "source": [ 361 "Step 1: Import the required packages." 362 ] 363 }, 364 { 365 "cell_type": "code", 366 "execution_count": null, 367 "metadata": { 368 "id": "H6Lrw3op6uGX" 369 }, 370 "outputs": [], 371 "source": [ 372 "from tflite_support.metadata_writers import image_segmenter\n", 373 "from tflite_support.metadata_writers import writer_utils" 374 ] 375 }, 376 { 377 "cell_type": "markdown", 378 "metadata": { 379 "id": "9EFs8Oyi6uGX" 380 }, 381 "source": [ 382 "Step 2: Download the example image segmenter, [deeplabv3.tflite](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite), and the [label file](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt)." 383 ] 384 }, 385 { 386 "cell_type": "code", 387 "execution_count": null, 388 "metadata": { 389 "id": "feQDH0bN6uGY" 390 }, 391 "outputs": [], 392 "source": [ 393 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite -o deeplabv3.tflite\n", 394 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt -o deeplabv3_labels.txt" 395 ] 396 }, 397 { 398 "cell_type": "markdown", 399 "metadata": { 400 "id": "8LhiAbJM6uGY" 401 }, 402 "source": [ 403 "Step 3: Create metadata writer and populate." 404 ] 405 }, 406 { 407 "cell_type": "code", 408 "execution_count": null, 409 "metadata": { 410 "id": "yot8xLI46uGY" 411 }, 412 "outputs": [], 413 "source": [ 414 "ImageSegmenterWriter = image_segmenter.MetadataWriter\n", 415 "_MODEL_PATH = \"deeplabv3.tflite\"\n", 416 "# Task Library expects label files that are in the same format as the one below.\n", 417 "_LABEL_FILE = \"deeplabv3_labels.txt\"\n", 418 "_SAVE_TO_PATH = \"deeplabv3_metadata.tflite\"\n", 419 "# Normalization parameters is required when reprocessing the image. It is\n", 420 "# optional if the image pixel values are in range of [0, 255] and the input\n", 421 "# tensor is quantized to uint8. See the introduction for normalization and\n", 422 "# quantization parameters below for more details.\n", 423 "# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)\n", 424 "_INPUT_NORM_MEAN = 127.5\n", 425 "_INPUT_NORM_STD = 127.5\n", 426 "\n", 427 "# Create the metadata writer.\n", 428 "writer = ImageSegmenterWriter.create_for_inference(\n", 429 " writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],\n", 430 " [_LABEL_FILE])\n", 431 "\n", 432 "# Verify the metadata generated by metadata writer.\n", 433 "print(writer.get_metadata_json())\n", 434 "\n", 435 "# Populate the metadata into the model.\n", 436 "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)" 437 ] 438 }, 439 { 440 "cell_type": "markdown", 441 "metadata": { 442 "id": "NnvM80e7AG-h" 443 }, 444 "source": [ 445 "\u003ca name=nl_classifiers\u003e\u003c/a\u003e\n", 446 "###Natural language classifiers" 447 ] 448 }, 449 { 450 "cell_type": "markdown", 451 "metadata": { 452 "id": "dfOPhFwOAG-k" 453 }, 454 "source": [ 455 "See the [natural language classifier model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier#model_compatibility_requirements) for more details about the supported model format." 456 ] 457 }, 458 { 459 "cell_type": "markdown", 460 "metadata": { 461 "id": "WMJ7tvuwAG-k" 462 }, 463 "source": [ 464 "Step 1: Import the required packages." 465 ] 466 }, 467 { 468 "cell_type": "code", 469 "execution_count": null, 470 "metadata": { 471 "id": "_FGVyb2iAG-k" 472 }, 473 "outputs": [], 474 "source": [ 475 "from tflite_support.metadata_writers import nl_classifier\n", 476 "from tflite_support.metadata_writers import metadata_info\n", 477 "from tflite_support.metadata_writers import writer_utils" 478 ] 479 }, 480 { 481 "cell_type": "markdown", 482 "metadata": { 483 "id": "iIg7rATpAG-l" 484 }, 485 "source": [ 486 "Step 2: Download the example natural language classifier, [movie_review.tflite](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite), the [label file](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt), and the [vocab file](https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/nl_classifier/vocab.txt)." 487 ] 488 }, 489 { 490 "cell_type": "code", 491 "execution_count": null, 492 "metadata": { 493 "id": "TzuQcti2AG-l" 494 }, 495 "outputs": [], 496 "source": [ 497 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite -o movie_review.tflite\n", 498 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt -o movie_review_labels.txt\n", 499 "!curl -L https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/nl_classifier/vocab.txt -o movie_review_vocab.txt" 500 ] 501 }, 502 { 503 "cell_type": "markdown", 504 "metadata": { 505 "id": "BWxUtHdeAG-m" 506 }, 507 "source": [ 508 "Step 3: Create metadata writer and populate." 509 ] 510 }, 511 { 512 "cell_type": "code", 513 "execution_count": null, 514 "metadata": { 515 "id": "NGPWzRuHAG-m" 516 }, 517 "outputs": [], 518 "source": [ 519 "NLClassifierWriter = nl_classifier.MetadataWriter\n", 520 "_MODEL_PATH = \"movie_review.tflite\"\n", 521 "# Task Library expects label files and vocab files that are in the same formats\n", 522 "# as the ones below.\n", 523 "_LABEL_FILE = \"movie_review_labels.txt\"\n", 524 "_VOCAB_FILE = \"movie_review_vocab.txt\"\n", 525 "# NLClassifier supports tokenize input string using the regex tokenizer. See\n", 526 "# more details about how to set up RegexTokenizer below:\n", 527 "# https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py#L130\n", 528 "_DELIM_REGEX_PATTERN = r\"[^\\w\\']+\"\n", 529 "_SAVE_TO_PATH = \"moview_review_metadata.tflite\"\n", 530 "\n", 531 "# Create the metadata writer.\n", 532 "writer = nl_classifier.MetadataWriter.create_for_inference(\n", 533 " writer_utils.load_file(_MODEL_PATH),\n", 534 " metadata_info.RegexTokenizerMd(_DELIM_REGEX_PATTERN, _VOCAB_FILE),\n", 535 " [_LABEL_FILE])\n", 536 "\n", 537 "# Verify the metadata generated by metadata writer.\n", 538 "print(writer.get_metadata_json())\n", 539 "\n", 540 "# Populate the metadata into the model.\n", 541 "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)" 542 ] 543 }, 544 { 545 "cell_type": "markdown", 546 "metadata": { 547 "id": "qv0WDnzW711f" 548 }, 549 "source": [ 550 "\u003ca name=audio_classifiers\u003e\u003c/a\u003e\n", 551 "### Audio classifiers" 552 ] 553 }, 554 { 555 "cell_type": "markdown", 556 "metadata": { 557 "id": "xqP7X8jww8pL" 558 }, 559 "source": [ 560 "See the [audio classifier model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/audio_classifier#model_compatibility_requirements) for more details about the supported model format." 561 ] 562 }, 563 { 564 "cell_type": "markdown", 565 "metadata": { 566 "id": "7RToKepxw8pL" 567 }, 568 "source": [ 569 "Step 1: Import the required packages." 570 ] 571 }, 572 { 573 "cell_type": "code", 574 "execution_count": null, 575 "metadata": { 576 "id": "JjddvTXKw8pL" 577 }, 578 "outputs": [], 579 "source": [ 580 "from tflite_support.metadata_writers import audio_classifier\n", 581 "from tflite_support.metadata_writers import metadata_info\n", 582 "from tflite_support.metadata_writers import writer_utils" 583 ] 584 }, 585 { 586 "cell_type": "markdown", 587 "metadata": { 588 "id": "ar418rH6w8pL" 589 }, 590 "source": [ 591 "Step 2: Download the example audio classifier, [yamnet.tflite](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite), and the [label file](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt)." 592 ] 593 }, 594 { 595 "cell_type": "code", 596 "execution_count": null, 597 "metadata": { 598 "id": "5eQY6znmw8pM" 599 }, 600 "outputs": [], 601 "source": [ 602 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite -o yamnet.tflite\n", 603 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt -o yamnet_labels.txt\n" 604 ] 605 }, 606 { 607 "cell_type": "markdown", 608 "metadata": { 609 "id": "1TYP5w0Ew8pM" 610 }, 611 "source": [ 612 "Step 3: Create metadata writer and populate." 613 ] 614 }, 615 { 616 "cell_type": "code", 617 "execution_count": null, 618 "metadata": { 619 "id": "MDlSczBQw8pM" 620 }, 621 "outputs": [], 622 "source": [ 623 "AudioClassifierWriter = audio_classifier.MetadataWriter\n", 624 "_MODEL_PATH = \"yamnet.tflite\"\n", 625 "# Task Library expects label files that are in the same format as the one below.\n", 626 "_LABEL_FILE = \"yamnet_labels.txt\"\n", 627 "# Expected sampling rate of the input audio buffer.\n", 628 "_SAMPLE_RATE = 16000\n", 629 "# Expected number of channels of the input audio buffer. Note, Task library only\n", 630 "# support single channel so far.\n", 631 "_CHANNELS = 1\n", 632 "_SAVE_TO_PATH = \"yamnet_metadata.tflite\"\n", 633 "\n", 634 "# Create the metadata writer.\n", 635 "writer = AudioClassifierWriter.create_for_inference(\n", 636 " writer_utils.load_file(_MODEL_PATH), _SAMPLE_RATE, _CHANNELS, [_LABEL_FILE])\n", 637 "\n", 638 "# Verify the metadata generated by metadata writer.\n", 639 "print(writer.get_metadata_json())\n", 640 "\n", 641 "# Populate the metadata into the model.\n", 642 "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)" 643 ] 644 }, 645 { 646 "cell_type": "markdown", 647 "metadata": { 648 "id": "YoRLs84yNAJR" 649 }, 650 "source": [ 651 "## Create Model Metadata with semantic information" 652 ] 653 }, 654 { 655 "cell_type": "markdown", 656 "metadata": { 657 "id": "cxXsOBknOGJ2" 658 }, 659 "source": [ 660 "You can fill in more descriptive information about the model and each tensor through the Metadata Writer API to help improve model understanding. It can be done through the 'create_from_metadata_info' method in each metadata writer. In general, you can fill in data through the parameters of 'create_from_metadata_info', i.e. `general_md`, `input_md`, and `output_md`. See the example below to create a rich Model Metadata for image classifers." 661 ] 662 }, 663 { 664 "cell_type": "markdown", 665 "metadata": { 666 "id": "Q-LW6nrcQ9lv" 667 }, 668 "source": [ 669 "Step 1: Import the required packages." 670 ] 671 }, 672 { 673 "cell_type": "code", 674 "execution_count": null, 675 "metadata": { 676 "id": "KsL_egYcRGw3" 677 }, 678 "outputs": [], 679 "source": [ 680 "from tflite_support.metadata_writers import image_classifier\n", 681 "from tflite_support.metadata_writers import metadata_info\n", 682 "from tflite_support.metadata_writers import writer_utils\n", 683 "from tflite_support import metadata_schema_py_generated as _metadata_fb" 684 ] 685 }, 686 { 687 "cell_type": "markdown", 688 "metadata": { 689 "id": "0UWck_8uRboF" 690 }, 691 "source": [ 692 "Step 2: Download the example image classifier, [mobilenet_v2_1.0_224.tflite](https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite), and the [label file](https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt)." 693 ] 694 }, 695 { 696 "cell_type": "code", 697 "execution_count": null, 698 "metadata": { 699 "id": "TqJ-jh-PRVdk" 700 }, 701 "outputs": [], 702 "source": [ 703 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite\n", 704 "!curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt" 705 ] 706 }, 707 { 708 "cell_type": "markdown", 709 "metadata": { 710 "id": "r4I5wJMQRxzb" 711 }, 712 "source": [ 713 "Step 3: Create model and tensor information." 714 ] 715 }, 716 { 717 "cell_type": "code", 718 "execution_count": null, 719 "metadata": { 720 "id": "urd7HDuaR_HC" 721 }, 722 "outputs": [], 723 "source": [ 724 "model_buffer = writer_utils.load_file(\"mobilenet_v2_1.0_224.tflite\")\n", 725 "\n", 726 "# Create general model information.\n", 727 "general_md = metadata_info.GeneralMd(\n", 728 " name=\"ImageClassifier\",\n", 729 " version=\"v1\",\n", 730 " description=(\"Identify the most prominent object in the image from a \"\n", 731 " \"known set of categories.\"),\n", 732 " author=\"TensorFlow Lite\",\n", 733 " licenses=\"Apache License. Version 2.0\")\n", 734 "\n", 735 "# Create input tensor information.\n", 736 "input_md = metadata_info.InputImageTensorMd(\n", 737 " name=\"input image\",\n", 738 " description=(\"Input image to be classified. The expected image is \"\n", 739 " \"128 x 128, with three channels (red, blue, and green) per \"\n", 740 " \"pixel. Each element in the tensor is a value between min and \"\n", 741 " \"max, where (per-channel) min is [0] and max is [255].\"),\n", 742 " norm_mean=[127.5],\n", 743 " norm_std=[127.5],\n", 744 " color_space_type=_metadata_fb.ColorSpaceType.RGB,\n", 745 " tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])\n", 746 "\n", 747 "# Create output tensor information.\n", 748 "output_md = metadata_info.ClassificationTensorMd(\n", 749 " name=\"probability\",\n", 750 " description=\"Probabilities of the 1001 labels respectively.\",\n", 751 " label_files=[\n", 752 " metadata_info.LabelFileMd(file_path=\"mobilenet_labels.txt\",\n", 753 " locale=\"en\")\n", 754 " ],\n", 755 " tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0])" 756 ] 757 }, 758 { 759 "cell_type": "markdown", 760 "metadata": { 761 "id": "N5aL5Uxkf4aO" 762 }, 763 "source": [ 764 "Step 4: Create metadata writer and populate." 765 ] 766 }, 767 { 768 "cell_type": "code", 769 "execution_count": null, 770 "metadata": { 771 "id": "_iWIwdqEf_mr" 772 }, 773 "outputs": [], 774 "source": [ 775 "ImageClassifierWriter = image_classifier.MetadataWriter\n", 776 "# Create the metadata writer.\n", 777 "writer = ImageClassifierWriter.create_from_metadata_info(\n", 778 " model_buffer, general_md, input_md, output_md)\n", 779 "\n", 780 "# Verify the metadata generated by metadata writer.\n", 781 "print(writer.get_metadata_json())\n", 782 "\n", 783 "# Populate the metadata into the model.\n", 784 "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)" 785 ] 786 }, 787 { 788 "cell_type": "markdown", 789 "metadata": { 790 "id": "z78vuu6np5sb" 791 }, 792 "source": [ 793 "## Read the metadata populated to your model." 794 ] 795 }, 796 { 797 "cell_type": "markdown", 798 "metadata": { 799 "id": "DnWt-4oOselo" 800 }, 801 "source": [ 802 "You can display the metadata and associated files in a TFLite model through the following code:" 803 ] 804 }, 805 { 806 "cell_type": "code", 807 "execution_count": null, 808 "metadata": { 809 "id": "5D13YPUsp5VT" 810 }, 811 "outputs": [], 812 "source": [ 813 "from tflite_support import metadata\n", 814 "\n", 815 "displayer = metadata.MetadataDisplayer.with_model_file(\"mobilenet_v2_1.0_224_metadata.tflite\")\n", 816 "print(\"Metadata populated:\")\n", 817 "print(displayer.get_metadata_json())\n", 818 "\n", 819 "print(\"Associated file(s) populated:\")\n", 820 "for file_name in displayer.get_packed_associated_file_list():\n", 821 " print(\"file name: \", file_name)\n", 822 " print(\"file content:\")\n", 823 " print(displayer.get_associated_file_buffer(file_name))" 824 ] 825 } 826 ], 827 "metadata": { 828 "colab": { 829 "collapsed_sections": [ 830 "Mq-riZs-TJGt" 831 ], 832 "name": "Metadata Writer tutorial", 833 "private_outputs": true, 834 "provenance": [], 835 "toc_visible": true 836 }, 837 "kernelspec": { 838 "display_name": "Python 3", 839 "name": "python3" 840 }, 841 "language_info": { 842 "name": "python" 843 } 844 }, 845 "nbformat": 4, 846 "nbformat_minor": 0 847} 848