1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for convert_to_constants.py.""" 16 17import os 18import re 19 20import numpy as np 21 22from google.protobuf import text_format 23from tensorflow.core.framework import attr_value_pb2 24from tensorflow.core.framework import function_pb2 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.framework import node_def_pb2 27from tensorflow.core.framework import op_def_pb2 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.protobuf import meta_graph_pb2 30from tensorflow.core.protobuf import saved_model_pb2 31from tensorflow.python.client import session as session_lib 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import convert_to_constants 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import function 37from tensorflow.python.framework import importer 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.framework import test_util 41from tensorflow.python.grappler import tf_optimizer 42from tensorflow.python.lib.io import file_io 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import cond_v2 45from tensorflow.python.ops import control_flow_ops 46from tensorflow.python.ops import control_flow_v2_toggles 47from tensorflow.python.ops import gen_math_ops 48from tensorflow.python.ops import math_ops 49from tensorflow.python.ops import rnn 50from tensorflow.python.ops import rnn_cell_impl 51from tensorflow.python.ops import variable_scope 52from tensorflow.python.ops import variables 53from tensorflow.python.ops import while_v2 54from tensorflow.python.platform import test 55from tensorflow.python.saved_model import constants 56from tensorflow.python.saved_model import loader_impl 57from tensorflow.python.saved_model import simple_save 58from tensorflow.python.saved_model.load import load 59from tensorflow.python.saved_model.save import save 60from tensorflow.python.trackable import autotrackable 61from tensorflow.python.training.saver import export_meta_graph 62from tensorflow.python.util import compat 63from tensorflow.python.util import nest 64 65 66class _GraphMerger(object): 67 """GraphDef merging methods for testing purposes.""" 68 69 @staticmethod 70 def merge_any(x1, x2, empty_fn): 71 """Merges two values using the message's CopyFrom/MergeFrom methods.""" 72 merged = empty_fn() 73 merged.CopyFrom(x1) 74 merged.MergeFrom(x2) 75 return merged 76 77 @staticmethod 78 def merge_nodes(node1, node2): 79 """Merges two NodeDef messages.""" 80 merged = _GraphMerger.merge_any(node1, node2, node_def_pb2.NodeDef) 81 merged_inputs = node1.input[:] 82 merged_inputs.extend([i for i in node2.input[:] if i not in merged_inputs]) 83 merged.input[:] = merged_inputs 84 return merged 85 86 @staticmethod 87 def merge_lists(repeated1, repeated2, empty_fn, key_fn, merge_fn): 88 """Merges two lists representing maps.""" 89 merged = {} 90 xs1 = {key_fn(x): x for x in repeated1} 91 xs2 = {key_fn(x): x for x in repeated2} 92 for name in set().union(xs1.keys(), xs2.keys()): 93 x1 = empty_fn() if name not in xs1 else xs1[name] 94 x2 = empty_fn() if name not in xs2 else xs2[name] 95 merged[name] = merge_fn(x1, x2) 96 return sorted(merged.values(), key=key_fn) 97 98 @staticmethod 99 def merge_node_lists(repeated_nodes1, repeated_nodes2): 100 """Merges two repeated node fields.""" 101 return _GraphMerger.merge_lists(repeated_nodes1, repeated_nodes2, 102 node_def_pb2.NodeDef, lambda n: n.name, 103 _GraphMerger.merge_nodes) 104 105 @staticmethod 106 def merge_functions(fn1, fn2): 107 """Merges two FunctionDefs.""" 108 merged = _GraphMerger.merge_any(fn1, fn2, function_pb2.FunctionDef) 109 110 del merged.signature.input_arg[:] 111 merged.signature.input_arg.extend( 112 _GraphMerger.merge_lists( 113 fn1.signature.input_arg[:], fn2.signature.input_arg[:], 114 op_def_pb2.OpDef.ArgDef, lambda a: a.name, 115 lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef))) 116 117 del merged.signature.output_arg[:] 118 merged.signature.output_arg.extend( 119 _GraphMerger.merge_lists( 120 fn1.signature.output_arg[:], fn2.signature.output_arg[:], 121 op_def_pb2.OpDef.ArgDef, lambda a: a.name, 122 lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef))) 123 124 del merged.node_def[:] 125 merged.node_def.extend( 126 _GraphMerger.merge_node_lists(fn1.node_def[:], fn2.node_def[:])) 127 128 return merged 129 130 @staticmethod 131 def merge_graphs(graph1, graph2): 132 """Merges two GraphDef messages.""" 133 merged = graph_pb2.GraphDef() 134 merged.node.extend( 135 _GraphMerger.merge_node_lists(graph1.node[:], graph2.node[:])) 136 137 merged.library.function.extend( 138 _GraphMerger.merge_lists(graph1.library.function, 139 graph2.library.function, 140 function_pb2.FunctionDef, 141 lambda f: f.signature.name, 142 _GraphMerger.merge_functions)) 143 144 return merged 145 146 147def has_stateful_partitioned_call_op(graph_def): 148 """Determines if a StatefulPartitionedCall op exists in the graph.""" 149 for node in graph_def.node: 150 if node.op == "StatefulPartitionedCall": 151 return True 152 return False 153 154 155def get_num_variables(graph_def): 156 """Returns the number of ReadVariableOp in the graph.""" 157 return sum(node.op == "ReadVariableOp" for node in graph_def.node) 158 159 160class VariablesToConstantsTest(test.TestCase): 161 162 def _freezeModel(self, func): 163 """Freezes the function. 164 165 Args: 166 func: Function. 167 168 Returns: 169 root: AutoTrackable object with original ConcreteFunction. 170 output_func: frozen ConcreteFunction. 171 """ 172 root = autotrackable.AutoTrackable() 173 root.f = func 174 input_func = root.f.get_concrete_function() 175 176 output_func = convert_to_constants.convert_variables_to_constants_v2( 177 input_func, lower_control_flow=False) 178 return root, output_func 179 180 def _testConvertedFunction(self, obj, func, converted_concrete_func, 181 input_data): 182 # Ensure the converted graph has no variables and no function calls. 183 constant_graph_def = converted_concrete_func.graph.as_graph_def() 184 self.assertEqual(0, get_num_variables(constant_graph_def)) 185 self.assertFalse(has_stateful_partitioned_call_op(constant_graph_def)) 186 187 # Check that the converted ConcreteFunction produces the same result as the 188 # original Function. 189 expected_value = nest.flatten(func(**input_data)) 190 actual_value = nest.flatten(converted_concrete_func(**input_data)) 191 192 for expected, actual in zip(expected_value, actual_value): 193 np.testing.assert_almost_equal(expected.numpy(), actual.numpy()) 194 195 # Ensure the shape is retained. 196 for tensor in converted_concrete_func.inputs: 197 actual_shape = input_data[tensor.name.split(":")[0]].shape 198 self.assertEqual(tensor.shape, actual_shape) 199 200 # Save the converted ConcreteFunction as a signature. 201 save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model") 202 root = autotrackable.AutoTrackable() 203 root.f = converted_concrete_func 204 save(root, save_dir, {"mykey": converted_concrete_func}) 205 206 # Load it back and make sure it works. 207 loaded_obj = load(save_dir) 208 actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data)) 209 for expected, actual in zip(expected_value, actual_value): 210 np.testing.assert_almost_equal(expected.numpy(), actual.numpy()) 211 212 @test_util.run_v2_only 213 def testConstSavedModel(self): 214 """Test a basic model with constants while saving/loading the SavedModel.""" 215 input_data = {"x": constant_op.constant(1., shape=[1])} 216 root = autotrackable.AutoTrackable() 217 root.f = def_function.function(lambda x: 2. * x) 218 to_save = root.f.get_concrete_function(input_data["x"]) 219 220 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 221 save(root, save_dir, to_save) 222 saved_model = load(save_dir) 223 input_func = saved_model.signatures["serving_default"] 224 225 variable_graph_def = input_func.graph.as_graph_def() 226 self.assertEqual(0, get_num_variables(variable_graph_def)) 227 self.assertTrue(variable_graph_def.library.function) 228 229 output_func = convert_to_constants.convert_variables_to_constants_v2( 230 input_func) 231 self._testConvertedFunction(root, root.f, output_func, input_data) 232 233 @test_util.run_v2_only 234 def testVariableModel(self): 235 """Test a basic model with Variables.""" 236 input_data = {"x": constant_op.constant(1., shape=[1])} 237 root = autotrackable.AutoTrackable() 238 root.v1 = variables.Variable(3.) 239 root.v2 = variables.Variable(2.) 240 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 241 input_func = root.f.get_concrete_function(input_data["x"]) 242 243 variable_graph_def = input_func.graph.as_graph_def() 244 self.assertEqual(2, get_num_variables(variable_graph_def)) 245 246 output_func = convert_to_constants.convert_variables_to_constants_v2( 247 input_func) 248 self._testConvertedFunction(root, root.f, output_func, input_data) 249 250 @test_util.run_v2_only 251 def testScalarModel(self): 252 """Test a basic model with Variables.""" 253 input_data = {"x": constant_op.constant(1., shape=[])} 254 root = autotrackable.AutoTrackable() 255 root.v1 = variables.Variable(3.) 256 root.v2 = variables.Variable(2.) 257 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 258 input_func = root.f.get_concrete_function(input_data["x"]) 259 260 variable_graph_def = input_func.graph.as_graph_def() 261 self.assertEqual(2, get_num_variables(variable_graph_def)) 262 263 output_func = convert_to_constants.convert_variables_to_constants_v2( 264 input_func) 265 self._testConvertedFunction(root, root.f, output_func, input_data) 266 267 @test_util.run_v2_only 268 def testVariableSavedModel(self): 269 """Test a basic model with Variables with saving/loading the SavedModel.""" 270 input_data = {"x": constant_op.constant(1., shape=[1])} 271 root = autotrackable.AutoTrackable() 272 root.v1 = variables.Variable(3.) 273 root.v2 = variables.Variable(2.) 274 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 275 to_save = root.f.get_concrete_function(input_data["x"]) 276 277 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 278 save(root, save_dir, to_save) 279 saved_model = load(save_dir) 280 input_func = saved_model.signatures["serving_default"] 281 282 variable_graph_def = input_func.graph.as_graph_def() 283 self.assertTrue(has_stateful_partitioned_call_op(variable_graph_def)) 284 285 output_func = convert_to_constants.convert_variables_to_constants_v2( 286 input_func) 287 self._testConvertedFunction(root, root.f, output_func, input_data) 288 289 @test_util.run_v2_only 290 def testMultiFunctionModel(self): 291 """Test a basic model with multiple tf.functions.""" 292 293 class BasicModel(autotrackable.AutoTrackable): 294 295 def __init__(self): 296 self.y = None 297 self.z = None 298 299 @def_function.function 300 def add(self, x): 301 if self.y is None: 302 self.y = variables.Variable(2.) 303 return x + self.y 304 305 @def_function.function 306 def sub(self, x): 307 if self.z is None: 308 self.z = variables.Variable(3.) 309 return x - self.z 310 311 input_data = {"x": constant_op.constant(1., shape=[1])} 312 root = BasicModel() 313 input_func = root.add.get_concrete_function(input_data["x"]) 314 315 variable_graph_def = input_func.graph.as_graph_def() 316 self.assertEqual(1, get_num_variables(variable_graph_def)) 317 318 output_func = convert_to_constants.convert_variables_to_constants_v2( 319 input_func) 320 self._testConvertedFunction(root, root.add, output_func, input_data) 321 322 def _singleMetaGraphSavedModel(self): 323 export_graph = ops.Graph() 324 with export_graph.as_default(): 325 start = array_ops.placeholder( 326 shape=[1, 1], dtype=dtypes.float32, name="start") 327 distractor = variables.RefVariable(-1., name="distractor") 328 v = variables.RefVariable(3., name="v") 329 local_variable = variables.VariableV1( 330 1., 331 collections=[ops.GraphKeys.LOCAL_VARIABLES], 332 trainable=False, 333 use_resource=True) 334 output = array_ops.identity(start * v * local_variable, name="output") 335 with session_lib.Session() as session: 336 session.run([v.initializer, distractor.initializer, 337 local_variable.initializer]) 338 path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) 339 simple_save.simple_save( 340 session, 341 path, 342 inputs={"start": start}, 343 outputs={"output": output}, 344 legacy_init_op=local_variable.initializer) 345 return path 346 347 @test_util.run_v2_only 348 def testRefVariableImport(self): 349 """Test a model with 1.X ReferenceVariables.""" 350 input_data = {"start": constant_op.constant(1., shape=[1, 1])} 351 352 saved = self._singleMetaGraphSavedModel() 353 imported = load(saved) 354 fn = imported.signatures["serving_default"] 355 356 output_func = convert_to_constants.convert_variables_to_constants_v2(fn) 357 root = autotrackable.AutoTrackable() 358 self._testConvertedFunction(root, fn, output_func, input_data) 359 360 @test_util.run_v2_only 361 def testIf(self): 362 """Test a model with the If op.""" 363 input_data = { 364 "x": constant_op.constant([1., 2.], shape=[1, 2]), 365 "b": constant_op.constant(True) 366 } 367 368 weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32) 369 370 def true_fn(x): 371 return math_ops.matmul(x, weights) 372 373 def false_fn(x): 374 return math_ops.add(x, weights) 375 376 @def_function.function(input_signature=[ 377 tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32), 378 tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool) 379 ]) 380 def model(x, b): 381 return control_flow_ops.cond( 382 b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x)) 383 384 root, output_func = self._freezeModel(model) 385 self._testConvertedFunction(root, root.f, output_func, input_data) 386 387 @test_util.run_v2_only 388 def testStatelessIf(self): 389 """Test a model with the StatelessIf op.""" 390 input_data = {"b": constant_op.constant(True)} 391 392 x = constant_op.constant([1., 2.], shape=[1, 2], name="x") 393 394 def true_fn(): 395 return x 396 397 def false_fn(): 398 return x + 2 399 400 @def_function.function( 401 input_signature=[tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)]) 402 def model(b): 403 return cond_v2.cond_v2(b, true_fn, false_fn) 404 405 root, output_func = self._freezeModel(model) 406 self._testConvertedFunction(root, root.f, output_func, input_data) 407 408 @test_util.run_v2_only 409 def testStaticRnn(self): 410 """Test a StaticRnn containing If ops.""" 411 input_data = { 412 "x": 413 constant_op.constant( 414 np.array(np.random.random_sample((3, 10)), dtype=np.float32)) 415 } 416 417 cell = rnn_cell_impl.LSTMCell(10) 418 419 @def_function.function(input_signature=[ 420 tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32) 421 ]) 422 def model(x): 423 seq = array_ops.split(x, 3, 0) 424 return rnn.static_rnn( 425 cell, seq, dtype=dtypes.float32, sequence_length=[1]) 426 427 root, output_func = self._freezeModel(model) 428 429 self._testConvertedFunction(root, root.f, output_func, input_data) 430 431 @test_util.run_v2_only 432 def testWhile(self): 433 """Test a While loop.""" 434 input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])} 435 436 weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32) 437 438 def condition(x): 439 return math_ops.reduce_sum(x) < 100 440 441 def body(x): 442 return math_ops.add(x, weights) 443 444 @def_function.function(input_signature=[ 445 tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32) 446 ]) 447 def model(x): 448 return control_flow_ops.while_loop(condition, body, [x]) 449 450 root, output_func = self._freezeModel(model) 451 452 self._testConvertedFunction(root, root.f, output_func, input_data) 453 454 @test_util.run_v2_only 455 def testStatelessWhile(self): 456 """Test a StatelessWhile loop.""" 457 input_data = {"x": constant_op.constant(2.)} 458 459 @def_function.function(input_signature=[ 460 tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) 461 ]) 462 def model(x): 463 return while_v2.while_loop( 464 lambda v: v < 4., 465 lambda v: v * v, [x], 466 return_same_structure=False, 467 name="while_1") # x**2 468 469 root, output_func = self._freezeModel(model) 470 self._testConvertedFunction(root, root.f, output_func, input_data) 471 472 @test_util.run_v2_only 473 def testDynamicRnn(self): 474 """Test a DynamicRnn containing While loops.""" 475 input_data = { 476 "x": 477 constant_op.constant( 478 np.array( 479 np.random.random_sample((3, 10, 10)), dtype=np.float32)) 480 } 481 482 cell = rnn_cell_impl.LSTMCell(10) 483 484 @def_function.function(input_signature=[ 485 tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32) 486 ]) 487 def model(x): 488 return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32) 489 490 root, output_func = self._freezeModel(model) 491 self._testConvertedFunction(root, root.f, output_func, input_data) 492 493 @test_util.run_v2_only 494 @test_util.disable_tfrt("b/180451239") 495 def testSwitchCase(self): 496 """Test a switch_case statement.""" 497 input_data = { 498 "i": constant_op.constant(np.random.randint(0, 3, dtype=np.int32)), 499 "x": constant_op.constant( 500 np.asarray(np.random.random_sample((10, 3)), dtype=np.float32)), 501 } 502 503 w0 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32) 504 w1 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32) 505 w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32) 506 507 def branch0(x): 508 return math_ops.matmul(x, w0) 509 510 def branch1(x): 511 return math_ops.matmul(x, w1) 512 513 def branch2(x): 514 x = array_ops.pad(x, [[0, 0], [0, 1]]) 515 return x + w2 516 517 @def_function.function(input_signature=[ 518 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), 519 tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32), 520 ]) 521 def model(i, x): 522 return control_flow_ops.switch_case(i, [ 523 lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)]) 524 525 root, output_func = self._freezeModel(model) 526 self._testConvertedFunction(root, root.f, output_func, input_data) 527 528 529class ConvertVariablesToConstantsV2SessionTest(test.TestCase): 530 531 def _freezeModel(self, func): 532 """Freezes the function. 533 534 Args: 535 func: Function. 536 537 Returns: 538 root: AutoTrackable object with original ConcreteFunction. 539 output_func: frozen ConcreteFunction. 540 """ 541 root = autotrackable.AutoTrackable() 542 root.f = func 543 input_func = root.f.get_concrete_function() 544 545 output_func = convert_to_constants.convert_var_to_const_function_in_v1( 546 input_func, lower_control_flow=False) 547 return root, output_func 548 549 def _testConvertedFunction(self, sess, obj, func, converted_concrete_func, 550 input_data): 551 # Ensure the converted graph has no variables and no function calls. 552 constant_graph_def = converted_concrete_func.graph.as_graph_def() 553 self.assertEqual(0, get_num_variables(constant_graph_def)) 554 self.assertFalse(has_stateful_partitioned_call_op(constant_graph_def)) 555 556 # Check that the converted ConcreteFunction produces the same result as the 557 # original Function. 558 expected_value = nest.flatten(func(**input_data)) 559 actual_value = nest.flatten(converted_concrete_func(**input_data)) 560 561 for expected, actual in zip(expected_value, actual_value): 562 np.testing.assert_almost_equal(sess.run(expected), sess.run(actual)) 563 564 # Ensure the shape is retained. 565 for tensor in converted_concrete_func.inputs: 566 actual_shape = input_data[tensor.name.split(":")[0]].shape 567 self.assertEqual(tensor.shape, actual_shape) 568 569 # Save the converted ConcreteFunction as a signature. 570 save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model") 571 root = autotrackable.AutoTrackable() 572 root.f = converted_concrete_func 573 save(root, save_dir, {"mykey": converted_concrete_func}) 574 575 # Load it back and make sure it works. 576 loaded_obj = load(save_dir) 577 actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data)) 578 for expected, actual in zip(expected_value, actual_value): 579 np.testing.assert_almost_equal(sess.run(expected), sess.run(actual)) 580 581 def testRaiseErrorInEagerMode(self): 582 """Test the raised exception in Eager mode.""" 583 input_data = {"x": constant_op.constant(1., shape=[1])} 584 root = autotrackable.AutoTrackable() 585 root.v1 = variables.Variable(3.) 586 root.v2 = variables.Variable(2.) 587 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 588 input_func = root.f.get_concrete_function(input_data["x"]) 589 590 with self.assertRaisesRegex(RuntimeError, 591 "must be carried out in a Session"): 592 convert_to_constants.convert_var_to_const_function_in_v1( 593 input_func) 594 595 def testConvertVariables(self): 596 """Test a basic model with Variables.""" 597 with ops.Graph().as_default(): 598 with session_lib.Session() as sess: 599 input_data = {"x": constant_op.constant(1., shape=[1])} 600 root = autotrackable.AutoTrackable() 601 root.v1 = variables.Variable(3.) 602 root.v2 = variables.Variable(2.) 603 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 604 input_func = root.f.get_concrete_function(input_data["x"]) 605 606 variable_graph_def = input_func.graph.as_graph_def() 607 self.assertEqual(2, get_num_variables(variable_graph_def)) 608 609 output_func = convert_to_constants.convert_var_to_const_function_in_v1( 610 input_func) 611 612 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 613 614 def testConvertVariablesWithAssignments(self): 615 """Test a basic model with Variables and assignment ops.""" 616 with ops.Graph().as_default(): 617 with session_lib.Session() as sess: 618 input_data = {"x": constant_op.constant(1., shape=[1])} 619 root = autotrackable.AutoTrackable() 620 root.v1 = variables.Variable(3.) 621 root.v2 = variables.Variable(2.) 622 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 623 input_func = root.f.get_concrete_function(input_data["x"]) 624 625 variable_graph_def = input_func.graph.as_graph_def() 626 self.assertEqual(2, get_num_variables(variable_graph_def)) 627 628 assign_op_1 = root.v1.assign(1.5) 629 assign_op_2 = root.v2.assign(3.0) 630 assign_op_3 = root.v1.assign(4.0) 631 ops.get_default_graph().add_to_collection( 632 convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_1) 633 ops.get_default_graph().add_to_collection( 634 convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_2) 635 ops.get_default_graph().add_to_collection( 636 convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_3) 637 638 output_func = convert_to_constants.convert_var_to_const_function_in_v1( 639 input_func) 640 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 641 642 def testConstSavedModel(self): 643 """Test a basic model with constants while saving/loading the SavedModel.""" 644 with ops.Graph().as_default(): 645 with session_lib.Session() as sess: 646 input_data = {"x": constant_op.constant(1., shape=[1])} 647 root = autotrackable.AutoTrackable() 648 root.f = def_function.function(lambda x: 2. * x) 649 to_save = root.f.get_concrete_function(input_data["x"]) 650 651 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 652 save(root, save_dir, to_save) 653 saved_model = load(save_dir) 654 input_func = saved_model.signatures["serving_default"] 655 656 variable_graph_def = input_func.graph.as_graph_def() 657 self.assertEqual(0, get_num_variables(variable_graph_def)) 658 self.assertTrue(variable_graph_def.library.function) 659 660 output_func = convert_to_constants.convert_var_to_const_function_in_v1( 661 input_func) 662 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 663 664 def testVariableSavedModel(self): 665 """Test a basic model with Variables with saving/loading the SavedModel.""" 666 with ops.Graph().as_default(): 667 with session_lib.Session() as sess: 668 input_data = {"x": constant_op.constant(1., shape=[1])} 669 root = autotrackable.AutoTrackable() 670 root.v1 = variables.Variable(3.) 671 root.v2 = variables.Variable(2.) 672 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 673 to_save = root.f.get_concrete_function(input_data["x"]) 674 sess.run(variables.global_variables_initializer()) 675 676 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 677 save(root, save_dir, to_save) 678 saved_model = load(save_dir) 679 input_func = saved_model.signatures["serving_default"] 680 681 variable_graph_def = input_func.graph.as_graph_def() 682 self.assertTrue(has_stateful_partitioned_call_op(variable_graph_def)) 683 684 output_func = convert_to_constants.convert_var_to_const_function_in_v1( 685 input_func) 686 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 687 688 def testMultiFunctionModel(self): 689 """Test a basic model with multiple tf.functions.""" 690 691 class BasicModel(autotrackable.AutoTrackable): 692 693 def __init__(self): 694 self.y = None 695 self.z = None 696 697 @def_function.function 698 def add(self, x): 699 if self.y is None: 700 self.y = variables.Variable(2.) 701 return x + self.y 702 703 @def_function.function 704 def sub(self, x): 705 if self.z is None: 706 self.z = variables.Variable(3.) 707 return x - self.z 708 709 with ops.Graph().as_default(): 710 with session_lib.Session() as sess: 711 input_data = {"x": constant_op.constant(1., shape=[1])} 712 root = BasicModel() 713 input_func = root.add.get_concrete_function(input_data["x"]) 714 715 variable_graph_def = input_func.graph.as_graph_def() 716 self.assertEqual(1, get_num_variables(variable_graph_def)) 717 718 output_func = convert_to_constants.convert_var_to_const_function_in_v1( 719 input_func) 720 self._testConvertedFunction(sess, root, root.add, output_func, 721 input_data) 722 723 def testIf(self): 724 """Test a model with the If op.""" 725 with ops.Graph().as_default(): 726 with session_lib.Session() as sess: 727 input_data = { 728 "x": constant_op.constant([1., 2.], shape=[1, 2]), 729 "b": constant_op.constant(True) 730 } 731 732 weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], 733 dtype=dtypes.float32) 734 735 def true_fn(x): 736 return math_ops.matmul(x, weights) 737 738 def false_fn(x): 739 return math_ops.add(x, weights) 740 741 @def_function.function(input_signature=[ 742 tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32), 743 tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool) 744 ]) 745 def model(x, b): 746 return control_flow_ops.cond( 747 b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x)) 748 749 root, output_func = self._freezeModel(model) 750 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 751 752 def testStatelessIf(self): 753 """Test a model with the StatelessIf op.""" 754 with ops.Graph().as_default(): 755 with session_lib.Session() as sess: 756 input_data = {"b": constant_op.constant(True)} 757 758 x = constant_op.constant([1., 2.], shape=[1, 2], name="x") 759 760 def true_fn(): 761 return x 762 763 def false_fn(): 764 return x + 2 765 766 @def_function.function(input_signature=[ 767 tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool) 768 ]) 769 def model(b): 770 return cond_v2.cond_v2(b, true_fn, false_fn) 771 772 root, output_func = self._freezeModel(model) 773 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 774 775 def testStaticRnn(self): 776 """Test a StaticRnn containing If ops.""" 777 with ops.Graph().as_default(): 778 with session_lib.Session() as sess: 779 input_data = { 780 "x": 781 constant_op.constant( 782 np.array( 783 np.random.random_sample((3, 10)), dtype=np.float32)) 784 } 785 786 cell = rnn_cell_impl.LSTMCell(10) 787 788 @def_function.function(input_signature=[ 789 tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32) 790 ]) 791 def model(x): 792 seq = array_ops.split(x, 3, 0) 793 return rnn.static_rnn( 794 cell, seq, dtype=dtypes.float32, sequence_length=[1]) 795 796 root, output_func = self._freezeModel(model) 797 798 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 799 800 def testWhile(self): 801 """Test a While loop.""" 802 with ops.Graph().as_default(): 803 with session_lib.Session() as sess: 804 input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])} 805 806 weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], 807 dtype=dtypes.float32) 808 809 def condition(x): 810 return math_ops.reduce_sum(x) < 100 811 812 def body(x): 813 return math_ops.add(x, weights) 814 815 @def_function.function(input_signature=[ 816 tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32) 817 ]) 818 def model(x): 819 return control_flow_ops.while_loop(condition, body, [x]) 820 821 root, output_func = self._freezeModel(model) 822 823 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 824 825 def testStatelessWhile(self): 826 """Test a StatelessWhile loop.""" 827 with ops.Graph().as_default(): 828 with session_lib.Session() as sess: 829 input_data = {"x": constant_op.constant(2.)} 830 831 @def_function.function(input_signature=[ 832 tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) 833 ]) 834 def model(x): 835 return while_v2.while_loop( 836 lambda v: v < 4., 837 lambda v: v * v, [x], 838 return_same_structure=False, 839 name="while_1") # x**2 840 841 root, output_func = self._freezeModel(model) 842 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 843 844 def testDynamicRnn(self): 845 """Test a DynamicRnn containing While loops.""" 846 with ops.Graph().as_default(): 847 with session_lib.Session() as sess: 848 input_data = { 849 "x": 850 constant_op.constant( 851 np.array( 852 np.random.random_sample((3, 10, 10)), dtype=np.float32)) 853 } 854 855 cell = rnn_cell_impl.LSTMCell(10) 856 857 @def_function.function(input_signature=[ 858 tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32) 859 ]) 860 def model(x): 861 return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32) 862 863 root, output_func = self._freezeModel(model) 864 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 865 866 @test_util.disable_tfrt("b/180451239") 867 def testSwitchCase(self): 868 """Test a switch_case statement.""" 869 with ops.Graph().as_default(): 870 with session_lib.Session() as sess: 871 input_data = { 872 "i": 873 constant_op.constant(np.random.randint(0, 3, dtype=np.int32)), 874 "x": 875 constant_op.constant( 876 np.asarray( 877 np.random.random_sample((10, 3)), dtype=np.float32)), 878 } 879 880 w0 = variables.Variable( 881 np.random.random_sample((3, 4)), dtype=np.float32) 882 w1 = variables.Variable( 883 np.random.random_sample((3, 4)), dtype=np.float32) 884 w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32) 885 886 def branch0(x): 887 return math_ops.matmul(x, w0) 888 889 def branch1(x): 890 return math_ops.matmul(x, w1) 891 892 def branch2(x): 893 x = array_ops.pad(x, [[0, 0], [0, 1]]) 894 return x + w2 895 896 @def_function.function(input_signature=[ 897 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), 898 tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32), 899 ]) 900 def model(i, x): 901 return control_flow_ops.switch_case( 902 i, [lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)]) 903 904 root, output_func = self._freezeModel(model) 905 self._testConvertedFunction(sess, root, root.f, output_func, input_data) 906 907 908class ConvertVariablesToConstantsSessionTest(test.TestCase): 909 910 def _assertGraphContains(self, graph, subgraph): 911 """Asserts that the given subgraph is contained within the given graph.""" 912 913 def normalize_uids(msg): 914 """Replace auto-id function names with something consistent.""" 915 # These functions have non-deterministic names, the non-determinism coming 916 # from having an ops.uid() suffix in their names. We're replacing these 917 # with new sequential IDs starting from 0 for each prefix, which is 918 # is sufficient for tests. 919 if isinstance(msg, graph_pb2.GraphDef): 920 msg = text_format.MessageToString(msg) 921 name_prefixes = ["case_cond_true.*", "case_cond_false.*"] 922 name_regex = r"\b(" + "|".join(name_prefixes) + r")_([0-9]+)\b" 923 names = {} 924 for (name, index) in re.findall(name_regex, msg): 925 names.setdefault(name, set()).add(int(index)) 926 for name, indices in names.items(): 927 for new_index, old_index in enumerate(sorted(list(indices))): 928 msg = re.sub(r"\b" + name + "_" + str(old_index) + r"\b", 929 name + "_" + str(new_index), msg) 930 return msg 931 932 norm_graph = text_format.Parse(normalize_uids(graph), graph_pb2.GraphDef()) 933 norm_subgraph = text_format.Parse( 934 normalize_uids(subgraph), graph_pb2.GraphDef()) 935 936 # Graph S is contained in C if and only if merge(C,S) == C. 937 # We merge the input graph with an empty graph to normalize repeated fields: 938 # assertProtoEquals is sensitive to ordering. 939 norm_graph = _GraphMerger.merge_graphs(norm_graph, graph_pb2.GraphDef()) 940 merged_graph = _GraphMerger.merge_graphs(norm_graph, norm_subgraph) 941 self.assertProtoEquals(norm_graph, merged_graph) 942 943 def _ensure_no_variables_in_graph(self, graph_def): 944 """Ensures there are no variables in the graph.""" 945 for node in graph_def.node: 946 self.assertNotIn( 947 node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"]) 948 949 def _test_variable_to_const_conversion(self, use_resource): 950 with ops.Graph().as_default(): 951 with variable_scope.variable_scope("", use_resource=use_resource): 952 variable_node = variable_scope.get_variable( 953 "variable_node", initializer=1.0) 954 variable_scope.get_variable("unused_variable_node", initializer=1.0) 955 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 956 with session_lib.Session() as sess: 957 self.evaluate(variable_node.initializer) 958 output = self.evaluate(output_node) 959 self.assertNear(2.0, output, 0.00001) 960 variable_graph_def = sess.graph.as_graph_def() 961 constant_graph_def = ( 962 convert_to_constants 963 .convert_variables_to_constants_from_session_graph( 964 session=sess, 965 graph_def=variable_graph_def, 966 output_node_names=["output_node"])) 967 968 self._ensure_no_variables_in_graph(constant_graph_def) 969 970 # Now we make sure the variable is now a constant, and that the graph still 971 # produces the expected result. 972 with ops.Graph().as_default(): 973 _ = importer.import_graph_def(constant_graph_def, name="") 974 self.assertEqual(4, len(constant_graph_def.node)) 975 self._ensure_no_variables_in_graph(constant_graph_def) 976 with session_lib.Session() as sess: 977 output_node = sess.graph.get_tensor_by_name("output_node:0") 978 output = self.evaluate(output_node) 979 self.assertNear(2.0, output, 0.00001) 980 981 def test_resource_variable_can_be_written_after_denylisting(self): 982 with ops.Graph().as_default(): 983 with variable_scope.variable_scope("", use_resource=True): 984 variable_node = variable_scope.get_variable( 985 "variable_node", initializer=1.0) 986 another_variable = variable_scope.get_variable( 987 "unused_variable_node", initializer=2.0) 988 with ops.control_dependencies( 989 [variable_node.assign(another_variable + variable_node)]): 990 output_node = array_ops.identity(variable_node, name="output_node") 991 initializer_name = variable_node.initializer.name 992 with session_lib.Session() as sess: 993 self.evaluate(variable_node.initializer) 994 self.evaluate(another_variable.initializer) 995 output = self.evaluate(output_node) 996 self.assertNear(3.0, output, 0.00001) 997 variable_graph_def = sess.graph.as_graph_def() 998 999 # Test variable name black list. This should result in the variable 1000 # not being a const. Furthermore, the paths that read from and assign 1001 # to the denylisted variable should continue to be valid. 1002 constant_graph_def_with_denylist = ( 1003 convert_to_constants 1004 .convert_variables_to_constants_from_session_graph( 1005 session=sess, 1006 graph_def=variable_graph_def, 1007 output_node_names=["output_node", initializer_name], 1008 variable_names_denylist=set(["variable_node"]))) 1009 1010 variable_node = None 1011 for node in constant_graph_def_with_denylist.node: 1012 if node.name == "variable_node": 1013 variable_node = node 1014 self.assertIsNotNone(variable_node) 1015 self.assertEqual(variable_node.op, "VarHandleOp") 1016 1017 # Now we make sure another_variable is now a constant, but the original 1018 # variable is not, and that the graph can be executed and update the 1019 # variable can be updated with each execution. 1020 with ops.Graph().as_default(): 1021 _ = importer.import_graph_def(constant_graph_def_with_denylist, name="") 1022 with session_lib.Session() as sess: 1023 output_node = sess.graph.get_tensor_by_name("output_node:0") 1024 self.evaluate(sess.graph.get_operation_by_name(initializer_name)) 1025 output = self.evaluate(output_node) 1026 self.assertNear(3.0, output, 0.00001) 1027 output = self.evaluate(output_node) 1028 self.assertNear(5.0, output, 0.00001) 1029 1030 def _inline_functions(self, graph_def, arrays): 1031 meta_graph = export_meta_graph(graph_def=graph_def) 1032 fetch_collection = meta_graph_pb2.CollectionDef() 1033 for name in arrays: 1034 fetch_collection.node_list.value.append(name) 1035 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 1036 1037 # Initialize RewriterConfig with everything disabled except function 1038 # inlining. 1039 config = config_pb2.ConfigProto() 1040 rewrite_options = config.graph_options.rewrite_options 1041 rewrite_options.optimizers.append("function") 1042 return tf_optimizer.OptimizeGraph(config, meta_graph) 1043 1044 def _test_convert_variables_with_functions(self, inline_functions): 1045 """Freezes a graph with functions.""" 1046 1047 @function.Defun(dtypes.float32) 1048 def plus_one(x): 1049 return x + 1.0 1050 1051 with ops.Graph().as_default(): 1052 variable_node = variables.Variable(1.0, name="variable_node") 1053 _ = variables.Variable(1.0, name="unused_variable_node") 1054 defun_node = plus_one(variable_node) 1055 _ = math_ops.multiply(defun_node, 2.0, name="output_node") 1056 1057 with session_lib.Session() as sess: 1058 self.evaluate(variables.variables_initializer([variable_node])) 1059 variable_graph_def = sess.graph.as_graph_def() 1060 1061 if inline_functions: 1062 # Run Grappler to create the VarOpHandle --> Placeholder --> 1063 # ResourceVariable pattern. 1064 variable_graph_def = self._inline_functions( 1065 variable_graph_def, ["variable_node", "output_node"]) 1066 1067 constant_graph_def = ( 1068 convert_to_constants 1069 .convert_variables_to_constants_from_session_graph( 1070 session=sess, 1071 graph_def=variable_graph_def, 1072 output_node_names=["output_node"])) 1073 1074 self._ensure_no_variables_in_graph(constant_graph_def) 1075 1076 def testReferenceVariables(self): 1077 """Freezes a graph with reference variables.""" 1078 self._test_variable_to_const_conversion(use_resource=False) 1079 1080 def testResourceVariables(self): 1081 """Freezes a graph with resource variables.""" 1082 self._test_variable_to_const_conversion(use_resource=True) 1083 1084 def testWithFunctions(self): 1085 """Freezes a graph with functions.""" 1086 self._test_convert_variables_with_functions(inline_functions=False) 1087 1088 def testWithInlinedFunctions(self): 1089 """Freezes a graph with functions that have been inlined using Grappler.""" 1090 self._test_convert_variables_with_functions(inline_functions=True) 1091 1092 def testGraphWithSwitch(self): 1093 """Freezes a graph which contains a Switch with type RESOURCE_DT.""" 1094 with ops.Graph().as_default(): 1095 with variable_scope.variable_scope("", use_resource=True): 1096 x = variable_scope.get_variable("var_x", initializer=1.0) 1097 y = variable_scope.get_variable("var_y", initializer=2.0) 1098 f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0) 1099 f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0) 1100 cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)], 1101 default=f2) 1102 _ = math_ops.multiply(cond_node, 2.0, name="output_node") 1103 1104 with session_lib.Session() as sess: 1105 sess.run(variables.global_variables_initializer()) 1106 variable_graph_def = sess.graph.as_graph_def() 1107 1108 constant_graph_def = ( 1109 convert_to_constants 1110 .convert_variables_to_constants_from_session_graph( 1111 session=sess, 1112 graph_def=variable_graph_def, 1113 output_node_names=["output_node"])) 1114 1115 self._ensure_no_variables_in_graph(constant_graph_def) 1116 1117 def testConvertSingleVariable(self): 1118 """Tests that a single variable is properly converted to a constant.""" 1119 1120 with ops.Graph().as_default(): 1121 with variable_scope.variable_scope("", use_resource=False): 1122 _ = variable_scope.get_variable("x", initializer=1.0) 1123 with session_lib.Session() as sess: 1124 sess.run(variables.global_variables_initializer()) 1125 variable_graph_def = sess.graph.as_graph_def() 1126 constant_graph_def = ( 1127 convert_to_constants 1128 .convert_variables_to_constants_from_session_graph( 1129 sess, variable_graph_def, ["x/read"])) 1130 self._assertGraphContains( 1131 constant_graph_def, """ 1132 node { 1133 name: "x" op: "Const" 1134 attr { key: "dtype" value { type: DT_FLOAT } } 1135 attr { 1136 key: "value" 1137 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} 1138 } 1139 node { 1140 name: "x/read" op: "Identity" input: "x" 1141 attr { key: "T" value { type: DT_FLOAT } } 1142 }""") 1143 1144 def testConvertSingleResourceVariable(self): 1145 """Tests that a resource variable is properly converted to a constant.""" 1146 with ops.Graph().as_default(): 1147 with variable_scope.variable_scope("", use_resource=True): 1148 _ = variable_scope.get_variable("x", initializer=1.0) 1149 with session_lib.Session() as sess: 1150 sess.run(variables.global_variables_initializer()) 1151 variable_graph_def = sess.graph.as_graph_def() 1152 constant_graph_def = ( 1153 convert_to_constants 1154 .convert_variables_to_constants_from_session_graph( 1155 sess, variable_graph_def, ["x/Read/ReadVariableOp"])) 1156 self._assertGraphContains( 1157 constant_graph_def, """ 1158 node { 1159 name: "x" op: "Const" 1160 attr { key: "dtype" value { type: DT_FLOAT } } 1161 attr { 1162 key: "value" 1163 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} 1164 } 1165 node { 1166 name: "x/Read/ReadVariableOp" op: "Identity" input: "x" 1167 attr { key: "T" value { type: DT_FLOAT } } 1168 }""") 1169 1170 def testConvertOneVariableOfTwo(self): 1171 """Tests that one variable can be kept unconverted.""" 1172 with ops.Graph().as_default(): 1173 with variable_scope.variable_scope("", use_resource=False): 1174 x = variable_scope.get_variable("x", initializer=1.0) 1175 y = variable_scope.get_variable("y", initializer=1.0) 1176 _ = math_ops.multiply(x, y, name="out") 1177 with session_lib.Session() as sess: 1178 sess.run(variables.global_variables_initializer()) 1179 variable_graph_def = sess.graph.as_graph_def() 1180 constant_graph_def = ( 1181 convert_to_constants 1182 .convert_variables_to_constants_from_session_graph( 1183 sess, 1184 variable_graph_def, ["out"], 1185 variable_names_denylist=["y"])) 1186 self._assertGraphContains( 1187 constant_graph_def, """ 1188 node { 1189 name: "x" op: "Const" 1190 attr { key: "dtype" value { type: DT_FLOAT } } 1191 attr { 1192 key: "value" 1193 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} 1194 } 1195 node { 1196 name: "x/read" op: "Identity" input: "x" 1197 attr { key: "T" value { type: DT_FLOAT } } 1198 } 1199 node { 1200 name: "y" op: "VariableV2" 1201 attr { key: "dtype" value { type: DT_FLOAT } } 1202 } 1203 node { 1204 name: "y/read" op: "Identity" input: "y" 1205 attr { key: "T" value { type: DT_FLOAT } } 1206 } 1207 node { 1208 name: "out" op: "Mul" input: "x/read" input: "y/read" 1209 attr {key: "T" value {type: DT_FLOAT}} 1210 }""") 1211 1212 def testConvertOneResourceVariableOfTwo(self): 1213 """Tests that one variable can be kept unconverted.""" 1214 with ops.Graph().as_default(): 1215 with variable_scope.variable_scope("", use_resource=True): 1216 x = variable_scope.get_variable("x", initializer=1.0) 1217 y = variable_scope.get_variable("y", initializer=1.0) 1218 _ = math_ops.multiply(x, y, name="out") 1219 with session_lib.Session() as sess: 1220 sess.run(variables.global_variables_initializer()) 1221 variable_graph_def = sess.graph.as_graph_def() 1222 constant_graph_def = ( 1223 convert_to_constants 1224 .convert_variables_to_constants_from_session_graph( 1225 sess, 1226 variable_graph_def, ["out"], 1227 variable_names_denylist=["y"])) 1228 self._assertGraphContains( 1229 constant_graph_def, """ 1230 node { 1231 name: "x" op: "Const" 1232 attr { key: "dtype" value { type: DT_FLOAT } } 1233 attr { 1234 key: "value" 1235 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} 1236 } 1237 node { 1238 name: "y" op: "VarHandleOp" 1239 attr { key: "dtype" value { type: DT_FLOAT } } 1240 } 1241 node { 1242 name: "out/ReadVariableOp" op: "Identity" input: "x" 1243 attr { key: "T" value { type: DT_FLOAT } } 1244 } 1245 node { 1246 name: "out/ReadVariableOp_1" op: "ReadVariableOp" input: "y" 1247 attr { key: "dtype" value { type: DT_FLOAT } } 1248 } 1249 node { 1250 name: "out" op: "Mul" 1251 input: "out/ReadVariableOp" input: "out/ReadVariableOp_1" 1252 attr {key: "T" value {type: DT_FLOAT}} 1253 }""") 1254 1255 def testConvertIdentityChain(self): 1256 """Tests that a chain of Identity ops is converted properly.""" 1257 with ops.Graph().as_default(): 1258 with variable_scope.variable_scope("", use_resource=True): 1259 x = variable_scope.get_variable("x", initializer=1.0) 1260 y = array_ops.identity(x, name="y") 1261 _ = array_ops.identity(y, name="z") 1262 with session_lib.Session() as sess: 1263 sess.run(variables.global_variables_initializer()) 1264 variable_graph_def = sess.graph.as_graph_def() 1265 constant_graph_def = ( 1266 convert_to_constants 1267 .convert_variables_to_constants_from_session_graph( 1268 sess, variable_graph_def, ["z"])) 1269 self._assertGraphContains( 1270 constant_graph_def, """ 1271 node { 1272 name: "x" op: "Const" 1273 attr { key: "dtype" value { type: DT_FLOAT } } 1274 attr { 1275 key: "value" 1276 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} 1277 } 1278 node { 1279 name: "y/ReadVariableOp" op: "Identity" input: "x" 1280 attr { key: "T" value { type: DT_FLOAT } } 1281 } 1282 node { 1283 name: "y" op: "Identity" input: "y/ReadVariableOp" 1284 attr { key: "T" value { type: DT_FLOAT } } 1285 } 1286 node { 1287 name: "z" op: "Identity" input: "y" 1288 attr { key: "T" value { type: DT_FLOAT } } 1289 }""") 1290 1291 def testConvertCase(self): 1292 """Tests that a v1 case() construction converts properly.""" 1293 with ops.Graph().as_default(): 1294 with variable_scope.variable_scope("", use_resource=False): 1295 control_flow_v2_toggles.disable_control_flow_v2() 1296 x = variable_scope.get_variable("x", initializer=1.0) 1297 y = variable_scope.get_variable("y", initializer=2.0) 1298 _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)], 1299 default=lambda: y) 1300 with session_lib.Session() as sess: 1301 sess.run(variables.global_variables_initializer()) 1302 variable_graph_def = sess.graph.as_graph_def() 1303 constant_graph_def = ( 1304 convert_to_constants 1305 .convert_variables_to_constants_from_session_graph( 1306 sess, variable_graph_def, ["case/cond/Merge"])) 1307 self._assertGraphContains( 1308 constant_graph_def, """ 1309 node { 1310 name: "x" op: "Const" 1311 attr { key: "dtype" value { type: DT_FLOAT } } 1312 attr { 1313 key: "value" 1314 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} 1315 } 1316 node { 1317 name: "y" op: "Const" 1318 attr { key: "dtype" value { type: DT_FLOAT } } 1319 attr { 1320 key: "value" 1321 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}} 1322 } 1323 node {name: "x/read" op: "Identity" input: "x"} 1324 node {name: "y/read" op: "Identity" input: "y"} 1325 node {name: "Less" op: "Less" input: "x/read" input: "y/read"} 1326 node {name: "case/cond/pred_id" op: "Identity" input: "Less"} 1327 node { 1328 name: "case/cond/Switch_1" op: "Switch" 1329 input: "case/cond/pred_id" input: "x/read" 1330 } 1331 node { 1332 name: "case/cond/Switch_2" op: "Switch" 1333 input: "case/cond/pred_id" input: "y/read" 1334 } 1335 node { 1336 name: "case/cond/Merge" op: "Merge" 1337 input: "case/cond/Switch_2" input: "case/cond/Switch_1:1" 1338 attr {key: "T" value {type: DT_FLOAT}} 1339 }""") 1340 1341 def testConvertV2Case(self): 1342 """Tests that a v2 case() converts properly.""" 1343 with ops.Graph().as_default(): 1344 with variable_scope.variable_scope("", use_resource=False): 1345 control_flow_v2_toggles.enable_control_flow_v2() 1346 a = variable_scope.get_variable("a", initializer=2.0) 1347 x = variable_scope.get_variable("x", initializer=1.0) 1348 y = variable_scope.get_variable("y", initializer=2.0) 1349 _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: a)], 1350 default=lambda: y) 1351 control_flow_v2_toggles.disable_control_flow_v2() 1352 with session_lib.Session() as sess: 1353 sess.run(variables.global_variables_initializer()) 1354 variable_graph_def = sess.graph.as_graph_def() 1355 constant_graph_def = ( 1356 convert_to_constants 1357 .convert_variables_to_constants_from_session_graph( 1358 sess, variable_graph_def, ["case/cond"])) 1359 self._assertGraphContains( 1360 constant_graph_def, """ 1361 node { 1362 name: "x" op: "Const" 1363 attr { key: "dtype" value { type: DT_FLOAT } } 1364 attr { 1365 key: "value" 1366 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} 1367 } 1368 node { 1369 name: "y" op: "Const" 1370 attr { key: "dtype" value { type: DT_FLOAT } } 1371 attr { 1372 key: "value" 1373 value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}} 1374 } 1375 node {name: "x/read" op: "Identity" input: "x"} 1376 node {name: "y/read" op: "Identity" input: "y"} 1377 node {name: "Less" op: "Less" input: "x/read" input: "y/read"} 1378 node { 1379 name: "case/cond" op: "StatelessIf" 1380 input: "Less" input: "a/read" input: "y/read" 1381 attr {key: "Tcond" value {type: DT_BOOL}} 1382 attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}} 1383 attr {key: "Tout" value {list {type: DT_FLOAT}}} 1384 } 1385 library { 1386 function { 1387 signature { 1388 name: "case_cond_false_frozen_0" 1389 input_arg {name: "placeholder" type: DT_FLOAT} 1390 input_arg {name: "y_read_0" type: DT_FLOAT} 1391 output_arg {name: "y_read" type: DT_FLOAT} 1392 } 1393 } 1394 function { 1395 signature { 1396 name: "case_cond_true_frozen_0" 1397 input_arg {name: "a_read_0" type: DT_FLOAT} 1398 input_arg {name: "placeholder" type: DT_FLOAT} 1399 output_arg {name: "a_read" type: DT_FLOAT} 1400 } 1401 } 1402 }""") 1403 1404 def testConvertV2ResourceCase(self): 1405 """Tests that a v2 case() with resource variables converts properly.""" 1406 with ops.Graph().as_default(): 1407 with variable_scope.variable_scope("", use_resource=True): 1408 control_flow_v2_toggles.enable_control_flow_v2() 1409 x = variable_scope.get_variable("x", initializer=1.0) 1410 y = variable_scope.get_variable("y", initializer=2.0) 1411 _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)], 1412 default=lambda: y) 1413 control_flow_v2_toggles.disable_control_flow_v2() 1414 with session_lib.Session() as sess: 1415 sess.run(variables.global_variables_initializer()) 1416 variable_graph_def = sess.graph.as_graph_def() 1417 constant_graph_def = ( 1418 convert_to_constants 1419 .convert_variables_to_constants_from_session_graph( 1420 sess, variable_graph_def, ["case/cond"])) 1421 self._assertGraphContains( 1422 constant_graph_def, """ 1423 node {name: "x" op: "Const"} 1424 node {name: "y" op: "Const"} 1425 node { 1426 name: "case/cond" op: "If" input: "Less" input: "x" input: "y" 1427 attr {key: "Tcond" value {type: DT_BOOL}} 1428 attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}} 1429 attr {key: "Tout" value {list {type: DT_FLOAT}}} 1430 } 1431 library { 1432 function { 1433 signature { 1434 name: "case_cond_false_frozen_0" 1435 input_arg {name: "placeholder" type: DT_FLOAT} 1436 input_arg {name: "readvariableop_y" type: DT_FLOAT} 1437 output_arg {name: "readvariableop" type: DT_FLOAT} 1438 } 1439 } 1440 function { 1441 signature { 1442 name: "case_cond_true_frozen_0" 1443 input_arg {name: "placeholder" type: DT_FLOAT} 1444 input_arg {name: "readvariableop_x" type: DT_FLOAT} 1445 output_arg {name: "readvariableop" type: DT_FLOAT} 1446 } 1447 } 1448 }""") 1449 1450 def testConvertV2UnconvertedResourceNestedCase(self): 1451 """Tests unconverted variable propagation through nested functions.""" 1452 with ops.Graph().as_default(): 1453 with variable_scope.variable_scope("", use_resource=True): 1454 control_flow_v2_toggles.enable_control_flow_v2() 1455 x = variable_scope.get_variable("x", initializer=1.0) 1456 y = variable_scope.get_variable("y", initializer=2.0) 1457 z = variable_scope.get_variable("z", initializer=3.0) 1458 # pylint: disable=g-long-lambda 1459 _ = control_flow_ops.case( 1460 [(gen_math_ops.less(x, y), lambda: x)], 1461 default=lambda: control_flow_ops.case( 1462 [(gen_math_ops.less(z, y), lambda: z)], default=lambda: y)) 1463 # pylint: enable=g-long-lambda 1464 control_flow_v2_toggles.disable_control_flow_v2() 1465 with session_lib.Session() as sess: 1466 sess.run(variables.global_variables_initializer()) 1467 variable_graph_def = sess.graph.as_graph_def() 1468 constant_graph_def = ( 1469 convert_to_constants 1470 .convert_variables_to_constants_from_session_graph( 1471 sess, 1472 variable_graph_def, ["case/cond"], 1473 variable_names_denylist=["y"])) 1474 self._assertGraphContains( 1475 constant_graph_def, """ 1476 node {name: "x" op: "Const"} 1477 node {name: "y" op: "VarHandleOp"} 1478 node {name: "z" op: "Const"} 1479 1480 node {name: "Less/ReadVariableOp" op: "Identity" input: "x"} 1481 node {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" input: "y"} 1482 1483 node { 1484 name: "case/cond" op: "If" 1485 input: "x" input: "z" input: "y" 1486 attr { 1487 key: "Tin" 1488 value {list 1489 {type: DT_FLOAT type: DT_FLOAT type: DT_RESOURCE}}} 1490 attr { 1491 key: "_read_only_resource_inputs" 1492 value {list {i: 1 i: 2 i: 3}}} 1493 attr {key: "then_branch" 1494 value {func {name: "case_cond_true_frozen_0"}}} 1495 attr {key: "else_branch" 1496 value {func {name: "case_cond_false_frozen_0"}}} 1497 attr {key: "output_shapes" value {list {shape {}}}} 1498 } 1499 library { 1500 function { 1501 signature { 1502 name: "case_cond_true_frozen_0" 1503 input_arg {name: "placeholder" type: DT_FLOAT} 1504 input_arg {name: "placeholder_1" type: DT_RESOURCE} 1505 input_arg {name: "readvariableop_x" type: DT_FLOAT} 1506 output_arg {name: "readvariableop" type: DT_FLOAT} 1507 is_stateful: true 1508 } 1509 1510 node_def {name: "ReadVariableOp" op: "Identity" 1511 input: "readvariableop_x"}} 1512 1513 function { 1514 signature { 1515 name: "case_cond_false_frozen_0" 1516 input_arg {name: "placeholder" type: DT_FLOAT} 1517 input_arg {name: "less_readvariableop_1_y" type: DT_RESOURCE} 1518 input_arg {name: "less_readvariableop_z" type: DT_FLOAT} 1519 output_arg {name: "case_cond_identity" type: DT_FLOAT} 1520 is_stateful: true 1521 } 1522 1523 node_def {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" 1524 input: "less_readvariableop_1_y"} 1525 1526 node_def {name: "Less/ReadVariableOp" op: "Identity" 1527 input: "less_readvariableop_z"} 1528 1529 node_def {name: "case/cond" op: "If" 1530 input: "less_readvariableop_z" 1531 input: "less_readvariableop_1_y" 1532 attr { 1533 key: "Tin" 1534 value {list {type: DT_FLOAT type: DT_RESOURCE}}} 1535 attr {key: "then_branch" 1536 value {func {name: "case_cond_true_frozen_1"}}} 1537 attr {key: "else_branch" 1538 value {func {name: "case_cond_false_frozen_1"}}} 1539 attr { 1540 key: "_read_only_resource_inputs" 1541 value {list {i: 1 i: 2}}}}} 1542 1543 function { 1544 signature { 1545 name: "case_cond_false_frozen_1" 1546 input_arg {name: "placeholder" type: DT_FLOAT} 1547 input_arg {name: "readvariableop_y" type: DT_RESOURCE} 1548 output_arg {name: "readvariableop" type: DT_FLOAT} 1549 is_stateful: true 1550 } 1551 1552 node_def {name: "ReadVariableOp" op: "ReadVariableOp" 1553 input: "readvariableop_y"}} 1554 1555 function { 1556 signature { 1557 name: "case_cond_true_frozen_1" 1558 input_arg {name: "placeholder" type: DT_RESOURCE} 1559 input_arg {name: "readvariableop_z" type: DT_FLOAT} 1560 output_arg {name: "readvariableop" type: DT_FLOAT} 1561 is_stateful: true 1562 } 1563 1564 node_def {name: "ReadVariableOp" op: "Identity" 1565 input: "readvariableop_z"}}}""") 1566 1567 def _addNoinlineAttributeToFunction(self, saved_model_dir, func_name): 1568 saved_model_proto = loader_impl.parse_saved_model(saved_model_dir) 1569 new_saved_model = saved_model_pb2.SavedModel() 1570 new_saved_model.CopyFrom(saved_model_proto) 1571 new_meta_graph_def = new_saved_model.meta_graphs[0] 1572 prefix_len = len("__inference_") 1573 for func_def in new_meta_graph_def.graph_def.library.function: 1574 func_name_without_prefix = func_def.signature.name[prefix_len:] 1575 if func_name_without_prefix.startswith(func_name): 1576 func_def.attr["_noinline"].CopyFrom(attr_value_pb2.AttrValue(b=True)) 1577 old_saved_model_file = os.path.join(saved_model_dir, 1578 constants.SAVED_MODEL_FILENAME_PB) 1579 if os.path.exists(old_saved_model_file): 1580 os.remove(old_saved_model_file) 1581 path = os.path.join( 1582 compat.as_bytes(saved_model_dir), 1583 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 1584 file_io.write_string_to_file( 1585 path, new_saved_model.SerializeToString(deterministic=True)) 1586 1587 @test_util.run_v2_only 1588 def testVariableModelWithFunctionAndFunctionInliningDisabled(self): 1589 """Test a model with Variables and disable function inlining.""" 1590 1591 class BasicModel: 1592 1593 def __init__(self): 1594 self.v1 = None 1595 self.v2 = variables.Variable(2.) 1596 1597 @def_function.function(input_signature=[ 1598 tensor_spec.TensorSpec(shape=[1], dtype=dtypes.float32) 1599 ]) 1600 def add_all(self, x): 1601 if self.v1 is None: 1602 self.v1 = variables.Variable(3.) 1603 return x + self.v1 + self.v2 1604 1605 def run(self, x): 1606 y = self.add_all(x) 1607 return y 1608 1609 save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model") 1610 with ops.Graph().as_default(): 1611 model = BasicModel() 1612 a = array_ops.placeholder(dtypes.float32, shape=[1]) 1613 b = model.run(a) 1614 with session_lib.Session() as sess: 1615 sess.run(variables.global_variables_initializer()) 1616 simple_save.simple_save(sess, save_dir, {"myinput": a}, {"myoutput": b}) 1617 1618 # Add _noinline to the SavedModel. 1619 self._addNoinlineAttributeToFunction( 1620 saved_model_dir=save_dir, func_name="add_all") 1621 1622 saved_model = load(save_dir) 1623 func = saved_model.signatures["serving_default"] 1624 frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) 1625 constant_graph_def = frozen_func.graph.as_graph_def() 1626 self._ensure_no_variables_in_graph(constant_graph_def) 1627 1628 1629if __name__ == "__main__": 1630 test.main() 1631