xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/append_slices.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Provides the `append_slices` and `merge_appended_slices operations.
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard WorkerThis wraps the generated ops and ensures that necessary shared libraries
17*14675a02SAndroid Build Coastguard Workerare loaded.
18*14675a02SAndroid Build Coastguard Worker"""
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import gen_append_slices_py
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard Worker_append_slices_so = tf.load_op_library(
25*14675a02SAndroid Build Coastguard Worker    tf.compat.v1.resource_loader.get_path_to_datafile('./_append_slices_op.so'))
26*14675a02SAndroid Build Coastguard Worker
27*14675a02SAndroid Build Coastguard Worker
28*14675a02SAndroid Build Coastguard Workerdef append_slices(filename, tensor_names, shapes_and_slices, data, name=None):
29*14675a02SAndroid Build Coastguard Worker  """Append slices to `filename`.
30*14675a02SAndroid Build Coastguard Worker
31*14675a02SAndroid Build Coastguard Worker  Must be paired with `merge_appended_slices`.
32*14675a02SAndroid Build Coastguard Worker
33*14675a02SAndroid Build Coastguard Worker  This op is identical to `tf.raw_ops.SaveSlices`, except that it appends the
34*14675a02SAndroid Build Coastguard Worker  resulting checkpoint to `filename` rather than erasing the contents of
35*14675a02SAndroid Build Coastguard Worker  `filename`.
36*14675a02SAndroid Build Coastguard Worker
37*14675a02SAndroid Build Coastguard Worker  Note: the resulting file at `filename` will not be in checkpoint format until
38*14675a02SAndroid Build Coastguard Worker  `merge_appended_slices` has been called.
39*14675a02SAndroid Build Coastguard Worker
40*14675a02SAndroid Build Coastguard Worker  Args:
41*14675a02SAndroid Build Coastguard Worker    filename: A `Tensor` fo type `string`. Must have a single element. The name
42*14675a02SAndroid Build Coastguard Worker      of the file to which the tensor should be appended.
43*14675a02SAndroid Build Coastguard Worker    tensor_names: A `Tensor` of type `string`. Shape `[N]`. The names of the
44*14675a02SAndroid Build Coastguard Worker      tensors to be saved.
45*14675a02SAndroid Build Coastguard Worker    shapes_and_slices: A `Tensor` of type `string`. Shape `[N]`. The shapes and
46*14675a02SAndroid Build Coastguard Worker      slice specifications to use when saving the tensors.
47*14675a02SAndroid Build Coastguard Worker    data: A list of `Tensor` objects. `N` tensors to save.
48*14675a02SAndroid Build Coastguard Worker    name: A name for the operation (optional).
49*14675a02SAndroid Build Coastguard Worker
50*14675a02SAndroid Build Coastguard Worker  Returns:
51*14675a02SAndroid Build Coastguard Worker    The created `Operation`.
52*14675a02SAndroid Build Coastguard Worker  """
53*14675a02SAndroid Build Coastguard Worker  return gen_append_slices_py.append_slices(
54*14675a02SAndroid Build Coastguard Worker      filename, tensor_names, shapes_and_slices, data, name=name)
55*14675a02SAndroid Build Coastguard Worker
56*14675a02SAndroid Build Coastguard Worker
57*14675a02SAndroid Build Coastguard Workerdef merge_appended_slices(filename, name=None):
58*14675a02SAndroid Build Coastguard Worker  """Merges the appended file created by `append_slices` to a single checkpoint.
59*14675a02SAndroid Build Coastguard Worker
60*14675a02SAndroid Build Coastguard Worker  The immediate file output of `append_slices` is not in checkpoint format. It
61*14675a02SAndroid Build Coastguard Worker  must be converted to a checkpoint using this function `merge_appended_slices`.
62*14675a02SAndroid Build Coastguard Worker
63*14675a02SAndroid Build Coastguard Worker  Note: Users must call `control_dependencies` or other mechanisms to ensure
64*14675a02SAndroid Build Coastguard Worker  that the `append_slices` calls have executed prior to the execution of
65*14675a02SAndroid Build Coastguard Worker  `merge_appended_slices`.
66*14675a02SAndroid Build Coastguard Worker
67*14675a02SAndroid Build Coastguard Worker  Args:
68*14675a02SAndroid Build Coastguard Worker    filename: The name of a file appended to by calls to `append_slices`.
69*14675a02SAndroid Build Coastguard Worker    name: A name for the operation (optional).
70*14675a02SAndroid Build Coastguard Worker
71*14675a02SAndroid Build Coastguard Worker  Returns:
72*14675a02SAndroid Build Coastguard Worker    The created `Operation`.
73*14675a02SAndroid Build Coastguard Worker  """
74*14675a02SAndroid Build Coastguard Worker  return gen_append_slices_py.merge_appended_slices(filename, name)
75