xref: /aosp_15_r20/external/tensorflow/tensorflow/python/trackable/python_state.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Utilities for including Python state in TensorFlow checkpoints."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import abc
17
18from tensorflow.python.trackable import base
19from tensorflow.python.util.tf_export import tf_export
20
21
22@tf_export("train.experimental.PythonState")
23class PythonState(base.Trackable, metaclass=abc.ABCMeta):
24  """A mixin for putting Python state in an object-based checkpoint.
25
26  This is an abstract class which allows extensions to TensorFlow's object-based
27  checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy
28  arrays:
29
30  ```python
31  import io
32  import numpy
33
34  class NumpyWrapper(tf.train.experimental.PythonState):
35
36    def __init__(self, array):
37      self.array = array
38
39    def serialize(self):
40      string_file = io.BytesIO()
41      try:
42        numpy.save(string_file, self.array, allow_pickle=False)
43        serialized = string_file.getvalue()
44      finally:
45        string_file.close()
46      return serialized
47
48    def deserialize(self, string_value):
49      string_file = io.BytesIO(string_value)
50      try:
51        self.array = numpy.load(string_file, allow_pickle=False)
52      finally:
53        string_file.close()
54  ```
55
56  Instances of `NumpyWrapper` are checkpointable objects, and will be saved and
57  restored from checkpoints along with TensorFlow state like variables.
58
59  ```python
60  root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))
61  save_path = root.save(prefix)
62  root.numpy.array *= 2.
63  assert [2.] == root.numpy.array
64  root.restore(save_path)
65  assert [1.] == root.numpy.array
66  ```
67  """
68
69  @abc.abstractmethod
70  def serialize(self):
71    """Callback to serialize the object. Returns a string."""
72
73  @abc.abstractmethod
74  def deserialize(self, string_value):
75    """Callback to deserialize the object."""
76