1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Options for saving SavedModels.""" 16 17import enum 18 19from tensorflow.python.util import compat 20from tensorflow.python.util.tf_export import tf_export 21 22 23@tf_export("saved_model.experimental.VariablePolicy") 24class VariablePolicy(enum.Enum): 25 """Enum defining options for variable handling when saving. 26 27 NONE 28 No policy applied: Distributed variables are saved as one variable, with no 29 device attached. 30 31 SAVE_VARIABLE_DEVICES 32 When saving variables, also save their device assignment. 33 This is useful if one wants to hardcode devices in saved models, but it also 34 makes them non-portable if soft device placement is disabled (more details 35 in `tf.config.set_soft_device_placement`). This is currently not 36 fully supported by `saved_model.load`, and is mainly intended to be used 37 when one will be reading the saved model at a lower API level. In the 38 example below, the graph saved by the call to `saved_model.save` will have 39 the variable devices correctly specified: 40 ```python 41 exported = tf.train.Checkpoint() 42 with tf.device('/GPU:0'): 43 exported.x_gpu = tf.Variable(1.0) 44 with tf.device('/CPU:0'): 45 exported.x_cpu = tf.Variable(1.0) 46 tf.saved_model.save(exported, export_dir, 47 options = tf.saved_model.SaveOptions( 48 experimental_variable_policy= 49 tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES)) 50 ``` 51 Distributed variables are still saved as one variable under this policy. 52 53 EXPAND_DISTRIBUTED_VARIABLES 54 Distributed variables will be saved with information about their components, 55 allowing for their restoration on load. Also, the saved graph will contain 56 references to those variables. This is useful when one wants to use the 57 model for training in environments where the original distribution strategy 58 is not available. 59 """ 60 61 NONE = None 62 63 SAVE_VARIABLE_DEVICES = "save_variable_devices" 64 65 EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables" 66 67 def _save_variable_devices(self): 68 """Checks whether variable devices should be saved.""" 69 return self != VariablePolicy.NONE 70 71 def _expand_distributed_variables(self): 72 """Checks whether distributed variables should be expanded.""" 73 return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES 74 75 @staticmethod 76 def from_obj(obj): 77 """Tries to convert `obj` to a VariablePolicy instance.""" 78 if obj is None: 79 return VariablePolicy.NONE 80 if isinstance(obj, VariablePolicy): 81 return obj 82 key = str(obj).lower() 83 for policy in VariablePolicy: 84 if key == policy.value: 85 return policy 86 raise ValueError(f"Received invalid VariablePolicy value: {obj}.") 87 88 89@tf_export("saved_model.SaveOptions") 90class SaveOptions: 91 """Options for saving to SavedModel. 92 93 This function may be used in the `options` argument in functions that 94 save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`). 95 """ 96 97 # Define object attributes in __slots__ for improved memory and performance. 98 __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases", 99 "experimental_io_device", "experimental_variable_policy", 100 "experimental_custom_gradients") 101 102 def __init__(self, 103 namespace_whitelist=None, 104 save_debug_info=False, 105 function_aliases=None, 106 experimental_io_device=None, 107 experimental_variable_policy=None, 108 experimental_custom_gradients=True): 109 """Creates an object that stores options for SavedModel saving. 110 111 Args: 112 namespace_whitelist: List of strings containing op namespaces to whitelist 113 when saving a model. Saving an object that uses namespaced ops must 114 explicitly add all namespaces to the whitelist. The namespaced ops must 115 be registered into the framework when loading the SavedModel. If no 116 whitelist is provided, all namespaced ops will be allowed. 117 save_debug_info: Boolean indicating whether debug information is saved. If 118 True, then a debug/saved_model_debug_info.pb file will be written with 119 the contents of a GraphDebugInfo binary protocol buffer containing stack 120 trace information for all ops and functions that are saved. 121 function_aliases: Python dict. Mapping from string to object returned by 122 @tf.function. A single tf.function can generate many ConcreteFunctions. 123 If a downstream tool wants to refer to all concrete functions generated 124 by a single tf.function you can use the `function_aliases` argument to 125 store a map from the alias name to all concrete function names. 126 E.g. 127 128 >>> class Adder(tf.Module): 129 ... @tf.function 130 ... def double(self, x): 131 ... return x + x 132 133 >>> model = Adder() 134 >>> model.double.get_concrete_function( 135 ... tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input")) 136 >>> model.double.get_concrete_function( 137 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")) 138 139 >>> options = tf.saved_model.SaveOptions( 140 ... function_aliases={'double': model.double}) 141 >>> tf.saved_model.save(model, '/tmp/adder', options=options) 142 143 experimental_io_device: string. Applies in a distributed setting. 144 Tensorflow device to use to access the filesystem. If `None` (default) 145 then for each variable the filesystem is accessed from the CPU:0 device 146 of the host where that variable is assigned. If specified, the 147 filesystem is instead accessed from that device for all variables. 148 149 This is for example useful if you want to save to a local directory, 150 such as "/tmp" when running in a distributed setting. In that case pass 151 a device for the host where the "/tmp" directory is accessible. 152 experimental_variable_policy: The policy to apply to variables when 153 saving. This is either a `saved_model.experimental.VariablePolicy` enum 154 instance or one of its value strings (case is not important). See that 155 enum documentation for details. A value of `None` corresponds to the 156 default policy. 157 experimental_custom_gradients: Boolean. When True, will save traced 158 gradient functions for the functions decorated by `tf.custom_gradient`. 159 Defaults to `True`. 160 """ 161 self.namespace_whitelist = _validate_namespace_whitelist( 162 namespace_whitelist) 163 self.save_debug_info = save_debug_info 164 self.function_aliases = function_aliases if function_aliases else dict() 165 self.experimental_custom_gradients = experimental_custom_gradients 166 self.experimental_io_device = experimental_io_device 167 self.experimental_variable_policy = ( 168 VariablePolicy.from_obj(experimental_variable_policy)) 169 170 171def _validate_namespace_whitelist(namespace_whitelist): 172 """Validates namespace whitelist argument.""" 173 if namespace_whitelist is None: 174 return None 175 if not isinstance(namespace_whitelist, list): 176 raise TypeError("`namespace_whitelist` must be a list of strings. Got: " 177 f"{namespace_whitelist} with type " 178 f"{type(namespace_whitelist)}.") 179 180 processed = [] 181 for namespace in namespace_whitelist: 182 if not isinstance(namespace, str): 183 raise ValueError("Whitelisted namespace must be a string. Got: " 184 f"{namespace} of type {type(namespace)}.") 185 processed.append(compat.as_str(namespace)) 186 return processed 187