xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/graph_helpers.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utilities for manipulating TensorFlow graph logic."""
15
16from typing import Optional, Union
17
18import tensorflow as tf
19import tensorflow_federated as tff
20
21from fcp.artifact_building import data_spec
22from fcp.artifact_building import tensor_utils
23from fcp.artifact_building import type_checks
24from fcp.tensorflow import external_dataset
25from tensorflow_federated.proto.v0 import computation_pb2
26
27TfValue = Union[tf.Variable, tf.Tensor]
28DatasetTensor = tf.Tensor
29Argument = Union[TfValue, list[TfValue], DatasetTensor]
30Args = Optional[Union[Argument, tuple[Argument, ...]]]
31
32Result = Argument
33MaybeSplitOutputs = Union[Result, tuple[Result, ...]]
34
35
36EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX = 'example_selector'
37
38
39def generate_example_selector_placeholders(
40    type_spec: tff.Type,
41    name_prefix: str,
42):
43  """Generates list of tff.compat.v1.placeholders for each leaf in a type spec.
44
45  The order of the placeholders aligns with the order given by
46  tff.structure.to_elements().
47
48  Placeholders will be named by concatenating the name_prefix arg with the list
49  of indexes at each level of the struct to get to the placeholder's leaf in the
50  tff.Type.
51
52  Args:
53    type_spec: A type spec to infer the list of placeholders from. This is
54      expected to be a tff.SequenceType or a tff.StructType, and if it is a
55      tff.StructType, it is expected to be a tree of tff.StructTypes with
56      tff.SequenceTypes at the leaves. This is expected to reflect the TFF type
57      signature of the input client data.
58    name_prefix: The name prefix that should be used when naming each
59      placeholder.
60
61  Returns:
62    A list of tf.compat.v2.placeholders.
63  """
64  type_spec = tff.to_type(type_spec)
65  type_checks.check_type(
66      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
67  )
68  if type_spec.is_sequence():
69    # Each client input is a sequence of serialized `tf.Example`s, which is why
70    # the leaves of these TFF type signatures are sequences. Each input sequence
71    # of `tf.Example`s requires a single `ExampleSelector` that determines that
72    # stream of `tf.Example`s is selected from the data store, which is why we
73    # only have a single placeholder for the `ExampleSelector`.
74    return [tf.compat.v1.placeholder(tf.string, shape=[], name=name_prefix)]
75  else:
76    type_spec.check_struct()
77    type_spec_elements = tff.structure.to_elements(type_spec)
78    placeholders = []
79    for element_index, (_, element_type) in enumerate(type_spec_elements):
80      placeholders.extend(
81          generate_example_selector_placeholders(
82              element_type, f'{name_prefix}_{element_index}'
83          )
84      )
85    return placeholders
86
87
88def embed_data_logic(
89    client_data_type: tff.Type,
90    dataspec: Optional[data_spec.NestedDataSpec] = None,
91) -> tuple[tf.Tensor, list[MaybeSplitOutputs], list[tf.Tensor]]:
92  """Embeds the data logic into the current TensorFlow graph.
93
94  Adds dataset ops to the current graph, using the custom `ExternalDataset`
95  which returns a placeholder token. The initialization op and data values are
96  also returned.
97
98  Args:
99    client_data_type: The TFF type signature of the input client data.
100    dataspec: If provided, either an instance of `data_spec.DataSpec` or a
101      nested structure of these that matches the structure of the first element
102      of the input to the client work part of the computation.
103
104  Returns:
105    A `tuple` containing the following (in order):
106      token_placeholder: A dataset token placeholder tensor
107      data_values: A list of dataset output values
108      example_selector_placeholders: A possibly empty list of placeholders used
109        for passing in example selector information into the client graph. This
110        list will be empty iff dataspec is supplied.
111
112  Raises:
113    ValueError: If the number of dataset output from one data source is not 1.
114    ValueError: If a node exists in the graph already that contains a node with
115      the same name as the example selector placeholders.
116  """
117  data_values = []
118  # Embeds the token placeholder for the custom ExternalDataset op.
119  token_placeholder = tf.compat.v1.placeholder(
120      tf.string, shape=[], name='data_token'
121  )
122
123  example_selector_placeholders = []
124  if dataspec is None:
125    example_selector_placeholders = generate_example_selector_placeholders(
126        client_data_type, EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX
127    )
128    # If the first placeholder does not have the expected prefix, then it is due
129    # to other variables in the graph, likely created from the input
130    # tff.Computation, having the special name. This check ensures that no other
131    # variables use this special example selector placeholder name and we can
132    # easily extract example selector placeholders in the generated artifact.
133    if example_selector_placeholders and (
134        not (
135            example_selector_placeholders[0].name.startswith(
136                f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}:'
137            )
138            or example_selector_placeholders[0].name.startswith(
139                f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}_0'
140            )
141        )
142    ):
143      raise ValueError(
144          'Graph already contains a placeholder with name '
145          f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}. Please '
146          'avoid the use of this special name.'
147      )
148    data_sources = make_data_sources_without_dataspec(client_data_type)
149    assert len(example_selector_placeholders) == len(data_sources)
150  else:
151    data_sources = make_data_sources_with_dataspec(client_data_type, dataspec)
152
153  # Embeds data source computations into the current graph.
154  for index, data_comp in enumerate(data_sources):
155    data_comp_import_args = [token_placeholder]
156    if example_selector_placeholders:
157      data_comp_import_args.append(example_selector_placeholders[index])
158    ds_values = import_tensorflow(
159        'data_{}'.format(index), data_comp, data_comp_import_args
160    )  # pytype: disable=wrong-arg-types
161    if len(ds_values) != 1:
162      raise ValueError(
163          'Expected one dataset output from a data source, found {}.'.format(
164              str(len(ds_values))
165          )
166      )
167    data_values.extend(ds_values)
168
169  return token_placeholder, data_values, example_selector_placeholders
170
171
172def import_tensorflow(
173    name: str,
174    comp: tff.framework.ConcreteComputation,
175    args: Args = None,
176    split_outputs: bool = False,
177    session_token_tensor: Optional[tf.Tensor] = None,
178) -> MaybeSplitOutputs:
179  """Imports a tensorflow computation into the current graph.
180
181  Args:
182    name: The string name to use as the graph import prefix.
183    comp: An instance of `tff.framework.ConcreteComputation` with just the
184      `tensorflow` section.
185    args: Either a single argument, a tuple of arguments, or None. An argument
186      must be either: - a Python `list` containing either tensors or variables,
187      or - a single variant tensor representing a dataset input.
188    split_outputs: Whether to unpack the result tuple into a Python tuple. If
189      `True`, `import_tensorflow` will return a tuple with multiple result
190      objects, corresponding to the return elements in the type signature of
191      `comp`. Notice that the return type signature of `comp` must be a tuple in
192      this case. If `False`, `import_tensorflow` will return the entire result
193      in a flattened form as a single Python result object. Each Python result
194      object, similar to the argumens in `args`, will be either a Python `list`
195      of variant tensors or a singleton Python list containing only the dataset
196      variant tensor.
197    session_token_tensor: A tensor in the current graph containing the "session
198      token" of the TensorFlow being imported. This is useful for passing a
199      session-global identifier into the graph for use with ops like
200      `ServeSlices` and `ExternalDataset` that take in a token which references
201      session-global state.
202
203  Returns:
204    One of:
205      - A single result (Python `list` of variable value or variant tensors) if
206        `split_outputs` is `False`.
207      - A Python `tuple` of such results, if `split_outputs` is `True`.
208
209  Raises:
210    TypeError: If the arguments are of the wrong types.
211  """
212  type_checks.check_type(name, str, name='name')
213  type_checks.check_type(comp, tff.framework.ConcreteComputation, name='comp')
214  type_checks.check_type(split_outputs, bool, name='split_outputs')
215
216  comp_proto = tff.framework.ConcreteComputation.get_proto(comp)
217  type_checks.check_type(
218      comp_proto, computation_pb2.Computation, name='comp_proto'
219  )
220
221  which_comp = comp_proto.WhichOneof('computation')
222  if which_comp != 'tensorflow':
223    raise TypeError(
224        'Expected a TensorFlow computation, found {}.'.format(which_comp)
225    )
226  if args is None:
227    input_map = None
228  elif isinstance(args, tuple):
229    which_binding = comp_proto.tensorflow.parameter.WhichOneof('binding')
230    if which_binding != 'struct':
231      raise TypeError(
232          'Expected a struct binding with a struct of args, found {}.'.format(
233              which_binding
234          )
235      )
236    input_map = {}
237    for index, arg in enumerate(args):
238      input_map.update(
239          create_tensor_map(
240              comp_proto.tensorflow.parameter.struct.element[index], arg
241          )
242      )
243  else:
244    input_map = create_tensor_map(comp_proto.tensorflow.parameter, args)
245  if input_map is not None:
246    # Add remappings for all potential control dependencies in the graph as
247    # well. Since `tf.graph_util.import_graph_def` input map works on the tensor
248    # (not graph node) level, we must handle this case also.
249    def control_dep_name(name: str) -> str:
250      if name.startswith('^'):
251        return name
252      node_name = name.split(':', maxsplit=1)[0]
253      return f'^{node_name}'
254
255    input_map.update(
256        {
257            control_dep_name(k): control_dep_name(v.name)
258            for k, v in input_map.items()
259            if not k.startswith('^')
260        }
261    )
262  input_map = {} if input_map is None else input_map
263  if (
264      session_token_tensor is not None
265      and comp_proto.tensorflow.session_token_tensor_name
266  ):
267    input_map[comp_proto.tensorflow.session_token_tensor_name] = (
268        session_token_tensor
269    )
270  if split_outputs:
271    return_elements = []
272    subset_sizes = []
273    which_binding = comp_proto.tensorflow.result.WhichOneof('binding')
274    if which_binding != 'struct':
275      raise TypeError(
276          'If `split_outputs` is `True`, the result of the computation we are '
277          'importing must be a `struct`; found {}.'.format(which_binding)
278      )
279    for binding in comp_proto.tensorflow.result.struct.element:
280      tensor_names = _list_tensor_names_in_binding(binding)
281      return_elements.extend(tensor_names)
282      subset_sizes.append(len(tensor_names))
283  else:
284    return_elements = _list_tensor_names_in_binding(
285        comp_proto.tensorflow.result
286    )
287    subset_sizes = [len(return_elements)]
288
289  graph_def = tensor_utils.import_graph_def_from_any(
290      comp_proto.tensorflow.graph_def
291  )
292
293  # We will be importing multiple GraphDefs into the server or client graphs.
294  # These individual graphs may have identifical `shared_name` attributes on
295  # variable ops, which causes the runtime to reference the same resource, which
296  # is highly undesired. We must uniquify the names before importing.
297  def uniquify_shared_names(
298      graph_def: tf.compat.v1.GraphDef, suffix: bytes
299  ) -> tf.compat.v1.GraphDef:
300    for x in graph_def.node:
301      shared_name = x.attr.get('shared_name')
302      if shared_name is not None:
303        if not shared_name.s:
304          # Encountered an empty string shared name, avoid creating a shared
305          # name that starts with an underscore (not allowed by TF).
306          shared_name.s = b'None'
307        shared_name.s += b'_' + suffix
308    return graph_def
309
310  uniquified_graph_def = uniquify_shared_names(
311      graph_def, suffix=name.encode('utf-8')
312  )
313  if comp_proto.tensorflow.initialize_op:
314    uniquified_graph_def = add_control_deps_for_init_op(
315        uniquified_graph_def, comp_proto.tensorflow.initialize_op
316    )
317  import_result = tf.graph_util.import_graph_def(
318      uniquified_graph_def,
319      input_map=input_map,
320      return_elements=return_elements,
321      name=name,
322  )
323
324  if split_outputs:
325    subsets = []
326    offset = 0
327    for subset_size in subset_sizes:
328      next_offset = offset + subset_size
329      subsets.append(import_result[offset:next_offset])
330      offset = next_offset
331    results = tuple(subsets)
332  else:
333    results = import_result[: subset_sizes[0]]
334  return results
335
336
337def _get_deps_for_graph_node(
338    graph_def: tf.compat.v1.GraphDef, node_name: str
339) -> set[str]:
340  """Returns the set of node names that a node named `node_name` depends on.
341
342  Note that this function does not work for nodes in the function library.
343
344  Args:
345    graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
346    node_name: The node name, a string.
347
348  Returns:
349    An instance of `set()` containing string names of the nodes `node_name`
350    depends on in graph_def.
351
352  Raises:
353    TypeError: If either argument is of the wrong type.
354  """
355  type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
356  type_checks.check_type(node_name, str, name='node_name')
357  input_map = {}
358  for node in graph_def.node:
359    input_map[node.name] = set(tensor_utils.bare_name(x) for x in node.input)
360  dependencies = set()
361  initial_singleton = set([node_name])
362  nodes_to_process = initial_singleton
363  while nodes_to_process:
364    dependencies.update(nodes_to_process)
365    nodes_to_process = set.union(
366        *[input_map[name] for name in nodes_to_process]
367    ).difference(dependencies)
368  return dependencies.difference(initial_singleton)
369
370
371def add_control_deps_for_init_op(
372    graph_def: tf.compat.v1.GraphDef, init_op: str
373) -> tf.compat.v1.GraphDef:
374  """Adds control deps on `init_op` to nodes in GraphDef.
375
376  Note that control deps are not added to any of the ancestors of `init_op`
377  (which would result in a control dep cycle) and control deps are not added to
378  any nodes in the function library of a GraphDef.
379
380  Args:
381    graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
382    init_op: The init op name, a string.
383
384  Returns:
385    The updated graph, an instance of `tf.compat.v1.GraphDef`.
386
387  Raises:
388    TypeError: If either argument is of the wrong type.
389  """
390  type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
391  type_checks.check_type(init_op, str, name='init_op')
392  init_op_str = tensor_utils.bare_name(init_op)
393  init_op_control_dep = '^{}'.format(init_op_str)
394  deps = _get_deps_for_graph_node(graph_def, init_op_str).union(
395      set([init_op_str])
396  )
397  new_graph_def = tf.compat.v1.GraphDef()
398  new_graph_def.CopyFrom(graph_def)
399  for new_node in new_graph_def.node:
400    if new_node.name not in deps:
401      node_inputs = new_node.input
402      if init_op_control_dep not in node_inputs:
403        new_node.input.append(init_op_control_dep)
404  return new_graph_def
405
406
407def create_tensor_map(
408    binding: computation_pb2.TensorFlow.Binding,
409    arg: list[Union[tf.Tensor, tf.Variable]],
410) -> dict[str, tf.Tensor]:
411  """Creates a `dict` mapping tensor names in the binding to tensors in `arg`.
412
413  Args:
414    binding: An instance of `computation_pb2.TensorFlow.Binding`.
415    arg: Either a singleton Python `list` with variant tensor in case of a
416      sequence binding, or a Python `list` of tensors or resource variables
417      otherwise for a tuple binding.
418
419  Returns:
420    An instance of Python `dict` with the map as specified above.
421
422  Raises:
423    TypeError: If the argument types are incorrect.
424    ValueError: If the arguments are malformed (e.g., multiple variant tensors).
425  """
426  type_checks.check_type(
427      binding, computation_pb2.TensorFlow.Binding, name='binding'
428  )
429  type_checks.check_type(arg, list, name='arg')
430  tensor_names_in_binding = _list_tensor_names_in_binding(binding)
431  which_binding = binding.WhichOneof('binding')
432  if which_binding == 'sequence':
433    if (len(tensor_names_in_binding) != 1) or (len(arg) != 1):
434      raise ValueError('Multiple variant tensors found.')
435    variant_tensor_name = tensor_names_in_binding[0]
436    arg = arg[0]
437    if not tf.is_tensor(arg):
438      raise TypeError('Expected a tensor, found {!r}.'.format(type(arg)))
439    if arg.dtype != tf.variant:
440      raise TypeError('Expected `tf.variant`, found {!r}.'.format(arg.dtype))
441    return {variant_tensor_name: arg}
442  else:
443    return {
444        k: v.read_value() if hasattr(v, 'read_value') else v
445        for k, v in zip(tensor_names_in_binding, arg)
446    }
447
448
449def _validate_data_comp(data_comp: tff.Computation, type_spec: tff.Type):
450  type_checks.check_type(data_comp.type_signature, tff.FunctionType)
451  if not type_spec.is_assignable_from(data_comp.type_signature.result):
452    type_mismatch_string = tff.types.type_mismatch_error_message(
453        type_spec,
454        data_comp.type_signature.result,
455        tff.types.TypeRelation.ASSIGNABLE,
456    )
457    raise TypeError(
458        'The data source constructed with the supplied dataspec returns data '
459        'which does not match type of request. Details of the mismatch:\n'
460        + type_mismatch_string
461    )
462
463
464def make_data_sources_with_dataspec(
465    type_spec: tff.Type, ds: data_spec.NestedDataSpec
466) -> list[tff.Computation]:
467  """Creates a list of computations that feed data into the graph using specified example selectors.
468
469  The computations use the custom ExternalDataset op to feed in example data.
470  The computations will expect one input:
471    -- A token specifying where the data store is on the device.
472  Example selectors that describes what data to take from the on-device data
473  store will be hard-coded into the computations.
474
475  Args:
476    type_spec: The TFF type signature of the output, which must be either a
477      sequence, or a named tuple of sequences.
478    ds: Either a single `data_spec.DataSpec`, or a nested structure of these,
479      made up of Python containers, that exactly matches the structure of the
480      `type_spec`.
481
482  Returns:
483    A list of `tff.Computation`s, each of which accepts a single `string`-typed
484    tensor as input (the token for the ExternalDataset op) and returns a
485    sequence as output (with the result that matches the corresponding part of
486    `type_spec`). The computations appear on the list in a depth-first order
487    (matching exactly the convention used in the
488    `_list_tensor_names_in_binding()` method below).
489
490  Raises:
491    TypeError: If the arguments are of the wrong types.
492  """
493  assert ds
494  type_spec = tff.to_type(type_spec)
495  type_checks.check_type(
496      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
497  )
498  if type_spec.is_sequence():
499    type_checks.check_type(ds, data_spec.DataSpec)
500    assert isinstance(ds, data_spec.DataSpec)
501    assert ds.example_selector_proto is not None
502    sel_bytes = ds.example_selector_proto.SerializeToString()
503
504    @tff.tf_computation(tf.string)
505    def data_comp(token):
506      """The data source computation.
507
508      Args:
509        token: The token placeholder tensor (`tf.string`).
510
511      Returns:
512        An instance of `tf.data.Dataset`.
513      """
514      if ds.preprocessing_fn is not None:
515        processed_ds = ds.preprocessing_fn(
516            external_dataset.ExternalDataset(token=token, selector=sel_bytes)
517        )
518      else:
519        processed_ds = external_dataset.ExternalDataset(
520            token=token, selector=sel_bytes
521        )
522
523      if 'Dataset' not in type(processed_ds).__name__:
524        raise TypeError(
525            'The preprocessing function returned an unrecognized non-dataset '
526            'type {!r}.'.format(type(processed_ds))
527        )
528      return processed_ds
529
530    _validate_data_comp(data_comp, type_spec)
531    return [data_comp]
532  else:
533    type_spec.check_struct()
534    if isinstance(ds, data_spec.DataSpec):
535      raise TypeError(
536          'Expected nested structure of `DataSpec`s conforming to '
537          f'the structure of the type {type_spec}. '
538          'Found single `DataSpec` instead.'
539      )
540    ds = tff.structure.from_container(ds)
541    assert isinstance(ds, tff.structure.Struct)
542    type_spec_elements = tff.structure.to_elements(type_spec)
543    data_spec_elements = tff.structure.to_elements(ds)
544    type_spec_element_names = [str(k) for k, _ in type_spec_elements]
545    data_spec_element_names = [str(k) for k, _ in data_spec_elements]
546    if type_spec_element_names != data_spec_element_names:
547      raise TypeError(
548          'Type vs. data spec elements names mismatch: {} vs. {}.'.format(
549              str(type_spec_element_names), str(data_spec_element_names)
550          )
551      )
552    elements = []
553    for element_index, (_, element_type) in enumerate(type_spec_elements):
554      elements.extend(
555          make_data_sources_with_dataspec(element_type, ds[element_index])
556      )
557    return elements
558
559
560def make_data_sources_without_dataspec(type_spec) -> list[tff.Computation]:
561  """Creates a list of computations that feed data into the graph.
562
563  The computations use the custom ExternalDataset op to feed in example data.
564  The computations will expect two inputs:
565    -- A token specifying where the data store is on the device.
566    -- An example selector that describes what data to take from the on-device
567      data store.
568
569  Args:
570    type_spec: The TFF type signature of the output, which must be either a
571      sequence, or a named tuple of sequences.
572
573  Returns:
574    A list of `tff.Computation`s, each of which accepts a single `string`-typed
575    tensor as input (the token for the ExternalDataset op) and returns a
576    sequence as output (with the result that matches the corresponding part of
577    `type_spec`). The computations appear on the list in a depth-first order
578    (matching exactly the convention used in the
579    `_list_tensor_names_in_binding()` method below).
580
581  Raises:
582    TypeError: If the arguments are of the wrong types.
583  """
584  type_spec = tff.to_type(type_spec)
585  type_checks.check_type(
586      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
587  )
588  if type_spec.is_sequence():
589
590    @tff.tf_computation(tf.string, tf.string)
591    def data_comp(token, example_selector):
592      """The data source computation.
593
594      Args:
595        token: The token placeholder tensor (`tf.string`).
596        example_selector: The example selector placeholder tensor (`tf.string`).
597
598      Returns:
599        An instance of `tf.data.Dataset`.
600      """
601      processed_ds = external_dataset.ExternalDataset(
602          token=token, selector=example_selector
603      )
604
605      if 'Dataset' not in type(processed_ds).__name__:
606        raise TypeError(
607            'The preprocessing function returned an unrecognized non-dataset '
608            'type {!r}.'.format(type(processed_ds))
609        )
610      return processed_ds
611
612    _validate_data_comp(data_comp, type_spec)
613    return [data_comp]
614  else:  # type_spec is a struct.
615    type_spec.check_struct()
616    type_spec_elements = tff.structure.to_elements(type_spec)
617    elements = []
618    for _, element_type in type_spec_elements:
619      elements.extend(make_data_sources_without_dataspec(element_type))
620    return elements
621
622
623def _list_tensor_names_in_binding(
624    binding: computation_pb2.TensorFlow.Binding,
625) -> list[str]:
626  """Returns a flat Python list of tensor names that appear in the `binding`.
627
628  Args:
629    binding: An instance of `computation_pb2.TensorFlow.Binding` in which any
630      sequence bindings must contain variant tensors.
631
632  Returns:
633    A list of `str` instances with tensor names that appear in `binding` in the
634    order in which they appear in the depth-first traversal of the potentially
635    nested binding structure.
636
637  Raises:
638    TypeError: If the arguments are of the wrong types.
639  """
640  type_checks.check_type(binding, computation_pb2.TensorFlow.Binding)
641  which_binding = binding.WhichOneof('binding')
642  if which_binding == 'tensor':
643    return [str(binding.tensor.tensor_name)]
644  elif which_binding == 'struct':
645    result = []
646    for element in binding.struct.element:
647      result.extend(_list_tensor_names_in_binding(element))
648    return result
649  elif which_binding == 'sequence':
650    which_sequence = binding.sequence.WhichOneof('binding')
651    if which_sequence != 'variant_tensor_name':
652      raise TypeError(
653          'Expected a variant tensor in sequence binding, found {}.'.format(
654              which_sequence
655          )
656      )
657    return [binding.sequence.variant_tensor_name]
658  else:
659    raise TypeError('Unexpected type of binding {}.'.format(which_binding))
660