xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/save_options.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Options for saving SavedModels."""
16
17import enum
18
19from tensorflow.python.util import compat
20from tensorflow.python.util.tf_export import tf_export
21
22
23@tf_export("saved_model.experimental.VariablePolicy")
24class VariablePolicy(enum.Enum):
25  """Enum defining options for variable handling when saving.
26
27  NONE
28    No policy applied: Distributed variables are saved as one variable, with no
29    device attached.
30
31  SAVE_VARIABLE_DEVICES
32    When saving variables, also save their device assignment.
33    This is useful if one wants to hardcode devices in saved models, but it also
34    makes them non-portable if soft device placement is disabled (more details
35    in `tf.config.set_soft_device_placement`). This is currently not
36    fully supported by `saved_model.load`, and is mainly intended to be used
37    when one will be reading the saved model at a lower API level. In the
38    example below, the graph saved by the call to `saved_model.save` will have
39    the variable devices correctly specified:
40    ```python
41    exported = tf.train.Checkpoint()
42    with tf.device('/GPU:0'):
43      exported.x_gpu = tf.Variable(1.0)
44    with tf.device('/CPU:0'):
45      exported.x_cpu = tf.Variable(1.0)
46    tf.saved_model.save(exported, export_dir,
47        options = tf.saved_model.SaveOptions(
48            experimental_variable_policy=
49              tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES))
50    ```
51    Distributed variables are still saved as one variable under this policy.
52
53  EXPAND_DISTRIBUTED_VARIABLES
54    Distributed variables will be saved with information about their components,
55    allowing for their restoration on load. Also, the saved graph will contain
56    references to those variables. This is useful when one wants to use the
57    model for training in environments where the original distribution strategy
58    is not available.
59  """
60
61  NONE = None
62
63  SAVE_VARIABLE_DEVICES = "save_variable_devices"
64
65  EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
66
67  def _save_variable_devices(self):
68    """Checks whether variable devices should be saved."""
69    return self != VariablePolicy.NONE
70
71  def _expand_distributed_variables(self):
72    """Checks whether distributed variables should be expanded."""
73    return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
74
75  @staticmethod
76  def from_obj(obj):
77    """Tries to convert `obj` to a VariablePolicy instance."""
78    if obj is None:
79      return VariablePolicy.NONE
80    if isinstance(obj, VariablePolicy):
81      return obj
82    key = str(obj).lower()
83    for policy in VariablePolicy:
84      if key == policy.value:
85        return policy
86    raise ValueError(f"Received invalid VariablePolicy value: {obj}.")
87
88
89@tf_export("saved_model.SaveOptions")
90class SaveOptions:
91  """Options for saving to SavedModel.
92
93  This function may be used in the `options` argument in functions that
94  save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
95  """
96
97  # Define object attributes in __slots__ for improved memory and performance.
98  __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
99               "experimental_io_device", "experimental_variable_policy",
100               "experimental_custom_gradients")
101
102  def __init__(self,
103               namespace_whitelist=None,
104               save_debug_info=False,
105               function_aliases=None,
106               experimental_io_device=None,
107               experimental_variable_policy=None,
108               experimental_custom_gradients=True):
109    """Creates an object that stores options for SavedModel saving.
110
111    Args:
112      namespace_whitelist: List of strings containing op namespaces to whitelist
113        when saving a model. Saving an object that uses namespaced ops must
114        explicitly add all namespaces to the whitelist. The namespaced ops must
115        be registered into the framework when loading the SavedModel. If no
116        whitelist is provided, all namespaced ops will be allowed.
117      save_debug_info: Boolean indicating whether debug information is saved. If
118        True, then a debug/saved_model_debug_info.pb file will be written with
119        the contents of a GraphDebugInfo binary protocol buffer containing stack
120        trace information for all ops and functions that are saved.
121      function_aliases: Python dict. Mapping from string to object returned by
122        @tf.function. A single tf.function can generate many ConcreteFunctions.
123        If a downstream tool wants to refer to all concrete functions generated
124        by a single tf.function you can use the `function_aliases` argument to
125        store a map from the alias name to all concrete function names.
126        E.g.
127
128        >>> class Adder(tf.Module):
129        ...   @tf.function
130        ...   def double(self, x):
131        ...     return x + x
132
133        >>> model = Adder()
134        >>> model.double.get_concrete_function(
135        ...   tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input"))
136        >>> model.double.get_concrete_function(
137        ...   tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))
138
139        >>> options = tf.saved_model.SaveOptions(
140        ...   function_aliases={'double': model.double})
141        >>> tf.saved_model.save(model, '/tmp/adder', options=options)
142
143      experimental_io_device: string. Applies in a distributed setting.
144        Tensorflow device to use to access the filesystem. If `None` (default)
145        then for each variable the filesystem is accessed from the CPU:0 device
146        of the host where that variable is assigned. If specified, the
147        filesystem is instead accessed from that device for all variables.
148
149        This is for example useful if you want to save to a local directory,
150        such as "/tmp" when running in a distributed setting. In that case pass
151        a device for the host where the "/tmp" directory is accessible.
152      experimental_variable_policy: The policy to apply to variables when
153        saving. This is either a `saved_model.experimental.VariablePolicy` enum
154        instance or one of its value strings (case is not important). See that
155        enum documentation for details. A value of `None` corresponds to the
156        default policy.
157      experimental_custom_gradients: Boolean. When True, will save traced
158        gradient functions for the functions decorated by `tf.custom_gradient`.
159        Defaults to `True`.
160    """
161    self.namespace_whitelist = _validate_namespace_whitelist(
162        namespace_whitelist)
163    self.save_debug_info = save_debug_info
164    self.function_aliases = function_aliases if function_aliases else dict()
165    self.experimental_custom_gradients = experimental_custom_gradients
166    self.experimental_io_device = experimental_io_device
167    self.experimental_variable_policy = (
168        VariablePolicy.from_obj(experimental_variable_policy))
169
170
171def _validate_namespace_whitelist(namespace_whitelist):
172  """Validates namespace whitelist argument."""
173  if namespace_whitelist is None:
174    return None
175  if not isinstance(namespace_whitelist, list):
176    raise TypeError("`namespace_whitelist` must be a list of strings. Got: "
177                    f"{namespace_whitelist} with type "
178                    f"{type(namespace_whitelist)}.")
179
180  processed = []
181  for namespace in namespace_whitelist:
182    if not isinstance(namespace, str):
183      raise ValueError("Whitelisted namespace must be a string. Got: "
184                       f"{namespace} of type {type(namespace)}.")
185    processed.append(compat.as_str(namespace))
186  return processed
187