1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Context for building SavedModel.""" 16 17import contextlib 18import threading 19 20 21class SaveContext(threading.local): 22 """A context for building a graph of SavedModel.""" 23 24 def __init__(self): 25 super(SaveContext, self).__init__() 26 self._in_save_context = False 27 self._options = None 28 29 def options(self): 30 if not self.in_save_context(): 31 raise ValueError("Not in a SaveContext.") 32 return self._options 33 34 def enter_save_context(self, options): 35 self._in_save_context = True 36 self._options = options 37 38 def exit_save_context(self): 39 self._in_save_context = False 40 self._options = None 41 42 def in_save_context(self): 43 return self._in_save_context 44 45_save_context = SaveContext() 46 47 48@contextlib.contextmanager 49def save_context(options): 50 if in_save_context(): 51 raise ValueError("Already in a SaveContext.") 52 _save_context.enter_save_context(options) 53 try: 54 yield 55 finally: 56 _save_context.exit_save_context() 57 58 59def in_save_context(): 60 """Returns whether under a save context.""" 61 return _save_context.in_save_context() 62 63 64def get_save_options(): 65 """Returns the save options if under a save context.""" 66 return _save_context.options() 67