1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for cross_device_ops.""" 16 17import copy 18import threading 19from typing import Callable, List, Optional, Union 20 21from tensorflow.python.distribute import collective_util 22from tensorflow.python.distribute import values as value_lib 23from tensorflow.python.eager import backprop 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import indexed_slices 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import collective_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import nccl_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.types import core 37 38INSTANCE_KEY_START_NUMBER = 100 39 40 41def aggregate_gradients_using_nccl(replica_grads): 42 """Aggregate gradients using nccl allreduce.""" 43 agg_all_g_and_v = [] 44 for single_g_and_v in zip(*replica_grads): 45 single_grads = [g for g, _ in single_g_and_v] 46 agg_grads = nccl_ops.all_sum(single_grads) 47 agg_all_g_and_v.append( 48 [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) 49 50 agg_all_g_and_v = list(zip(*agg_all_g_and_v)) 51 52 return agg_all_g_and_v 53 54 55def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads): 56 """Aggregate gradients using hierarchical copies. 57 58 Args: 59 avail_devices: available GPU devices. 60 replica_grads: List of lists of (gradient, variable) tuples. The outer list 61 is over replicas. The inner list is over individual gradients. 62 63 Returns: 64 The list of (aggregated_gradient, variable), where the gradient has been 65 summed across all replicas and the variable is chosen from the first 66 replica. 67 """ 68 # This only works for DGX-1 type of machine topology 69 # Device peer to peer matrix 70 # DMA: 0 1 2 3 4 5 6 7 71 # 0: Y Y Y Y Y N N N 72 # 1: Y Y Y Y N Y N N 73 # 2: Y Y Y Y N N Y N 74 # 3: Y Y Y Y N N N Y 75 # 4: Y N N N Y Y Y Y 76 # 5: N Y N N Y Y Y Y 77 # 6: N N Y N Y Y Y Y 78 # 7: N N N Y Y Y Y Y 79 agg_grads = [] 80 num_devices = len(avail_devices) 81 # In the special case of DGX-1 machine topology, the two groups have equal 82 # size. 83 group_size = num_devices // 2 84 for i, single_grads in enumerate(zip(*replica_grads)): 85 group_0_main_device = i % num_devices 86 group_1_main_device = (group_0_main_device + group_size) % num_devices 87 if group_0_main_device < group_size: 88 group_0_begin = 0 89 group_1_begin = group_size 90 else: 91 group_0_begin = group_size 92 group_1_begin = 0 93 94 # Aggregate the first group. 95 group_0_device_grads = single_grads[group_0_begin: 96 group_0_begin + group_size] 97 with ops.device(avail_devices[group_0_main_device]): 98 group_0_agg_grads, _ = aggregate_single_gradient_using_copy( 99 group_0_device_grads, False, False) 100 101 # Aggregate the second group. 102 group_1_device_grads = single_grads[group_1_begin: 103 group_1_begin + group_size] 104 with ops.device(avail_devices[group_1_main_device]): 105 group_1_agg_grads, _ = aggregate_single_gradient_using_copy( 106 group_1_device_grads, False, False) 107 108 # Aggregate between the groups. 109 with ops.device(avail_devices[group_0_main_device]): 110 (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( 111 [group_0_agg_grads, group_1_agg_grads], False, False) 112 113 # Broadcast the result back into the root of each group. 114 with ops.device(avail_devices[group_0_main_device]): 115 group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) 116 with ops.device(avail_devices[group_1_main_device]): 117 group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) 118 119 agg_grads_bcast = [] 120 for j in range(len(single_grads)): 121 with ops.device(avail_devices[j]): 122 # Broadcast the result back to each member in the group from the root. 123 if (group_0_main_device < group_size) == (j < group_size): 124 src_device_grad = group_0_agg_grads_bcast 125 else: 126 src_device_grad = group_1_agg_grads_bcast 127 agg_grads_bcast.append(array_ops.identity(src_device_grad)) 128 129 agg_grads.append( 130 [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) 131 132 agg_grads = list(zip(*agg_grads)) 133 134 return agg_grads 135 136 137def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, 138 check_inf_nan): 139 """Calculate the average gradient for a shared variable across all replicas. 140 141 Note that this function provides a synchronization point across all replicas. 142 143 Args: 144 grad_and_vars: A list or tuple of (gradient, variable) tuples. Each 145 (gradient, variable) pair within the outer list represents the gradient 146 of the variable calculated for a single replica, and the number of pairs 147 equals the number of replicas. 148 use_mean: if True, mean is taken, else sum of gradients is taken. 149 check_inf_nan: check grads for nans and infs. 150 151 Returns: 152 The tuple ([(average_gradient, variable),], has_nan_or_inf) where the 153 gradient has been averaged across all replicas. The variable is chosen 154 from the first replica. The has_nan_or_inf indicates the grads has nan or 155 inf. 156 """ 157 grads = [g for g, _ in grad_and_vars] 158 grad = math_ops.add_n(grads) 159 160 if use_mean and len(grads) > 1: 161 grad = array_ops.multiply(grad, 1.0 / len(grads)) 162 163 v = grad_and_vars[0][1] 164 if check_inf_nan: 165 has_nan_or_inf = array_ops.logical_not( 166 array_ops.reduce_all(array_ops.is_finite(grads))) 167 return (grad, v), has_nan_or_inf 168 else: 169 return (grad, v), None 170 171 172# TODO(yuefengz): use random key starts to avoid reusing keys? 173class CollectiveKeys(object): 174 """Class that manages collective keys. 175 176 We need to manage three different keys for collective: 177 178 *Group key*: an integer key to identify the set of cooperative devices. 179 Collective ops work under the same set of devices must using the same group 180 key. 181 182 *Instance key*: an integer key to identify the set of same counterpart of 183 tensors on different devices in a device group that need to be all-reduced. 184 185 This class is thread safe. 186 """ 187 188 def __init__(self, group_key_start=1): 189 """Initializes the object. 190 191 Args: 192 group_key_start: the starting integer of group key. 193 """ 194 self._group_key = group_key_start 195 self._instance_key_table = {} 196 self._lock = threading.Lock() 197 198 def get_group_key(self, devices): 199 """Returns a new group key. 200 201 The caller should store and reuse the same group key for the same set of 202 devices. Calling this method always returns a new group key. 203 204 Args: 205 devices: a list of canonical device strings in a collective group. 206 207 Returns: 208 a new group key. 209 """ 210 with self._lock: 211 new_key = self._group_key 212 self._group_key += 1 213 self._instance_key_table[new_key] = {} 214 for device in devices: 215 self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER 216 return new_key 217 218 def get_instance_key(self, group_key, device): 219 """Returns a new instance key for use in defining a collective op. 220 221 You should call this once per each collective op of a collective instance. 222 223 Args: 224 group_key: the group key returned by get_group_key(). You should not 225 assign the group key yourself. 226 device: a canonical device string. It should be the device this collective 227 op is on. 228 229 Returns: 230 a new instance key. 231 232 Raises: 233 ValueError: when the group key is invalid or the device is not in the 234 group. 235 """ 236 with self._lock: 237 group = self._instance_key_table.get(group_key, None) 238 if group is None: 239 raise ValueError(f'Group {group_key} is not found.') 240 if device not in group: 241 raise ValueError(f'Device {device} is not present in group {group_key}') 242 v = group[device] 243 group[device] += 1 244 return v 245 246 def __deepcopy__(self, memo): 247 # distribute_coordinator deep-copies the strategy object, so 248 # CollectiveKeys needs to support deep copy as well. 249 copied = CollectiveKeys() 250 copied._group_key = self._group_key 251 copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo) 252 return copied 253 254 255class CollectiveReplicaLauncher(object): 256 """Launch collectives on one replica.""" 257 258 _prefer_unique_instance_key = True 259 _prefer_ordering_token = True 260 261 def __init__(self, group_key: int, group_size: int, 262 collective_keys: CollectiveKeys, device: str, 263 options: collective_util.Options): 264 self._group_key = group_key 265 self._group_size = group_size 266 self._collective_keys = collective_keys 267 self._device = device 268 self._options = options 269 if self._use_ordering_token(): 270 with ops.init_scope(), ops.device(device): 271 self._ordering_token = resource_variable_ops.ResourceVariable(0.) 272 else: 273 self._ordering_token = None 274 275 def _control_input(self, control_input: Union[core.TensorLike, 276 ops.Operation]): 277 if control_input is not None and not self._use_ordering_token(): 278 return ops.control_dependencies([control_input]) 279 return ops.NullContextmanager() 280 281 def _use_unique_instance_key(self): 282 if not ops.executing_eagerly_outside_functions(): 283 return False 284 return CollectiveReplicaLauncher._prefer_unique_instance_key 285 286 def _use_ordering_token(self): 287 # We rely on auto control dep to insert control edges between NCCL calls, 288 # but for tf1 graph mode auto control dep is not used. 289 if not ops.executing_eagerly_outside_functions(): 290 return False 291 return CollectiveReplicaLauncher._prefer_ordering_token 292 293 def _next_instance_key(self): 294 """Returns the next instance key.""" 295 if self._use_unique_instance_key(): 296 # Assigning instance keys at function building time have issues since 297 # different workers may retrace the function at different times. With 298 # collective V2 we can use capture_call_time_value to use a placeholder as 299 # the instance key and feed it at function call time. In this way we also 300 # don't reuse instance keys, which allows for per-instance cancellation. 301 graph = ops.get_default_graph() 302 # Control flow ops don't work with capture_call_time_value, so we put the 303 # capture in the function graph of that control flow op. 304 while getattr(graph, 'is_control_flow_graph', False): 305 graph = graph.outer_graph 306 if not context.executing_eagerly() and graph.building_function: 307 with graph.as_default(): 308 # Capture self._next_instance_key so that when building a function 309 # that calls another tf.function, the instance key assignment is 310 # further delayed until we actually call the function in eager. Note 311 # that capture_call_time_value doesn't automatically propagate the 312 # deferred capture to the outer function. 313 return graph.capture_call_time_value( 314 self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32)) 315 else: 316 instance_key = self._collective_keys.get_instance_key( 317 self._group_key, self._device) 318 with ops.device('CPU:0'): 319 return ops.convert_to_tensor(instance_key, dtype=dtypes.int32) 320 else: 321 return self._collective_keys.get_instance_key(self._group_key, 322 self._device) 323 324 def _get_ordering_token(self): 325 if self._use_ordering_token(): 326 return self._ordering_token.handle 327 328 def can_order_nccl(self): 329 """Whether this launcher can order NCCL operations.""" 330 return self._use_ordering_token() 331 332 def all_reduce( 333 self, 334 input_tensor: core.TensorLike, 335 control_input: Optional[Union[core.TensorLike, ops.Operation]] = None, 336 options: Optional[collective_util.Options] = None) -> core.Tensor: 337 """All-reduce a dense tensor. 338 339 Args: 340 input_tensor: a dense tensor. It must have the same shape on all replicas. 341 control_input: if not None, add control edges between control_input and 342 the all-reduce. 343 options: an optional tf.distribute.experimental.CommunicationOptions. If 344 provided, it overrides the default options. 345 346 Returns: 347 The reduced tensor. 348 """ 349 instance_key = self._next_instance_key() 350 options = self._options.merge(options) 351 ordering_token = self._get_ordering_token() 352 with ops.device(self._device), \ 353 self._control_input(control_input): 354 return collective_ops.all_reduce_v2( 355 input_tensor, 356 self._group_size, 357 self._group_key, 358 instance_key, 359 communication_hint=options.implementation.value, 360 timeout=options.timeout_seconds, 361 ordering_token=ordering_token) 362 363 def _all_gather(self, input_tensor: core.TensorLike, 364 options: Optional[collective_util.Options]) -> core.Tensor: 365 """All-gather a dense tensor. 366 367 Args: 368 input_tensor: a dense tensor. It must have the same shape on all replicas. 369 options: an optional tf.distribute.experimental.CommunicationOptions. If 370 provided, it overrides the default options. 371 372 Returns: 373 The reduced tensor. 374 """ 375 instance_key = self._next_instance_key() 376 options = self._options.merge(options) 377 ordering_token = self._get_ordering_token() 378 with ops.device(self._device): 379 return collective_ops.all_gather_v2( 380 input_tensor, 381 self._group_size, 382 self._group_key, 383 instance_key, 384 communication_hint=options.implementation.value, 385 timeout=options.timeout_seconds, 386 ordering_token=ordering_token) 387 388 def batch_all_reduce( 389 self, 390 input_tensor_packs: List[List[core.TensorLike]], 391 options: Optional[collective_util.Options] = None) -> core.Tensor: 392 """Batch all-reduce dense tensors. 393 394 This takes a list of batches of tensors. Using multiple batches have the 395 benefit that it doesn't need to wait for all inputs to be ready to start the 396 all-reduce. 397 398 Args: 399 input_tensor_packs: a list of lists of dense tensors. 400 options: an optional tf.distribute.experimental.CommunicationOptions. If 401 provided, it overrides the default options. 402 403 Returns: 404 A flat list of reduced tensors. 405 """ 406 options = self._options.merge(options) 407 outputs = [] 408 for pack in input_tensor_packs: 409 if context.executing_eagerly(): 410 # We don't batch in eager as it sometimes makes the performance worse 411 # due the concat/split ops. 412 for input_tensor in pack: 413 outputs.append(self.all_reduce(input_tensor, None, options)) 414 else: 415 # TODO(b/169168846): inserts a parallel all_gather to verify packings 416 # are the same on each replica. 417 with ops.device(self._device): 418 flat_tensors = [array_ops.reshape(t, [-1]) for t in pack] 419 shapes = [array_ops.shape(t) for t in pack] 420 if (options.implementation 421 == collective_util.CommunicationImplementation.NCCL and outputs): 422 control_input = outputs[-1] 423 else: 424 control_input = None 425 reduced = self.all_reduce( 426 array_ops.concat(flat_tensors, axis=0), control_input, options) 427 num_elements = [math_ops.reduce_prod(s) for s in shapes] 428 flat_outputs = array_ops.split(reduced, num_elements, axis=0) 429 for shape, flat_output in zip(shapes, flat_outputs): 430 outputs.append(array_ops.reshape(flat_output, shape)) 431 432 return outputs 433 434 def all_gather( 435 self, 436 input_tensor: core.TensorLike, 437 axis: core.TensorLike, 438 options: Optional[collective_util.Options] = None) -> core.Tensor: 439 """All-gather a dense tensor. 440 441 This method must be called inside a tf.function. 442 443 Args: 444 input_tensor: a dense tensor. It must have the same rank on all replicas, 445 and dimensions other than `axis` need to be the same as well. 446 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 447 range [0, rank(value)). 448 options: an optional tf.distribute.experimental.CommunicationOptions. If 449 provided, it overrides the default options. 450 451 Returns: 452 The gathered Tensor. 453 454 Raises: 455 RuntimeError: if called in eager mode. 456 """ 457 if context.executing_eagerly(): 458 raise RuntimeError('all_gather is not supported in eager mode.') 459 460 with ops.device(self._device), \ 461 ops.control_dependencies([array_ops.identity(input_tensor)]): 462 # 1. Transpose 463 # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, 464 # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which 465 # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to 466 # place it back. 467 perm_pre = array_ops.concat( 468 ([axis], math_ops.range(axis), 469 math_ops.range(axis + 1, array_ops.rank(input_tensor))), 470 axis=0) 471 input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) 472 # 2. Pad 473 gathered_shape = self._all_gather( 474 array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), 475 options) 476 first_dims = gathered_shape[:, 0] 477 full_axis_dim = math_ops.reduce_max(first_dims) 478 padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) 479 480 # 3. Gather 481 gather_padded_out_tensor = self._all_gather(padded_input_tensor, options) 482 # 4. Unpad 483 split_tensors = [] 484 for i in range(self._group_size): 485 start_pos = i * full_axis_dim 486 split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + 487 first_dims[i]]) 488 out_tensor_t = array_ops.concat(split_tensors, 0) 489 490 # 5. Transpose back 491 perm_after = array_ops.concat( 492 (math_ops.range(1, axis + 1), [0], 493 math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), 494 axis=0) 495 return array_ops.transpose(out_tensor_t, perm=perm_after) 496 497 def all_reduce_indexed_slices( 498 self, 499 input_slices: indexed_slices.IndexedSlices, 500 options: Optional[collective_util.Options] = None 501 ) -> indexed_slices.IndexedSlices: 502 """All-reduce an IndexedSlices. 503 504 This method must be called inside a tf.function. 505 506 Args: 507 input_slices: an IndexedSlices. 508 options: an optional tf.distribute.experimental.CommunicationOptions. If 509 provided, it overrides the default options. 510 511 Returns: 512 The reduced IndexedSlices. 513 514 Raises: 515 RuntimeError: if called in eager mode. 516 """ 517 if context.executing_eagerly(): 518 raise RuntimeError( 519 'all_reduce_indexed_slices is not supported in eager mode.') 520 521 # Current CollectiveAllGather implementations require input IndexedSlices to 522 # have consistent length across the board, we handle the reduction of 523 # IndexedSlices as follows: 524 # 1. Gather the lengths of IndexedSlices from all participants. 525 # 2. If they have consistent length, apply all_gather. 526 # 3. Otherwise pad IndexedSlices to be the same length across all 527 # participants and apply_gather. 528 options = self._options.merge(options) 529 with ops.device(self._device): 530 531 def all_gather_indexed_slices( 532 all_gather_fn: Callable[ 533 [core.TensorLike, Optional[collective_util.Options]], core.Tensor] 534 ) -> indexed_slices.IndexedSlices: 535 """Use all_gather_fn to aggregate `IndexedSlices`.""" 536 all_values = all_gather_fn(input_slices.values, options) 537 # Add control dependency to order the all-gather. 538 if (options.implementation == 539 collective_util.CommunicationImplementation.NCCL): 540 control = [all_values] 541 else: 542 control = [] 543 with ops.control_dependencies(control): 544 all_indices = all_gather_fn(input_slices.indices, options) 545 return indexed_slices.IndexedSlices( 546 values=all_values, 547 indices=all_indices, 548 dense_shape=input_slices.dense_shape) 549 550 length = array_ops.shape(input_slices.indices) 551 all_lengths = self._all_gather(length, options) 552 553 def all_gather_with_padding( 554 input_tensor: core.TensorLike, 555 options: Optional[collective_util.Options]) -> core.Tensor: 556 """all_gather tensors of different sizes using padding.""" 557 max_length = math_ops.reduce_max(all_lengths) 558 padded_tensor = _pad_util(input_tensor, max_length) 559 all_padded_tensors = self._all_gather(padded_tensor, options) 560 split_tensors = [] 561 for i in range(self._group_size): 562 start_pos = i * max_length 563 split_tensors.append(all_padded_tensors[start_pos:start_pos + 564 all_lengths[i]]) 565 return array_ops.concat(split_tensors, 0) 566 567 return control_flow_ops.cond( 568 math_ops.equal( 569 math_ops.reduce_max(all_lengths), 570 math_ops.reduce_min(all_lengths)), 571 lambda: all_gather_indexed_slices(self._all_gather), 572 lambda: all_gather_indexed_slices(all_gather_with_padding)) 573 574 575def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): 576 """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" 577 if any(isinstance(v, indexed_slices.IndexedSlices) for v in values): 578 return backprop.aggregate_indexed_slices_gradients(values) 579 else: 580 return accumulation_fn(values) 581 582 583def divide_by_n_tensors_or_indexed_slices(value, n): 584 if isinstance(value, indexed_slices.IndexedSlices): 585 value = backprop.flatten_nested_indexed_slices(value) 586 return indexed_slices.IndexedSlices(value.values / n, value.indices, 587 value.dense_shape) 588 else: 589 return value / n 590 591 592def copy_tensor_or_indexed_slices_to_device(value, device): 593 """Copies a tensor or IndexedSlices to a device.""" 594 with ops.device(device): 595 if isinstance(value, indexed_slices.IndexedSlices): 596 copied_values = array_ops.identity(value.values) 597 copied_indices = array_ops.identity(value.indices) 598 if value.dense_shape is not None: 599 copied_shape = array_ops.identity(value.dense_shape) 600 else: 601 copied_shape = None 602 result = indexed_slices.IndexedSlices(copied_values, copied_indices, 603 copied_shape) 604 else: 605 result = array_ops.identity(value) 606 return result 607 608 609def is_indexed_slices(value): 610 if isinstance(value, indexed_slices.IndexedSlices): 611 return True 612 if isinstance(value, value_lib.DistributedValues): 613 return all( 614 isinstance(v, indexed_slices.IndexedSlices) for v in value.values) 615 return False 616 617 618def split_by_sparsity(values): 619 """Split values into dense and sparse values. 620 621 Args: 622 values: a list of tensors or `PerReplica`s. 623 624 Returns: 625 Four lists: 626 a list of dense values, a list of their indices in `values` and 627 a list of sparse values, a list of their indices in `values`. 628 """ 629 dense_values = [] 630 dense_indices = [] 631 sparse_values = [] 632 sparse_indices = [] 633 for i, v in enumerate(values): 634 if is_indexed_slices(v): 635 sparse_values.append(v) 636 sparse_indices.append(i) 637 else: 638 dense_values.append(v) 639 dense_indices.append(i) 640 return dense_values, dense_indices, sparse_values, sparse_indices 641 642 643def stitch_values(values_and_indices_list): 644 """Stitch values together according to their indices. 645 646 Args: 647 values_and_indices_list: a list of tuples of values and indices indicating 648 the values and positions in the returned list. 649 650 Returns: 651 a stitched list of values. 652 """ 653 length = 0 654 for values_and_indices in values_and_indices_list: 655 length += len(values_and_indices[0]) 656 657 result = [None] * length 658 for values_and_indices in values_and_indices_list: 659 if values_and_indices and values_and_indices[0]: 660 for v, i in zip(*values_and_indices): 661 assert result[i] is None 662 result[i] = v 663 return result 664 665 666def group_by_size(input_tensors, bytes_per_pack): 667 """Groups `input_tensors` into chunks of `bytes_per_pack`. 668 669 The method preserves the original order of `input_tensors`. The grouping is 670 best effort, each pack could have more or less bytes than `bytes_per_pack`. 671 It only groups values with known shape. 672 673 Args: 674 input_tensors: a list of Tensor. 675 bytes_per_pack: an integer. 676 677 Returns: 678 A list of packs of Tensor. All values are grouped into one pack if 679 `bytes_per_pack` is zero or any of the value has unknown shape. 680 """ 681 682 if bytes_per_pack == 0: 683 return [input_tensors] 684 packs = [] 685 last_pack_size = 0 686 for value in input_tensors: 687 num_elements = value.shape.num_elements() 688 if num_elements is None: 689 # Can't pack values with unknown shape. 690 logging.warning( 691 'not packing values due to the unknown or inconsistent shape of %s', 692 value) 693 return [input_tensors] 694 size = num_elements * value.dtype.size 695 # Try to keep each pack as close to bytes_per_pack as possible, while each 696 # pack is at least bytes_per_pack large. I.E. we err on the side of having 697 # few but large packs. 698 if not packs or last_pack_size > bytes_per_pack: 699 packs.append([]) 700 last_pack_size = 0 701 packs[-1].append(value) 702 last_pack_size += size 703 return packs 704 705 706def _pad_util(input_tensor, full_axis_dim): 707 """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" 708 missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] 709 tensor_rank = array_ops.rank(input_tensor) 710 paddings_axis = [[0, missing_axis_dim]] 711 paddings = array_ops.concat([ 712 paddings_axis, 713 array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) 714 ], 715 axis=0) 716 padded_input_tensor = array_ops.pad(input_tensor, paddings) 717 return padded_input_tensor 718