1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Provides the `append_slices` and `merge_appended_slices operations. 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard WorkerThis wraps the generated ops and ensures that necessary shared libraries 17*14675a02SAndroid Build Coastguard Workerare loaded. 18*14675a02SAndroid Build Coastguard Worker""" 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import gen_append_slices_py 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard Worker_append_slices_so = tf.load_op_library( 25*14675a02SAndroid Build Coastguard Worker tf.compat.v1.resource_loader.get_path_to_datafile('./_append_slices_op.so')) 26*14675a02SAndroid Build Coastguard Worker 27*14675a02SAndroid Build Coastguard Worker 28*14675a02SAndroid Build Coastguard Workerdef append_slices(filename, tensor_names, shapes_and_slices, data, name=None): 29*14675a02SAndroid Build Coastguard Worker """Append slices to `filename`. 30*14675a02SAndroid Build Coastguard Worker 31*14675a02SAndroid Build Coastguard Worker Must be paired with `merge_appended_slices`. 32*14675a02SAndroid Build Coastguard Worker 33*14675a02SAndroid Build Coastguard Worker This op is identical to `tf.raw_ops.SaveSlices`, except that it appends the 34*14675a02SAndroid Build Coastguard Worker resulting checkpoint to `filename` rather than erasing the contents of 35*14675a02SAndroid Build Coastguard Worker `filename`. 36*14675a02SAndroid Build Coastguard Worker 37*14675a02SAndroid Build Coastguard Worker Note: the resulting file at `filename` will not be in checkpoint format until 38*14675a02SAndroid Build Coastguard Worker `merge_appended_slices` has been called. 39*14675a02SAndroid Build Coastguard Worker 40*14675a02SAndroid Build Coastguard Worker Args: 41*14675a02SAndroid Build Coastguard Worker filename: A `Tensor` fo type `string`. Must have a single element. The name 42*14675a02SAndroid Build Coastguard Worker of the file to which the tensor should be appended. 43*14675a02SAndroid Build Coastguard Worker tensor_names: A `Tensor` of type `string`. Shape `[N]`. The names of the 44*14675a02SAndroid Build Coastguard Worker tensors to be saved. 45*14675a02SAndroid Build Coastguard Worker shapes_and_slices: A `Tensor` of type `string`. Shape `[N]`. The shapes and 46*14675a02SAndroid Build Coastguard Worker slice specifications to use when saving the tensors. 47*14675a02SAndroid Build Coastguard Worker data: A list of `Tensor` objects. `N` tensors to save. 48*14675a02SAndroid Build Coastguard Worker name: A name for the operation (optional). 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker Returns: 51*14675a02SAndroid Build Coastguard Worker The created `Operation`. 52*14675a02SAndroid Build Coastguard Worker """ 53*14675a02SAndroid Build Coastguard Worker return gen_append_slices_py.append_slices( 54*14675a02SAndroid Build Coastguard Worker filename, tensor_names, shapes_and_slices, data, name=name) 55*14675a02SAndroid Build Coastguard Worker 56*14675a02SAndroid Build Coastguard Worker 57*14675a02SAndroid Build Coastguard Workerdef merge_appended_slices(filename, name=None): 58*14675a02SAndroid Build Coastguard Worker """Merges the appended file created by `append_slices` to a single checkpoint. 59*14675a02SAndroid Build Coastguard Worker 60*14675a02SAndroid Build Coastguard Worker The immediate file output of `append_slices` is not in checkpoint format. It 61*14675a02SAndroid Build Coastguard Worker must be converted to a checkpoint using this function `merge_appended_slices`. 62*14675a02SAndroid Build Coastguard Worker 63*14675a02SAndroid Build Coastguard Worker Note: Users must call `control_dependencies` or other mechanisms to ensure 64*14675a02SAndroid Build Coastguard Worker that the `append_slices` calls have executed prior to the execution of 65*14675a02SAndroid Build Coastguard Worker `merge_appended_slices`. 66*14675a02SAndroid Build Coastguard Worker 67*14675a02SAndroid Build Coastguard Worker Args: 68*14675a02SAndroid Build Coastguard Worker filename: The name of a file appended to by calls to `append_slices`. 69*14675a02SAndroid Build Coastguard Worker name: A name for the operation (optional). 70*14675a02SAndroid Build Coastguard Worker 71*14675a02SAndroid Build Coastguard Worker Returns: 72*14675a02SAndroid Build Coastguard Worker The created `Operation`. 73*14675a02SAndroid Build Coastguard Worker """ 74*14675a02SAndroid Build Coastguard Worker return gen_append_slices_py.merge_appended_slices(filename, name) 75