1"""Utilities for including Python state in TensorFlow checkpoints.""" 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16import abc 17 18from tensorflow.python.trackable import base 19from tensorflow.python.util.tf_export import tf_export 20 21 22@tf_export("train.experimental.PythonState") 23class PythonState(base.Trackable, metaclass=abc.ABCMeta): 24 """A mixin for putting Python state in an object-based checkpoint. 25 26 This is an abstract class which allows extensions to TensorFlow's object-based 27 checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy 28 arrays: 29 30 ```python 31 import io 32 import numpy 33 34 class NumpyWrapper(tf.train.experimental.PythonState): 35 36 def __init__(self, array): 37 self.array = array 38 39 def serialize(self): 40 string_file = io.BytesIO() 41 try: 42 numpy.save(string_file, self.array, allow_pickle=False) 43 serialized = string_file.getvalue() 44 finally: 45 string_file.close() 46 return serialized 47 48 def deserialize(self, string_value): 49 string_file = io.BytesIO(string_value) 50 try: 51 self.array = numpy.load(string_file, allow_pickle=False) 52 finally: 53 string_file.close() 54 ``` 55 56 Instances of `NumpyWrapper` are checkpointable objects, and will be saved and 57 restored from checkpoints along with TensorFlow state like variables. 58 59 ```python 60 root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.]))) 61 save_path = root.save(prefix) 62 root.numpy.array *= 2. 63 assert [2.] == root.numpy.array 64 root.restore(save_path) 65 assert [1.] == root.numpy.array 66 ``` 67 """ 68 69 @abc.abstractmethod 70 def serialize(self): 71 """Callback to serialize the object. Returns a string.""" 72 73 @abc.abstractmethod 74 def deserialize(self, string_value): 75 """Callback to deserialize the object.""" 76