xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/external_dataset.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2019 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker
15*14675a02SAndroid Build Coastguard Worker"""Provides the 'ExternalDataset' implementation of tf.Data.Dataset.
16*14675a02SAndroid Build Coastguard Worker
17*14675a02SAndroid Build Coastguard WorkerThis wraps the generated op (in external_dataset_py_wrapper).
18*14675a02SAndroid Build Coastguard Worker"""
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Workerfrom __future__ import absolute_import
21*14675a02SAndroid Build Coastguard Workerfrom __future__ import division
22*14675a02SAndroid Build Coastguard Workerfrom __future__ import print_function
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
26*14675a02SAndroid Build Coastguard Worker
27*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import gen_external_dataset_py
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker_external_dataset_so = tf.load_op_library(
30*14675a02SAndroid Build Coastguard Worker    tf.compat.v1.resource_loader.get_path_to_datafile(
31*14675a02SAndroid Build Coastguard Worker        "./_external_dataset_op.so"))
32*14675a02SAndroid Build Coastguard Worker
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Workerclass ExternalDataset(tf.data.Dataset):
35*14675a02SAndroid Build Coastguard Worker  """An ExternalDataset is defined by whomever is running the graph.
36*14675a02SAndroid Build Coastguard Worker
37*14675a02SAndroid Build Coastguard Worker  To use an ExternalDataset, the graph must be fed a 'token' indicating what
38*14675a02SAndroid Build Coastguard Worker  external dataset to use. It also takes a 'selector' input - an opaque string,
39*14675a02SAndroid Build Coastguard Worker  to be interpreted by that external implementation.
40*14675a02SAndroid Build Coastguard Worker  """
41*14675a02SAndroid Build Coastguard Worker
42*14675a02SAndroid Build Coastguard Worker  def __init__(self, token, selector):
43*14675a02SAndroid Build Coastguard Worker    token = tf.convert_to_tensor(token, dtype=tf.string, name="token")
44*14675a02SAndroid Build Coastguard Worker    selector = tf.convert_to_tensor(selector, dtype=tf.string, name="selector")
45*14675a02SAndroid Build Coastguard Worker    variant_tensor = gen_external_dataset_py.ExternalDataset(
46*14675a02SAndroid Build Coastguard Worker        token=token, selector=selector)
47*14675a02SAndroid Build Coastguard Worker    super(ExternalDataset, self).__init__(variant_tensor)
48*14675a02SAndroid Build Coastguard Worker
49*14675a02SAndroid Build Coastguard Worker  @property
50*14675a02SAndroid Build Coastguard Worker  def element_spec(self):
51*14675a02SAndroid Build Coastguard Worker    return tf.TensorSpec([], tf.string)
52*14675a02SAndroid Build Coastguard Worker
53*14675a02SAndroid Build Coastguard Worker  def _inputs(self):
54*14675a02SAndroid Build Coastguard Worker    return []
55