1# Operation Semantics 2 3The following describes the semantics of operations defined in the 4[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 5interface. Typically, these operations map one-to-one to operations defined in 6the RPC interface in 7[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto). 8 9A note on nomenclature: the generalized data type XLA deals with is an 10N-dimensional array holding elements of some uniform type (such as 32-bit 11float). Throughout the documentation, *array* is used to denote an 12arbitrary-dimensional array. For convenience, special cases have more specific 13and familiar names; for example a *vector* is a 1-dimensional array and a 14*matrix* is a 2-dimensional array. 15 16## AfterAll 17 18See also 19[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 20 21AfterAll takes a variadic number of tokens and produces a single token. Tokens 22are primitive types which can be threaded between side-effecting operations to 23enforce ordering. `AfterAll` can be used as a join of tokens for ordering a 24operation after a set operations. 25 26<b> `AfterAll(operands)` </b> 27 28Arguments | Type | Semantics 29---------- | ------- | ------------------------- 30`operands` | `XlaOp` | variadic number of tokens 31 32## AllGather 33 34See also 35[`XlaBuilder::AllGather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 36 37Performs concatenation across replicas. 38 39<b> `AllGather(operand, all_gather_dim, shard_count, replica_group_ids, 40channel_id)` </b> 41 42| Arguments | Type | Semantics | 43| ---------------- | -------------------- | --------------------------- | 44| `operand` | `XlaOp` | Array to concatenate across | 45: : : replicas. : 46| `all_gather_dim` | `int64` | Concatenation dimension. | 47| `replica_groups` | vector of vectors of | Groups between which the | 48: : `int64` : concatenation is performed. : 49| `channel_id` | optional `int64` | Optional channel ID for | 50: : : cross-module communication. : 51 52- `replica_groups` is a list of replica groups between which the concatenation 53 is performed (replica id for the current replica can be retrieved using 54 [`ReplicaId`](#replicaid)). The order of replicas in each group determines 55 the order in which their inputs are located in the result. `replica_groups` 56 must either be empty (in which case all replicas belong to a single group, 57 ordered from `0` to `N - 1`), or contain the same number of elements as the 58 number of replicas. For example, `replica_groups = {0, 2}, {1, 3}` performs 59 concatenation between the replicas `0` and `2`, and `1` and `3`. 60- `shard_count` is the size of each replica group. We need this in cases where 61 `replica_groups` are empty. 62- `channel_id` is used for cross-module communication: only `all-gather` 63 operations with the same `channel_id` can communicate to each other. 64 65The output shape is the input shape with the `all_gather_dim` made `shard_count` 66times larger. For example, if there are two replicas and the operand has the 67value `[1.0, 2.5]` and `[3.0, 5.25]` respectively on the two replicas, then the 68output value from this op where `all_gather_dim` is `0` will be `[1.0, 2.5, 3.0, 695.25]` on both replicas. 70 71## AllReduce 72 73See also 74[`XlaBuilder::AllReduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 75 76Performs a custom computation across replicas. 77 78<b> `AllReduce(operand, computation, replica_group_ids, channel_id)` </b> 79 80| Arguments | Type | Semantics | 81| ---------------- | -------------------- | --------------------------------- | 82| `operand` | `XlaOp` | Array or a non-empty tuple of | 83: : : arrays to reduce across replicas. : 84| `computation` | `XlaComputation` | Reduction computation | 85| `replica_groups` | vector of vectors of | Groups between which the | 86: : `int64` : reductions are performed : 87| `channel_id` | optional `int64` | Optional channel ID for | 88: : : cross-module communication : 89 90- When `operand` is a tuple of arrays, the all-reduce is performed on each 91 element of the tuple. 92- `replica_groups` is a list of replica groups between which the reduction is 93 performed (replica id for the current replica can be retrieved using 94 [`ReplicaId`](#replicaid)). `replica_groups` must either be empty (in which 95 case all replicas belong to a single group), or contain the same number of 96 elements as the number of replicas. For example, `replica_groups = {0, 2}, 97 {1, 3}` performs reduction between the replicas `0` and `2`, and `1` and 98 `3`. 99- `channel_id` is used for cross-module communication: only `all-reduce` 100 operations with the same `channel_id` can communicate to each other. 101 102The output shape is the same as the input shape. For example, if there are two 103replicas and the operand has the value `[1.0, 2.5]` and `[3.0, 5.25]` 104respectively on the two replicas, then the output value from this op and 105summation computation will be `[4.0, 7.75]` on both replicas. If the input is a 106tuple, the output is a tuple as well. 107 108Computing the result of `AllReduce` requires having one input from each replica, 109so if one replica executes a `AllReduce` node more times than another, then the 110former replica will wait forever. Since the replicas are all running the same 111program, there are not a lot of ways for that to happen, but it is possible when 112a while loop's condition depends on data from infeed and the data that is infed 113causes the while loop to iterate more times on one replica than another. 114 115## AllToAll 116 117See also 118[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 119 120AllToAll is a collective operation that sends data from all cores to all cores. 121It has two phases: 122 1231. The scatter phase. On each core, the operand is split into `split_count` 124 number of blocks along the `split_dimensions`, and the blocks are scattered 125 to all cores, e.g., the ith block is send to the ith core. 1262. The gather phase. Each core concatenates the received blocks along the 127 `concat_dimension`. 128 129The participating cores can be configured by: 130 131- `replica_groups`: each ReplicaGroup contains a list of replica id 132 participating in the computation (replica id for the current replica can be 133 retrieved using [`ReplicaId`](#replicaid)). AllToAll will be applied within 134 subgroups in the specified order. For example, `replica_groups = {{1,2,3}, 135 {4,5,0}}` means that an AllToAll will be applied within replicas `{1, 2, 136 3}`, and in the gather phase, and the received blocks will be concatenated 137 in the same order of 1, 2, 3. Then, another AllToAll will be applied within 138 replicas 4, 5, 0, and the concatenation order is also 4, 5, 0. If 139 `replica_groups` is empty, all replicas belong to one group, in the 140 concatenation order of their appearance. 141 142Prerequisites: 143 144- The dimension size of the operand on the `split_dimension` is divisible by 145`split_count`. 146- The operand's shape is not tuple. 147 148<b> `AllToAll(operand, split_dimension, concat_dimension, split_count, 149replica_groups)` </b> 150 151 152| Arguments | Type | Semantics | 153| ------------------ | --------------------- | ------------------------------- | 154| `operand` | `XlaOp` | n dimensional input array | 155| `split_dimension` | `int64` | A value in the interval `[0, | 156: : : n)` that names the dimension : 157: : : along which the operand is : 158: : : split : 159| `concat_dimension` | `int64` | a value in the interval `[0, | 160: : : n)` that names the dimension : 161: : : along which the split blocks : 162: : : are concatenated : 163| `split_count` | `int64` | the number of cores that | 164: : : participate this operation. If : 165: : : `replica_groups` is empty, this : 166: : : should be the number of : 167: : : replicas; otherwise, this : 168: : : should be equal to the number : 169: : : of replicas in each group. : 170| `replica_groups` | `ReplicaGroup` vector | each group contains a list of | 171: : : replica id. : 172 173Below shows an example of Alltoall. 174 175``` 176XlaBuilder b("alltoall"); 177auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); 178AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); 179``` 180 181<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 182<img style="width:100%" src="./images/ops_alltoall.png"> 183</div> 184 185In this example, there are 4 cores participating the Alltoall. On each core, the 186operand is split into 4 parts along dimension 0, so each part has shape 187f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates 188the received parts along dimension 1, in the order or core 0-4. So the output on 189each core has shape f32[16,4]. 190 191## BatchNormGrad 192 193See also 194[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 195and [the original batch normalization paper](https://arxiv.org/abs/1502.03167) 196for a detailed description of the algorithm. 197 198Calculates gradients of batch norm. 199 200<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b> 201 202| Arguments | Type | Semantics | 203| --------------- | ----------------------- | -------------------------------- | 204| `operand` | `XlaOp` | n dimensional array to be | 205: : : normalized (x) : 206| `scale` | `XlaOp` | 1 dimensional array | 207: : : (\\(\gamma\\)) : 208| `mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) | 209| `variance` | `XlaOp` | 1 dimensional array | 210: : : (\\(\sigma^2\\)) : 211| `grad_output` | `XlaOp` | Gradients passed to | 212: : : `BatchNormTraining` : 213: : : (\\( \nabla y\\)) : 214| `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) | 215| `feature_index` | `int64` | Index to feature dimension in | 216: : : `operand` : 217 218For each feature in the feature dimension (`feature_index` is the index for the 219feature dimension in `operand`), the operation calculates the gradients with 220respect to `operand`, `offset` and `scale` across all the other dimensions. The 221`feature_index` must be a valid index for the feature dimension in `operand`. 222 223The three gradients are defined by the following formulas (assuming a 2244-dimensional array as `operand` and with feature dimension index `l`, batch 225size `m` and spatial sizes `w` and `h`): 226 227\\[ \begin{split} c_l&= 228\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h 229\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) 230\\\\ 231\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}} 232\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l}) 233\right) 234\\\\ 235\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} 236\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right) 237\\\\\ 238\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} 239\end{split} \\] 240 241The inputs `mean` and `variance` represent moments value 242across batch and spatial dimensions. 243 244The output type is a tuple of three handles: 245 246| Outputs | Type | Semantics | 247| ------------- | ----------------------- | --------------------------------- | 248| `grad_operand` | `XlaOp` | gradient with respect to input | 249: : : `operand` (\\( \nabla x\\)) : 250| `grad_scale` | `XlaOp` | gradient with respect to input | 251: : : `scale` (\\( \nabla \gamma\\)) : 252| `grad_offset` | `XlaOp` | gradient with respect to input | 253: : : `offset`(\\( \nabla \beta\\)) : 254 255## BatchNormInference 256 257See also 258[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 259and [the original batch normalization paper](https://arxiv.org/abs/1502.03167) 260for a detailed description of the algorithm. 261 262Normalizes an array across batch and spatial dimensions. 263 264<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b> 265 266Arguments | Type | Semantics 267--------------- | ------- | --------------------------------------- 268`operand` | `XlaOp` | n dimensional array to be normalized 269`scale` | `XlaOp` | 1 dimensional array 270`offset` | `XlaOp` | 1 dimensional array 271`mean` | `XlaOp` | 1 dimensional array 272`variance` | `XlaOp` | 1 dimensional array 273`epsilon` | `float` | Epsilon value 274`feature_index` | `int64` | Index to feature dimension in `operand` 275 276For each feature in the feature dimension (`feature_index` is the index for the 277feature dimension in `operand`), the operation calculates the mean and variance 278across all the other dimensions and uses the mean and variance to normalize each 279element in `operand`. The `feature_index` must be a valid index for the feature 280dimension in `operand`. 281 282`BatchNormInference` is equivalent to calling `BatchNormTraining` without 283computing `mean` and `variance` for each batch. It uses the input `mean` and 284`variance` instead as estimated values. The purpose of this op is to reduce 285latency in inference, hence the name `BatchNormInference`. 286 287The output is an n-dimensional, normalized array with the same shape as input 288`operand`. 289 290## BatchNormTraining 291 292See also 293[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 294and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167) 295for a detailed description of the algorithm. 296 297Normalizes an array across batch and spatial dimensions. 298 299<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b> 300 301Arguments | Type | Semantics 302--------------- | ------- | ---------------------------------------- 303`operand` | `XlaOp` | n dimensional array to be normalized (x) 304`scale` | `XlaOp` | 1 dimensional array (\\(\gamma\\)) 305`offset` | `XlaOp` | 1 dimensional array (\\(\beta\\)) 306`epsilon` | `float` | Epsilon value (\\(\epsilon\\)) 307`feature_index` | `int64` | Index to feature dimension in `operand` 308 309For each feature in the feature dimension (`feature_index` is the index for the 310feature dimension in `operand`), the operation calculates the mean and variance 311across all the other dimensions and uses the mean and variance to normalize each 312element in `operand`. The `feature_index` must be a valid index for the feature 313dimension in `operand`. 314 315The algorithm goes as follows for each batch in `operand` \\(x\\) that 316contains `m` elements with `w` and `h` as the size of spatial dimensions 317(assuming `operand` is an 4 dimensional array): 318 319- Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension: 320\\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\) 321 322- Calculates batch variance \\(\sigma^2_l\\): 323\\(\sigma^2_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (x_{ijkl} - \mu_l)^2\\) 324 325- Normalizes, scales and shifts: 326\\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\\) 327 328The epsilon value, usually a small number, is added to avoid divide-by-zero errors. 329 330The output type is a tuple of three `XlaOp`s: 331 332| Outputs | Type | Semantics | 333| ------------ | ----------------------- | -------------------------------------| 334| `output` | `XlaOp` | n dimensional array with the same | 335: : : shape as input `operand` (y) : 336| `batch_mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) | 337| `batch_var` | `XlaOp` | 1 dimensional array (\\(\sigma^2\\)) | 338 339The `batch_mean` and `batch_var` are moments calculated across the batch and 340spatial dimensions using the formulas above. 341 342## BitcastConvertType 343 344See also 345[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 346 347Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast 348operation from a data shape to a target shape. The input and output size must 349match: e.g. `s32` elements become `f32` elements via bitcast routine, and one 350`s32` element will become four `s8` elements. Bitcast is implemented as a 351low-level cast, so machines with different floating-point representations will 352give different results. 353 354<b> `BitcastConvertType(operand, new_element_type)` </b> 355 356Arguments | Type | Semantics 357------------------ | --------------- | --------------------------- 358`operand` | `XlaOp` | array of type T with dims D 359`new_element_type` | `PrimitiveType` | type U 360 361The dimensions of the operand and the target shape must match, apart from the 362last dimension which will change by the ratio of the primitive size before and 363after the conversion. 364 365The source and destination element types must not be tuples. 366 367### Bitcast-converting to primitive type of different width 368 369`BitcastConvert` HLO instruction supports the case where the size of the output 370element type `T'` is not equal to the size of the input element `T`. As the 371whole operation is conceptually a bitcast and does not change the underlying 372bytes, the shape of the output element has to change. For `B = sizeof(T), B' = 373sizeof(T')`, there are two possible cases. 374 375First, when `B > B'`, the output shape gets a new minor-most dimension of size 376`B/B'`. For example: 377 378``` 379 f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input) 380``` 381 382The rule remains the same for effective scalars: 383 384``` 385 f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input) 386``` 387 388Alternatively, for `B' > B` the instruction requires the last logical dimension 389of the input shape to be equal to `B'/B`, and this dimension is dropped during 390the conversion: 391 392``` 393 f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input) 394``` 395 396Note that conversions between different bitwidths are not elementwise. 397 398## Broadcast 399 400See also 401[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 402 403Adds dimensions to an array by duplicating the data in the array. 404 405<b> `Broadcast(operand, broadcast_sizes)` </b> 406 407Arguments | Type | Semantics 408----------------- | ------------------- | ------------------------------- 409`operand` | `XlaOp` | The array to duplicate 410`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions 411 412The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has 413values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then 414the shape of the output has dimensions `{a0, ..., aN, b0, ..., bM}`. 415 416The new dimensions index into copies of the operand, i.e. 417 418``` 419output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] 420``` 421 422For example, if `operand` is a scalar `f32` with value `2.0f`, and 423`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape 424`f32[2, 3]` and all the values in the result will be `2.0f`. 425 426## BroadcastInDim 427 428See also 429[`XlaBuilder::BroadcastInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 430 431Expands the size and rank of an array by duplicating the data in the array. 432 433<b> `BroadcastInDim(operand, out_dim_size, broadcast_dimensions)` </b> 434 435| Arguments | Type | Semantics | 436| ---------------------- | ------------------- | ----------------------------- | 437| `operand` | `XlaOp` | The array to duplicate | 438| `out_dim_size` | `ArraySlice<int64>` | The sizes of the dimensions | 439: : : of the target shape : 440| `broadcast_dimensions` | `ArraySlice<int64>` | Which dimension in the target | 441: : : shape each dimension of the : 442: : : operand shape corresponds to : 443 444Similar to Broadcast, but allows adding dimensions anywhere and expanding 445existing dimensions with size 1. 446 447The `operand` is broadcast to the shape described by `out_dim_size`. 448`broadcast_dimensions` maps the dimensions of `operand` to the dimensions of the 449target shape, i.e. the i'th dimension of the operand is mapped to the 450broadcast_dimension\[i\]'th dimension of the output shape. The dimensions of 451`operand` must have size 1 or be the same size as the dimension in the output 452shape they are mapped to. The remaining dimensions are filled with dimensions of 453size 1. Degenerate-dimension broadcasting then broadcasts along these degenerate 454dimensions to reach the output shape. The semantics are described in detail on 455the [broadcasting page](broadcasting.md). 456 457## Call 458 459See also 460[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 461 462Invokes a computation with the given arguments. 463 464<b> `Call(computation, args...)` </b> 465 466| Arguments | Type | Semantics | 467| ------------- | ---------------------- | ----------------------------------- | 468| `computation` | `XlaComputation` | computation of type `T_0, T_1, ..., | 469: : : T_{N-1} -> S` with N parameters of : 470: : : arbitrary type : 471| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type | 472 473The arity and types of the `args` must match the parameters of the 474`computation`. It is allowed to have no `args`. 475 476## Cholesky 477 478See also 479[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 480 481Computes the 482[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition) 483of a batch of symmetric (Hermitian) positive definite matrices. 484 485<b> `Cholesky(a, lower)` </b> 486 487Arguments | Type | Semantics 488--------- | ------- | ----------------------------------------------------- 489`a` | `XlaOp` | a rank > 2 array of a complex or floating-point type. 490`lower` | `bool` | whether to use the upper or lower triangle of `a`. 491 492If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l 493. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such 494that $$ a = u^T . u $$. 495 496Input data is read only from the lower/upper triangle of `a`, depending on the 497value of `lower`. Values from the other triangle are ignored. Output data is 498returned in the same triangle; the values in the other triangle are 499implementation-defined and may be anything. 500 501If the rank of `a` is greater than 2, `a` is treated as a batch of matrices, 502where all except the minor 2 dimensions are batch dimensions. 503 504If `a` is not symmetric (Hermitian) positive definite, the result is 505implementation-defined. 506 507## Clamp 508 509See also 510[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 511 512Clamps an operand to within the range between a minimum and maximum value. 513 514<b> `Clamp(min, operand, max)` </b> 515 516Arguments | Type | Semantics 517--------- | ------- | --------------- 518`min` | `XlaOp` | array of type T 519`operand` | `XlaOp` | array of type T 520`max` | `XlaOp` | array of type T 521 522Given an operand and minimum and maximum values, returns the operand if it is in 523the range between the minimum and maximum, else returns the minimum value if the 524operand is below this range or the maximum value if the operand is above this 525range. That is, `clamp(a, x, b) = min(max(a, x), b)`. 526 527All three arrays must be the same shape. Alternatively, as a restricted form of 528[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`. 529 530Example with scalar `min` and `max`: 531 532``` 533let operand: s32[3] = {-1, 5, 9}; 534let min: s32 = 0; 535let max: s32 = 6; 536==> 537Clamp(min, operand, max) = s32[3]{0, 5, 6}; 538``` 539 540## Collapse 541 542See also 543[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 544and the `tf.reshape` operation. 545 546Collapses dimensions of an array into one dimension. 547 548<b> `Collapse(operand, dimensions)` </b> 549 550Arguments | Type | Semantics 551------------ | -------------- | ----------------------------------------------- 552`operand` | `XlaOp` | array of type T 553`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions. 554 555Collapse replaces the given subset of the operand's dimensions by a single 556dimension. The input arguments are an arbitrary array of type T and a 557compile-time-constant vector of dimension indices. The dimension indices must be 558an in-order (low to high dimension numbers), consecutive subset of T's 559dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension sets, but 560{1, 0} or {0, 2} are not. They are replaced by a single new dimension, in the 561same position in the dimension sequence as those they replace, with the new 562dimension size equal to the product of original dimension sizes. The lowest 563dimension number in `dimensions` is the slowest varying dimension (most major) 564in the loop nest which collapses these dimension, and the highest dimension 565number is fastest varying (most minor). See the `tf.reshape` operator 566if more general collapse ordering is needed. 567 568For example, let v be an array of 24 elements: 569 570``` 571let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, 572{{20, 21, 22}, {25, 26, 27}}, 573{{30, 31, 32}, {35, 36, 37}}, 574{{40, 41, 42}, {45, 46, 47}}}; 575 576// Collapse to a single dimension, leaving one dimension. 577let v012 = Collapse(v, {0,1,2}); 578then v012 == f32[24] {10, 11, 12, 15, 16, 17, 57920, 21, 22, 25, 26, 27, 58030, 31, 32, 35, 36, 37, 58140, 41, 42, 45, 46, 47}; 582 583// Collapse the two lower dimensions, leaving two dimensions. 584let v01 = Collapse(v, {0,1}); 585then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17}, 586{20, 21, 22, 25, 26, 27}, 587{30, 31, 32, 35, 36, 37}, 588{40, 41, 42, 45, 46, 47}}; 589 590// Collapse the two higher dimensions, leaving two dimensions. 591let v12 = Collapse(v, {1,2}); 592then v12 == f32[8x3] {{10, 11, 12}, 593{15, 16, 17}, 594{20, 21, 22}, 595{25, 26, 27}, 596{30, 31, 32}, 597{35, 36, 37}, 598{40, 41, 42}, 599{45, 46, 47}}; 600 601``` 602 603## CollectivePermute 604 605See also 606[`XlaBuilder::CollectivePermute`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 607 608CollectivePermute is a collective operation that sends and receives data cross 609replicas. 610 611<b> `CollectivePermute(operand, source_target_pairs)` </b> 612 613| Arguments | Type | Semantics | 614| --------------------- | ----------------------- | -------------------------- | 615| `operand` | `XlaOp` | n dimensional input array | 616| `source_target_pairs` | `<int64, int64>` vector | A list of | 617: : : (source_replica_id, : 618: : : target_replica_id) pairs. : 619: : : For each pair, the operand : 620: : : is sent from source : 621: : : replica to target replica. : 622 623Note that there are the following restrictions on the `source_target_pair`: 624 625- Any two pairs should not have the same target replica id, and they should 626not have the same source replica id. 627- If a replica id is not a target in any pair, then the output on that replica 628is a tensor consists of 0(s) with the same shape as the input. 629 630## Concatenate 631 632See also 633[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 634 635Concatenate composes an array from multiple array operands. The array is of the 636same rank as each of the input array operands (which must be of the same rank as 637each other) and contains the arguments in the order that they were specified. 638 639<b> `Concatenate(operands..., dimension)` </b> 640 641| Arguments | Type | Semantics | 642| ----------- | --------------------- | -------------------------------------- | 643| `operands` | sequence of N `XlaOp` | N arrays of type T with dimensions | 644: : : [L0, L1, ...]. Requires N >= 1. : 645| `dimension` | `int64` | A value in the interval `[0, N)` that | 646: : : names the dimension to be concatenated : 647: : : between the `operands`. : 648 649With the exception of `dimension` all dimensions must be the same. This is 650because XLA does not support "ragged" arrays. Also note that rank-0 values 651cannot be concatenated (as it's impossible to name the dimension along which the 652concatenation occurs). 653 6541-dimensional example: 655 656``` 657Concat({{2, 3}, {4, 5}, {6, 7}}, 0) 658>>> {2, 3, 4, 5, 6, 7} 659``` 660 6612-dimensional example: 662 663``` 664let a = { 665{1, 2}, 666{3, 4}, 667{5, 6}, 668}; 669let b = { 670{7, 8}, 671}; 672Concat({a, b}, 0) 673>>> { 674{1, 2}, 675{3, 4}, 676{5, 6}, 677{7, 8}, 678} 679``` 680 681Diagram: 682<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 683<img style="width:100%" src="./images/ops_concatenate.png"> 684</div> 685 686## Conditional 687 688See also 689[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 690 691<b> `Conditional(pred, true_operand, true_computation, false_operand, 692false_computation)` </b> 693 694<!-- mdformat off(disable mdformat for proper MathJax formatting) --> 695 696Arguments | Type | Semantics 697------------------- | ---------------- | ------------------------------------ 698`pred` | `XlaOp` | Scalar of type `PRED` 699`true_operand` | `XlaOp` | Argument of type \\(T_0\\) 700`true_computation` | `XlaComputation` | XlaComputation of type \\(T_0 \to S\\) 701`false_operand` | `XlaOp` | Argument of type \\(T_1\\) 702`false_computation` | `XlaComputation` | XlaComputation of type \\(T_1 \to S\\) 703 704Executes `true_computation` if `pred` is `true`, `false_computation` if `pred` 705is `false`, and returns the result. 706 707The `true_computation` must take in a single argument of type \\(T_0\\) and will 708be invoked with `true_operand` which must be of the same type. The 709`false_computation` must take in a single argument of type \\(T_1\\) and will be 710invoked with `false_operand` which must be of the same type. The type of the 711returned value of `true_computation` and `false_computation` must be the same. 712 713<!-- mdformat on --> 714 715Note that only one of `true_computation` and `false_computation` will be 716executed depending on the value of `pred`. 717 718<b> `Conditional(branch_index, branch_computations, branch_operands)` </b> 719 720<!-- mdformat off(disable mdformat for proper MathJax formatting) --> 721 722| Arguments | Type | Semantics | 723| --------------------- | --------------------- | ---------------------------- | 724| `branch_index` | `XlaOp` | Scalar of type `S32` | 725| `branch_computations` | sequence of N | XlaComputations of type \\( | 726: : `XlaComputation` : T_0 \to S , T_1 \to S , ..., : 727: : : T_{N-1} \to S \\) : 728| `branch_operands` | sequence of N `XlaOp` | Arguments of type \\( T_0 , | 729: : : T_1 , ..., T_{N-1} \\) : 730 731<!-- mdformat on --> 732 733Executes `branch_computations[branch_index]`, and returns the result. If 734`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]` 735is executed as the default branch. 736 737Each `branch_computations[b]` must take in a single argument of type `T_b` and 738will be invoked with `branch_operands[b]` which must be of the same type. The 739type of the returned value of each `branch_computations[b]` must be the same. 740 741Note that only one of the `branch_computations` will be executed depending on 742the value of `branch_index`. 743 744## Conv (convolution) 745 746See also 747[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 748 749As ConvWithGeneralPadding, but the padding is specified in a short-hand way as 750either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that 751the output has the same shape as the input when not taking striding into 752account. VALID padding simply means no padding. 753 754## ConvWithGeneralPadding (convolution) 755 756See also 757[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 758 759Computes a convolution of the kind used in neural networks. Here, a convolution 760can be thought of as a n-dimensional window moving across a n-dimensional base 761area and a computation is performed for each possible position of the window. 762 763| Arguments | Type | Semantics | 764| --------------------- | ------------------------ | ------------------------ | 765| `lhs` | `XlaOp` | rank n+2 array of inputs | 766| `rhs` | `XlaOp` | rank n+2 array of kernel | 767: : : weights : 768| `window_strides` | `ArraySlice<int64>` | n-d array of kernel | 769: : : strides : 770| `padding` | `ArraySlice< pair<int64, | n-d array of (low, high) | 771: : int64>>` : padding : 772| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor | 773: : : array : 774| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor | 775: : : array : 776| `feature_group_count` | int64 | the number of feature | 777: : : groups : 778| `batch_group_count` | int64 | the number of batch | 779: : : groups : 780 781Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2 782array describing the base area. This is called the input, even though of course 783the rhs is also an input. In a neural network, these are the input activations. 784The n+2 dimensions are, in this order: 785 786* `batch`: Each coordinate in this dimension represents an independent input 787for which convolution is carried out. 788* `z/depth/features`: Each (y,x) position in the base area has a vector 789associated to it, which goes into this dimension. 790* `spatial_dims`: Describes the `n` spatial dimensions that define the base 791area that the window moves across. 792 793The `rhs` argument is a rank n+2 array describing the convolutional 794filter/kernel/window. The dimensions are, in this order: 795 796* `output-z`: The `z` dimension of the output. 797* `input-z`: The size of this dimension times `feature_group_count` should 798equal the size of the `z` dimension in lhs. 799* `spatial_dims`: Describes the `n` spatial dimensions that define the n-d 800window that moves across the base area. 801 802The `window_strides` argument specifies the stride of the convolutional window 803in the spatial dimensions. For example, if the stride in the first spatial 804dimension is 3, then the window can only be placed at coordinates where the 805first spatial index is divisible by 3. 806 807The `padding` argument specifies the amount of zero padding to be applied to the 808base area. The amount of padding can be negative -- the absolute value of 809negative padding indicates the number of elements to remove from the specified 810dimension before doing the convolution. `padding[0]` specifies the padding for 811dimension `y` and `padding[1]` specifies the padding for dimension `x`. Each 812pair has the low padding as the first element and the high padding as the second 813element. The low padding is applied in the direction of lower indices while the 814high padding is applied in the direction of higher indices. For example, if 815`padding[1]` is `(2,3)` then there will be a padding by 2 zeroes on the left and 816by 3 zeroes on the right in the second spatial dimension. Using padding is 817equivalent to inserting those same zero values into the input (`lhs`) before 818doing the convolution. 819 820The `lhs_dilation` and `rhs_dilation` arguments specify the dilation factor to 821be applied to the lhs and rhs, respectively, in each spatial dimension. If the 822dilation factor in a spatial dimension is d, then d-1 holes are implicitly 823placed between each of the entries in that dimension, increasing the size of the 824array. The holes are filled with a no-op value, which for convolution means 825zeroes. 826 827Dilation of the rhs is also called atrous convolution. For more details, see 828`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed 829convolution. For more details, see `tf.nn.conv2d_transpose`. 830 831The `feature_group_count` argument (default value 1) can be used for grouped 832convolutions. `feature_group_count` needs to be a divisor of both the input and 833the output feature dimension. If `feature_group_count` is greater than 1, it 834means that conceptually the input and output feature dimension and the `rhs` 835output feature dimension are split evenly into `feature_group_count` many 836groups, each group consisting of a consecutive subsequence of features. The 837input feature dimension of `rhs` needs to be equal to the `lhs` input feature 838dimension divided by `feature_group_count` (so it already has the size of a 839group of input features). The i-th groups are used together to compute 840`feature_group_count` many separate convolutions. The results of these 841convolutions are concatenated together in the output feature dimension. 842 843For depthwise convolution the `feature_group_count` argument would be set to the 844input feature dimension, and the filter would be reshaped from 845`[filter_height, filter_width, in_channels, channel_multiplier]` to 846`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more 847details, see `tf.nn.depthwise_conv2d`. 848 849The `batch_group_count` (default value 1) argument can be used for grouped 850filters during backpropagation. `batch_group_count` needs to be a divisor of the 851size of the `lhs` (input) batch dimension. If `batch_group_count` is greater 852than 1, it means that the output batch dimension should be of size `input batch 853/ batch_group_count`. The `batch_group_count` must be a divisor of the output 854feature size. 855 856The output shape has these dimensions, in this order: 857 858* `batch`: The size of this dimension times `batch_group_count` should equal 859 the size of the `batch` dimension in lhs. 860* `z`: Same size as `output-z` on the kernel (`rhs`). 861* `spatial_dims`: One value for each valid placement of the convolutional 862 window. 863 864The valid placements of the convolutional window are determined by the strides 865and the size of the base area after padding. 866 867To describe what a convolution does, consider a 2d convolution, and pick some 868fixed `batch`, `z`, `y`, `x` coordinates in the output. Then `(y,x)` is a 869position of a corner of the window within the base area (e.g. the upper left 870corner, depending on how you interpret the spatial dimensions). We now have a 2d 871window, taken from the base area, where each 2d point is associated to a 1d 872vector, so we get a 3d box. From the convolutional kernel, since we fixed the 873output coordinate `z`, we also have a 3d box. The two boxes have the same 874dimensions, so we can take the sum of the element-wise products between the two 875boxes (similar to a dot product). That is the output value. 876 877Note that if `output-z` is e.g., 5, then each position of the window produces 5 878values in the output into the `z` dimension of the output. These values differ 879in what part of the convolutional kernel is used - there is a separate 3d box of 880values used for each `output-z` coordinate. So you could think of it as 5 881separate convolutions with a different filter for each of them. 882 883Here is pseudo-code for a 2d convolution with padding and striding: 884 885``` 886for (b, oz, oy, ox) { // output coordinates 887 value = 0; 888 for (iz, ky, kx) { // kernel coordinates and input z 889 iy = oy*stride_y + ky - pad_low_y; 890 ix = ox*stride_x + kx - pad_low_x; 891 if ((iy, ix) inside the base area considered without padding) { 892 value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); 893 } 894 } 895 output(b, oz, oy, ox) = value; 896} 897``` 898 899## ConvertElementType 900 901See also 902[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 903 904Similar to an element-wise `static_cast` in C++, performs an element-wise 905conversion operation from a data shape to a target shape. The dimensions must 906match, and the conversion is an element-wise one; e.g. `s32` elements become 907`f32` elements via an `s32`-to-`f32` conversion routine. 908 909<b> `ConvertElementType(operand, new_element_type)` </b> 910 911Arguments | Type | Semantics 912------------------ | --------------- | --------------------------- 913`operand` | `XlaOp` | array of type T with dims D 914`new_element_type` | `PrimitiveType` | type U 915 916The dimensions of the operand and the target shape must match. The source and 917destination element types must not be tuples. 918 919A conversion such as `T=s32` to `U=f32` will perform a normalizing int-to-float 920conversion routine such as round-to-nearest-even. 921 922> Note: The precise float-to-int and visa-versa conversions are currently 923> unspecified, but may become additional arguments to the convert operation in 924> the future. Not all possible conversions have been implemented for all 925>targets. 926 927``` 928let a: s32[3] = {0, 1, 2}; 929let b: f32[3] = convert(a, f32); 930then b == f32[3]{0.0, 1.0, 2.0} 931``` 932 933## CrossReplicaSum 934 935Performs `AllReduce` with a summation computation. 936 937## CustomCall 938 939See also 940[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 941 942Call a user-provided function within a computation. 943 944<b> `CustomCall(target_name, args..., shape)` </b> 945 946| Arguments | Type | Semantics | 947| ------------- | ---------------------- | --------------------------------- | 948| `target_name` | `string` | Name of the function. A call | 949: : : instruction will be emitted which : 950: : : targets this symbol name. : 951| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type, | 952: : : which will be passed to the : 953: : : function. : 954| `shape` | `Shape` | Output shape of the function | 955 956The function signature is the same, regardless of the arity or type of args: 957 958``` 959extern "C" void target_name(void* out, void** in); 960``` 961 962For example, if CustomCall is used as follows: 963 964``` 965let x = f32[2] {1,2}; 966let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}}; 967 968CustomCall("myfunc", {x, y}, f32[3x3]) 969``` 970 971Here is an example of an implementation of `myfunc`: 972 973``` 974extern "C" void myfunc(void* out, void** in) { 975 float (&x)[2] = *static_cast<float(*)[2]>(in[0]); 976 float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]); 977 EXPECT_EQ(1, x[0]); 978 EXPECT_EQ(2, x[1]); 979 EXPECT_EQ(10, y[0][0]); 980 EXPECT_EQ(20, y[0][1]); 981 EXPECT_EQ(30, y[0][2]); 982 EXPECT_EQ(40, y[1][0]); 983 EXPECT_EQ(50, y[1][1]); 984 EXPECT_EQ(60, y[1][2]); 985 float (&z)[3][3] = *static_cast<float(*)[3][3]>(out); 986 z[0][0] = x[1] + y[1][0]; 987 // ... 988} 989``` 990 991The user-provided function must not have side-effects and its execution must be 992idempotent. 993 994> Note: The opaque nature of the user-provided function restricts optimization 995> opportunities for the compiler. Try to express your computation in terms of 996> native XLA ops whenever possible; only use CustomCall as a last resort. 997 998## Dot 999 1000See also 1001[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1002 1003<b> `Dot(lhs, rhs)` </b> 1004 1005Arguments | Type | Semantics 1006--------- | ------- | --------------- 1007`lhs` | `XlaOp` | array of type T 1008`rhs` | `XlaOp` | array of type T 1009 1010The exact semantics of this operation depend on the ranks of the operands: 1011 1012| Input | Output | Semantics | 1013| ----------------------- | --------------------- | ----------------------- | 1014| vector [n] `dot` vector | scalar | vector dot product | 1015: [n] : : : 1016| matrix [m x k] `dot` | vector [m] | matrix-vector | 1017: vector [k] : : multiplication : 1018| matrix [m x k] `dot` | matrix [m x n] | matrix-matrix | 1019: matrix [k x n] : : multiplication : 1020 1021The operation performs sum of products over the second dimension of `lhs` (or 1022the first if it has rank 1) and the first dimension of `rhs`. These are the 1023"contracted" dimensions. The contracted dimensions of `lhs` and `rhs` must be of 1024the same size. In practice, it can be used to perform dot products between 1025vectors, vector/matrix multiplications or matrix/matrix multiplications. 1026 1027## DotGeneral 1028 1029See also 1030[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1031 1032<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b> 1033 1034Arguments | Type | Semantics 1035------------------- | --------------------- | --------------- 1036`lhs` | `XlaOp` | array of type T 1037`rhs` | `XlaOp` | array of type T 1038`dimension_numbers` | `DotDimensionNumbers` | contracting and batch dimension numbers 1039 1040As Dot, but allows contracting and batch dimension numbers to be specified for 1041both the 'lhs' and 'rhs'. 1042 1043| DotDimensionNumbers Fields | Type | Semantics 1044| --------- | ----------------------- | --------------- 1045| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers | 1046| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers | 1047| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers | 1048| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers | 1049 1050DotGeneral performs the sum of products over contracting dimensions specified 1051in 'dimension_numbers'. 1052 1053Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need 1054to be the same but must have the same dimension sizes. 1055 1056Example with contracting dimension numbers: 1057 1058``` 1059lhs = { {1.0, 2.0, 3.0}, 1060{4.0, 5.0, 6.0} } 1061 1062rhs = { {1.0, 1.0, 1.0}, 1063{2.0, 2.0, 2.0} } 1064 1065DotDimensionNumbers dnums; 1066dnums.add_lhs_contracting_dimensions(1); 1067dnums.add_rhs_contracting_dimensions(1); 1068 1069DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, 1070{15.0, 30.0} } 1071``` 1072 1073Associated batch dimension numbers from the 'lhs' and 'rhs' must 1074have the same dimension sizes. 1075 1076Example with batch dimension numbers (batch size 2, 2x2 matrices): 1077 1078``` 1079lhs = { { {1.0, 2.0}, 1080{3.0, 4.0} }, 1081{ {5.0, 6.0}, 1082{7.0, 8.0} } } 1083 1084rhs = { { {1.0, 0.0}, 1085{0.0, 1.0} }, 1086{ {1.0, 0.0}, 1087{0.0, 1.0} } } 1088 1089DotDimensionNumbers dnums; 1090dnums.add_lhs_contracting_dimensions(2); 1091dnums.add_rhs_contracting_dimensions(1); 1092dnums.add_lhs_batch_dimensions(0); 1093dnums.add_rhs_batch_dimensions(0); 1094 1095DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, 1096{3.0, 4.0} }, 1097{ {5.0, 6.0}, 1098{7.0, 8.0} } } 1099``` 1100 1101| Input | Output | Semantics | 1102| ----------------------------------- | ----------------- | ---------------- | 1103| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul | 1104| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul | 1105 1106It follows that the resulting dimension number starts with the batch dimension, 1107then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' 1108non-contracting/non-batch dimension. 1109 1110## DynamicSlice 1111 1112See also 1113[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1114 1115DynamicSlice extracts a sub-array from the input array at dynamic 1116`start_indices`. The size of the slice in each dimension is passed in 1117`size_indices`, which specify the end point of exclusive slice intervals in each 1118dimension: [start, start + size). The shape of `start_indices` must be rank == 11191, with dimension size equal to the rank of `operand`. 1120 1121<b> `DynamicSlice(operand, start_indices, size_indices)` </b> 1122 1123| Arguments | Type | Semantics | 1124| --------------- | --------------------- | ---------------------------------- | 1125| `operand` | `XlaOp` | N dimensional array of type T | 1126| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | 1127: : : containing the starting indices of : 1128: : : the slice for each dimension. : 1129: : : Value must be greater than or : 1130: : : equal to zero. : 1131| `size_indices` | `ArraySlice<int64>` | List of N integers containing the | 1132: : : slice size for each dimension. : 1133: : : Each value must be strictly : 1134: : : greater than zero, and start + : 1135: : : size must be less than or equal to : 1136: : : the size of the dimension to avoid : 1137: : : wrapping modulo dimension size. : 1138 1139The effective slice indices are computed by applying the following 1140transformation for each index `i` in `[1, N)` before performing the slice: 1141 1142``` 1143start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) 1144``` 1145 1146This ensures that the extracted slice is always in-bounds with respect to the 1147operand array. If the slice is in-bounds before the transformation is applied, 1148the transformation has no effect. 1149 11501-dimensional example: 1151 1152``` 1153let a = {0.0, 1.0, 2.0, 3.0, 4.0} 1154let s = {2} 1155 1156DynamicSlice(a, s, {2}) produces: 1157{2.0, 3.0} 1158``` 1159 11602-dimensional example: 1161 1162``` 1163let b = 1164{ {0.0, 1.0, 2.0}, 1165{3.0, 4.0, 5.0}, 1166{6.0, 7.0, 8.0}, 1167{9.0, 10.0, 11.0} } 1168let s = {2, 1} 1169 1170DynamicSlice(b, s, {2, 2}) produces: 1171{ { 7.0, 8.0}, 1172{10.0, 11.0} } 1173``` 1174## DynamicUpdateSlice 1175 1176See also 1177[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1178 1179DynamicUpdateSlice generates a result which is the value of the input array 1180`operand`, with a slice `update` overwritten at `start_indices`. 1181The shape of `update` determines the shape of the sub-array of the result which 1182is updated. 1183The shape of `start_indices` must be rank == 1, with dimension size equal to 1184the rank of `operand`. 1185 1186<b> `DynamicUpdateSlice(operand, update, start_indices)` </b> 1187 1188| Arguments | Type | Semantics | 1189| --------------- | --------------------- | ---------------------------------- | 1190| `operand` | `XlaOp` | N dimensional array of type T | 1191| `update` | `XlaOp` | N dimensional array of type T | 1192: : : containing the slice update. Each : 1193: : : dimension of update shape must be : 1194: : : strictly greater than zero, and : 1195: : : start + update must be less than : 1196: : : or equal to the operand size for : 1197: : : each dimension to avoid generating : 1198: : : out-of-bounds update indices. : 1199| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | 1200: : : containing the starting indices of : 1201: : : the slice for each dimension. : 1202: : : Value must be greater than or : 1203: : : equal to zero. : 1204 1205The effective slice indices are computed by applying the following 1206transformation for each index `i` in `[1, N)` before performing the slice: 1207 1208``` 1209start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i]) 1210``` 1211 1212This ensures that the updated slice is always in-bounds with respect to the 1213operand array. If the slice is in-bounds before the transformation is applied, 1214the transformation has no effect. 1215 12161-dimensional example: 1217 1218``` 1219let a = {0.0, 1.0, 2.0, 3.0, 4.0} 1220let u = {5.0, 6.0} 1221let s = {2} 1222 1223DynamicUpdateSlice(a, u, s) produces: 1224{0.0, 1.0, 5.0, 6.0, 4.0} 1225``` 1226 12272-dimensional example: 1228 1229``` 1230let b = 1231{ {0.0, 1.0, 2.0}, 1232{3.0, 4.0, 5.0}, 1233{6.0, 7.0, 8.0}, 1234{9.0, 10.0, 11.0} } 1235let u = 1236{ {12.0, 13.0}, 1237{14.0, 15.0}, 1238{16.0, 17.0} } 1239 1240let s = {1, 1} 1241 1242DynamicUpdateSlice(b, u, s) produces: 1243{ {0.0, 1.0, 2.0}, 1244{3.0, 12.0, 13.0}, 1245{6.0, 14.0, 15.0}, 1246{9.0, 16.0, 17.0} } 1247``` 1248 1249## Element-wise binary arithmetic operations 1250 1251See also 1252[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1253 1254A set of element-wise binary arithmetic operations is supported. 1255 1256<b> `Op(lhs, rhs)` </b> 1257 1258Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul` 1259(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min` 1260(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR). 1261 1262Arguments | Type | Semantics 1263--------- | ------- | ---------------------------------------- 1264`lhs` | `XlaOp` | left-hand-side operand: array of type T 1265`rhs` | `XlaOp` | right-hand-side operand: array of type T 1266 1267The arguments' shapes have to be either similar or compatible. See the 1268[broadcasting](broadcasting.md) documentation about what it means for shapes to 1269be compatible. The result of an operation has a shape which is the result of 1270broadcasting the two input arrays. In this variant, operations between arrays of 1271different ranks are *not* supported, unless one of the operands is a scalar. 1272 1273When `Op` is `Rem`, the sign of the result is taken from the dividend, and the 1274absolute value of the result is always less than the divisor's absolute value. 1275 1276Integer division overflow (signed/unsigned division/remainder by zero or signed 1277division/remainder of `INT_SMIN` with `-1`) produces an implementation defined 1278value. 1279 1280An alternative variant with different-rank broadcasting support exists for these 1281operations: 1282 1283<b> `Op(lhs, rhs, broadcast_dimensions)` </b> 1284 1285Where `Op` is the same as above. This variant of the operation should be used 1286for arithmetic operations between arrays of different ranks (such as adding a 1287matrix to a vector). 1288 1289The additional `broadcast_dimensions` operand is a slice of integers used to 1290expand the rank of the lower-rank operand up to the rank of the higher-rank 1291operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to 1292the dimensions of the higher-rank shape. The unmapped dimensions of the expanded 1293shape are filled with dimensions of size one. Degenerate-dimension broadcasting 1294then broadcasts the shapes along these degenerate dimensions to equalize the 1295shapes of both operands. The semantics are described in detail on the 1296[broadcasting page](broadcasting.md). 1297 1298## Element-wise comparison operations 1299 1300See also 1301[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1302 1303A set of standard element-wise binary comparison operations is supported. Note 1304that standard IEEE 754 floating-point comparison semantics apply when comparing 1305floating-point types. 1306 1307<b> `Op(lhs, rhs)` </b> 1308 1309Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge` 1310(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt` 1311(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder, 1312GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities, 1313except that they additionally support a total order over the floating point 1314numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN. 1315 1316Arguments | Type | Semantics 1317--------- | ------- | ---------------------------------------- 1318`lhs` | `XlaOp` | left-hand-side operand: array of type T 1319`rhs` | `XlaOp` | right-hand-side operand: array of type T 1320 1321The arguments' shapes have to be either similar or compatible. See the 1322[broadcasting](broadcasting.md) documentation about what it means for shapes to 1323be compatible. The result of an operation has a shape which is the result of 1324broadcasting the two input arrays with the element type `PRED`. In this variant, 1325operations between arrays of different ranks are *not* supported, unless one of 1326the operands is a scalar. 1327 1328An alternative variant with different-rank broadcasting support exists for these 1329operations: 1330 1331<b> `Op(lhs, rhs, broadcast_dimensions)` </b> 1332 1333Where `Op` is the same as above. This variant of the operation should be used 1334for comparison operations between arrays of different ranks (such as adding a 1335matrix to a vector). 1336 1337The additional `broadcast_dimensions` operand is a slice of integers specifying 1338the dimensions to use for broadcasting the operands. The semantics are described 1339in detail on the [broadcasting page](broadcasting.md). 1340 1341## Element-wise unary functions 1342 1343XlaBuilder supports these element-wise unary functions: 1344 1345<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`. 1346 1347<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`. 1348 1349<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`. 1350 1351<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`. 1352 1353<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`. 1354 1355<b>`Imag(operand)`</b> Element-wise imaginary part of a complex (or real) 1356shape. `x -> imag(x)`. If the operand is a floating point type, returns 0. 1357 1358<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite, 1359i.e., is not positive or negative infinity, and is not `NaN`. Returns an array 1360of `PRED` values with the same shape as the input, where each element is `true` 1361if and only if the corresponding input element is finite. 1362 1363<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`. 1364 1365<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`. 1366 1367<b>`Logistic(operand)`</b> Element-wise logistic function computation `x -> 1368logistic(x)`. 1369 1370<b>`PopulationCount(operand)`</b> Computes the number of bits set in each 1371element of `operand`. 1372 1373<b>`Neg(operand)`</b> Element-wise negation `x -> -x`. 1374 1375<b>`Real(operand)`</b> Element-wise real part of a complex (or real) shape. 1376`x -> real(x)`. If the operand is a floating point type, returns the same value. 1377 1378<b>`Rsqrt(operand)`</b> Element-wise reciprocal of square root operation 1379`x -> 1.0 / sqrt(x)`. 1380 1381<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where 1382 1383$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$ 1384 1385using the comparison operator of the element type of `operand`. 1386 1387<b>`Sqrt(operand)`</b> Element-wise square root operation `x -> sqrt(x)`. 1388 1389<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`. 1390 1391<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`. 1392 1393<b>`Round(operand)`</b> Element-wise rounding, ties away from zero. 1394 1395<b>`RoundNearestEven(operand)`</b> Element-wise rounding, ties to nearest even. 1396 1397Arguments | Type | Semantics 1398--------- | ------- | --------------------------- 1399`operand` | `XlaOp` | The operand to the function 1400 1401The function is applied to each element in the `operand` array, resulting in an 1402array with the same shape. It is allowed for `operand` to be a scalar (rank 0). 1403 1404## Fft 1405 1406The XLA FFT operation implements the forward and inverse Fourier Transforms for 1407real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are 1408supported. 1409 1410See also 1411[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1412 1413| Arguments | Type | Semantics | 1414| ------------ | ------------------- | ------------------------ | 1415| `operand` | `XlaOp` | The array we are Fourier | 1416: : : transforming. : 1417| `fft_type` | `FftType` | See the table below. | 1418| `fft_length` | `ArraySlice<int64>` | The time-domain lengths | 1419: : : of the axes being : 1420: : : transformed. This is : 1421: : : needed in particular for : 1422: : : IRFFT to right-size the : 1423: : : innermost axis, since : 1424: : : `RFFT(fft_length=[16])` : 1425: : : has the same output : 1426: : : shape as : 1427: : : `RFFT(fft_length=[17])`. : 1428 1429| `FftType` | Semantics | 1430| --------- | ---------------------------------------------------------------- | 1431| `FFT` | Forward complex-to-complex FFT. Shape is unchanged. | 1432| `IFFT` | Inverse complex-to-complex FFT. Shape is unchanged. | 1433| `RFFT` | Forward real-to-complex FFT. Shape of the innermost axis is | 1434: : reduced to `fft_length[-1] // 2 + 1` if `fft_length[-1]` is a : 1435: : non-zero value, omitting the reversed conjugate part of the : 1436: : transformed signal beyond the Nyquist frequency. : 1437| `IRFFT` | Inverse real-to-complex FFT (i.e. takes complex, returns real). | 1438: : Shape of the innermost axis is expanded to `fft_length[-1]` if : 1439: : `fft_length[-1]` is a non-zero value, inferring the part of the : 1440: : transformed signal beyond the Nyquist frequency from the reverse : 1441: : conjugate of the `1` to `fft_length[-1] // 2 + 1` entries. : 1442 1443#### Multidimensional FFT 1444 1445When more than 1 `fft_length` is provided, this is equivalent to applying a 1446cascade of FFT operations to each of the innermost axes. Note that for the 1447real->complex and complex->real cases, the innermost axis transform is 1448(effectively) performed first (RFFT; last for IRFFT), which is why the innermost 1449axis is the one which changes size. Other axis transforms will then be 1450complex->complex. 1451 1452#### Implementation details 1453 1454CPU FFT is backed by Eigen's TensorFFT. GPU FFT uses cuFFT. 1455 1456## Gather 1457 1458The XLA gather operation stitches together several slices (each slice at a 1459potentially different runtime offset) of an input array. 1460 1461### General Semantics 1462 1463See also 1464[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1465For a more intuitive description, see the "Informal Description" section below. 1466 1467<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b> 1468 1469| Arguments | Type | Semantics | 1470| ---------------------- | ------------------- | ----------------------------- | 1471| `operand` | `XlaOp` | The array we’re gathering | 1472: : : from. : 1473| `start_indices` | `XlaOp` | Array containing the starting | 1474: : : indices of the slices we : 1475: : : gather. : 1476| `index_vector_dim` | `int64` | The dimension in | 1477: : : `start_indices` that : 1478: : : "contains" the starting : 1479: : : indices. See below for a : 1480: : : detailed description. : 1481| `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the | 1482: : : output shape that offset into : 1483: : : an array sliced from operand. : 1484| `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the | 1485: : : bounds for the slice on : 1486: : : dimension `i`. : 1487| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each | 1488: : : slice that are collapsed : 1489: : : away. These dimensions must : 1490: : : have size 1. : 1491| `start_index_map` | `ArraySlice<int64>` | A map that describes how to | 1492: : : map indices in : 1493: : : `start_indices` to legal : 1494: : : indices into operand. : 1495| `indices_are_sorted` | `bool` | Whether the indices are | 1496: : : guaranteed to be sorted by : 1497: : : the caller. : 1498| `unique_indices` | `bool` | Whether the indices are | 1499: : : guaranteed to be unique by : 1500: : : the caller. : 1501 1502For convenience, we label dimensions in the output array not in `offset_dims` 1503as `batch_dims`. 1504 1505The output is an array of rank `batch_dims.size` + `offset_dims.size`. 1506 1507The `operand.rank` must equal the sum of `offset_dims.size` and 1508`collapsed_slice_dims.size`. Also, `slice_sizes.size` has to be equal to 1509`operand.rank`. 1510 1511If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider 1512`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of 1513shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the 1514shape of `start_indices` to be `[6,7,1]`). 1515 1516The bounds for the output array along dimension `i` is computed as follows: 1517 15181. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for 1519some `k`) then we pick the corresponding dimension bounds out of 1520`start_indices.shape`, skipping `index_vector_dim` (i.e. pick 1521`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and 1522`start_indices.shape.dims`[`k`+`1`] otherwise). 1523 15242. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for 1525some `k`) then we pick the corresponding bound out of `slice_sizes` after 1526accounting for `collapsed_slice_dims` (i.e. we pick 1527`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` 1528with the bounds at indices `collapsed_slice_dims` removed). 1529 1530Formally, the operand index `In` corresponding to a given output index `Out` is 1531calculated as follows: 1532 15331. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a 1534 vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where 1535 Combine(A, b) inserts b at position `index_vector_dim` into A. Note that 1536 this is well defined even if `G` is empty -- if `G` is empty then `S` = 1537 `start_indices`. 1538 15392. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by 1540 scattering `S` using `start_index_map`. More precisely: 1541 1542 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` < 1543 `start_index_map.size`. 1544 1545 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. 1546 15473. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices 1548 at the offset dimensions in `Out` according to the `collapsed_slice_dims` 1549 set. More precisely: 1550 1551 1. `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] = 1552 `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` 1553 (`remapped_offset_dims` is defined below). 1554 1555 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise. 1556 15574. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise 1558 addition. 1559 1560`remapped_offset_dims` is a monotonic function with domain [`0`, 1561`offset_dims.size`) and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So 1562if, e.g., `offset_dims.size` is `4`, `operand.rank` is `6` and 1563`collapsed_slice_dims` is {`0`, `2`} then `remapped_offset_dims` is {`0`→`1`, 1564`1`→`3`, `2`→`4`, `3`→`5`}. 1565 1566If `indices_are_sorted` is set to true then XLA can assume that `start_indices` 1567are sorted (in ascending `start_index_map` order) by the user. If they are not 1568then the semantics is implementation defined. 1569 1570If `unique_indices` is set to true then XLA can assume that all element 1571scattered to are unique. So XLA could use non-atomic operations. If 1572`unique_indices` is set to true and the indices being scattered to are not 1573unique then the semantics is implementation defined. 1574 1575### Informal Description and Examples 1576 1577Informally, every index `Out` in the output array corresponds to an element `E` 1578in the operand array, computed as follows: 1579 1580- We use the batch dimensions in `Out` to look up a starting index from 1581 `start_indices`. 1582 1583- We use `start_index_map` to map the starting index (whose size may be less 1584 than operand.rank) to a "full" starting index into the `operand`. 1585 1586- We dynamic-slice out a slice with size `slice_sizes` using the full starting 1587 index. 1588 1589- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. 1590 Since all collapsed slice dimensions must have a bound of 1, this reshape is 1591 always legal. 1592 1593- We use the offset dimensions in `Out` to index into this slice to get the 1594 input element, `E`, corresponding to output index `Out`. 1595 1596`index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples 1597that follow. More interesting values for `index_vector_dim` do not change the 1598operation fundamentally, but make the visual representation more cumbersome. 1599 1600To get an intuition on how all of the above fits together, let's look at an 1601example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The 1602position of a slice into the `[16,11]` array can be represented as an index 1603vector of shape `S64[2]`, so the set of 5 positions can be represented as a 1604`S64[5,2]` array. 1605 1606The behavior of the gather operation can then be depicted as an index 1607transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in 1608the output shape, and maps it to an element in the input array in the following 1609way: 1610 1611<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1612<img style="width:100%" src="./images/ops_xla_gather_0.svg"> 1613</div> 1614 1615We first select an (`X`,`Y`) vector from the gather indices array using `G`. 1616The element in the output array at index 1617[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input 1618array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>]. 1619 1620`slice_sizes` is `[8,6]`, which decides the range of O<sub>`0`</sub> and 1621O<sub>`1`</sub>, and this in turn decides the bounds of the slice. 1622 1623This gather operation acts as a batch dynamic slice with `G` as the batch 1624dimension. 1625 1626The gather indices may be multidimensional. For instance, a more general 1627version of the example above using a "gather indices" array of shape `[4,5,2]` 1628would translate indices like this: 1629 1630<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1631<img style="width:100%" src="./images/ops_xla_gather_1.svg"> 1632</div> 1633 1634Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and 1635`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`. 1636 1637The gather operation in XLA generalizes the informal semantics outlined above in 1638the following ways: 1639 16401. We can configure which dimensions in the output shape are the offset 1641dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in 1642the last example). The output batch dimensions (dimensions containing 1643`G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be 1644the output dimensions that are not offset dimensions. 1645 16462. The number of output offset dimensions explicitly present in the output 1647shape may be smaller than the input rank. These "missing" dimensions, which 1648are listed explicitly as `collapsed_slice_dims`, must have a slice size of 1649`1`. Since they have a slice size of `1` the only valid index for them is 1650`0` and eliding them does not introduce ambiguity. 1651 16523. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last 1653example) may have fewer elements than the input array rank, and an explicit 1654mapping dictates how the index should be expanded to have the same rank as 1655the input. 1656 1657As a final example, we use (2) and (3) to implement `tf.gather_nd`: 1658 1659<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1660<img style="width:100%" src="./images/ops_xla_gather_2.svg"> 1661</div> 1662 1663`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index 1664from the gather indices array as usual, except the starting index has only one 1665element, `X`. Similarly, there is only one output offset index with the value 1666`O`<sub>`0`</sub>. However, before being used as indices into the input array, 1667these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in 1668the formal description) and "Offset Mapping" (`remapped_offset_dims` in the 1669formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively, 1670adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index 1671[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index 1672[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us 1673the semantics for `tf.gather_nd`. 1674 1675`slice_sizes` for this case is `[1,11]`. Intuitively this means that every 1676index `X` in the gather indices array picks an entire row and the result is the 1677concatenation of all these rows. 1678 1679## GetDimensionSize 1680 1681See also 1682[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1683 1684Returns the size of the given dimension of the operand. The operand must be 1685array shaped. 1686 1687<b> `GetDimensionSize(operand, dimension)` </b> 1688 1689| Arguments | Type | Semantics | 1690| ----------- | ------- | --------------------------------------------------- | 1691| `operand` | `XlaOp` | n dimensional input array | 1692| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | 1693: : : dimension : 1694 1695## SetDimensionSize 1696 1697See also 1698[`XlaBuilder::SetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1699 1700Sets the dynamic size of XlaOp's given dimension. The operand must be 1701array shaped. 1702 1703<b> `SetDimensionSize(operand, size, dimension)` </b> 1704 1705| Arguments | Type | Semantics | 1706| ----------- | ------- | --------------------------------------------------- | 1707| `operand` | `XlaOp` | n dimensional input array. | 1708| `size` | `XlaOp` | int32 representing the runtime dynamic size. | 1709| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | 1710: : : dimension. : 1711 1712Pass through the operand as result, with dynamic dimension tracked by the 1713compiler. 1714 1715Padded values will be ignored by downstream reduction ops. 1716 1717``` 1718let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; 1719let five: s32 = 5; 1720let six: s32 = 6; 1721 1722// Setting dynamic dimension size doesn't change the upper bound of the static 1723// shape. 1724let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0); 1725let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0); 1726 1727// sum == 1 + 2 + 3 + 4 + 5 1728let sum:f32[] = reduce_sum(padded_v_five); 1729// product == 1 * 2 * 3 * 4 * 5 1730let product:f32[] = reduce_product(padded_v_five); 1731 1732// Changing padding size will yield different result. 1733// sum == 1 + 2 + 3 + 4 + 5 + 6 1734let sum':f32[] = reduce_sum(padded_v_six); 1735``` 1736 1737## GetTupleElement 1738 1739See also 1740[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1741 1742Indexes into a tuple with a compile-time-constant value. 1743 1744The value must be a compile-time-constant so that shape inference can determine 1745the type of the resulting value. 1746 1747This is analogous to `std::get<int N>(t)` in C++. Conceptually: 1748 1749``` 1750let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; 1751let s: s32 = 5; 1752let t: (f32[10], s32) = tuple(v, s); 1753let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32. 1754``` 1755 1756See also `tf.tuple`. 1757 1758## Infeed 1759 1760See also 1761[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1762 1763<b> `Infeed(shape)` </b> 1764 1765| Argument | Type | Semantics | 1766| -------- | ------- | ----------------------------------------------------- | 1767| `shape` | `Shape` | Shape of the data read from the Infeed interface. The | 1768: : : layout field of the shape must be set to match the : 1769: : : layout of the data sent to the device; otherwise its : 1770: : : behavior is undefined. : 1771 1772Reads a single data item from the implicit Infeed streaming interface of the 1773device, interpreting the data as the given shape and its layout, and returns a 1774`XlaOp` of the data. Multiple Infeed operations are allowed in a 1775computation, but there must be a total order among the Infeed operations. For 1776example, two Infeeds in the code below have a total order since there is a 1777dependency between the while loops. 1778 1779``` 1780result1 = while (condition, init = init_value) { 1781 Infeed(shape) 1782} 1783 1784result2 = while (condition, init = result1) { 1785 Infeed(shape) 1786} 1787``` 1788 1789Nested tuple shapes are not supported. For an empty tuple shape, the Infeed 1790operation is effectively a no-op and proceeds without reading any data from the 1791Infeed of the device. 1792 1793> Note: We plan to allow multiple Infeed operations without a total order, in 1794> which case the compiler will provide information about how the Infeed 1795> operations are serialized in the compiled program. 1796 1797## Iota 1798 1799See also 1800[`XlaBuilder::Iota`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1801 1802<b> `Iota(shape, iota_dimension)` </b> 1803 1804Builds a constant literal on device rather than a potentially large host 1805transfer. Creates an array that has specified shape and holds values starting at 1806zero and incrementing by one along the specified dimension. For floating-point 1807types, the produced array is equivalent to `ConvertElementType(Iota(...))` where 1808the `Iota` is of integral type and the conversion is to the floating-point type. 1809 1810Arguments | Type | Semantics 1811---------------- | ------- | -------------------------------------- 1812`shape` | `Shape` | Shape of the array created by `Iota()` 1813`iota_dimension` | `int64` | The dimension to increment along. 1814 1815For example, `Iota(s32[4, 8], 0)` returns 1816 1817``` 1818 [[0, 0, 0, 0, 0, 0, 0, 0 ], 1819 [1, 1, 1, 1, 1, 1, 1, 1 ], 1820 [2, 2, 2, 2, 2, 2, 2, 2 ], 1821 [3, 3, 3, 3, 3, 3, 3, 3 ]] 1822``` 1823 1824`Iota(s32[4, 8], 1)` returns 1825 1826``` 1827 [[0, 1, 2, 3, 4, 5, 6, 7 ], 1828 [0, 1, 2, 3, 4, 5, 6, 7 ], 1829 [0, 1, 2, 3, 4, 5, 6, 7 ], 1830 [0, 1, 2, 3, 4, 5, 6, 7 ]] 1831``` 1832 1833## Map 1834 1835See also 1836[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1837 1838<b> `Map(operands..., computation)` </b> 1839 1840| Arguments | Type | Semantics | 1841| ----------------- | ---------------------- | ------------------------------ | 1842| `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} | 1843| `computation` | `XlaComputation` | computation of type `T_0, T_1, | 1844: : : ..., T_{N + M -1} -> S` with N : 1845: : : parameters of type T and M of : 1846: : : arbitrary type : 1847| `dimensions` | `int64` array | array of map dimensions | 1848 1849Applies a scalar function over the given `operands` arrays, producing an array 1850of the same dimensions where each element is the result of the mapped function 1851applied to the corresponding elements in the input arrays. 1852 1853The mapped function is an arbitrary computation with the restriction that it has 1854N inputs of scalar type `T` and a single output with type `S`. The output has 1855the same dimensions as the operands except that the element type T is replaced 1856with S. 1857 1858For example: `Map(op1, op2, op3, computation, par1)` maps `elem_out <- 1859computation(elem1, elem2, elem3, par1)` at each (multi-dimensional) index in the 1860input arrays to produce the output array. 1861 1862## OptimizationBarrier 1863 1864Blocks any optimization pass from moving computations across the barrier. 1865 1866Ensures that all inputs are evaluated before any operators that depend on the 1867barrier's outputs. 1868 1869## Pad 1870 1871See also 1872[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1873 1874<b> `Pad(operand, padding_value, padding_config)` </b> 1875 1876| Arguments | Type | Semantics | 1877| ---------------- | --------------- | --------------------------------------- | 1878| `operand` | `XlaOp` | array of type `T` | 1879| `padding_value` | `XlaOp` | scalar of type `T` to fill in the added | 1880: : : padding : 1881| `padding_config` | `PaddingConfig` | padding amount on both edges (low, | 1882: : : high) and between the elements of each : 1883: : : dimension : 1884 1885Expands the given `operand` array by padding around the array as well as between 1886the elements of the array with the given `padding_value`. `padding_config` 1887specifies the amount of edge padding and the interior padding for each 1888dimension. 1889 1890`PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains 1891three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and 1892`interior_padding`. 1893 1894`edge_padding_low` and `edge_padding_high` specify the amount of padding added 1895at the low-end (next to index 0) and the high-end (next to the highest index) of 1896each dimension respectively. The amount of edge padding can be negative -- the 1897absolute value of negative padding indicates the number of elements to remove 1898from the specified dimension. 1899 1900`interior_padding` specifies the amount of padding added between any two 1901elements in each dimension; it may not be negative. Interior padding occurs 1902logically before edge padding, so in the case of negative edge padding, elements 1903are removed from the interior-padded operand. 1904 1905This operation is a no-op if the edge padding pairs are all (0, 0) and the 1906interior padding values are all 0. The figure below shows examples of different 1907`edge_padding` and `interior_padding` values for a two-dimensional array. 1908 1909<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1910 <img style="width:100%" src="./images/ops_pad.png"> 1911</div> 1912 1913## Recv 1914 1915See also 1916[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1917 1918<b> `Recv(shape, channel_handle)` </b> 1919 1920| Arguments | Type | Semantics | 1921| ---------------- | --------------- | ------------------------------------ | 1922| `shape` | `Shape` | shape of the data to receive | 1923| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair | 1924 1925Receives data of the given shape from a `Send` instruction in another 1926computation that shares the same channel handle. Returns a 1927XlaOp for the received data. 1928 1929The client API of `Recv` operation represents synchronous communication. 1930However, the instruction is internally decomposed into 2 HLO instructions 1931(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also 1932[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). 1933 1934<b>`Recv(const Shape& shape, int64 channel_id)`</b> 1935 1936Allocates resources required to receive data from a `Send` instruction with the 1937same channel_id. Returns a context for the allocated resources, which is used 1938by a following `RecvDone` instruction to wait for the completion of the data 1939transfer. The context is a tuple of {receive buffer (shape), request identifier 1940(U32)} and it can only be used by a `RecvDone` instruction. 1941 1942<b> `RecvDone(HloInstruction context)` </b> 1943 1944Given a context created by a `Recv` instruction, waits for the data transfer to 1945complete and returns the received data. 1946 1947## Reduce 1948 1949See also 1950[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1951 1952Applies a reduction function to one or more arrays in parallel. 1953 1954<b> `Reduce(operands..., init_values..., computation, dimensions)` </b> 1955 1956| Arguments | Type | Semantics | 1957| ------------- | --------------------- | -------------------------------- | 1958| `operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., | 1959: : : T_{N-1}`. : 1960| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., | 1961: : : T_{N-1}`. : 1962| `computation` | `XlaComputation` | computation of type `T_0, ..., | 1963: : : T_{N-1}, T_0, ..., T_{N-1} ->` : 1964: : : `Collate(T_0, ..., T_{N-1})`. : 1965| `dimensions` | `int64` array | unordered array of dimensions to | 1966: : : reduce. : 1967 1968Where: 1969 1970* N is required to be greater or equal to 1. 1971* The computation has to be "roughly" associative (see below). 1972* All input arrays must have the same dimensions. 1973* All initial values have to form an identity under `computation`. 1974* If `N = 1`, `Collate(T)` is `T`. 1975* If `N > 1`, `Collate(T_0, ..., T_{N-1})` is a tuple of `N` elements of type 1976 `T`. 1977 1978This operation reduces one or more dimensions of each input array into scalars. 1979The rank of each returned array is `rank(operand) - len(dimensions)`. The output 1980of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type `T_i`, the 1981dimensions of which are described below. 1982 1983Different backends are allowed to reassociate the reduction computation. This 1984can lead to numerical differences, as some reduction functions like addition are 1985not associative for floats. 1986However, if the range of the data is limited, floating-point addition is close 1987enough to being associative for most practical uses. 1988 1989### Examples 1990 1991When reducing across one dimension in a single 1D array with values `[10, 11, 199212, 13]`, with reduction function `f` (this is `computation`) then that could be 1993computed as 1994 1995`f(10, f(11, f(12, f(init_value, 13)))` 1996 1997but there are also many other possibilities, e.g. 1998 1999`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))` 2000 2001The following is a rough pseudo-code example of how reduction could be 2002implemented, using summation as the reduction computation with an initial value 2003of 0. 2004 2005```python 2006result_shape <- remove all dims in dimensions from operand_shape 2007 2008# Iterate over all elements in result_shape. The number of r's here is equal 2009# to the rank of the result 2010for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...: 2011 # Initialize this result element 2012 result[r0, r1...] <- 0 2013 2014 # Iterate over all the reduction dimensions 2015 for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...: 2016 # Increment the result element with the value of the operand's element. 2017 # The index of the operand's element is constructed from all ri's and di's 2018 # in the right order (by construction ri's and di's together index over the 2019 # whole operand shape). 2020 result[r0, r1...] += operand[ri... di] 2021``` 2022 2023Here's an example of reducing a 2D array (matrix). The shape has rank 2, 2024dimension 0 of size 2 and dimension 1 of size 3: 2025 2026<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2027 <img style="width:35%" src="./images/ops_2d_matrix.png"> 2028</div> 2029 2030Results of reducing dimensions 0 or 1 with an "add" function: 2031 2032<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2033 <img style="width:35%" src="./images/ops_reduce_from_2d_matrix.png"> 2034</div> 2035 2036Note that both reduction results are 1D arrays. The diagram shows one as column 2037and another as row just for visual convenience. 2038 2039For a more complex example, here is a 3D array. Its rank is 3, dimension 0 of 2040size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the 2041values 1 to 6 are replicated across dimension 0. 2042 2043<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2044 <img style="width:35%" src="./images/ops_reduce_from_3d_matrix.png"> 2045</div> 2046 2047Similarly to the 2D example, we can reduce just one dimension. If we reduce 2048dimension 0, for example, we get a rank-2 array where all values across 2049dimension 0 were folded into a scalar: 2050 2051```text 2052| 4 8 12 | 2053| 16 20 24 | 2054``` 2055 2056If we reduce dimension 2, we also get a rank-2 array where all values across 2057dimension 2 were folded into a scalar: 2058 2059```text 2060| 6 15 | 2061| 6 15 | 2062| 6 15 | 2063| 6 15 | 2064``` 2065 2066Note that the relative order between the remaining dimensions in the input is 2067preserved in the output, but some dimensions may get assigned new numbers (since 2068the rank changes). 2069 2070We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces 2071the 1D array `[20, 28, 36]`. 2072 2073Reducing the 3D array over all its dimensions produces the scalar `84`. 2074 2075### Variadic Reduce 2076 2077When `N > 1`, reduce function application is slightly more complex, as it is 2078applied simultaneously to all inputs. The operands are supplied to the 2079computation in the following order: 2080 2081* Running reduced value for the first operand 2082* ... 2083* Running reduced value for the N'th operand 2084* Input value for the first operand 2085* ... 2086* Input value for the N'th operand 2087 2088For example, consider the following reduction function, which can be used to 2089compute the max and the argmax of a 1-D array in parallel: 2090 2091```python 2092f: (Float, Int, Float, Int) -> Float, Int 2093f(max, argmax, value, index): 2094 if value >= max: 2095 return (value, index) 2096 else: 2097 return (max, argmax) 2098``` 2099 2100For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values 2101`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only 2102input dimension is equivalent to the following recursive application: 2103 2104``` 2105f_0 = f(I_V, I_K, V_0, K_0) 2106f_1 = f(f_0.first, f_0.second, V_1, K_1) 2107... 2108f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1)) 2109``` 2110 2111Applying this reduction to an array of values, and an array of sequential 2112indices (i.e. iota), will co-iterate over the arrays, and return a tuple 2113containing the maximal value and the matching index. 2114 2115## ReducePrecision 2116 2117See also 2118[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2119 2120Models the effect of converting floating-point values to a lower-precision 2121format (such as IEEE-FP16) and back to the original format. The number of 2122exponent and mantissa bits in the lower-precision format can be specified 2123arbitrarily, although all bit sizes may not be supported on all hardware 2124implementations. 2125 2126<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b> 2127 2128Arguments | Type | Semantics 2129--------------- | ------- | ------------------------------------------------- 2130`operand` | `XlaOp` | array of floating-point type `T`. 2131`exponent_bits` | `int32` | number of exponent bits in lower-precision format 2132`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format 2133 2134The result is an array of type `T`. The input values are rounded to the nearest 2135value representable with the given number of mantissa bits (using "ties to even" 2136semantics), and any values that exceed the range specified by the number of 2137exponent bits are clamped to positive or negative infinity. `NaN` values are 2138retained, although they may be converted to canonical `NaN` values. 2139 2140The lower-precision format must have at least one exponent bit (in order to 2141distinguish a zero value from an infinity, since both have a zero mantissa), and 2142must have a non-negative number of mantissa bits. The number of exponent or 2143mantissa bits may exceed the corresponding value for type `T`; the corresponding 2144portion of the conversion is then simply a no-op. 2145 2146## ReduceScatter 2147 2148See also 2149[`XlaBuilder::ReduceScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2150 2151ReduceScatter is a collective operation that effectively does an AllReduce and 2152then scatters the result by splitting it into `shard_count` blocks along the 2153`scatter_dimension` and replica `i` in the replica group receives the `ith` 2154shard. 2155 2156<b> `ReduceScatter(operand, computation, scatter_dim, shard_count, 2157replica_group_ids, channel_id)` </b> 2158 2159| Arguments | Type | Semantics | 2160| ------------------- | -------------------- | ----------------------------- | 2161| `operand` | `XlaOp` | Array or a non-empty tuple of | 2162: : : arrays to reduce across : 2163: : : replicas. : 2164| `computation` | `XlaComputation` | Reduction computation | 2165| `scatter_dimension` | `int64` | Dimension to scatter. | 2166| `shard_count` | `int64` | Number of blocks to split | 2167: : : `scatter_dimension` : 2168| `replica_groups` | vector of vectors of | Groups between which the | 2169: : `int64` : reductions are performed : 2170| `channel_id` | optional `int64` | Optional channel ID for | 2171: : : cross-module communication : 2172 2173- When `operand` is a tuple of arrays, the reduce-scatter is performed on each 2174 element of the tuple. 2175- `replica_groups` is a list of replica groups between which the reduction is 2176 performed (replica id for the current replica can be retrieved using 2177 [`ReplicaId`](#replicaid)). The order of replicas in each group determines 2178 the order in which the all-reduce result will be scattered. `replica_groups` 2179 must either be empty (in which case all replicas belong to a single group), 2180 or contain the same number of elements as the number of replicas. When there 2181 are more than one replica groups, they all must be of the same size. For 2182 example, `replica_groups = {0, 2}, {1, 3}` performs reduction between the 2183 replicas `0` and `2`, and `1` and `3` and then scatters the result. 2184- `shard_count` is the size of each replica group. We need this in cases where 2185 `replica_groups` are empty. If `replica_groups` is not empty, `shard_count` 2186 must be equal to the size of each replica group. 2187- `channel_id` is used for cross-module communication: only `reduce-scatter` 2188 operations with the same `channel_id` can communicate with each other. 2189 2190The output shape is the input shape with the `scatter_dimension` made 2191`shard_count` times smaller. For example, if there are two replicas and the 2192operand has the value `[1.0, 2.25]` and `[3.0, 5.25]` respectively on the two 2193replicas, then the output value from this op where `scatter_dim` is `0` will be 2194`[4.0]` for the first replica and `[7.5]` for the second replica. 2195 2196## ReduceWindow 2197 2198See also 2199[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2200 2201Applies a reduction function to all elements in each window of a sequence of N 2202multi-dimensional arrays, producing a single or a tuple of N multi-dimensional 2203arrays as output. Each output array has the same number of elements as the 2204number of valid positions of the window. A pooling layer can be expressed as a 2205`ReduceWindow`. Similar to [`Reduce`](#reduce), the applied `computation` is 2206always passed the `init_values` on the left-hand side. 2207 2208<b> `ReduceWindow(operands..., init_values..., computation, window_dimensions, 2209window_strides, padding)` </b> 2210 2211| Arguments | Type | Semantics | 2212| ------------------- | ------------------- | -------------------------------- | 2213| `operands` | `N XlaOps` | A sequence of N | 2214: : : multi-dimensional arrays of : 2215: : : types `T_0,..., T_{N-1}`, each : 2216: : : representing the base area on : 2217: : : which the window is placed. : 2218| `init_values` | `N XlaOps` | The N starting values for the | 2219: : : reduction, one for each of the N : 2220: : : operands. See [Reduce](#reduce) : 2221: : : for details. : 2222| `computation` | `XlaComputation` | Reduction function of type `T_0, | 2223: : : ..., T_{N-1}, T_0, ..., T_{N-1} : 2224: : : -> Collate(T_0, ..., T_{N-1})`, : 2225: : : to apply to elements in each : 2226: : : window of all the input : 2227: : : operands. : 2228| `window_dimensions` | `ArraySlice<int64>` | array of integers for window | 2229: : : dimension values : 2230| `window_strides` | `ArraySlice<int64>` | array of integers for window | 2231: : : stride values : 2232| `base_dilations` | `ArraySlice<int64>` | array of integers for base | 2233: : : dilation values : 2234| `window_dilations` | `ArraySlice<int64>` | array of integers for window | 2235: : : dilation values : 2236| `padding` | `Padding` | padding type for window | 2237: : : (Padding\:\:kSame, which pads so : 2238: : : as to have the same output shape : 2239: : : as input if the stride is 1, or : 2240: : : Padding\:\:kValid, which uses no : 2241: : : padding and "stops" the window : 2242: : : once it no longer fits) : 2243 2244Where: 2245 2246* N is required to be greater or equal to 1. 2247* All input arrays must have the same dimensions. 2248* If `N = 1`, `Collate(T)` is `T`. 2249* If `N > 1`, `Collate(T_0, ..., T_{N-1})` is a tuple of `N` elements of type 2250 `(T0,...T{N-1})`. 2251 2252Below code and figure shows an example of using `ReduceWindow`. Input is a 2253matrix of size [4x6] and both window_dimensions and window_stride_dimensions are 2254[2x3]. 2255 2256``` 2257// Create a computation for the reduction (maximum). 2258XlaComputation max; 2259{ 2260 XlaBuilder builder(client_, "max"); 2261 auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y"); 2262 auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x"); 2263 builder.Max(y, x); 2264 max = builder.Build().value(); 2265} 2266 2267// Create a ReduceWindow computation with the max reduction computation. 2268XlaBuilder builder(client_, "reduce_window_2x3"); 2269auto shape = ShapeUtil::MakeShape(F32, {4, 6}); 2270auto input = builder.Parameter(0, shape, "input"); 2271builder.ReduceWindow( 2272 input, 2273 /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)), 2274 *max, 2275 /*window_dimensions=*/{2, 3}, 2276 /*window_stride_dimensions=*/{2, 3}, 2277 Padding::kValid); 2278``` 2279 2280<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2281 <img style="width:35%" src="./images/ops_reduce_window.png"> 2282</div> 2283 2284Stride of 1 in a dimension specifies that the position of a window in the 2285dimension is 1 element away from its adjacent window. In order to specify that 2286no windows overlap with each other, window_stride_dimensions should be equal to 2287window_dimensions. The figure below illustrates the use of two different stride 2288values. Padding is applied to each dimension of the input and the calculations 2289are the same as though the input came in with the dimensions it has after 2290padding. 2291 2292<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2293 <img style="width:75%" src="./images/ops_reduce_window_stride.png"> 2294</div> 2295 2296For a non-trivial padding example, consider computing reduce-window minimum 2297(initial value is `MAX_FLOAT`) with dimension `3` and stride `2` over the input 2298array `[10000, 1000, 100, 10, 1]`. Padding `kValid` computes minimums over two 2299_valid_ windows: `[10000, 1000, 100]` and `[100, 10, 1]`, resulting in the 2300output `[100, 1]`. Padding `kSame` first pads the array so that the shape after 2301the reduce-window would be the _same_ as input for stride one by adding initial 2302elements on both sides, getting `[MAX_VALUE, 10000, 1000, 100, 10, 1, 2303MAX_VALUE]`. Running reduce-window over the padded array operates on three 2304windows `[MAX_VALUE, 10000, 1000]`, `[1000, 100, 10]`, `[10, 1, MAX_VALUE]`, and 2305yields `[1000, 10, 1]`. 2306 2307The evaluation order of the reduction function is arbitrary and may be 2308non-deterministic. Therefore, the reduction function should not be overly 2309sensitive to reassociation. See the discussion about associativity in the 2310context of [`Reduce`](#reduce) for more details. 2311 2312## ReplicaId 2313 2314See also 2315[`XlaBuilder::ReplicaId`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2316 2317Returns the unique ID (U32 scalar) of the replica. 2318 2319<b> `ReplicaId()` </b> 2320 2321The unique ID of each replica is an unsigned integer in the interval `[0, N)`, 2322where `N` is the number of replicas. Since all the replicas are running the same 2323program, a `ReplicaId()` call in the program will return a different value on 2324each replica. 2325 2326## Reshape 2327 2328See also 2329[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 2330and the [`Collapse`](#collapse) operation. 2331 2332Reshapes the dimensions of an array into a new configuration. 2333 2334<b> `Reshape(operand, new_sizes)` </b> 2335<b> `Reshape(operand, dimensions, new_sizes)` </b> 2336 2337Arguments | Type | Semantics 2338------------ | -------------- | --------------------------------------- 2339`operand` | `XlaOp` | array of type T 2340`dimensions` | `int64` vector | order in which dimensions are collapsed 2341`new_sizes` | `int64` vector | vector of sizes of new dimensions 2342 2343Conceptually, reshape first flattens an array into a one-dimensional vector of 2344data values, and then refines this vector into a new shape. The input arguments 2345are an arbitrary array of type T, a compile-time-constant vector of dimension 2346indices, and a compile-time-constant vector of dimension sizes for the result. 2347The values in the `dimension` vector, if given, must be a permutation of all of 2348T's dimensions; the default if not given is `{0, ..., rank - 1}`. The order of 2349the dimensions in `dimensions` is from slowest-varying dimension (most major) to 2350fastest-varying dimension (most minor) in the loop nest which collapses the 2351input array into a single dimension. The `new_sizes` vector determines the size 2352of the output array. The value at index 0 in `new_sizes` is the size of 2353dimension 0, the value at index 1 is the size of dimension 1, and so on. The 2354product of the `new_size` dimensions must equal the product of the operand's 2355dimension sizes. When refining the collapsed array into the multidimensional 2356array defined by `new_sizes`, the dimensions in `new_sizes` are ordered from 2357slowest varying (most major) and to fastest varying (most minor). 2358 2359For example, let v be an array of 24 elements: 2360 2361``` 2362let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, 2363 {{20, 21, 22}, {25, 26, 27}}, 2364 {{30, 31, 32}, {35, 36, 37}}, 2365 {{40, 41, 42}, {45, 46, 47}}}; 2366 2367In-order collapse: 2368let v012_24 = Reshape(v, {0,1,2}, {24}); 2369then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 2370 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}; 2371 2372let v012_83 = Reshape(v, {0,1,2}, {8,3}); 2373then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17}, 2374 {20, 21, 22}, {25, 26, 27}, 2375 {30, 31, 32}, {35, 36, 37}, 2376 {40, 41, 42}, {45, 46, 47}}; 2377 2378Out-of-order collapse: 2379let v021_24 = Reshape(v, {1,2,0}, {24}); 2380then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 2381 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}; 2382 2383let v021_83 = Reshape(v, {1,2,0}, {8,3}); 2384then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21}, 2385 {31, 41, 12}, {22, 32, 42}, 2386 {15, 25, 35}, {45, 16, 26}, 2387 {36, 46, 17}, {27, 37, 47}}; 2388 2389 2390let v021_262 = Reshape(v, {1,2,0}, {2,6,2}); 2391then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40}, 2392 {11, 21}, {31, 41}, 2393 {12, 22}, {32, 42}}, 2394 {{15, 25}, {35, 45}, 2395 {16, 26}, {36, 46}, 2396 {17, 27}, {37, 47}}}; 2397``` 2398 2399As a special case, reshape can transform a single-element array to a scalar and 2400vice versa. For example, 2401 2402``` 2403Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5; 2404Reshape(5, {}, {1,1}) == f32[1x1] {{5}}; 2405``` 2406 2407## Rev (reverse) 2408 2409See also 2410[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2411 2412<b>`Rev(operand, dimensions)`</b> 2413 2414Arguments | Type | Semantics 2415------------ | ------------------- | --------------------- 2416`operand` | `XlaOp` | array of type T 2417`dimensions` | `ArraySlice<int64>` | dimensions to reverse 2418 2419Reverses the order of elements in the `operand` array along the specified 2420`dimensions`, generating an output array of the same shape. Each element of the 2421operand array at a multidimensional index is stored into the output array at a 2422transformed index. The multidimensional index is transformed by reversing the 2423index in each dimension to be reversed (i.e., if a dimension of size N is one of 2424the reversing dimensions, its index i is transformed into N - 1 - i). 2425 2426One use for the `Rev` operation is to reverse the convolution weight array along 2427the two window dimensions during the gradient computation in neural networks. 2428 2429## RngNormal 2430 2431See also 2432[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2433 2434Constructs an output of a given shape with random numbers generated following 2435the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and 2436$$\sigma$$, and output shape have to have a floating point elemental type. The 2437parameters furthermore have to be scalar valued. 2438 2439<b>`RngNormal(mu, sigma, shape)`</b> 2440 2441| Arguments | Type | Semantics | 2442| --------- | ------- | --------------------------------------------------- | 2443| `mu` | `XlaOp` | Scalar of type T specifying mean of generated | 2444: : : numbers : 2445| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of | 2446: : : generated numbers : 2447| `shape` | `Shape` | Output shape of type T | 2448 2449## RngUniform 2450 2451See also 2452[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2453 2454Constructs an output of a given shape with random numbers generated following 2455the uniform distribution over the interval $$[a,b)$$. The parameters and output 2456element type have to be a boolean type, an integral type or a floating point 2457types, and the types have to be consistent. The CPU and GPU backends currently 2458only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the 2459parameters need to be scalar valued. If $$b <= a$$ the result is 2460implementation-defined. 2461 2462<b>`RngUniform(a, b, shape)`</b> 2463 2464| Arguments | Type | Semantics | 2465| --------- | ----------------------- | --------------------------------- | 2466| `a` | `XlaOp` | Scalar of type T specifying lower | 2467: : : limit of interval : 2468| `b` | `XlaOp` | Scalar of type T specifying upper | 2469: : : limit of interval : 2470| `shape` | `Shape` | Output shape of type T | 2471 2472## RngBitGenerator 2473 2474Generates an output with a given shape filled with uniform random bits using the 2475specified algorithm (or backend default) and returns an updated state (with the 2476same shape as initial state) and the generated random data. 2477 2478Initial state is the initial state of the current random number generation. It 2479and the required shape and valid values are dependent on the algorithm used. 2480 2481The output is guaranteed to be a deterministic function of the initial state but 2482it is *not* guaranteed to be deterministic between backends and different 2483compiler versions. 2484 2485<b>`RngBitGenerator(algorithm, key, shape)`</b> 2486 2487Arguments | Type | Semantics 2488--------------- | ----------------- | ------------------------------------- 2489`algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. 2490`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. 2491`shape` | `Shape` | Output shape for generated data. 2492 2493Available values for `algorithm`: 2494 2495- `rng_default`: Backend specific algorithm with backend specific shape 2496 requirements. 2497 2498- `rng_three_fry`: ThreeFry counter-based PRNG algorithm. The `initial_state` 2499 shape is `u64[2]` with arbitrary values. 2500 [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) 2501 2502- `rng_philox`: Philox algorithm to generate random numbers in parallel. The 2503 `initial_state` shape is `u64[3]` with arbitrary values. 2504 [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) 2505 2506## Scatter 2507 2508The XLA scatter operation generates a sequence of results which are the values 2509of the input array `operands`, with several slices (at indices specified by 2510`scatter_indices`) updated with the sequence of values in `updates` using 2511`update_computation`. 2512 2513See also 2514[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2515 2516<b> `scatter(operands..., scatter_indices, updates..., update_computation, 2517index_vector_dim, update_window_dims, inserted_window_dims, 2518scatter_dims_to_operand_dims)` </b> 2519 2520Arguments | Type | Semantics 2521------------------------------ | --------------------- | --------- 2522`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N` to be scattered into. 2523`scatter_indices` | `XlaOp` | Array containing the starting indices of the slices that must be scattered to. 2524`updates` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. `updates[i]` contains the values that must be used for scattering `operands[i]`. 2525`update_computation` | `XlaComputation` | Computation to be used for combining the existing values in the input array and the updates during scatter. This computation should be of type `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`. 2526`index_vector_dim` | `int64` | The dimension in `scatter_indices` that contains the starting indices. 2527`update_window_dims` | `ArraySlice<int64>` | The set of dimensions in `updates` shape that are *window dimensions*. 2528`inserted_window_dims` | `ArraySlice<int64>` | The set of *window dimensions* that must be inserted into `updates` shape. 2529`scatter_dims_to_operand_dims` | `ArraySlice<int64>` | A dimensions map from the scatter indices to the operand index space. This array is interpreted as mapping `i` to `scatter_dims_to_operand_dims[i]` . It has to be one-to-one and total. 2530`indices_are_sorted` | `bool` | Whether the indices are guaranteed to be sorted by the caller. 2531 2532Where: 2533 2534* N is required to be greater or equal to 1. 2535* `operands`[`0`], ..., `operands`[`N-1`] must all have the same dimensions. 2536* `updates`[`0`], ..., `updates`[`N-1`] must all have the same dimensions. 2537* If `N = 1`, `Collate(T)` is `T`. 2538* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`. 2539 2540If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider 2541`scatter_indices` to have a trailing `1` dimension. 2542 2543We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of 2544dimensions in `updates` shape that are not in `update_window_dims`, in ascending 2545order. 2546 2547The arguments of scatter should follow these constraints: 2548 2549- Each `updates` array must be of rank `update_window_dims.size + 2550 scatter_indices.rank - 1`. 2551 2552- Bounds of dimension `i` in each `updates` array must conform to the 2553 following: 2554 2555 - If `i` is present in `update_window_dims` (i.e. equal to 2556 `update_window_dims`[`k`] for some `k`), then the bound of dimension `i` 2557 in `updates` must not exceed the corresponding bound of `operand` after 2558 accounting for the `inserted_window_dims` (i.e. 2559 `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains 2560 the bounds of `operand` with the bounds at indices 2561 `inserted_window_dims` removed). 2562 - If `i` is present in `update_scatter_dims` (i.e. equal to 2563 `update_scatter_dims`[`k`] for some `k`), then the bound of dimension 2564 `i` in `updates` must be equal to the corresponding bound of 2565 `scatter_indices`, skipping `index_vector_dim` (i.e. 2566 `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and 2567 `scatter_indices.shape.dims`[`k+1`] otherwise). 2568 2569- `update_window_dims` must be in ascending order, not have any repeating 2570 dimension numbers, and be in the range `[0, updates.rank)`. 2571 2572- `inserted_window_dims` must be in ascending order, not have any repeating 2573 dimension numbers, and be in the range `[0, operand.rank)`. 2574 2575- `operand.rank` must equal the sum of `update_window_dims.size` and 2576 `inserted_window_dims.size`. 2577 2578- `scatter_dims_to_operand_dims.size` must be equal to 2579 `scatter_indices.shape.dims`[`index_vector_dim`], and its values must be in 2580 the range `[0, operand.rank)`. 2581 2582For a given index `U` in each `updates` array, the corresponding index `I` in 2583the corresponding `operands` array into which this update has to be applied is 2584computed as follows: 2585 25861. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up 2587 an index vector `S` in the `scatter_indices` array such that `S`[`i`] = 2588 `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at 2589 positions `index_vector_dim` into A. 25902. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering 2591 `S` using the `scatter_dims_to_operand_dims` map. More formally: 2592 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if 2593 `k` < `scatter_dims_to_operand_dims.size`. 2594 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. 25953. Create an index `W`<sub>`in`</sub> into each `operands` array by scattering 2596 the indices at `update_window_dims` in `U` according to 2597 `inserted_window_dims`. More formally: 2598 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if `k` 2599 is in `update_window_dims`, where `window_dims_to_operand_dims` is the 2600 monotonic function with domain [`0`, `update_window_dims.size`) and 2601 range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For example, if 2602 `update_window_dims.size` is `4`, `operand.rank` is `6`, and 2603 `inserted_window_dims` is {`0`, `2`} then `window_dims_to_operand_dims` 2604 is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}). 2605 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise. 26064. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise 2607 addition. 2608 2609In summary, the scatter operation can be defined as follows. 2610 2611- Initialize `output` with `operands`, i.e. for all indices `J`, for all 2612 indices `O` in the `operands`[`J`] array: \ 2613 `output`[`J`][`O`] = `operands`[`J`][`O`] 2614- For every index `U` in the `updates`[`J`] array and the corresponding index 2615 `O` in the `operand`[`J`] array, if `O` is a valid index for `output`: \ 2616 `(output`[`0`][`O`], ..., output`[`N-1`][`O`]) 2617 =`update_computation`(`output`[`0`][`O`], ..., 2618 ,`output`[`N-1`][`O`],`updates`[`0`][`U`], ...,`updates`[`N-1`][`U`]) 2619 2620The order in which updates are applied is non-deterministic. So, when multiple 2621indices in `updates` refer to the same index in `operands`, the corresponding 2622value in `output` will be non-deterministic. 2623 2624Note that the first parameter that is passed into the `update_computation` will 2625always be the current value from the `output` array and the second parameter 2626will always be the value from the `updates` array. This is important 2627specifically for cases when the `update_computation` is _not commutative_. 2628 2629If `indices_are_sorted` is set to true then XLA can assume that `start_indices` 2630are sorted (in ascending `start_index_map` order) by the user. If they are not 2631then the semantics is implementation defined. 2632 2633Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e. 2634the scatter op updates the elements in the input that are extracted by the 2635corresponding gather op. 2636 2637For a detailed informal description and examples, refer to the 2638"Informal Description" section under `Gather`. 2639 2640## Select 2641 2642See also 2643[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2644 2645Constructs an output array from elements of two input arrays, based on the 2646values of a predicate array. 2647 2648<b> `Select(pred, on_true, on_false)` </b> 2649 2650Arguments | Type | Semantics 2651---------- | ------- | ------------------ 2652`pred` | `XlaOp` | array of type PRED 2653`on_true` | `XlaOp` | array of type T 2654`on_false` | `XlaOp` | array of type T 2655 2656The arrays `on_true` and `on_false` must have the same shape. This is also the 2657shape of the output array. The array `pred` must have the same dimensionality as 2658`on_true` and `on_false`, with the `PRED` element type. 2659 2660For each element `P` of `pred`, the corresponding element of the output array is 2661taken from `on_true` if the value of `P` is `true`, and from `on_false` if the 2662value of `P` is `false`. As a restricted form of [broadcasting](broadcasting.md), 2663`pred` can be a scalar of type `PRED`. In this case, the output array is taken 2664wholly from `on_true` if `pred` is `true`, and from `on_false` if `pred` is `false`. 2665 2666Example with non-scalar `pred`: 2667 2668``` 2669let pred: PRED[4] = {true, false, false, true}; 2670let v1: s32[4] = {1, 2, 3, 4}; 2671let v2: s32[4] = {100, 200, 300, 400}; 2672==> 2673Select(pred, v1, v2) = s32[4]{1, 200, 300, 4}; 2674``` 2675 2676Example with scalar `pred`: 2677 2678``` 2679let pred: PRED = true; 2680let v1: s32[4] = {1, 2, 3, 4}; 2681let v2: s32[4] = {100, 200, 300, 400}; 2682==> 2683Select(pred, v1, v2) = s32[4]{1, 2, 3, 4}; 2684``` 2685 2686Selections between tuples are supported. Tuples are considered to be scalar 2687types for this purpose. If `on_true` and `on_false` are tuples (which must have 2688the same shape!) then `pred` has to be a scalar of type `PRED`. 2689 2690## SelectAndScatter 2691 2692See also 2693[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2694 2695This operation can be considered as a composite operation that first computes 2696`ReduceWindow` on the `operand` array to select an element from each window, and 2697then scatters the `source` array to the indices of the selected elements to 2698construct an output array with the same shape as the operand array. The binary 2699`select` function is used to select an element from each window by applying it 2700across each window, and it is called with the property that the first 2701parameter's index vector is lexicographically less than the second parameter's 2702index vector. The `select` function returns `true` if the first parameter is 2703selected and returns `false` if the second parameter is selected, and the 2704function must hold transitivity (i.e., if `select(a, b)` and `select(b, c)` are 2705`true`, then `select(a, c)` is also `true`) so that the selected element does 2706not depend on the order of the elements traversed for a given window. 2707 2708The function `scatter` is applied at each selected index in the output array. It 2709takes two scalar parameters: 2710 27111. Current value at the selected index in the output array 27122. The scatter value from `source` that applies to the selected index 2713 2714It combines the two parameters and returns a scalar value that's used to update 2715the value at the selected index in the output array. Initially, all indices of 2716the output array are set to `init_value`. 2717 2718The output array has the same shape as the `operand` array and the `source` 2719array must have the same shape as the result of applying a `ReduceWindow` 2720operation on the `operand` array. `SelectAndScatter` can be used to 2721backpropagate the gradient values for a pooling layer in a neural network. 2722 2723<b>`SelectAndScatter(operand, select, window_dimensions, window_strides, 2724padding, source, init_value, scatter)`</b> 2725 2726| Arguments | Type | Semantics | 2727| ------------------- | ------------------- | -------------------------------- | 2728| `operand` | `XlaOp` | array of type T over which the | 2729: : : windows slide : 2730| `select` | `XlaComputation` | binary computation of type `T, T | 2731: : : -> PRED`, to apply to all : 2732: : : elements in each window; returns : 2733: : : `true` if the first parameter is : 2734: : : selected and returns `false` if : 2735: : : the second parameter is selected : 2736| `window_dimensions` | `ArraySlice<int64>` | array of integers for window | 2737: : : dimension values : 2738| `window_strides` | `ArraySlice<int64>` | array of integers for window | 2739: : : stride values : 2740| `padding` | `Padding` | padding type for window | 2741: : : (Padding\:\:kSame or : 2742: : : Padding\:\:kValid) : 2743| `source` | `XlaOp` | array of type T with the values | 2744: : : to scatter : 2745| `init_value` | `XlaOp` | scalar value of type T for the | 2746: : : initial value of the output : 2747: : : array : 2748| `scatter` | `XlaComputation` | binary computation of type `T, T | 2749: : : -> T`, to apply each scatter : 2750: : : source element with its : 2751: : : destination element : 2752 2753The figure below shows examples of using `SelectAndScatter`, with the `select` 2754function computing the maximal value among its parameters. Note that when the 2755windows overlap, as in the figure (2) below, an index of the `operand` array may 2756be selected multiple times by different windows. In the figure, the element of 2757value 9 is selected by both of the top windows (blue and red) and the binary 2758addition `scatter` function produces the output element of value 8 (2 + 6). 2759 2760<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2761 <img style="width:100%" 2762 src="./images/ops_scatter_to_selected_window_element.png"> 2763</div> 2764 2765The evaluation order of the `scatter` function is arbitrary and may be 2766non-deterministic. Therefore, the `scatter` function should not be overly 2767sensitive to reassociation. See the discussion about associativity in the 2768context of [`Reduce`](#reduce) for more details. 2769 2770## Send 2771 2772See also 2773[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2774 2775<b> `Send(operand, channel_handle)` </b> 2776 2777Arguments | Type | Semantics 2778---------------- | --------------- | ----------------------------------------- 2779`operand` | `XlaOp` | data to send (array of type T) 2780`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair 2781 2782Sends the given operand data to a `Recv` instruction in another computation 2783that shares the same channel handle. Does not return any data. 2784 2785Similar to the `Recv` operation, the client API of `Send` operation represents 2786synchronous communication, and is internally decomposed into 2 HLO instructions 2787(`Send` and `SendDone`) to enable asynchronous data transfers. See also 2788[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). 2789 2790<b>`Send(HloInstruction operand, int64 channel_id)`</b> 2791 2792Initiates an asynchronous transfer of the operand to the resources allocated by 2793the `Recv` instruction with the same channel id. Returns a context, which is 2794used by a following `SendDone` instruction to wait for the completion of the 2795data transfer. The context is a tuple of {operand (shape), request identifier 2796(U32)} and it can only be used by a `SendDone` instruction. 2797 2798<b> `SendDone(HloInstruction context)` </b> 2799 2800Given a context created by a `Send` instruction, waits for the data transfer to 2801complete. The instruction does not return any data. 2802 2803<b> Scheduling of channel instructions </b> 2804 2805The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`, 2806`Send`, `SendDone`) is as below. 2807 2808<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2809 <img style="width:70%" src="./images/send_recv_order.png"> 2810</div> 2811 2812* `Recv` happens before `Send` 2813* `Send` happens before `RecvDone` 2814* `Recv` happens before `RecvDone` 2815* `Send` happens before `SendDone` 2816 2817When the backend compilers generate a linear schedule for each computation that 2818communicates via channel instructions, there must not be cycles across the 2819computations. For example, below schedules lead to deadlocks. 2820 2821<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2822 <img style="width:100%" src="./images/send_recv_schedule.png"> 2823</div> 2824 2825## Slice 2826 2827See also 2828[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2829 2830Slicing extracts a sub-array from the input array. The sub-array is of the same 2831rank as the input and contains the values inside a bounding box within the input 2832array where the dimensions and indices of the bounding box are given as 2833arguments to the slice operation. 2834 2835<b> `Slice(operand, start_indices, limit_indices, strides)` </b> 2836 2837| Arguments | Type | Semantics | 2838| --------------- | ------------------- | ------------------------------------ | 2839| `operand` | `XlaOp` | N dimensional array of type T | 2840| `start_indices` | `ArraySlice<int64>` | List of N integers containing the | 2841: : : starting indices of the slice for : 2842: : : each dimension. Values must be : 2843: : : greater than or equal to zero. : 2844| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the | 2845: : : ending indices (exclusive) for the : 2846: : : slice for each dimension. Each value : 2847: : : must be greater than or equal to the : 2848: : : respective `start_indices` value for : 2849: : : the dimension and less than or equal : 2850: : : to the size of the dimension. : 2851| `strides` | `ArraySlice<int64>` | List of N integers that decides the | 2852: : : input stride of the slice. The slice : 2853: : : picks every `strides[d]` element in : 2854: : : dimension `d`. : 2855 2856 28571-dimensional example: 2858 2859``` 2860let a = {0.0, 1.0, 2.0, 3.0, 4.0} 2861Slice(a, {2}, {4}) produces: 2862 {2.0, 3.0} 2863``` 2864 28652-dimensional example: 2866 2867``` 2868let b = 2869 { {0.0, 1.0, 2.0}, 2870 {3.0, 4.0, 5.0}, 2871 {6.0, 7.0, 8.0}, 2872 {9.0, 10.0, 11.0} } 2873 2874Slice(b, {2, 1}, {4, 3}) produces: 2875 { { 7.0, 8.0}, 2876 {10.0, 11.0} } 2877``` 2878 2879## Sort 2880 2881See also 2882[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2883 2884<b>`Sort(operands, comparator, dimension, is_stable)`</b> 2885 2886Arguments | Type | Semantics 2887------------ | ------------------- | -------------------- 2888`operands` | `ArraySlice<XlaOp>` | The operands to sort. 2889`comparator` | `XlaComputation` | The comparator computation to use. 2890`dimension` | `int64` | The dimension along which to sort. 2891`is_stable` | `bool` | Whether stable sorting should be used. 2892 2893If only one operand is provided: 2894 2895* If the operand is a rank-1 tensor (an array), the result is a sorted array. 2896 If you want to sort the array into ascending order, the comparator should 2897 perform a less-than comparison. Formally, after the array is sorted, it holds 2898 for all index positions `i, j` with `i < j` that either 2899 `comparator(value[i], value[j]) = comparator(value[j], value[i]) = false` or 2900 `comparator(value[i], value[j]) = true`. 2901 2902* If the operand has higher rank, the operand is sorted along the provided 2903 dimension. For example, for a rank-2 tensor (a matrix), a dimension value of 2904 `0` will independently sort every column, and a dimension value of `1` will 2905 independently sort each row. If no dimension number is provided, then the last 2906 dimension is chosen by default. For the dimension which is sorted, the same 2907 sorting order applies as in the rank-1 case. 2908 2909If `n > 1` operands are provided: 2910 2911* All `n` operands must be tensors with the same dimensions. The element types 2912 of the tensors may be different. 2913 2914* All operands are sorted together, not individually. Conceptually the operands 2915 are treated as a tuple. When checking whether the elements of each operand at 2916 index positions `i` and `j` need to be swapped, the comparator is called with 2917 `2 * n` scalar parameters, where parameter `2 * k` corresponds to the value at 2918 position `i` from the `k-th` operand, and parameter `2 * k + 1` corresponds to 2919 the value at position `j` from the `k-th` operand. Usually, the comparator 2920 would thus compare parameters `2 * k` and `2 * k + 1` with each other and 2921 possibly use other parameter pairs as tie breakers. 2922 2923* The result is a tuple that consists of the operands in sorted order (along 2924 the provided dimension, as above). The `i-th` operand of the tuple corresponds 2925 to the `i-th` operand of Sort. 2926 2927For example, if there are three operands `operand0 = [3, 1]`, 2928`operand1 = [42, 50]`, `operand2 = [-3.0, 1.1]`, and the comparator compares 2929only the values of `operand0` with less-than, then the output of the sort is the 2930tuple `([1, 3], [50, 42], [1.1, -3.0])`. 2931 2932If `is_stable` is set to true, the sort is guaranteed to be stable, that is, if 2933there are elements which are considered to be equal by the comparator, the 2934relative order of the equal values is preserved. By default, `is_stable` is set 2935to false. 2936 2937## Transpose 2938 2939See also the `tf.reshape` operation. 2940 2941<b>`Transpose(operand)`</b> 2942 2943Arguments | Type | Semantics 2944------------- | ------------------- | ------------------------------ 2945`operand` | `XlaOp` | The operand to transpose. 2946`permutation` | `ArraySlice<int64>` | How to permute the dimensions. 2947 2948 2949Permutes the operand dimensions with the given permutation, so 2950`∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]`. 2951 2952This is the same as Reshape(operand, permutation, 2953 Permute(permutation, operand.shape.dimensions)). 2954 2955## TriangularSolve 2956 2957See also 2958[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2959 2960Solves systems of linear equations with lower or upper triangular coefficient 2961matrices by forward- or back-substitution. Broadcasting along leading 2962dimensions, this routine solves one of the matrix systems `op(a) * x = 2963b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is 2964either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`. 2965 2966<b> `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` </b> 2967 2968| Arguments | Type | Semantics | 2969| --------------- | ----------- | -------------------------------------------- | 2970| `a` | `XlaOp` | a rank > 2 array of a complex or | 2971: : : floating-point type with shape `[..., M, : 2972: : : M]`. : 2973| `b` | `XlaOp` | a rank > 2 array of the same type with shape | 2974: : : `[..., M, K]` if `left_side` is true, `[..., : 2975: : : K, M]` otherwise. : 2976| `left_side` | `bool` | indicates whether to solve a system of the | 2977: : : form `op(a) * x = b` (`true`) or `x * : 2978: : : op(a) = b` (`false`). : 2979| `lower` | `bool` | whether to use the upper or lower triangle | 2980: : : of `a`. : 2981| `unit_diagonal` | `bool` | if `true`, the diagonal elements of `a` are | 2982: : : assumed to be `1` and not accessed. : 2983| `transpose_a` | `Transpose` | whether to use `a` as is, transpose it or | 2984: : : take its conjugate transpose. : 2985 2986Input data is read only from the lower/upper triangle of `a`, depending on the 2987value of `lower`. Values from the other triangle are ignored. Output data is 2988returned in the same triangle; the values in the other triangle are 2989implementation-defined and may be anything. 2990 2991If the rank of `a` and `b` are greater than 2, they are treated as batches of 2992matrices, where all except the minor 2 dimensions are batch dimensions. `a` and 2993`b` must have equal batch dimensions. 2994 2995## Tuple 2996 2997See also 2998[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2999 3000A tuple containing a variable number of data handles, each of which has its own 3001shape. 3002 3003This is analogous to `std::tuple` in C++. Conceptually: 3004 3005``` 3006let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; 3007let s: s32 = 5; 3008let t: (f32[10], s32) = tuple(v, s); 3009``` 3010 3011Tuples can be deconstructed (accessed) via the [`GetTupleElement`] 3012(#gettupleelement) operation. 3013 3014## While 3015 3016See also 3017[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 3018 3019<b> `While(condition, body, init)` </b> 3020 3021| Arguments | Type | Semantics | 3022| ----------- | ---------------- | ---------------------------------------- | 3023| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which | 3024: : : defines the termination condition of the : 3025: : : loop. : 3026| `body` | `XlaComputation` | XlaComputation of type `T -> T` which | 3027: : : defines the body of the loop. : 3028| `init` | `T` | Initial value for the parameter of | 3029: : : `condition` and `body`. : 3030 3031Sequentially executes the `body` until the `condition` fails. This is similar to 3032a typical while loop in many other languages except for the differences and 3033restrictions listed below. 3034 3035* A `While` node returns a value of type `T`, which is the result from the 3036 last execution of the `body`. 3037* The shape of the type `T` is statically determined and must be the same 3038 across all iterations. 3039 3040The T parameters of the computations are initialized with the `init` value in 3041the first iteration and are automatically updated to the new result from `body` 3042in each subsequent iteration. 3043 3044One main use case of the `While` node is to implement the repeated execution of 3045training in neural networks. Simplified pseudocode is shown below with a graph 3046that represents the computation. The code can be found in 3047[`while_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/xla/tests/while_test.cc). 3048The type `T` in this example is a `Tuple` consisting of an `int32` for the 3049iteration count and a `vector[10]` for the accumulator. For 1000 iterations, the 3050loop keeps adding a constant vector to the accumulator. 3051 3052``` 3053// Pseudocode for the computation. 3054init = {0, zero_vector[10]} // Tuple of int32 and float[10]. 3055result = init; 3056while (result(0) < 1000) { 3057 iteration = result(0) + 1; 3058 new_vector = result(1) + constant_vector[10]; 3059 result = {iteration, new_vector}; 3060} 3061``` 3062 3063<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 3064 <img style="width:100%" src="./images/ops_while.png"> 3065</div> 3066