xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/g3doc/operation_semantics.md (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Operation Semantics
2
3The following describes the semantics of operations defined in the
4[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
5interface. Typically, these operations map one-to-one to operations defined in
6the RPC interface in
7[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto).
8
9A note on nomenclature: the generalized data type XLA deals with is an
10N-dimensional array holding elements of some uniform type (such as 32-bit
11float). Throughout the documentation, *array* is used to denote an
12arbitrary-dimensional array. For convenience, special cases have more specific
13and familiar names; for example a *vector* is a 1-dimensional array and a
14*matrix* is a 2-dimensional array.
15
16## AfterAll
17
18See also
19[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
20
21AfterAll takes a variadic number of tokens and produces a single token. Tokens
22are primitive types which can be threaded between side-effecting operations to
23enforce ordering. `AfterAll` can be used as a join of tokens for ordering a
24operation after a set operations.
25
26<b> `AfterAll(operands)` </b>
27
28Arguments  | Type    | Semantics
29---------- | ------- | -------------------------
30`operands` | `XlaOp` | variadic number of tokens
31
32## AllGather
33
34See also
35[`XlaBuilder::AllGather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
36
37Performs concatenation across replicas.
38
39<b> `AllGather(operand, all_gather_dim, shard_count, replica_group_ids,
40channel_id)` </b>
41
42| Arguments        | Type                 | Semantics                   |
43| ---------------- | -------------------- | --------------------------- |
44| `operand`        | `XlaOp`              | Array to concatenate across |
45:                  :                      : replicas.                   :
46| `all_gather_dim` | `int64`              | Concatenation dimension.    |
47| `replica_groups` | vector of vectors of | Groups between which the    |
48:                  : `int64`              : concatenation is performed. :
49| `channel_id`     | optional `int64`     | Optional channel ID for     |
50:                  :                      : cross-module communication. :
51
52-   `replica_groups` is a list of replica groups between which the concatenation
53    is performed (replica id for the current replica can be retrieved using
54    [`ReplicaId`](#replicaid)). The order of replicas in each group determines
55    the order in which their inputs are located in the result. `replica_groups`
56    must either be empty (in which case all replicas belong to a single group,
57    ordered from `0` to `N - 1`), or contain the same number of elements as the
58    number of replicas. For example, `replica_groups = {0, 2}, {1, 3}` performs
59    concatenation between the replicas `0` and `2`, and `1` and `3`.
60-   `shard_count` is the size of each replica group. We need this in cases where
61    `replica_groups` are empty.
62-   `channel_id` is used for cross-module communication: only `all-gather`
63    operations with the same `channel_id` can communicate to each other.
64
65The output shape is the input shape with the `all_gather_dim` made `shard_count`
66times larger. For example, if there are two replicas and the operand has the
67value `[1.0, 2.5]` and `[3.0, 5.25]` respectively on the two replicas, then the
68output value from this op where `all_gather_dim` is `0` will be `[1.0, 2.5, 3.0,
695.25]` on both replicas.
70
71## AllReduce
72
73See also
74[`XlaBuilder::AllReduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
75
76Performs a custom computation across replicas.
77
78<b> `AllReduce(operand, computation, replica_group_ids, channel_id)` </b>
79
80| Arguments        | Type                 | Semantics                         |
81| ---------------- | -------------------- | --------------------------------- |
82| `operand`        | `XlaOp`              | Array or a non-empty tuple of     |
83:                  :                      : arrays to reduce across replicas. :
84| `computation`    | `XlaComputation`     | Reduction computation             |
85| `replica_groups` | vector of vectors of | Groups between which the          |
86:                  : `int64`              : reductions are performed          :
87| `channel_id`     | optional `int64`     | Optional channel ID for           |
88:                  :                      : cross-module communication        :
89
90-   When `operand` is a tuple of arrays, the all-reduce is performed on each
91    element of the tuple.
92-   `replica_groups` is a list of replica groups between which the reduction is
93    performed (replica id for the current replica can be retrieved using
94    [`ReplicaId`](#replicaid)). `replica_groups` must either be empty (in which
95    case all replicas belong to a single group), or contain the same number of
96    elements as the number of replicas. For example, `replica_groups = {0, 2},
97    {1, 3}` performs reduction between the replicas `0` and `2`, and `1` and
98    `3`.
99-   `channel_id` is used for cross-module communication: only `all-reduce`
100    operations with the same `channel_id` can communicate to each other.
101
102The output shape is the same as the input shape. For example, if there are two
103replicas and the operand has the value `[1.0, 2.5]` and `[3.0, 5.25]`
104respectively on the two replicas, then the output value from this op and
105summation computation will be `[4.0, 7.75]` on both replicas. If the input is a
106tuple, the output is a tuple as well.
107
108Computing the result of `AllReduce` requires having one input from each replica,
109so if one replica executes a `AllReduce` node more times than another, then the
110former replica will wait forever. Since the replicas are all running the same
111program, there are not a lot of ways for that to happen, but it is possible when
112a while loop's condition depends on data from infeed and the data that is infed
113causes the while loop to iterate more times on one replica than another.
114
115## AllToAll
116
117See also
118[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
119
120AllToAll is a collective operation that sends data from all cores to all cores.
121It has two phases:
122
1231.  The scatter phase. On each core, the operand is split into `split_count`
124    number of blocks along the `split_dimensions`, and the blocks are scattered
125    to all cores, e.g., the ith block is send to the ith core.
1262.  The gather phase. Each core concatenates the received blocks along the
127    `concat_dimension`.
128
129The participating cores can be configured by:
130
131-   `replica_groups`: each ReplicaGroup contains a list of replica id
132    participating in the computation (replica id for the current replica can be
133    retrieved using [`ReplicaId`](#replicaid)). AllToAll will be applied within
134    subgroups in the specified order. For example, `replica_groups = {{1,2,3},
135    {4,5,0}}` means that an AllToAll will be applied within replicas `{1, 2,
136    3}`, and in the gather phase, and the received blocks will be concatenated
137    in the same order of 1, 2, 3. Then, another AllToAll will be applied within
138    replicas 4, 5, 0, and the concatenation order is also 4, 5, 0. If
139    `replica_groups` is empty, all replicas belong to one group, in the
140    concatenation order of their appearance.
141
142Prerequisites:
143
144-   The dimension size of the operand on the `split_dimension` is divisible by
145`split_count`.
146-   The operand's shape is not tuple.
147
148<b> `AllToAll(operand, split_dimension, concat_dimension, split_count,
149replica_groups)` </b>
150
151
152| Arguments          | Type                  | Semantics                       |
153| ------------------ | --------------------- | ------------------------------- |
154| `operand`          | `XlaOp`               | n dimensional input array       |
155| `split_dimension`  | `int64`               | A value in the interval `[0,    |
156:                    :                       : n)` that names the dimension    :
157:                    :                       : along which the operand is      :
158:                    :                       : split                           :
159| `concat_dimension` | `int64`               | a value in the interval `[0,    |
160:                    :                       : n)` that names the dimension    :
161:                    :                       : along which the split blocks    :
162:                    :                       : are concatenated                :
163| `split_count`      | `int64`               | the number of cores that        |
164:                    :                       : participate this operation. If  :
165:                    :                       : `replica_groups` is empty, this :
166:                    :                       : should be the number of         :
167:                    :                       : replicas; otherwise, this       :
168:                    :                       : should be equal to the number   :
169:                    :                       : of replicas in each group.      :
170| `replica_groups`   | `ReplicaGroup` vector | each group contains a list of   |
171:                    :                       : replica id.                     :
172
173Below shows an example of Alltoall.
174
175```
176XlaBuilder b("alltoall");
177auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
178AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
179```
180
181<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
182<img style="width:100%" src="./images/ops_alltoall.png">
183</div>
184
185In this example, there are 4 cores participating the Alltoall. On each core, the
186operand is split into 4 parts along dimension 0, so each part has shape
187f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates
188the received parts along dimension 1, in the order or core 0-4. So the output on
189each core has shape f32[16,4].
190
191## BatchNormGrad
192
193See also
194[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
195and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
196for a detailed description of the algorithm.
197
198Calculates gradients of batch norm.
199
200<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b>
201
202| Arguments       | Type                    | Semantics                        |
203| --------------- | ----------------------- | -------------------------------- |
204| `operand`       | `XlaOp`                 | n dimensional array to be        |
205:                 :                         : normalized (x)                   :
206| `scale`         | `XlaOp`                 | 1 dimensional array              |
207:                 :                         : (\\(\gamma\\))                   :
208| `mean`          | `XlaOp`                 | 1 dimensional array (\\(\mu\\))  |
209| `variance`      | `XlaOp`                 | 1 dimensional array              |
210:                 :                         : (\\(\sigma^2\\))                 :
211| `grad_output`   | `XlaOp`                 | Gradients passed to              |
212:                 :                         : `BatchNormTraining`              :
213:                 :                         : (\\( \nabla y\\))                :
214| `epsilon`       | `float`                 | Epsilon value (\\(\epsilon\\))   |
215| `feature_index` | `int64`                 | Index to feature dimension in    |
216:                 :                         : `operand`                        :
217
218For each feature in the feature dimension (`feature_index` is the index for the
219feature dimension in `operand`), the operation calculates the gradients with
220respect to `operand`, `offset` and `scale` across all the other dimensions. The
221`feature_index` must be a valid index for the feature dimension in `operand`.
222
223The three gradients are defined by the following formulas (assuming a
2244-dimensional array as `operand` and with feature dimension index `l`, batch
225size `m` and spatial sizes `w` and `h`):
226
227\\[ \begin{split} c_l&=
228\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h
229\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right)
230\\\\
231\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}}
232\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l})
233\right)
234\\\\
235\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl}
236\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right)
237\\\\\
238\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl}
239\end{split} \\]
240
241The inputs `mean` and `variance` represent moments value
242across batch and spatial dimensions.
243
244The output type is a tuple of three handles:
245
246| Outputs        | Type                    | Semantics                         |
247| -------------  | ----------------------- | --------------------------------- |
248| `grad_operand` | `XlaOp`                 | gradient with respect to input    |
249:                :                         : `operand` (\\( \nabla x\\))       :
250| `grad_scale`   | `XlaOp`                 | gradient with respect to input    |
251:                :                         : `scale` (\\( \nabla \gamma\\))    :
252| `grad_offset`  | `XlaOp`                 | gradient with respect to input    |
253:                :                         : `offset`(\\( \nabla \beta\\))     :
254
255## BatchNormInference
256
257See also
258[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
259and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
260for a detailed description of the algorithm.
261
262Normalizes an array across batch and spatial dimensions.
263
264<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b>
265
266Arguments       | Type    | Semantics
267--------------- | ------- | ---------------------------------------
268`operand`       | `XlaOp` | n dimensional array to be normalized
269`scale`         | `XlaOp` | 1 dimensional array
270`offset`        | `XlaOp` | 1 dimensional array
271`mean`          | `XlaOp` | 1 dimensional array
272`variance`      | `XlaOp` | 1 dimensional array
273`epsilon`       | `float` | Epsilon value
274`feature_index` | `int64` | Index to feature dimension in `operand`
275
276For each feature in the feature dimension (`feature_index` is the index for the
277feature dimension in `operand`), the operation calculates the mean and variance
278across all the other dimensions and uses the mean and variance to normalize each
279element in `operand`. The `feature_index` must be a valid index for the feature
280dimension in `operand`.
281
282`BatchNormInference`  is equivalent to calling `BatchNormTraining` without
283computing `mean` and `variance` for each batch. It uses the input `mean` and
284`variance` instead as estimated values. The purpose of this op is to reduce
285latency in inference, hence the name `BatchNormInference`.
286
287The output is an n-dimensional, normalized array with the same shape as input
288`operand`.
289
290## BatchNormTraining
291
292See also
293[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
294and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
295for a detailed description of the algorithm.
296
297Normalizes an array across batch and spatial dimensions.
298
299<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b>
300
301Arguments       | Type    | Semantics
302--------------- | ------- | ----------------------------------------
303`operand`       | `XlaOp` | n dimensional array to be normalized (x)
304`scale`         | `XlaOp` | 1 dimensional array (\\(\gamma\\))
305`offset`        | `XlaOp` | 1 dimensional array (\\(\beta\\))
306`epsilon`       | `float` | Epsilon value (\\(\epsilon\\))
307`feature_index` | `int64` | Index to feature dimension in `operand`
308
309For each feature in the feature dimension (`feature_index` is the index for the
310feature dimension in `operand`), the operation calculates the mean and variance
311across all the other dimensions and uses the mean and variance to normalize each
312element in `operand`. The `feature_index` must be a valid index for the feature
313dimension in `operand`.
314
315The algorithm goes as follows for each batch in `operand` \\(x\\) that
316contains `m` elements with `w` and `h` as the size of spatial dimensions
317(assuming `operand` is an 4 dimensional array):
318
319- Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension:
320\\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\)
321
322- Calculates batch variance \\(\sigma^2_l\\):
323\\(\sigma^2_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (x_{ijkl} - \mu_l)^2\\)
324
325- Normalizes, scales and shifts:
326\\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\\)
327
328The epsilon value, usually a small number, is added to avoid divide-by-zero errors.
329
330The output type is a tuple of three `XlaOp`s:
331
332| Outputs      | Type                    | Semantics                            |
333| ------------ | ----------------------- | -------------------------------------|
334| `output`     | `XlaOp`                 | n dimensional array with the same    |
335:              :                         : shape as input `operand` (y)         :
336| `batch_mean` | `XlaOp`                 | 1 dimensional array (\\(\mu\\))      |
337| `batch_var`  | `XlaOp`                 | 1 dimensional array (\\(\sigma^2\\)) |
338
339The `batch_mean` and `batch_var` are moments calculated across the batch and
340spatial dimensions using the formulas above.
341
342## BitcastConvertType
343
344See also
345[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
346
347Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast
348operation from a data shape to a target shape. The input and output size must
349match: e.g. `s32` elements become `f32` elements via bitcast routine, and one
350`s32` element will become four `s8` elements. Bitcast is implemented as a
351low-level cast, so machines with different floating-point representations will
352give different results.
353
354<b> `BitcastConvertType(operand, new_element_type)` </b>
355
356Arguments          | Type            | Semantics
357------------------ | --------------- | ---------------------------
358`operand`          | `XlaOp`         | array of type T with dims D
359`new_element_type` | `PrimitiveType` | type U
360
361The dimensions of the operand and the target shape must match, apart from the
362last dimension which will change by the ratio of the primitive size before and
363after the conversion.
364
365The source and destination element types must not be tuples.
366
367### Bitcast-converting to primitive type of different width
368
369`BitcastConvert` HLO instruction supports the case where the size of the output
370element type `T'` is not equal to the size of the input element `T`. As the
371whole operation is conceptually a bitcast and does not change the underlying
372bytes, the shape of the output element has to change. For `B = sizeof(T), B' =
373sizeof(T')`, there are two possible cases.
374
375First, when `B > B'`, the output shape gets a new minor-most dimension of size
376`B/B'`. For example:
377
378```
379  f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)
380```
381
382The rule remains the same for effective scalars:
383
384```
385  f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)
386```
387
388Alternatively, for `B' > B` the instruction requires the last logical dimension
389of the input shape to be equal to `B'/B`, and this dimension is dropped during
390the conversion:
391
392```
393  f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)
394```
395
396Note that conversions between different bitwidths are not elementwise.
397
398## Broadcast
399
400See also
401[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
402
403Adds dimensions to an array by duplicating the data in the array.
404
405<b> `Broadcast(operand, broadcast_sizes)` </b>
406
407Arguments         | Type                | Semantics
408----------------- | ------------------- | -------------------------------
409`operand`         | `XlaOp`             | The array to duplicate
410`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions
411
412The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has
413values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then
414the shape of the output has dimensions `{a0, ..., aN, b0, ..., bM}`.
415
416The new dimensions index into copies of the operand, i.e.
417
418```
419output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
420```
421
422For example, if `operand` is a scalar `f32` with value `2.0f`, and
423`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape
424`f32[2, 3]` and all the values in the result will be `2.0f`.
425
426## BroadcastInDim
427
428See also
429[`XlaBuilder::BroadcastInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
430
431Expands the size and rank of an array by duplicating the data in the array.
432
433<b> `BroadcastInDim(operand, out_dim_size, broadcast_dimensions)` </b>
434
435| Arguments              | Type                | Semantics                     |
436| ---------------------- | ------------------- | ----------------------------- |
437| `operand`              | `XlaOp`             | The array to duplicate        |
438| `out_dim_size`         | `ArraySlice<int64>` | The sizes of the dimensions   |
439:                        :                     : of the target shape           :
440| `broadcast_dimensions` | `ArraySlice<int64>` | Which dimension in the target |
441:                        :                     : shape each dimension of the   :
442:                        :                     : operand shape corresponds to  :
443
444Similar to Broadcast, but allows adding dimensions anywhere and expanding
445existing dimensions with size 1.
446
447The `operand` is broadcast to the shape described by `out_dim_size`.
448`broadcast_dimensions` maps the dimensions of `operand` to the dimensions of the
449target shape, i.e. the i'th dimension of the operand is mapped to the
450broadcast_dimension\[i\]'th dimension of the output shape. The dimensions of
451`operand` must have size 1 or be the same size as the dimension in the output
452shape they are mapped to. The remaining dimensions are filled with dimensions of
453size 1. Degenerate-dimension broadcasting then broadcasts along these degenerate
454dimensions to reach the output shape. The semantics are described in detail on
455the [broadcasting page](broadcasting.md).
456
457## Call
458
459See also
460[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
461
462Invokes a computation with the given arguments.
463
464<b> `Call(computation, args...)` </b>
465
466| Arguments     | Type                   | Semantics                           |
467| ------------- | ---------------------- | ----------------------------------- |
468| `computation` | `XlaComputation`       | computation of type `T_0, T_1, ..., |
469:               :                        : T_{N-1} -> S` with N parameters of  :
470:               :                        : arbitrary type                      :
471| `args`        | sequence of N `XlaOp`s | N arguments of arbitrary type       |
472
473The arity and types of the `args` must match the parameters of the
474`computation`. It is allowed to have no `args`.
475
476## Cholesky
477
478See also
479[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
480
481Computes the
482[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition)
483of a batch of symmetric (Hermitian) positive definite matrices.
484
485<b> `Cholesky(a, lower)` </b>
486
487Arguments | Type    | Semantics
488--------- | ------- | -----------------------------------------------------
489`a`       | `XlaOp` | a rank > 2 array of a complex or floating-point type.
490`lower`   | `bool`  | whether to use the upper or lower triangle of `a`.
491
492If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l
493. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such
494that $$ a = u^T . u $$.
495
496Input data is read only from the lower/upper triangle of `a`, depending on the
497value of `lower`. Values from the other triangle are ignored. Output data is
498returned in the same triangle; the values in the other triangle are
499implementation-defined and may be anything.
500
501If the rank of `a` is greater than 2, `a` is treated as a batch of matrices,
502where all except the minor 2 dimensions are batch dimensions.
503
504If `a` is not symmetric (Hermitian) positive definite, the result is
505implementation-defined.
506
507## Clamp
508
509See also
510[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
511
512Clamps an operand to within the range between a minimum and maximum value.
513
514<b> `Clamp(min, operand, max)` </b>
515
516Arguments | Type    | Semantics
517--------- | ------- | ---------------
518`min`     | `XlaOp` | array of type T
519`operand` | `XlaOp` | array of type T
520`max`     | `XlaOp` | array of type T
521
522Given an operand and minimum and maximum values, returns the operand if it is in
523the range between the minimum and maximum, else returns the minimum value if the
524operand is below this range or the maximum value if the operand is above this
525range.  That is, `clamp(a, x, b) =  min(max(a, x), b)`.
526
527All three arrays must be the same shape. Alternatively, as a restricted form of
528[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`.
529
530Example with scalar `min` and `max`:
531
532```
533let operand: s32[3] = {-1, 5, 9};
534let min: s32 = 0;
535let max: s32 = 6;
536==>
537Clamp(min, operand, max) = s32[3]{0, 5, 6};
538```
539
540## Collapse
541
542See also
543[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
544and the `tf.reshape` operation.
545
546Collapses dimensions of an array into one dimension.
547
548<b> `Collapse(operand, dimensions)` </b>
549
550Arguments    | Type           | Semantics
551------------ | -------------- | -----------------------------------------------
552`operand`    | `XlaOp`        | array of type T
553`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions.
554
555Collapse replaces the given subset of the operand's dimensions by a single
556dimension. The input arguments are an arbitrary array of type T and a
557compile-time-constant vector of dimension indices. The dimension indices must be
558an in-order (low to high dimension numbers), consecutive subset of T's
559dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension sets, but
560{1, 0} or {0, 2} are not. They are replaced by a single new dimension, in the
561same position in the dimension sequence as those they replace, with the new
562dimension size equal to the product of original dimension sizes. The lowest
563dimension number in `dimensions` is the slowest varying dimension (most major)
564in the loop nest which collapses these dimension, and the highest dimension
565number is fastest varying (most minor). See the `tf.reshape` operator
566if more general collapse ordering is needed.
567
568For example, let v be an array of 24 elements:
569
570```
571let v = f32[4x2x3] {{{10, 11, 12},  {15, 16, 17}},
572{{20, 21, 22},  {25, 26, 27}},
573{{30, 31, 32},  {35, 36, 37}},
574{{40, 41, 42},  {45, 46, 47}}};
575
576// Collapse to a single dimension, leaving one dimension.
577let v012 = Collapse(v, {0,1,2});
578then v012 == f32[24] {10, 11, 12, 15, 16, 17,
57920, 21, 22, 25, 26, 27,
58030, 31, 32, 35, 36, 37,
58140, 41, 42, 45, 46, 47};
582
583// Collapse the two lower dimensions, leaving two dimensions.
584let v01 = Collapse(v, {0,1});
585then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17},
586{20, 21, 22, 25, 26, 27},
587{30, 31, 32, 35, 36, 37},
588{40, 41, 42, 45, 46, 47}};
589
590// Collapse the two higher dimensions, leaving two dimensions.
591let v12 = Collapse(v, {1,2});
592then v12 == f32[8x3] {{10, 11, 12},
593{15, 16, 17},
594{20, 21, 22},
595{25, 26, 27},
596{30, 31, 32},
597{35, 36, 37},
598{40, 41, 42},
599{45, 46, 47}};
600
601```
602
603## CollectivePermute
604
605See also
606[`XlaBuilder::CollectivePermute`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
607
608CollectivePermute is a collective operation that sends and receives data cross
609replicas.
610
611<b> `CollectivePermute(operand, source_target_pairs)` </b>
612
613| Arguments             | Type                    | Semantics                  |
614| --------------------- | ----------------------- | -------------------------- |
615| `operand`             | `XlaOp`                 | n dimensional input array  |
616| `source_target_pairs` | `<int64, int64>` vector | A list of                  |
617:                       :                         : (source_replica_id,        :
618:                       :                         : target_replica_id) pairs.  :
619:                       :                         : For each pair, the operand :
620:                       :                         : is sent from source        :
621:                       :                         : replica to target replica. :
622
623Note that there are the following restrictions on the `source_target_pair`:
624
625-   Any two pairs should not have the same target replica id, and they should
626not have the same source replica id.
627-   If a replica id is not a target in any pair, then the output on that replica
628is a tensor consists of 0(s) with the same shape as the input.
629
630## Concatenate
631
632See also
633[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
634
635Concatenate composes an array from multiple array operands. The array is of the
636same rank as each of the input array operands (which must be of the same rank as
637each other) and contains the arguments in the order that they were specified.
638
639<b> `Concatenate(operands..., dimension)` </b>
640
641| Arguments   | Type                  | Semantics                              |
642| ----------- | --------------------- | -------------------------------------- |
643| `operands`  | sequence of N `XlaOp` | N arrays of type T with dimensions     |
644:             :                       : [L0, L1, ...]. Requires N >= 1.        :
645| `dimension` | `int64`               | A value in the interval `[0, N)` that  |
646:             :                       : names the dimension to be concatenated :
647:             :                       : between the `operands`.                :
648
649With the exception of `dimension` all dimensions must be the same. This is
650because XLA does not support "ragged" arrays. Also note that rank-0 values
651cannot be concatenated (as it's impossible to name the dimension along which the
652concatenation occurs).
653
6541-dimensional example:
655
656```
657Concat({{2, 3}, {4, 5}, {6, 7}}, 0)
658>>> {2, 3, 4, 5, 6, 7}
659```
660
6612-dimensional example:
662
663```
664let a = {
665{1, 2},
666{3, 4},
667{5, 6},
668};
669let b = {
670{7, 8},
671};
672Concat({a, b}, 0)
673>>> {
674{1, 2},
675{3, 4},
676{5, 6},
677{7, 8},
678}
679```
680
681Diagram:
682<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
683<img style="width:100%" src="./images/ops_concatenate.png">
684</div>
685
686## Conditional
687
688See also
689[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
690
691<b> `Conditional(pred, true_operand, true_computation, false_operand,
692false_computation)` </b>
693
694<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
695
696Arguments           | Type             | Semantics
697------------------- | ---------------- | ------------------------------------
698`pred`              | `XlaOp`          | Scalar of type `PRED`
699`true_operand`      | `XlaOp`          | Argument of type \\(T_0\\)
700`true_computation`  | `XlaComputation` | XlaComputation of type \\(T_0 \to S\\)
701`false_operand`     | `XlaOp`          | Argument of type \\(T_1\\)
702`false_computation` | `XlaComputation` | XlaComputation of type \\(T_1 \to S\\)
703
704Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
705is `false`, and returns the result.
706
707The `true_computation` must take in a single argument of type \\(T_0\\) and will
708be invoked with `true_operand` which must be of the same type. The
709`false_computation` must take in a single argument of type \\(T_1\\) and will be
710invoked with `false_operand` which must be of the same type. The type of the
711returned value of `true_computation` and `false_computation` must be the same.
712
713<!-- mdformat on -->
714
715Note that only one of `true_computation` and `false_computation` will be
716executed depending on the value of `pred`.
717
718<b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
719
720<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
721
722| Arguments             | Type                  | Semantics                    |
723| --------------------- | --------------------- | ---------------------------- |
724| `branch_index`        | `XlaOp`               | Scalar of type `S32`         |
725| `branch_computations` | sequence of N         | XlaComputations of type \\(  |
726:                       : `XlaComputation`      : T_0 \to S , T_1 \to S , ..., :
727:                       :                       : T_{N-1} \to S \\)            :
728| `branch_operands`     | sequence of N `XlaOp` | Arguments of type \\( T_0 ,  |
729:                       :                       : T_1 , ..., T_{N-1} \\)       :
730
731<!-- mdformat on -->
732
733Executes `branch_computations[branch_index]`, and returns the result. If
734`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]`
735is executed as the default branch.
736
737Each `branch_computations[b]` must take in a single argument of type `T_b` and
738will be invoked with `branch_operands[b]` which must be of the same type. The
739type of the returned value of each `branch_computations[b]` must be the same.
740
741Note that only one of the `branch_computations` will be executed depending on
742the value of `branch_index`.
743
744## Conv (convolution)
745
746See also
747[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
748
749As ConvWithGeneralPadding, but the padding is specified in a short-hand way as
750either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that
751the output has the same shape as the input when not taking striding into
752account. VALID padding simply means no padding.
753
754## ConvWithGeneralPadding (convolution)
755
756See also
757[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
758
759Computes a convolution of the kind used in neural networks. Here, a convolution
760can be thought of as a n-dimensional window moving across a n-dimensional base
761area and a computation is performed for each possible position of the window.
762
763| Arguments             | Type                     | Semantics                |
764| --------------------- | ------------------------ | ------------------------ |
765| `lhs`                 | `XlaOp`                  | rank n+2 array of inputs |
766| `rhs`                 | `XlaOp`                  | rank n+2 array of kernel |
767:                       :                          : weights                  :
768| `window_strides`      | `ArraySlice<int64>`      | n-d array of kernel      |
769:                       :                          : strides                  :
770| `padding`             | `ArraySlice< pair<int64, | n-d array of (low, high) |
771:                       : int64>>`                 : padding                  :
772| `lhs_dilation`        | `ArraySlice<int64>`      | n-d lhs dilation factor  |
773:                       :                          : array                    :
774| `rhs_dilation`        | `ArraySlice<int64>`      | n-d rhs dilation factor  |
775:                       :                          : array                    :
776| `feature_group_count` | int64                    | the number of feature    |
777:                       :                          : groups                   :
778| `batch_group_count`   | int64                    | the number of batch      |
779:                       :                          : groups                   :
780
781Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
782array describing the base area. This is called the input, even though of course
783the rhs is also an input. In a neural network, these are the input activations.
784The n+2 dimensions are, in this order:
785
786*   `batch`: Each coordinate in this dimension represents an independent input
787for which convolution is carried out.
788*   `z/depth/features`: Each (y,x) position in the base area has a vector
789associated to it, which goes into this dimension.
790*   `spatial_dims`: Describes the `n` spatial dimensions that define the base
791area that the window moves across.
792
793The `rhs` argument is a rank n+2 array describing the convolutional
794filter/kernel/window. The dimensions are, in this order:
795
796*   `output-z`: The `z` dimension of the output.
797*   `input-z`: The size of this dimension times `feature_group_count` should
798equal the size of the `z` dimension in lhs.
799*   `spatial_dims`: Describes the `n` spatial dimensions that define the n-d
800window that moves across the base area.
801
802The `window_strides` argument specifies the stride of the convolutional window
803in the spatial dimensions. For example, if the stride in the first spatial
804dimension is 3, then the window can only be placed at coordinates where the
805first spatial index is divisible by 3.
806
807The `padding` argument specifies the amount of zero padding to be applied to the
808base area. The amount of padding can be negative -- the absolute value of
809negative padding indicates the number of elements to remove from the specified
810dimension before doing the convolution. `padding[0]` specifies the padding for
811dimension `y` and `padding[1]` specifies the padding for dimension `x`. Each
812pair has the low padding as the first element and the high padding as the second
813element. The low padding is applied in the direction of lower indices while the
814high padding is applied in the direction of higher indices. For example, if
815`padding[1]` is `(2,3)` then there will be a padding by 2 zeroes on the left and
816by 3 zeroes on the right in the second spatial dimension. Using padding is
817equivalent to inserting those same zero values into the input (`lhs`) before
818doing the convolution.
819
820The `lhs_dilation` and `rhs_dilation` arguments specify the dilation factor to
821be applied to the lhs and rhs, respectively, in each spatial dimension. If the
822dilation factor in a spatial dimension is d, then d-1 holes are implicitly
823placed between each of the entries in that dimension, increasing the size of the
824array. The holes are filled with a no-op value, which for convolution means
825zeroes.
826
827Dilation of the rhs is also called atrous convolution. For more details, see
828`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed
829convolution. For more details, see `tf.nn.conv2d_transpose`.
830
831The `feature_group_count` argument (default value 1) can be used for grouped
832convolutions. `feature_group_count` needs to be a divisor of both the input and
833the output feature dimension. If `feature_group_count` is greater than 1, it
834means that conceptually the input and output feature dimension and the `rhs`
835output feature dimension are split evenly into `feature_group_count` many
836groups, each group consisting of a consecutive subsequence of features. The
837input feature dimension of `rhs` needs to be equal to the `lhs` input feature
838dimension divided by `feature_group_count` (so it already has the size of a
839group of input features). The i-th groups are used together to compute
840`feature_group_count` many separate convolutions. The results of these
841convolutions are concatenated together in the output feature dimension.
842
843For depthwise convolution the `feature_group_count` argument would be set to the
844input feature dimension, and the filter would be reshaped from
845`[filter_height, filter_width, in_channels, channel_multiplier]` to
846`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more
847details, see `tf.nn.depthwise_conv2d`.
848
849The `batch_group_count` (default value 1) argument can be used for grouped
850filters during backpropagation. `batch_group_count` needs to be a divisor of the
851size of the `lhs` (input) batch dimension. If `batch_group_count` is greater
852than 1, it means that the output batch dimension should be of size `input batch
853/ batch_group_count`. The `batch_group_count` must be a divisor of the output
854feature size.
855
856The output shape has these dimensions, in this order:
857
858*   `batch`: The size of this dimension times `batch_group_count` should equal
859    the size of the `batch` dimension in lhs.
860*   `z`: Same size as `output-z` on the kernel (`rhs`).
861*   `spatial_dims`: One value for each valid placement of the convolutional
862    window.
863
864The valid placements of the convolutional window are determined by the strides
865and the size of the base area after padding.
866
867To describe what a convolution does, consider a 2d convolution, and pick some
868fixed `batch`, `z`, `y`, `x` coordinates in the output. Then `(y,x)` is a
869position of a corner of the window within the base area (e.g. the upper left
870corner, depending on how you interpret the spatial dimensions). We now have a 2d
871window, taken from the base area, where each 2d point is associated to a 1d
872vector, so we get a 3d box. From the convolutional kernel, since we fixed the
873output coordinate `z`, we also have a 3d box. The two boxes have the same
874dimensions, so we can take the sum of the element-wise products between the two
875boxes (similar to a dot product). That is the output value.
876
877Note that if `output-z` is e.g., 5, then each position of the window produces 5
878values in the output into the `z` dimension of the output. These values differ
879in what part of the convolutional kernel is used - there is a separate 3d box of
880values used for each `output-z` coordinate. So you could think of it as 5
881separate convolutions with a different filter for each of them.
882
883Here is pseudo-code for a 2d convolution with padding and striding:
884
885```
886for (b, oz, oy, ox) {  // output coordinates
887  value = 0;
888  for (iz, ky, kx) {  // kernel coordinates and input z
889    iy = oy*stride_y + ky - pad_low_y;
890    ix = ox*stride_x + kx - pad_low_x;
891    if ((iy, ix) inside the base area considered without padding) {
892      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
893    }
894  }
895  output(b, oz, oy, ox) = value;
896}
897```
898
899## ConvertElementType
900
901See also
902[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
903
904Similar to an element-wise `static_cast` in C++, performs an element-wise
905conversion operation from a data shape to a target shape. The dimensions must
906match, and the conversion is an element-wise one; e.g. `s32` elements become
907`f32` elements via an `s32`-to-`f32` conversion routine.
908
909<b> `ConvertElementType(operand, new_element_type)` </b>
910
911Arguments          | Type            | Semantics
912------------------ | --------------- | ---------------------------
913`operand`          | `XlaOp`         | array of type T with dims D
914`new_element_type` | `PrimitiveType` | type U
915
916The dimensions of the operand and the target shape must match. The source and
917destination element types must not be tuples.
918
919A conversion such as `T=s32` to `U=f32` will perform a normalizing int-to-float
920conversion routine such as round-to-nearest-even.
921
922> Note: The precise float-to-int and visa-versa conversions are currently
923> unspecified, but may become additional arguments to the convert operation in
924> the future.  Not all possible conversions have been implemented for all
925>targets.
926
927```
928let a: s32[3] = {0, 1, 2};
929let b: f32[3] = convert(a, f32);
930then b == f32[3]{0.0, 1.0, 2.0}
931```
932
933## CrossReplicaSum
934
935Performs `AllReduce` with a summation computation.
936
937## CustomCall
938
939See also
940[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
941
942Call a user-provided function within a computation.
943
944<b> `CustomCall(target_name, args..., shape)` </b>
945
946| Arguments     | Type                   | Semantics                         |
947| ------------- | ---------------------- | --------------------------------- |
948| `target_name` | `string`               | Name of the function. A call      |
949:               :                        : instruction will be emitted which :
950:               :                        : targets this symbol name.         :
951| `args`        | sequence of N `XlaOp`s | N arguments of arbitrary type,    |
952:               :                        : which will be passed to the       :
953:               :                        : function.                         :
954| `shape`       | `Shape`                | Output shape of the function      |
955
956The function signature is the same, regardless of the arity or type of args:
957
958```
959extern "C" void target_name(void* out, void** in);
960```
961
962For example, if CustomCall is used as follows:
963
964```
965let x = f32[2] {1,2};
966let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}};
967
968CustomCall("myfunc", {x, y}, f32[3x3])
969```
970
971Here is an example of an implementation of `myfunc`:
972
973```
974extern "C" void myfunc(void* out, void** in) {
975  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
976  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
977  EXPECT_EQ(1, x[0]);
978  EXPECT_EQ(2, x[1]);
979  EXPECT_EQ(10, y[0][0]);
980  EXPECT_EQ(20, y[0][1]);
981  EXPECT_EQ(30, y[0][2]);
982  EXPECT_EQ(40, y[1][0]);
983  EXPECT_EQ(50, y[1][1]);
984  EXPECT_EQ(60, y[1][2]);
985  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
986  z[0][0] = x[1] + y[1][0];
987  // ...
988}
989```
990
991The user-provided function must not have side-effects and its execution must be
992idempotent.
993
994> Note: The opaque nature of the user-provided function restricts optimization
995> opportunities for the compiler. Try to express your computation in terms of
996> native XLA ops whenever possible; only use CustomCall as a last resort.
997
998## Dot
999
1000See also
1001[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1002
1003<b> `Dot(lhs, rhs)` </b>
1004
1005Arguments | Type    | Semantics
1006--------- | ------- | ---------------
1007`lhs`     | `XlaOp` | array of type T
1008`rhs`     | `XlaOp` | array of type T
1009
1010The exact semantics of this operation depend on the ranks of the operands:
1011
1012| Input                   | Output                | Semantics               |
1013| ----------------------- | --------------------- | ----------------------- |
1014| vector [n] `dot` vector | scalar                | vector dot product      |
1015: [n]                     :                       :                         :
1016| matrix [m x k] `dot`    | vector [m]            | matrix-vector           |
1017: vector [k]              :                       : multiplication          :
1018| matrix [m x k] `dot`    | matrix [m x n]        | matrix-matrix           |
1019: matrix [k x n]          :                       : multiplication          :
1020
1021The operation performs sum of products over the second dimension of `lhs` (or
1022the first if it has rank 1) and the first dimension of `rhs`. These are the
1023"contracted" dimensions. The contracted dimensions of `lhs` and `rhs` must be of
1024the same size. In practice, it can be used to perform dot products between
1025vectors, vector/matrix multiplications or matrix/matrix multiplications.
1026
1027## DotGeneral
1028
1029See also
1030[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1031
1032<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
1033
1034Arguments           | Type                  | Semantics
1035------------------- | --------------------- | ---------------
1036`lhs`               | `XlaOp`               | array of type T
1037`rhs`               | `XlaOp`               | array of type T
1038`dimension_numbers` | `DotDimensionNumbers` | contracting and batch dimension numbers
1039
1040As Dot, but allows contracting and batch dimension numbers to be specified for
1041both the 'lhs' and 'rhs'.
1042
1043| DotDimensionNumbers Fields | Type                    | Semantics
1044| --------- | ----------------------- | ---------------
1045| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers |
1046| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers |
1047| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers |
1048| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers |
1049
1050DotGeneral performs the sum of products over contracting dimensions specified
1051in 'dimension_numbers'.
1052
1053Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
1054to be the same but must have the same dimension sizes.
1055
1056Example with contracting dimension numbers:
1057
1058```
1059lhs = { {1.0, 2.0, 3.0},
1060{4.0, 5.0, 6.0} }
1061
1062rhs = { {1.0, 1.0, 1.0},
1063{2.0, 2.0, 2.0} }
1064
1065DotDimensionNumbers dnums;
1066dnums.add_lhs_contracting_dimensions(1);
1067dnums.add_rhs_contracting_dimensions(1);
1068
1069DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
1070{15.0, 30.0} }
1071```
1072
1073Associated batch dimension numbers from the 'lhs' and 'rhs' must
1074have the same dimension sizes.
1075
1076Example with batch dimension numbers (batch size 2, 2x2 matrices):
1077
1078```
1079lhs = { { {1.0, 2.0},
1080{3.0, 4.0} },
1081{ {5.0, 6.0},
1082{7.0, 8.0} } }
1083
1084rhs = { { {1.0, 0.0},
1085{0.0, 1.0} },
1086{ {1.0, 0.0},
1087{0.0, 1.0} } }
1088
1089DotDimensionNumbers dnums;
1090dnums.add_lhs_contracting_dimensions(2);
1091dnums.add_rhs_contracting_dimensions(1);
1092dnums.add_lhs_batch_dimensions(0);
1093dnums.add_rhs_batch_dimensions(0);
1094
1095DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
1096{3.0, 4.0} },
1097{ {5.0, 6.0},
1098{7.0, 8.0} } }
1099```
1100
1101| Input                               | Output            | Semantics        |
1102| ----------------------------------- | ----------------- | ---------------- |
1103| [b0, m, k] `dot` [b0, k, n]         | [b0, m, n]        |  batch matmul    |
1104| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n]    |  batch matmul    |
1105
1106It follows that the resulting dimension number starts with the batch dimension,
1107then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs'
1108non-contracting/non-batch dimension.
1109
1110## DynamicSlice
1111
1112See also
1113[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1114
1115DynamicSlice extracts a sub-array from the input array at dynamic
1116`start_indices`. The size of the slice in each dimension is passed in
1117`size_indices`, which specify the end point of exclusive slice intervals in each
1118dimension: [start, start + size). The shape of `start_indices` must be rank ==
11191, with dimension size equal to the rank of `operand`.
1120
1121<b> `DynamicSlice(operand, start_indices, size_indices)` </b>
1122
1123| Arguments       | Type                  | Semantics                          |
1124| --------------- | --------------------- | ---------------------------------- |
1125| `operand`       | `XlaOp`               | N dimensional array of type T      |
1126| `start_indices` | sequence of N `XlaOp` | List of N scalar integers          |
1127:                 :                       : containing the starting indices of :
1128:                 :                       : the slice for each dimension.      :
1129:                 :                       : Value must be greater than or      :
1130:                 :                       : equal to zero.                     :
1131| `size_indices`  | `ArraySlice<int64>`   | List of N integers containing the  |
1132:                 :                       : slice size for each dimension.     :
1133:                 :                       : Each value must be strictly        :
1134:                 :                       : greater than zero, and start +     :
1135:                 :                       : size must be less than or equal to :
1136:                 :                       : the size of the dimension to avoid :
1137:                 :                       : wrapping modulo dimension size.    :
1138
1139The effective slice indices are computed by applying the following
1140transformation for each index `i` in `[1, N)` before performing the slice:
1141
1142```
1143start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
1144```
1145
1146This ensures that the extracted slice is always in-bounds with respect to the
1147operand array. If the slice is in-bounds before the transformation is applied,
1148the transformation has no effect.
1149
11501-dimensional example:
1151
1152```
1153let a = {0.0, 1.0, 2.0, 3.0, 4.0}
1154let s = {2}
1155
1156DynamicSlice(a, s, {2}) produces:
1157{2.0, 3.0}
1158```
1159
11602-dimensional example:
1161
1162```
1163let b =
1164{ {0.0,  1.0,  2.0},
1165{3.0,  4.0,  5.0},
1166{6.0,  7.0,  8.0},
1167{9.0, 10.0, 11.0} }
1168let s = {2, 1}
1169
1170DynamicSlice(b, s, {2, 2}) produces:
1171{ { 7.0,  8.0},
1172{10.0, 11.0} }
1173```
1174## DynamicUpdateSlice
1175
1176See also
1177[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1178
1179DynamicUpdateSlice generates a result which is the value of the input array
1180`operand`, with a slice `update` overwritten at `start_indices`.
1181The shape of `update` determines the shape of the sub-array of the result which
1182is updated.
1183The shape of `start_indices` must be rank == 1, with dimension size equal to
1184the rank of `operand`.
1185
1186<b> `DynamicUpdateSlice(operand, update, start_indices)` </b>
1187
1188| Arguments       | Type                  | Semantics                          |
1189| --------------- | --------------------- | ---------------------------------- |
1190| `operand`       | `XlaOp`               | N dimensional array of type T      |
1191| `update`        | `XlaOp`               | N dimensional array of type T      |
1192:                 :                       : containing the slice update. Each  :
1193:                 :                       : dimension of update shape must be  :
1194:                 :                       : strictly greater than zero, and    :
1195:                 :                       : start + update must be less than   :
1196:                 :                       : or equal to the operand size for   :
1197:                 :                       : each dimension to avoid generating :
1198:                 :                       : out-of-bounds update indices.      :
1199| `start_indices` | sequence of N `XlaOp` | List of N scalar integers          |
1200:                 :                       : containing the starting indices of :
1201:                 :                       : the slice for each dimension.      :
1202:                 :                       : Value must be greater than or      :
1203:                 :                       : equal to zero.                     :
1204
1205The effective slice indices are computed by applying the following
1206transformation for each index `i` in `[1, N)` before performing the slice:
1207
1208```
1209start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
1210```
1211
1212This ensures that the updated slice is always in-bounds with respect to the
1213operand array. If the slice is in-bounds before the transformation is applied,
1214the transformation has no effect.
1215
12161-dimensional example:
1217
1218```
1219let a = {0.0, 1.0, 2.0, 3.0, 4.0}
1220let u = {5.0, 6.0}
1221let s = {2}
1222
1223DynamicUpdateSlice(a, u, s) produces:
1224{0.0, 1.0, 5.0, 6.0, 4.0}
1225```
1226
12272-dimensional example:
1228
1229```
1230let b =
1231{ {0.0,  1.0,  2.0},
1232{3.0,  4.0,  5.0},
1233{6.0,  7.0,  8.0},
1234{9.0, 10.0, 11.0} }
1235let u =
1236{ {12.0,  13.0},
1237{14.0,  15.0},
1238{16.0,  17.0} }
1239
1240let s = {1, 1}
1241
1242DynamicUpdateSlice(b, u, s) produces:
1243{ {0.0,  1.0,  2.0},
1244{3.0, 12.0, 13.0},
1245{6.0, 14.0, 15.0},
1246{9.0, 16.0, 17.0} }
1247```
1248
1249## Element-wise binary arithmetic operations
1250
1251See also
1252[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1253
1254A set of element-wise binary arithmetic operations is supported.
1255
1256<b> `Op(lhs, rhs)` </b>
1257
1258Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul`
1259(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min`
1260(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR).
1261
1262Arguments | Type    | Semantics
1263--------- | ------- | ----------------------------------------
1264`lhs`     | `XlaOp` | left-hand-side operand: array of type T
1265`rhs`     | `XlaOp` | right-hand-side operand: array of type T
1266
1267The arguments' shapes have to be either similar or compatible. See the
1268[broadcasting](broadcasting.md) documentation about what it means for shapes to
1269be compatible. The result of an operation has a shape which is the result of
1270broadcasting the two input arrays. In this variant, operations between arrays of
1271different ranks are *not* supported, unless one of the operands is a scalar.
1272
1273When `Op` is `Rem`, the sign of the result is taken from the dividend, and the
1274absolute value of the result is always less than the divisor's absolute value.
1275
1276Integer division overflow (signed/unsigned division/remainder by zero or signed
1277division/remainder of `INT_SMIN` with `-1`) produces an implementation defined
1278value.
1279
1280An alternative variant with different-rank broadcasting support exists for these
1281operations:
1282
1283<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
1284
1285Where `Op` is the same as above. This variant of the operation should be used
1286for arithmetic operations between arrays of different ranks (such as adding a
1287matrix to a vector).
1288
1289The additional `broadcast_dimensions` operand is a slice of integers used to
1290expand the rank of the lower-rank operand up to the rank of the higher-rank
1291operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to
1292the dimensions of the higher-rank shape. The unmapped dimensions of the expanded
1293shape are filled with dimensions of size one. Degenerate-dimension broadcasting
1294then broadcasts the shapes along these degenerate dimensions to equalize the
1295shapes of both operands. The semantics are described in detail on the
1296[broadcasting page](broadcasting.md).
1297
1298## Element-wise comparison operations
1299
1300See also
1301[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1302
1303A set of standard element-wise binary comparison operations is supported. Note
1304that standard IEEE 754 floating-point comparison semantics apply when comparing
1305floating-point types.
1306
1307<b> `Op(lhs, rhs)` </b>
1308
1309Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
1310(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
1311(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder,
1312GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities,
1313except that they additionally support a total order over the floating point
1314numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN.
1315
1316Arguments | Type    | Semantics
1317--------- | ------- | ----------------------------------------
1318`lhs`     | `XlaOp` | left-hand-side operand: array of type T
1319`rhs`     | `XlaOp` | right-hand-side operand: array of type T
1320
1321The arguments' shapes have to be either similar or compatible. See the
1322[broadcasting](broadcasting.md) documentation about what it means for shapes to
1323be compatible. The result of an operation has a shape which is the result of
1324broadcasting the two input arrays with the element type `PRED`. In this variant,
1325operations between arrays of different ranks are *not* supported, unless one of
1326the operands is a scalar.
1327
1328An alternative variant with different-rank broadcasting support exists for these
1329operations:
1330
1331<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
1332
1333Where `Op` is the same as above. This variant of the operation should be used
1334for comparison operations between arrays of different ranks (such as adding a
1335matrix to a vector).
1336
1337The additional `broadcast_dimensions` operand is a slice of integers specifying
1338the dimensions to use for broadcasting the operands. The semantics are described
1339in detail on the [broadcasting page](broadcasting.md).
1340
1341## Element-wise unary functions
1342
1343XlaBuilder supports these element-wise unary functions:
1344
1345<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
1346
1347<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`.
1348
1349<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`.
1350
1351<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`.
1352
1353<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
1354
1355<b>`Imag(operand)`</b> Element-wise imaginary part of a complex (or real)
1356shape. `x -> imag(x)`. If the operand is a floating point type, returns 0.
1357
1358<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite,
1359i.e., is not positive or negative infinity, and is not `NaN`. Returns an array
1360of `PRED` values with the same shape as the input, where each element is `true`
1361if and only if the corresponding input element is finite.
1362
1363<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
1364
1365<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
1366
1367<b>`Logistic(operand)`</b> Element-wise logistic function computation `x ->
1368logistic(x)`.
1369
1370<b>`PopulationCount(operand)`</b> Computes the number of bits set in each
1371element of `operand`.
1372
1373<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
1374
1375<b>`Real(operand)`</b> Element-wise real part of a complex (or real) shape.
1376`x -> real(x)`. If the operand is a floating point type, returns the same value.
1377
1378<b>`Rsqrt(operand)`</b> Element-wise reciprocal of square root operation
1379`x -> 1.0 / sqrt(x)`.
1380
1381<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
1382
1383$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$
1384
1385using the comparison operator of the element type of `operand`.
1386
1387<b>`Sqrt(operand)`</b> Element-wise square root operation `x -> sqrt(x)`.
1388
1389<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`.
1390
1391<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
1392
1393<b>`Round(operand)`</b> Element-wise rounding, ties away from zero.
1394
1395<b>`RoundNearestEven(operand)`</b> Element-wise rounding, ties to nearest even.
1396
1397Arguments | Type    | Semantics
1398--------- | ------- | ---------------------------
1399`operand` | `XlaOp` | The operand to the function
1400
1401The function is applied to each element in the `operand` array, resulting in an
1402array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
1403
1404## Fft
1405
1406The XLA FFT operation implements the forward and inverse Fourier Transforms for
1407real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are
1408supported.
1409
1410See also
1411[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1412
1413| Arguments    | Type                | Semantics                |
1414| ------------ | ------------------- | ------------------------ |
1415| `operand`    | `XlaOp`             | The array we are Fourier |
1416:              :                     : transforming.            :
1417| `fft_type`   | `FftType`           | See the table below.     |
1418| `fft_length` | `ArraySlice<int64>` | The time-domain lengths  |
1419:              :                     : of the axes being        :
1420:              :                     : transformed. This is     :
1421:              :                     : needed in particular for :
1422:              :                     : IRFFT to right-size the  :
1423:              :                     : innermost axis, since    :
1424:              :                     : `RFFT(fft_length=[16])`  :
1425:              :                     : has the same output      :
1426:              :                     : shape as                 :
1427:              :                     : `RFFT(fft_length=[17])`. :
1428
1429| `FftType` | Semantics                                                        |
1430| --------- | ---------------------------------------------------------------- |
1431| `FFT`     | Forward complex-to-complex FFT. Shape is unchanged.              |
1432| `IFFT`    | Inverse complex-to-complex FFT. Shape is unchanged.              |
1433| `RFFT`    | Forward real-to-complex FFT. Shape of the innermost axis is      |
1434:           : reduced to `fft_length[-1] // 2 + 1` if `fft_length[-1]` is a    :
1435:           : non-zero value, omitting the reversed conjugate part of the      :
1436:           : transformed signal beyond the Nyquist frequency.                 :
1437| `IRFFT`   | Inverse real-to-complex FFT (i.e. takes complex, returns real).  |
1438:           : Shape of the innermost axis is expanded to `fft_length[-1]` if   :
1439:           : `fft_length[-1]` is a non-zero value, inferring the part of the  :
1440:           : transformed signal beyond the Nyquist frequency from the reverse :
1441:           : conjugate of the `1` to `fft_length[-1] // 2 + 1` entries.       :
1442
1443#### Multidimensional FFT
1444
1445When more than 1 `fft_length` is provided, this is equivalent to applying a
1446cascade of FFT operations to each of the innermost axes. Note that for the
1447real->complex and complex->real cases, the innermost axis transform is
1448(effectively) performed first (RFFT; last for IRFFT), which is why the innermost
1449axis is the one which changes size. Other axis transforms will then be
1450complex->complex.
1451
1452#### Implementation details
1453
1454CPU FFT is backed by Eigen's TensorFFT. GPU FFT uses cuFFT.
1455
1456## Gather
1457
1458The XLA gather operation stitches together several slices (each slice at a
1459potentially different runtime offset) of an input array.
1460
1461### General Semantics
1462
1463See also
1464[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1465For a more intuitive description, see the "Informal Description" section below.
1466
1467<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b>
1468
1469| Arguments              | Type                | Semantics                     |
1470| ---------------------- | ------------------- | ----------------------------- |
1471| `operand`              | `XlaOp`             | The array we’re gathering     |
1472:                        :                     : from.                         :
1473| `start_indices`        | `XlaOp`             | Array containing the starting |
1474:                        :                     : indices of the slices we      :
1475:                        :                     : gather.                       :
1476| `index_vector_dim`     | `int64`             | The dimension in              |
1477:                        :                     : `start_indices` that          :
1478:                        :                     : "contains" the starting       :
1479:                        :                     : indices. See below for a      :
1480:                        :                     : detailed description.         :
1481| `offset_dims`          | `ArraySlice<int64>` | The set of dimensions in the  |
1482:                        :                     : output shape that offset into :
1483:                        :                     : an array sliced from operand. :
1484| `slice_sizes`          | `ArraySlice<int64>` | `slice_sizes[i]` is the       |
1485:                        :                     : bounds for the slice on       :
1486:                        :                     : dimension `i`.                :
1487| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each |
1488:                        :                     : slice that are collapsed      :
1489:                        :                     : away. These dimensions must   :
1490:                        :                     : have size 1.                  :
1491| `start_index_map`      | `ArraySlice<int64>` | A map that describes how to   |
1492:                        :                     : map indices in                :
1493:                        :                     : `start_indices` to legal      :
1494:                        :                     : indices into operand.         :
1495| `indices_are_sorted`   | `bool`              | Whether the indices are       |
1496:                        :                     : guaranteed to be sorted by    :
1497:                        :                     : the caller.                   :
1498| `unique_indices`       | `bool`              | Whether the indices are       |
1499:                        :                     : guaranteed to be unique by    :
1500:                        :                     : the caller.                   :
1501
1502For convenience, we label dimensions in the output array not in `offset_dims`
1503as `batch_dims`.
1504
1505The output is an array of rank `batch_dims.size` + `offset_dims.size`.
1506
1507The `operand.rank` must equal the sum of `offset_dims.size` and
1508`collapsed_slice_dims.size`. Also, `slice_sizes.size` has to be equal to
1509`operand.rank`.
1510
1511If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
1512`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
1513shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
1514shape of `start_indices` to be `[6,7,1]`).
1515
1516The bounds for the output array along dimension `i` is computed as follows:
1517
15181. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for
1519some `k`) then we pick the corresponding dimension bounds out of
1520`start_indices.shape`, skipping `index_vector_dim` (i.e. pick
1521`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
1522`start_indices.shape.dims`[`k`+`1`] otherwise).
1523
15242. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for
1525some `k`) then we pick the corresponding bound out of `slice_sizes` after
1526accounting for `collapsed_slice_dims` (i.e. we pick
1527`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
1528with the bounds at indices `collapsed_slice_dims` removed).
1529
1530Formally, the operand index `In` corresponding to a given output index `Out` is
1531calculated as follows:
1532
15331.  Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a
1534    vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
1535    Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
1536    this is well defined even if `G` is empty -- if `G` is empty then `S` =
1537    `start_indices`.
1538
15392.  Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
1540    scattering `S` using `start_index_map`. More precisely:
1541
1542    1.  `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
1543        `start_index_map.size`.
1544
1545    2.  `S`<sub>`in`</sub>[`_`] = `0` otherwise.
1546
15473.  Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
1548    at the offset dimensions in `Out` according to the `collapsed_slice_dims`
1549    set. More precisely:
1550
1551    1.  `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] =
1552        `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
1553        (`remapped_offset_dims` is defined below).
1554
1555    2.  `O`<sub>`in`</sub>[`_`] = `0` otherwise.
1556
15574.  `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
1558    addition.
1559
1560`remapped_offset_dims` is a monotonic function with domain [`0`,
1561`offset_dims.size`) and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So
1562if, e.g., `offset_dims.size` is `4`, `operand.rank` is `6` and
1563`collapsed_slice_dims` is {`0`, `2`} then `remapped_offset_dims` is {`0`→`1`,
1564`1`→`3`, `2`→`4`, `3`→`5`}.
1565
1566If `indices_are_sorted` is set to true then XLA can assume that `start_indices`
1567are sorted (in ascending `start_index_map` order) by the user. If they are not
1568then the semantics is implementation defined.
1569
1570If `unique_indices` is set to true then XLA can assume that all element
1571scattered to are unique. So XLA could use non-atomic operations. If
1572`unique_indices` is set to true and the indices being scattered to are not
1573unique then the semantics is implementation defined.
1574
1575### Informal Description and Examples
1576
1577Informally, every index `Out` in the output array corresponds to an element `E`
1578in the operand array, computed as follows:
1579
1580-   We use the batch dimensions in `Out` to look up a starting index from
1581    `start_indices`.
1582
1583-   We use `start_index_map` to map the starting index (whose size may be less
1584    than operand.rank) to a "full" starting index into the `operand`.
1585
1586-   We dynamic-slice out a slice with size `slice_sizes` using the full starting
1587    index.
1588
1589-   We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
1590    Since all collapsed slice dimensions must have a bound of 1, this reshape is
1591    always legal.
1592
1593-   We use the offset dimensions in `Out` to index into this slice to get the
1594    input element, `E`, corresponding to output index `Out`.
1595
1596`index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples
1597that follow. More interesting values for `index_vector_dim` do not change the
1598operation fundamentally, but make the visual representation more cumbersome.
1599
1600To get an intuition on how all of the above fits together, let's look at an
1601example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array.  The
1602position of a slice into the `[16,11]` array can be represented as an index
1603vector of shape `S64[2]`, so the set of 5 positions can be represented as a
1604`S64[5,2]` array.
1605
1606The behavior of the gather operation can then be depicted as an index
1607transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in
1608the output shape, and maps it to an element in the input array in the following
1609way:
1610
1611<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1612<img style="width:100%" src="./images/ops_xla_gather_0.svg">
1613</div>
1614
1615We first select an (`X`,`Y`) vector from the gather indices array using `G`.
1616The element in the output array at index
1617[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input
1618array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>].
1619
1620`slice_sizes` is `[8,6]`, which decides the range of O<sub>`0`</sub> and
1621O<sub>`1`</sub>, and this in turn decides the bounds of the slice.
1622
1623This gather operation acts as a batch dynamic slice with `G` as the batch
1624dimension.
1625
1626The gather indices may be multidimensional.  For instance, a more general
1627version of the example above using a "gather indices" array of shape `[4,5,2]`
1628would translate indices like this:
1629
1630<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1631<img style="width:100%" src="./images/ops_xla_gather_1.svg">
1632</div>
1633
1634Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
1635`G`<sub>`1`</sub> as the batch dimensions.  The slice size is still `[8,6]`.
1636
1637The gather operation in XLA generalizes the informal semantics outlined above in
1638the following ways:
1639
16401. We can configure which dimensions in the output shape are the offset
1641dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in
1642the last example).  The output batch dimensions (dimensions containing
1643`G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
1644the output dimensions that are not offset dimensions.
1645
16462. The number of output offset dimensions explicitly present in the output
1647shape may be smaller than the input rank.  These "missing" dimensions, which
1648are listed explicitly as `collapsed_slice_dims`, must have a slice size of
1649`1`.  Since they have a slice size of `1` the only valid index for them is
1650`0` and eliding them does not introduce ambiguity.
1651
16523. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last
1653example) may have fewer elements than the input array rank, and an explicit
1654mapping dictates how the index should be expanded to have the same rank as
1655the input.
1656
1657As a final example, we use (2) and (3) to implement `tf.gather_nd`:
1658
1659<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1660<img style="width:100%" src="./images/ops_xla_gather_2.svg">
1661</div>
1662
1663`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
1664from the gather indices array as usual, except the starting index has only one
1665element, `X`. Similarly, there is only one output offset index with the value
1666`O`<sub>`0`</sub>. However, before being used as indices into the input array,
1667these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
1668the formal description) and "Offset Mapping" (`remapped_offset_dims` in the
1669formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively,
1670adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
1671[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
1672[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
1673the semantics for `tf.gather_nd`.
1674
1675`slice_sizes` for this case is `[1,11]`.  Intuitively this means that every
1676index `X` in the gather indices array picks an entire row and the result is the
1677concatenation of all these rows.
1678
1679## GetDimensionSize
1680
1681See also
1682[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1683
1684Returns the size of the given dimension of the operand. The operand must be
1685array shaped.
1686
1687<b> `GetDimensionSize(operand, dimension)` </b>
1688
1689| Arguments   | Type    | Semantics                                           |
1690| ----------- | ------- | --------------------------------------------------- |
1691| `operand`   | `XlaOp` | n dimensional input array                           |
1692| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the |
1693:             :         : dimension                                           :
1694
1695## SetDimensionSize
1696
1697See also
1698[`XlaBuilder::SetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1699
1700Sets the dynamic size of XlaOp's given dimension. The operand must be
1701array shaped.
1702
1703<b> `SetDimensionSize(operand, size, dimension)` </b>
1704
1705| Arguments   | Type    | Semantics                                           |
1706| ----------- | ------- | --------------------------------------------------- |
1707| `operand`   | `XlaOp` | n dimensional input array.                          |
1708| `size`      | `XlaOp` | int32 representing the runtime dynamic size.        |
1709| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the |
1710:             :         : dimension.                                          :
1711
1712Pass through the operand as result, with dynamic dimension tracked by the
1713compiler.
1714
1715Padded values will be ignored by downstream reduction ops.
1716
1717```
1718let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
1719let five: s32 = 5;
1720let six: s32 = 6;
1721
1722// Setting dynamic dimension size doesn't change the upper bound of the static
1723// shape.
1724let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
1725let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);
1726
1727// sum == 1 + 2 + 3 + 4 + 5
1728let sum:f32[] = reduce_sum(padded_v_five);
1729// product == 1 * 2 * 3 * 4 * 5
1730let product:f32[] = reduce_product(padded_v_five);
1731
1732// Changing padding size will yield different result.
1733// sum == 1 + 2 + 3 + 4 + 5 + 6
1734let sum':f32[] = reduce_sum(padded_v_six);
1735```
1736
1737## GetTupleElement
1738
1739See also
1740[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1741
1742Indexes into a tuple with a compile-time-constant value.
1743
1744The value must be a compile-time-constant so that shape inference can determine
1745the type of the resulting value.
1746
1747This is analogous to `std::get<int N>(t)` in C++. Conceptually:
1748
1749```
1750let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
1751let s: s32 = 5;
1752let t: (f32[10], s32) = tuple(v, s);
1753let element_1: s32 = gettupleelement(t, 1);  // Inferred shape matches s32.
1754```
1755
1756See also `tf.tuple`.
1757
1758## Infeed
1759
1760See also
1761[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1762
1763<b> `Infeed(shape)` </b>
1764
1765| Argument | Type    | Semantics                                             |
1766| -------- | ------- | ----------------------------------------------------- |
1767| `shape`  | `Shape` | Shape of the data read from the Infeed interface. The |
1768:          :         : layout field of the shape must be set to match the    :
1769:          :         : layout of the data sent to the device; otherwise its  :
1770:          :         : behavior is undefined.                                :
1771
1772Reads a single data item from the implicit Infeed streaming interface of the
1773device, interpreting the data as the given shape and its layout, and returns a
1774`XlaOp` of the data. Multiple Infeed operations are allowed in a
1775computation, but there must be a total order among the Infeed operations. For
1776example, two Infeeds in the code below have a total order since there is a
1777dependency between the while loops.
1778
1779```
1780result1 = while (condition, init = init_value) {
1781  Infeed(shape)
1782}
1783
1784result2 = while (condition, init = result1) {
1785  Infeed(shape)
1786}
1787```
1788
1789Nested tuple shapes are not supported. For an empty tuple shape, the Infeed
1790operation is effectively a no-op and proceeds without reading any data from the
1791Infeed of the device.
1792
1793> Note: We plan to allow multiple Infeed operations without a total order, in
1794> which case the compiler will provide information about how the Infeed
1795> operations are serialized in the compiled program.
1796
1797## Iota
1798
1799See also
1800[`XlaBuilder::Iota`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1801
1802<b> `Iota(shape, iota_dimension)` </b>
1803
1804Builds a constant literal on device rather than a potentially large host
1805transfer. Creates an array that has specified shape and holds values starting at
1806zero and incrementing by one along the specified dimension. For floating-point
1807types, the produced array is equivalent to `ConvertElementType(Iota(...))` where
1808the `Iota` is of integral type and the conversion is to the floating-point type.
1809
1810Arguments        | Type    | Semantics
1811---------------- | ------- | --------------------------------------
1812`shape`          | `Shape` | Shape of the array created by `Iota()`
1813`iota_dimension` | `int64` | The dimension to increment along.
1814
1815For example, `Iota(s32[4, 8], 0)` returns
1816
1817```
1818  [[0, 0, 0, 0, 0, 0, 0, 0 ],
1819   [1, 1, 1, 1, 1, 1, 1, 1 ],
1820   [2, 2, 2, 2, 2, 2, 2, 2 ],
1821   [3, 3, 3, 3, 3, 3, 3, 3 ]]
1822```
1823
1824`Iota(s32[4, 8], 1)` returns
1825
1826```
1827  [[0, 1, 2, 3, 4, 5, 6, 7 ],
1828   [0, 1, 2, 3, 4, 5, 6, 7 ],
1829   [0, 1, 2, 3, 4, 5, 6, 7 ],
1830   [0, 1, 2, 3, 4, 5, 6, 7 ]]
1831```
1832
1833## Map
1834
1835See also
1836[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1837
1838<b> `Map(operands..., computation)` </b>
1839
1840| Arguments         | Type                   | Semantics                      |
1841| ----------------- | ---------------------- | ------------------------------ |
1842| `operands`        | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} |
1843| `computation`     | `XlaComputation`       | computation of type `T_0, T_1, |
1844:                   :                        : ..., T_{N + M -1} -> S` with N :
1845:                   :                        : parameters of type T and M of  :
1846:                   :                        : arbitrary type                 :
1847| `dimensions`      | `int64` array          | array of map dimensions        |
1848
1849Applies a scalar function over the given `operands` arrays, producing an array
1850of the same dimensions where each element is the result of the mapped function
1851applied to the corresponding elements in the input arrays.
1852
1853The mapped function is an arbitrary computation with the restriction that it has
1854N inputs of scalar type `T` and a single output with type `S`. The output has
1855the same dimensions as the operands except that the element type T is replaced
1856with S.
1857
1858For example: `Map(op1, op2, op3, computation, par1)` maps `elem_out <-
1859computation(elem1, elem2, elem3, par1)` at each (multi-dimensional) index in the
1860input arrays to produce the output array.
1861
1862## OptimizationBarrier
1863
1864Blocks any optimization pass from moving computations across the barrier.
1865
1866Ensures that all inputs are evaluated before any operators that depend on the
1867barrier's outputs.
1868
1869## Pad
1870
1871See also
1872[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1873
1874<b> `Pad(operand, padding_value, padding_config)` </b>
1875
1876| Arguments        | Type            | Semantics                               |
1877| ---------------- | --------------- | --------------------------------------- |
1878| `operand`        | `XlaOp`         | array of type `T`                       |
1879| `padding_value`  | `XlaOp`         | scalar of type `T` to fill in the added |
1880:                  :                 : padding                                 :
1881| `padding_config` | `PaddingConfig` | padding amount on both edges (low,      |
1882:                  :                 : high) and between the elements of each  :
1883:                  :                 : dimension                               :
1884
1885Expands the given `operand` array by padding around the array as well as between
1886the elements of the array with the given `padding_value`. `padding_config`
1887specifies the amount of edge padding and the interior padding for each
1888dimension.
1889
1890`PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains
1891three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and
1892`interior_padding`.
1893
1894`edge_padding_low` and `edge_padding_high` specify the amount of padding added
1895at the low-end (next to index 0) and the high-end (next to the highest index) of
1896each dimension respectively. The amount of edge padding can be negative -- the
1897absolute value of negative padding indicates the number of elements to remove
1898from the specified dimension.
1899
1900`interior_padding` specifies the amount of padding added between any two
1901elements in each dimension; it may not be negative.  Interior padding occurs
1902logically before edge padding, so in the case of negative edge padding, elements
1903are removed from the interior-padded operand.
1904
1905This operation is a no-op if the edge padding pairs are all (0, 0) and the
1906interior padding values are all 0. The figure below shows examples of different
1907`edge_padding` and `interior_padding` values for a two-dimensional array.
1908
1909<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1910  <img style="width:100%" src="./images/ops_pad.png">
1911</div>
1912
1913## Recv
1914
1915See also
1916[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1917
1918<b> `Recv(shape, channel_handle)` </b>
1919
1920| Arguments        | Type            | Semantics                            |
1921| ---------------- | --------------- | ------------------------------------ |
1922| `shape`          | `Shape`         | shape of the data to receive         |
1923| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
1924
1925Receives data of the given shape from a `Send` instruction in another
1926computation that shares the same channel handle. Returns a
1927XlaOp for the received data.
1928
1929The client API of `Recv` operation represents synchronous communication.
1930However, the instruction is internally decomposed into 2 HLO instructions
1931(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
1932[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
1933
1934<b>`Recv(const Shape& shape, int64 channel_id)`</b>
1935
1936Allocates resources required to receive data from a `Send` instruction with the
1937same channel_id. Returns a context for the allocated resources, which is used
1938by a following `RecvDone` instruction to wait for the completion of the data
1939transfer. The context is a tuple of {receive buffer (shape), request identifier
1940(U32)} and it can only be used by a `RecvDone` instruction.
1941
1942<b> `RecvDone(HloInstruction context)` </b>
1943
1944Given a context created by a `Recv` instruction, waits for the data transfer to
1945complete and returns the received data.
1946
1947## Reduce
1948
1949See also
1950[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1951
1952Applies a reduction function to one or more arrays in parallel.
1953
1954<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
1955
1956| Arguments     | Type                  | Semantics                        |
1957| ------------- | --------------------- | -------------------------------- |
1958| `operands`    | Sequence of N `XlaOp` | N arrays of types `T_0, ...,     |
1959:               :                       : T_{N-1}`.                        :
1960| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ...,    |
1961:               :                       : T_{N-1}`.                        :
1962| `computation` | `XlaComputation`      | computation of type `T_0, ...,   |
1963:               :                       : T_{N-1}, T_0, ..., T_{N-1} ->`   :
1964:               :                       : `Collate(T_0, ..., T_{N-1})`.    :
1965| `dimensions`  | `int64` array         | unordered array of dimensions to |
1966:               :                       : reduce.                          :
1967
1968Where:
1969
1970*   N is required to be greater or equal to 1.
1971*   The computation has to be "roughly" associative (see below).
1972*   All input arrays must have the same dimensions.
1973*   All initial values have to form an identity under `computation`.
1974*   If `N = 1`, `Collate(T)` is `T`.
1975*   If `N > 1`, `Collate(T_0, ..., T_{N-1})` is a tuple of `N` elements of type
1976    `T`.
1977
1978This operation reduces one or more dimensions of each input array into scalars.
1979The rank of each returned array is `rank(operand) - len(dimensions)`. The output
1980of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type `T_i`, the
1981dimensions of which are described below.
1982
1983Different backends are allowed to reassociate the reduction computation.  This
1984can lead to numerical differences, as some reduction functions like addition are
1985not associative for floats.
1986However, if the range of the data is limited, floating-point addition is close
1987enough to being associative for most practical uses.
1988
1989### Examples
1990
1991When reducing across one dimension in a single 1D array with values `[10, 11,
199212, 13]`, with reduction function `f` (this is `computation`) then that could be
1993computed as
1994
1995`f(10, f(11, f(12, f(init_value, 13)))`
1996
1997but there are also many other possibilities, e.g.
1998
1999`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))`
2000
2001The following is a rough pseudo-code example of how reduction could be
2002implemented, using summation as the reduction computation with an initial value
2003of 0.
2004
2005```python
2006result_shape <- remove all dims in dimensions from operand_shape
2007
2008# Iterate over all elements in result_shape. The number of r's here is equal
2009# to the rank of the result
2010for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
2011  # Initialize this result element
2012  result[r0, r1...] <- 0
2013
2014  # Iterate over all the reduction dimensions
2015  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
2016    # Increment the result element with the value of the operand's element.
2017    # The index of the operand's element is constructed from all ri's and di's
2018    # in the right order (by construction ri's and di's together index over the
2019    # whole operand shape).
2020    result[r0, r1...] += operand[ri... di]
2021```
2022
2023Here's an example of reducing a 2D array (matrix). The shape has rank 2,
2024dimension 0 of size 2 and dimension 1 of size 3:
2025
2026<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2027  <img style="width:35%" src="./images/ops_2d_matrix.png">
2028</div>
2029
2030Results of reducing dimensions 0 or 1 with an "add" function:
2031
2032<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2033  <img style="width:35%" src="./images/ops_reduce_from_2d_matrix.png">
2034</div>
2035
2036Note that both reduction results are 1D arrays. The diagram shows one as column
2037and another as row just for visual convenience.
2038
2039For a more complex example, here is a 3D array. Its rank is 3, dimension 0 of
2040size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the
2041values 1 to 6 are replicated across dimension 0.
2042
2043<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2044  <img style="width:35%" src="./images/ops_reduce_from_3d_matrix.png">
2045</div>
2046
2047Similarly to the 2D example, we can reduce just one dimension. If we reduce
2048dimension 0, for example, we get a rank-2 array where all values across
2049dimension 0 were folded into a scalar:
2050
2051```text
2052|  4   8  12 |
2053| 16  20  24 |
2054```
2055
2056If we reduce dimension 2, we also get a rank-2 array where all values across
2057dimension 2 were folded into a scalar:
2058
2059```text
2060| 6  15 |
2061| 6  15 |
2062| 6  15 |
2063| 6  15 |
2064```
2065
2066Note that the relative order between the remaining dimensions in the input is
2067preserved in the output, but some dimensions may get assigned new numbers (since
2068the rank changes).
2069
2070We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces
2071the 1D array `[20, 28, 36]`.
2072
2073Reducing the 3D array over all its dimensions produces the scalar `84`.
2074
2075### Variadic Reduce
2076
2077When `N > 1`, reduce function application is slightly more complex, as it is
2078applied simultaneously to all inputs. The operands are supplied to the
2079computation in the following order:
2080
2081*   Running reduced value for the first operand
2082*   ...
2083*   Running reduced value for the N'th operand
2084*   Input value for the first operand
2085*   ...
2086*   Input value for the N'th operand
2087
2088For example, consider the following reduction function, which can be used to
2089compute the max and the argmax of a 1-D array in parallel:
2090
2091```python
2092f: (Float, Int, Float, Int) -> Float, Int
2093f(max, argmax, value, index):
2094  if value >= max:
2095    return (value, index)
2096  else:
2097    return (max, argmax)
2098```
2099
2100For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
2101`I_V = Float, I_K =  Int`, the result `f_(N-1)` of reducing across the only
2102input dimension is equivalent to the following recursive application:
2103
2104```
2105f_0 = f(I_V, I_K, V_0, K_0)
2106f_1 = f(f_0.first, f_0.second, V_1, K_1)
2107...
2108f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
2109```
2110
2111Applying this reduction to an array of values, and an array of sequential
2112indices (i.e. iota), will co-iterate over the arrays, and return a tuple
2113containing the maximal value and the matching index.
2114
2115## ReducePrecision
2116
2117See also
2118[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2119
2120Models the effect of converting floating-point values to a lower-precision
2121format (such as IEEE-FP16) and back to the original format.  The number of
2122exponent and mantissa bits in the lower-precision format can be specified
2123arbitrarily, although all bit sizes may not be supported on all hardware
2124implementations.
2125
2126<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b>
2127
2128Arguments       | Type    | Semantics
2129--------------- | ------- | -------------------------------------------------
2130`operand`       | `XlaOp` | array of floating-point type `T`.
2131`exponent_bits` | `int32` | number of exponent bits in lower-precision format
2132`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format
2133
2134The result is an array of type `T`.  The input values are rounded to the nearest
2135value representable with the given number of mantissa bits (using "ties to even"
2136semantics), and any values that exceed the range specified by the number of
2137exponent bits are clamped to positive or negative infinity.  `NaN` values are
2138retained, although they may be converted to canonical `NaN` values.
2139
2140The lower-precision format must have at least one exponent bit (in order to
2141distinguish a zero value from an infinity, since both have a zero mantissa), and
2142must have a non-negative number of mantissa bits.  The number of exponent or
2143mantissa bits may exceed the corresponding value for type `T`; the corresponding
2144portion of the conversion is then simply a no-op.
2145
2146## ReduceScatter
2147
2148See also
2149[`XlaBuilder::ReduceScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2150
2151ReduceScatter is a collective operation that effectively does an AllReduce and
2152then scatters the result by splitting it into `shard_count` blocks along the
2153`scatter_dimension` and replica `i` in the replica group receives the `ith`
2154shard.
2155
2156<b> `ReduceScatter(operand, computation, scatter_dim, shard_count,
2157replica_group_ids, channel_id)` </b>
2158
2159| Arguments           | Type                 | Semantics                     |
2160| ------------------- | -------------------- | ----------------------------- |
2161| `operand`           | `XlaOp`              | Array or a non-empty tuple of |
2162:                     :                      : arrays to reduce across       :
2163:                     :                      : replicas.                     :
2164| `computation`       | `XlaComputation`     | Reduction computation         |
2165| `scatter_dimension` | `int64`              | Dimension to scatter.         |
2166| `shard_count`       | `int64`              | Number of blocks to split     |
2167:                     :                      : `scatter_dimension`           :
2168| `replica_groups`    | vector of vectors of | Groups between which the      |
2169:                     : `int64`              : reductions are performed      :
2170| `channel_id`        | optional `int64`     | Optional channel ID for       |
2171:                     :                      : cross-module communication    :
2172
2173-   When `operand` is a tuple of arrays, the reduce-scatter is performed on each
2174    element of the tuple.
2175-   `replica_groups` is a list of replica groups between which the reduction is
2176    performed (replica id for the current replica can be retrieved using
2177    [`ReplicaId`](#replicaid)). The order of replicas in each group determines
2178    the order in which the all-reduce result will be scattered. `replica_groups`
2179    must either be empty (in which case all replicas belong to a single group),
2180    or contain the same number of elements as the number of replicas. When there
2181    are more than one replica groups, they all must be of the same size. For
2182    example, `replica_groups = {0, 2}, {1, 3}` performs reduction between the
2183    replicas `0` and `2`, and `1` and `3` and then scatters the result.
2184-   `shard_count` is the size of each replica group. We need this in cases where
2185    `replica_groups` are empty. If `replica_groups` is not empty, `shard_count`
2186    must be equal to the size of each replica group.
2187-   `channel_id` is used for cross-module communication: only `reduce-scatter`
2188    operations with the same `channel_id` can communicate with each other.
2189
2190The output shape is the input shape with the `scatter_dimension` made
2191`shard_count` times smaller. For example, if there are two replicas and the
2192operand has the value `[1.0, 2.25]` and `[3.0, 5.25]` respectively on the two
2193replicas, then the output value from this op where `scatter_dim` is `0` will be
2194`[4.0]` for the first replica and `[7.5]` for the second replica.
2195
2196## ReduceWindow
2197
2198See also
2199[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2200
2201Applies a reduction function to all elements in each window of a sequence of N
2202multi-dimensional arrays, producing a single or a tuple of N multi-dimensional
2203arrays as output. Each output array has the same number of elements as the
2204number of valid positions of the window. A pooling layer can be expressed as a
2205`ReduceWindow`. Similar to [`Reduce`](#reduce), the applied `computation` is
2206always passed the `init_values` on the left-hand side.
2207
2208<b> `ReduceWindow(operands..., init_values..., computation, window_dimensions,
2209window_strides, padding)` </b>
2210
2211| Arguments           | Type                | Semantics                        |
2212| ------------------- | ------------------- | -------------------------------- |
2213| `operands`          | `N XlaOps`          | A sequence of N                  |
2214:                     :                     : multi-dimensional arrays of      :
2215:                     :                     : types `T_0,..., T_{N-1}`, each   :
2216:                     :                     : representing the base area on    :
2217:                     :                     : which the window is placed.      :
2218| `init_values`       | `N XlaOps`          | The N starting values for the    |
2219:                     :                     : reduction, one for each of the N :
2220:                     :                     : operands. See [Reduce](#reduce)  :
2221:                     :                     : for details.                     :
2222| `computation`       | `XlaComputation`    | Reduction function of type `T_0, |
2223:                     :                     : ..., T_{N-1}, T_0, ..., T_{N-1}  :
2224:                     :                     : -> Collate(T_0, ..., T_{N-1})`,  :
2225:                     :                     : to apply to elements in each     :
2226:                     :                     : window of all the input          :
2227:                     :                     : operands.                        :
2228| `window_dimensions` | `ArraySlice<int64>` | array of integers for window     |
2229:                     :                     : dimension values                 :
2230| `window_strides`    | `ArraySlice<int64>` | array of integers for window     |
2231:                     :                     : stride values                    :
2232| `base_dilations`    | `ArraySlice<int64>` | array of integers for base       |
2233:                     :                     : dilation values                  :
2234| `window_dilations`  | `ArraySlice<int64>` | array of integers for window     |
2235:                     :                     : dilation values                  :
2236| `padding`           | `Padding`           | padding type for window          |
2237:                     :                     : (Padding\:\:kSame, which pads so :
2238:                     :                     : as to have the same output shape :
2239:                     :                     : as input if the stride is 1, or  :
2240:                     :                     : Padding\:\:kValid, which uses no :
2241:                     :                     : padding and "stops" the window   :
2242:                     :                     : once it no longer fits)          :
2243
2244Where:
2245
2246*   N is required to be greater or equal to 1.
2247*   All input arrays must have the same dimensions.
2248*   If `N = 1`, `Collate(T)` is `T`.
2249*   If `N > 1`, `Collate(T_0, ..., T_{N-1})` is a tuple of `N` elements of type
2250    `(T0,...T{N-1})`.
2251
2252Below code and figure shows an example of using `ReduceWindow`. Input is a
2253matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
2254[2x3].
2255
2256```
2257// Create a computation for the reduction (maximum).
2258XlaComputation max;
2259{
2260  XlaBuilder builder(client_, "max");
2261  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
2262  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
2263  builder.Max(y, x);
2264  max = builder.Build().value();
2265}
2266
2267// Create a ReduceWindow computation with the max reduction computation.
2268XlaBuilder builder(client_, "reduce_window_2x3");
2269auto shape = ShapeUtil::MakeShape(F32, {4, 6});
2270auto input = builder.Parameter(0, shape, "input");
2271builder.ReduceWindow(
2272    input,
2273    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
2274    *max,
2275    /*window_dimensions=*/{2, 3},
2276    /*window_stride_dimensions=*/{2, 3},
2277    Padding::kValid);
2278```
2279
2280<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2281  <img style="width:35%" src="./images/ops_reduce_window.png">
2282</div>
2283
2284Stride of 1 in a dimension specifies that the position of a window in the
2285dimension is 1 element away from its adjacent window. In order to specify that
2286no windows overlap with each other, window_stride_dimensions should be equal to
2287window_dimensions. The figure below illustrates the use of two different stride
2288values. Padding is applied to each dimension of the input and the calculations
2289are the same as though the input came in with the dimensions it has after
2290padding.
2291
2292<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2293  <img style="width:75%" src="./images/ops_reduce_window_stride.png">
2294</div>
2295
2296For a non-trivial padding example, consider computing reduce-window minimum
2297(initial value is `MAX_FLOAT`) with dimension `3` and stride `2` over the input
2298array `[10000, 1000, 100, 10, 1]`. Padding `kValid` computes minimums over two
2299_valid_ windows: `[10000, 1000, 100]` and `[100, 10, 1]`, resulting in the
2300output `[100, 1]`. Padding `kSame` first pads the array so that the shape after
2301the reduce-window would be the _same_ as input for stride one by adding initial
2302elements on both sides, getting `[MAX_VALUE, 10000, 1000, 100, 10, 1,
2303MAX_VALUE]`. Running reduce-window over the padded array operates on three
2304windows `[MAX_VALUE, 10000, 1000]`, `[1000, 100, 10]`, `[10, 1, MAX_VALUE]`, and
2305yields `[1000, 10, 1]`.
2306
2307The evaluation order of the reduction function is arbitrary and may be
2308non-deterministic. Therefore, the reduction function should not be overly
2309sensitive to reassociation. See the discussion about associativity in the
2310context of [`Reduce`](#reduce) for more details.
2311
2312## ReplicaId
2313
2314See also
2315[`XlaBuilder::ReplicaId`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2316
2317Returns the unique ID (U32 scalar) of the replica.
2318
2319<b> `ReplicaId()` </b>
2320
2321The unique ID of each replica is an unsigned integer in the interval `[0, N)`,
2322where `N` is the number of replicas. Since all the replicas are running the same
2323program, a `ReplicaId()` call in the program will return a different value on
2324each replica.
2325
2326## Reshape
2327
2328See also
2329[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
2330and the [`Collapse`](#collapse) operation.
2331
2332Reshapes the dimensions of an array into a new configuration.
2333
2334<b> `Reshape(operand, new_sizes)` </b>
2335<b> `Reshape(operand, dimensions, new_sizes)` </b>
2336
2337Arguments    | Type           | Semantics
2338------------ | -------------- | ---------------------------------------
2339`operand`    | `XlaOp`        | array of type T
2340`dimensions` | `int64` vector | order in which dimensions are collapsed
2341`new_sizes`  | `int64` vector | vector of sizes of new dimensions
2342
2343Conceptually, reshape first flattens an array into a one-dimensional vector of
2344data values, and then refines this vector into a new shape. The input arguments
2345are an arbitrary array of type T, a compile-time-constant vector of dimension
2346indices, and a compile-time-constant vector of dimension sizes for the result.
2347The values in the `dimension` vector, if given, must be a permutation of all of
2348T's dimensions; the default if not given is `{0, ..., rank - 1}`. The order of
2349the dimensions in `dimensions` is from slowest-varying dimension (most major) to
2350fastest-varying dimension (most minor) in the loop nest which collapses the
2351input array into a single dimension. The `new_sizes` vector determines the size
2352of the output array. The value at index 0 in `new_sizes` is the size of
2353dimension 0, the value at index 1 is the size of dimension 1, and so on. The
2354product of the `new_size` dimensions must equal the product of the operand's
2355dimension sizes. When refining the collapsed array into the multidimensional
2356array defined by `new_sizes`, the dimensions in `new_sizes` are ordered from
2357slowest varying (most major) and to fastest varying (most minor).
2358
2359For example, let v be an array of 24 elements:
2360
2361```
2362let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
2363                    {{20, 21, 22}, {25, 26, 27}},
2364                    {{30, 31, 32}, {35, 36, 37}},
2365                    {{40, 41, 42}, {45, 46, 47}}};
2366
2367In-order collapse:
2368let v012_24 = Reshape(v, {0,1,2}, {24});
2369then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
2370                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
2371
2372let v012_83 = Reshape(v, {0,1,2}, {8,3});
2373then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17},
2374                          {20, 21, 22}, {25, 26, 27},
2375                          {30, 31, 32}, {35, 36, 37},
2376                          {40, 41, 42}, {45, 46, 47}};
2377
2378Out-of-order collapse:
2379let v021_24 = Reshape(v, {1,2,0}, {24});
2380then v012_24 == f32[24]  {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
2381                          15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
2382
2383let v021_83 = Reshape(v, {1,2,0}, {8,3});
2384then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21},
2385                          {31, 41, 12}, {22, 32, 42},
2386                          {15, 25, 35}, {45, 16, 26},
2387                          {36, 46, 17}, {27, 37, 47}};
2388
2389
2390let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
2391then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40},
2392                              {11, 21}, {31, 41},
2393                              {12, 22}, {32, 42}},
2394                             {{15, 25}, {35, 45},
2395                              {16, 26}, {36, 46},
2396                              {17, 27}, {37, 47}}};
2397```
2398
2399As a special case, reshape can transform a single-element array to a scalar and
2400vice versa. For example,
2401
2402```
2403Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5;
2404Reshape(5, {}, {1,1}) == f32[1x1] {{5}};
2405```
2406
2407## Rev (reverse)
2408
2409See also
2410[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2411
2412<b>`Rev(operand, dimensions)`</b>
2413
2414Arguments    | Type                | Semantics
2415------------ | ------------------- | ---------------------
2416`operand`    | `XlaOp`             | array of type T
2417`dimensions` | `ArraySlice<int64>` | dimensions to reverse
2418
2419Reverses the order of elements in the `operand` array along the specified
2420`dimensions`, generating an output array of the same shape. Each element of the
2421operand array at a multidimensional index is stored into the output array at a
2422transformed index. The multidimensional index is transformed by reversing the
2423index in each dimension to be reversed (i.e., if a dimension of size N is one of
2424the reversing dimensions, its index i is transformed into N - 1 - i).
2425
2426One use for the `Rev` operation is to reverse the convolution weight array along
2427the two window dimensions during the gradient computation in neural networks.
2428
2429## RngNormal
2430
2431See also
2432[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2433
2434Constructs an output of a given shape with random numbers generated following
2435the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and
2436$$\sigma$$, and output shape have to have a floating point elemental type. The
2437parameters furthermore have to be scalar valued.
2438
2439<b>`RngNormal(mu, sigma, shape)`</b>
2440
2441| Arguments | Type    | Semantics                                           |
2442| --------- | ------- | --------------------------------------------------- |
2443| `mu`      | `XlaOp` | Scalar of type T specifying mean of generated       |
2444:           :         : numbers                                   :
2445| `sigma`   | `XlaOp` | Scalar of type T specifying standard deviation of   |
2446:           :         : generated numbers                                   :
2447| `shape`   | `Shape` | Output shape of type T                              |
2448
2449## RngUniform
2450
2451See also
2452[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2453
2454Constructs an output of a given shape with random numbers generated following
2455the uniform distribution over the interval $$[a,b)$$. The parameters and output
2456element type have to be a boolean type, an integral type or a floating point
2457types, and the types have to be consistent. The CPU and GPU backends currently
2458only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the
2459parameters need to be scalar valued. If $$b <= a$$ the result is
2460implementation-defined.
2461
2462<b>`RngUniform(a, b, shape)`</b>
2463
2464| Arguments | Type                    | Semantics                         |
2465| --------- | ----------------------- | --------------------------------- |
2466| `a`       | `XlaOp`                 | Scalar of type T specifying lower |
2467:           :                         : limit of interval                 :
2468| `b`       | `XlaOp`                 | Scalar of type T specifying upper |
2469:           :                         : limit of interval                 :
2470| `shape`   | `Shape`                 | Output shape of type T            |
2471
2472## RngBitGenerator
2473
2474Generates an output with a given shape filled with uniform random bits using the
2475specified algorithm (or backend default) and returns an updated state (with the
2476same shape as initial state) and the generated random data.
2477
2478Initial state is the initial state of the current random number generation. It
2479and the required shape and valid values are dependent on the algorithm used.
2480
2481The output is guaranteed to be a deterministic function of the initial state but
2482it is *not* guaranteed to be deterministic between backends and different
2483compiler versions.
2484
2485<b>`RngBitGenerator(algorithm, key, shape)`</b>
2486
2487Arguments       | Type              | Semantics
2488--------------- | ----------------- | -------------------------------------
2489`algorithm`     | `RandomAlgorithm` | PRNG algorithm to be used.
2490`initial_state` | `XlaOp`           | Initial state for the PRNG algorithm.
2491`shape`         | `Shape`           | Output shape for generated data.
2492
2493Available values for `algorithm`:
2494
2495-   `rng_default`: Backend specific algorithm with backend specific shape
2496    requirements.
2497
2498-   `rng_three_fry`: ThreeFry counter-based PRNG algorithm. The `initial_state`
2499    shape is `u64[2]` with arbitrary values.
2500    [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
2501
2502-   `rng_philox`: Philox algorithm to generate random numbers in parallel. The
2503    `initial_state` shape is `u64[3]` with arbitrary values.
2504    [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
2505
2506## Scatter
2507
2508The XLA scatter operation generates a sequence of results which are the values
2509of the input array `operands`, with several slices (at indices specified by
2510`scatter_indices`) updated with the sequence of values in `updates` using
2511`update_computation`.
2512
2513See also
2514[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2515
2516<b> `scatter(operands..., scatter_indices, updates..., update_computation,
2517index_vector_dim, update_window_dims, inserted_window_dims,
2518scatter_dims_to_operand_dims)` </b>
2519
2520Arguments                      | Type                  | Semantics
2521------------------------------ | --------------------- | ---------
2522`operands`                     | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N` to be scattered into.
2523`scatter_indices`              | `XlaOp`               | Array containing the starting indices of the slices that must be scattered to.
2524`updates`                      | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. `updates[i]` contains the values that must be used for scattering `operands[i]`.
2525`update_computation`           | `XlaComputation`      | Computation to be used for combining the existing values in the input array and the updates during scatter. This computation should be of type `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`.
2526`index_vector_dim`             | `int64`               | The dimension in `scatter_indices` that contains the starting indices.
2527`update_window_dims`           | `ArraySlice<int64>`   | The set of dimensions in `updates` shape that are *window dimensions*.
2528`inserted_window_dims`         | `ArraySlice<int64>`   | The set of *window dimensions* that must be inserted into `updates` shape.
2529`scatter_dims_to_operand_dims` | `ArraySlice<int64>`   | A dimensions map from the scatter indices to the operand index space. This array is interpreted as mapping `i` to `scatter_dims_to_operand_dims[i]` . It has to be one-to-one and total.
2530`indices_are_sorted`           | `bool`                | Whether the indices are guaranteed to be sorted by the caller.
2531
2532Where:
2533
2534* N is required to be greater or equal to 1.
2535* `operands`[`0`], ..., `operands`[`N-1`] must all have the same dimensions.
2536* `updates`[`0`], ..., `updates`[`N-1`] must all have the same dimensions.
2537* If `N = 1`, `Collate(T)` is `T`.
2538* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`.
2539
2540If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider
2541`scatter_indices` to have a trailing `1` dimension.
2542
2543We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of
2544dimensions in `updates` shape that are not in `update_window_dims`, in ascending
2545order.
2546
2547The arguments of scatter should follow these constraints:
2548
2549-   Each `updates` array must be of rank `update_window_dims.size +
2550    scatter_indices.rank - 1`.
2551
2552-   Bounds of dimension `i` in each `updates` array must conform to the
2553    following:
2554
2555    -   If `i` is present in `update_window_dims` (i.e. equal to
2556        `update_window_dims`[`k`] for some `k`), then the bound of dimension `i`
2557        in `updates` must not exceed the corresponding bound of `operand` after
2558        accounting for the `inserted_window_dims` (i.e.
2559        `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains
2560        the bounds of `operand` with the bounds at indices
2561        `inserted_window_dims` removed).
2562    -   If `i` is present in `update_scatter_dims` (i.e. equal to
2563        `update_scatter_dims`[`k`] for some `k`), then the bound of dimension
2564        `i` in `updates` must be equal to the corresponding bound of
2565        `scatter_indices`, skipping `index_vector_dim` (i.e.
2566        `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and
2567        `scatter_indices.shape.dims`[`k+1`] otherwise).
2568
2569-   `update_window_dims` must be in ascending order, not have any repeating
2570    dimension numbers, and be in the range `[0, updates.rank)`.
2571
2572-   `inserted_window_dims` must be in ascending order, not have any repeating
2573    dimension numbers, and be in the range `[0, operand.rank)`.
2574
2575-   `operand.rank` must equal the sum of `update_window_dims.size` and
2576    `inserted_window_dims.size`.
2577
2578-   `scatter_dims_to_operand_dims.size` must be equal to
2579    `scatter_indices.shape.dims`[`index_vector_dim`], and its values must be in
2580    the range `[0, operand.rank)`.
2581
2582For a given index `U` in each `updates` array, the corresponding index `I` in
2583the corresponding `operands` array into which this update has to be applied is
2584computed as follows:
2585
25861.  Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up
2587    an index vector `S` in the `scatter_indices` array such that `S`[`i`] =
2588    `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at
2589    positions `index_vector_dim` into A.
25902.  Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering
2591    `S` using the `scatter_dims_to_operand_dims` map. More formally:
2592    1.  `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if
2593        `k` < `scatter_dims_to_operand_dims.size`.
2594    2.  `S`<sub>`in`</sub>[`_`] = `0` otherwise.
25953.  Create an index `W`<sub>`in`</sub> into each `operands` array by scattering
2596    the indices at `update_window_dims` in `U` according to
2597    `inserted_window_dims`. More formally:
2598    1.  `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if `k`
2599        is in `update_window_dims`, where `window_dims_to_operand_dims` is the
2600        monotonic function with domain [`0`, `update_window_dims.size`) and
2601        range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For example, if
2602        `update_window_dims.size` is `4`, `operand.rank` is `6`, and
2603        `inserted_window_dims` is {`0`, `2`} then `window_dims_to_operand_dims`
2604        is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}).
2605    2.  `W`<sub>`in`</sub>[`_`] = `0` otherwise.
26064.  `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
2607    addition.
2608
2609In summary, the scatter operation can be defined as follows.
2610
2611-   Initialize `output` with `operands`, i.e. for all indices `J`, for all
2612    indices `O` in the `operands`[`J`] array: \
2613    `output`[`J`][`O`] = `operands`[`J`][`O`]
2614-   For every index `U` in the `updates`[`J`] array and the corresponding index
2615    `O` in the `operand`[`J`] array, if `O` is a valid index for `output`: \
2616    `(output`[`0`][`O`], ..., output`[`N-1`][`O`])
2617    =`update_computation`(`output`[`0`][`O`], ...,
2618    ,`output`[`N-1`][`O`],`updates`[`0`][`U`], ...,`updates`[`N-1`][`U`])
2619
2620The order in which updates are applied is non-deterministic. So, when multiple
2621indices in `updates` refer to the same index in `operands`, the corresponding
2622value in `output` will be non-deterministic.
2623
2624Note that the first parameter that is passed into the `update_computation` will
2625always be the current value from the `output` array and the second parameter
2626will always be the value from the `updates` array. This is important
2627specifically for cases when the `update_computation` is _not commutative_.
2628
2629If `indices_are_sorted` is set to true then XLA can assume that `start_indices`
2630are sorted (in ascending `start_index_map` order) by the user. If they are not
2631then the semantics is implementation defined.
2632
2633Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e.
2634the scatter op updates the elements in the input that are extracted by the
2635corresponding gather op.
2636
2637For a detailed informal description and examples, refer to the
2638"Informal Description" section under `Gather`.
2639
2640## Select
2641
2642See also
2643[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2644
2645Constructs an output array from elements of two input arrays, based on the
2646values of a predicate array.
2647
2648<b> `Select(pred, on_true, on_false)` </b>
2649
2650Arguments  | Type    | Semantics
2651---------- | ------- | ------------------
2652`pred`     | `XlaOp` | array of type PRED
2653`on_true`  | `XlaOp` | array of type T
2654`on_false` | `XlaOp` | array of type T
2655
2656The arrays `on_true` and `on_false` must have the same shape. This is also the
2657shape of the output array. The array `pred` must have the same dimensionality as
2658`on_true` and `on_false`, with the `PRED` element type.
2659
2660For each element `P` of `pred`, the corresponding element of the output array is
2661taken from `on_true` if the value of `P` is `true`, and from `on_false` if the
2662value of `P` is `false`. As a restricted form of [broadcasting](broadcasting.md),
2663`pred` can be a scalar of type `PRED`. In this case, the output array is taken
2664wholly from `on_true` if `pred` is `true`, and from `on_false` if `pred` is `false`.
2665
2666Example with non-scalar `pred`:
2667
2668```
2669let pred: PRED[4] = {true, false, false, true};
2670let v1: s32[4] = {1, 2, 3, 4};
2671let v2: s32[4] = {100, 200, 300, 400};
2672==>
2673Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
2674```
2675
2676Example with scalar `pred`:
2677
2678```
2679let pred: PRED = true;
2680let v1: s32[4] = {1, 2, 3, 4};
2681let v2: s32[4] = {100, 200, 300, 400};
2682==>
2683Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
2684```
2685
2686Selections between tuples are supported. Tuples are considered to be scalar
2687types for this purpose. If `on_true` and `on_false` are tuples (which must have
2688the same shape!) then `pred` has to be a scalar of type `PRED`.
2689
2690## SelectAndScatter
2691
2692See also
2693[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2694
2695This operation can be considered as a composite operation that first computes
2696`ReduceWindow` on the `operand` array to select an element from each window, and
2697then scatters the `source` array to the indices of the selected elements to
2698construct an output array with the same shape as the operand array. The binary
2699`select` function is used to select an element from each window by applying it
2700across each window, and it is called with the property that the first
2701parameter's index vector is lexicographically less than the second parameter's
2702index vector. The `select` function returns `true` if the first parameter is
2703selected and returns `false` if the second parameter is selected, and the
2704function must hold transitivity (i.e., if `select(a, b)` and `select(b, c)` are
2705`true`, then `select(a, c)` is also `true`) so that the selected element does
2706not depend on the order of the elements traversed for a given window.
2707
2708The function `scatter` is applied at each selected index in the output array. It
2709takes two scalar parameters:
2710
27111.  Current value at the selected index in the output array
27122.  The scatter value from `source` that applies to the selected index
2713
2714It combines the two parameters and returns a scalar value that's used to update
2715the value at the selected index in the output array. Initially, all indices of
2716the output array are set to `init_value`.
2717
2718The output array has the same shape as the `operand` array and the `source`
2719array must have the same shape as the result of applying a `ReduceWindow`
2720operation on the `operand` array. `SelectAndScatter` can be used to
2721backpropagate the gradient values for a pooling layer in a neural network.
2722
2723<b>`SelectAndScatter(operand, select, window_dimensions, window_strides,
2724padding, source, init_value, scatter)`</b>
2725
2726| Arguments           | Type                | Semantics                        |
2727| ------------------- | ------------------- | -------------------------------- |
2728| `operand`           | `XlaOp`             | array of type T over which the   |
2729:                     :                     : windows slide                    :
2730| `select`            | `XlaComputation`    | binary computation of type `T, T |
2731:                     :                     : -> PRED`, to apply to all        :
2732:                     :                     : elements in each window; returns :
2733:                     :                     : `true` if the first parameter is :
2734:                     :                     : selected and returns `false` if  :
2735:                     :                     : the second parameter is selected :
2736| `window_dimensions` | `ArraySlice<int64>` | array of integers for window     |
2737:                     :                     : dimension values                 :
2738| `window_strides`    | `ArraySlice<int64>` | array of integers for window     |
2739:                     :                     : stride values                    :
2740| `padding`           | `Padding`           | padding type for window          |
2741:                     :                     : (Padding\:\:kSame or             :
2742:                     :                     : Padding\:\:kValid)               :
2743| `source`            | `XlaOp`             | array of type T with the values  |
2744:                     :                     : to scatter                       :
2745| `init_value`        | `XlaOp`             | scalar value of type T for the   |
2746:                     :                     : initial value of the output      :
2747:                     :                     : array                            :
2748| `scatter`           | `XlaComputation`    | binary computation of type `T, T |
2749:                     :                     : -> T`, to apply each scatter     :
2750:                     :                     : source element with its          :
2751:                     :                     : destination element              :
2752
2753The figure below shows examples of using `SelectAndScatter`, with the `select`
2754function computing the maximal value among its parameters. Note that when the
2755windows overlap, as in the figure (2) below, an index of the `operand` array may
2756be selected multiple times by different windows. In the figure, the element of
2757value 9 is selected by both of the top windows (blue and red) and the binary
2758addition `scatter` function produces the output element of value 8 (2 + 6).
2759
2760<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2761  <img style="width:100%"
2762    src="./images/ops_scatter_to_selected_window_element.png">
2763</div>
2764
2765The evaluation order of the `scatter` function is arbitrary and may be
2766non-deterministic. Therefore, the `scatter` function should not be overly
2767sensitive to reassociation. See the discussion about associativity in the
2768context of [`Reduce`](#reduce) for more details.
2769
2770## Send
2771
2772See also
2773[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2774
2775<b> `Send(operand, channel_handle)` </b>
2776
2777Arguments        | Type            | Semantics
2778---------------- | --------------- | -----------------------------------------
2779`operand`        | `XlaOp`         | data to send (array of type T)
2780`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair
2781
2782Sends the given operand data to a `Recv` instruction in another computation
2783that shares the same channel handle. Does not return any data.
2784
2785Similar to the `Recv` operation, the client API of `Send` operation represents
2786synchronous communication, and is internally decomposed into 2 HLO instructions
2787(`Send` and `SendDone`) to enable asynchronous data transfers. See also
2788[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
2789
2790<b>`Send(HloInstruction operand, int64 channel_id)`</b>
2791
2792Initiates an asynchronous transfer of the operand to the resources allocated by
2793the `Recv` instruction with the same channel id. Returns a context, which is
2794used by a following `SendDone` instruction to wait for the completion of the
2795data transfer. The context is a tuple of {operand (shape), request identifier
2796(U32)} and it can only be used by a `SendDone` instruction.
2797
2798<b> `SendDone(HloInstruction context)` </b>
2799
2800Given a context created by a `Send` instruction, waits for the data transfer to
2801complete.  The instruction does not return any data.
2802
2803<b> Scheduling of channel instructions </b>
2804
2805The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
2806`Send`, `SendDone`) is as below.
2807
2808<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2809  <img style="width:70%" src="./images/send_recv_order.png">
2810</div>
2811
2812* `Recv` happens before `Send`
2813* `Send` happens before `RecvDone`
2814* `Recv` happens before `RecvDone`
2815* `Send` happens before `SendDone`
2816
2817When the backend compilers generate a linear schedule for each computation that
2818communicates via channel instructions, there must not be cycles across the
2819computations. For example, below schedules lead to deadlocks.
2820
2821<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2822  <img style="width:100%" src="./images/send_recv_schedule.png">
2823</div>
2824
2825## Slice
2826
2827See also
2828[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2829
2830Slicing extracts a sub-array from the input array. The sub-array is of the same
2831rank as the input and contains the values inside a bounding box within the input
2832array where the dimensions and indices of the bounding box are given as
2833arguments to the slice operation.
2834
2835<b> `Slice(operand, start_indices, limit_indices, strides)` </b>
2836
2837| Arguments       | Type                | Semantics                            |
2838| --------------- | ------------------- | ------------------------------------ |
2839| `operand`       | `XlaOp`             | N dimensional array of type T        |
2840| `start_indices` | `ArraySlice<int64>` | List of N integers containing the    |
2841:                 :                     : starting indices of the slice for    :
2842:                 :                     : each dimension. Values must be       :
2843:                 :                     : greater than or equal to zero.       :
2844| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the    |
2845:                 :                     : ending indices (exclusive) for the   :
2846:                 :                     : slice for each dimension. Each value :
2847:                 :                     : must be greater than or equal to the :
2848:                 :                     : respective `start_indices` value for :
2849:                 :                     : the dimension and less than or equal :
2850:                 :                     : to the size of the dimension.        :
2851| `strides`      | `ArraySlice<int64>` | List of N integers that decides the   |
2852:                 :                     : input stride of the slice.  The slice :
2853:                 :                     : picks every `strides[d]` element in  :
2854:                 :                     : dimension `d`.                       :
2855
2856
28571-dimensional example:
2858
2859```
2860let a = {0.0, 1.0, 2.0, 3.0, 4.0}
2861Slice(a, {2}, {4}) produces:
2862  {2.0, 3.0}
2863```
2864
28652-dimensional example:
2866
2867```
2868let b =
2869 { {0.0,  1.0,  2.0},
2870   {3.0,  4.0,  5.0},
2871   {6.0,  7.0,  8.0},
2872   {9.0, 10.0, 11.0} }
2873
2874Slice(b, {2, 1}, {4, 3}) produces:
2875  { { 7.0,  8.0},
2876    {10.0, 11.0} }
2877```
2878
2879## Sort
2880
2881See also
2882[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2883
2884<b>`Sort(operands, comparator, dimension, is_stable)`</b>
2885
2886Arguments    | Type                | Semantics
2887------------ | ------------------- | --------------------
2888`operands`   | `ArraySlice<XlaOp>` | The operands to sort.
2889`comparator` | `XlaComputation`    | The comparator computation to use.
2890`dimension`  | `int64`             | The dimension along which to sort.
2891`is_stable`  | `bool`              | Whether stable sorting should be used.
2892
2893If only one operand is provided:
2894
2895* If the operand is a rank-1 tensor (an array), the result is a sorted array.
2896  If you want to sort the array into ascending order, the comparator should
2897  perform a less-than comparison. Formally, after the array is sorted, it holds
2898  for all index positions `i, j` with `i < j` that either
2899  `comparator(value[i], value[j]) = comparator(value[j], value[i]) = false` or
2900  `comparator(value[i], value[j]) = true`.
2901
2902* If the operand has higher rank, the operand is sorted along the provided
2903  dimension. For example, for a rank-2 tensor (a matrix), a dimension value of
2904  `0` will independently sort every column, and a dimension value of `1` will
2905  independently sort each row. If no dimension number is provided, then the last
2906  dimension is chosen by default. For the dimension which is sorted, the same
2907  sorting order applies as in the rank-1 case.
2908
2909If `n > 1` operands are provided:
2910
2911* All `n` operands must be tensors with the same dimensions. The element types
2912  of the tensors may be different.
2913
2914* All operands are sorted together, not individually. Conceptually the operands
2915  are treated as a tuple. When checking whether the elements of each operand at
2916  index positions `i` and `j` need to be swapped, the comparator is called with
2917  `2 * n` scalar parameters, where parameter `2 * k` corresponds to the value at
2918  position `i` from the `k-th` operand, and parameter `2 * k + 1` corresponds to
2919  the value at position `j` from the `k-th` operand. Usually, the comparator
2920  would thus compare parameters `2 * k` and `2 * k + 1` with each other and
2921  possibly use other parameter pairs as tie breakers.
2922
2923* The result is a tuple that consists of the operands in sorted order (along
2924  the provided dimension, as above). The `i-th` operand of the tuple corresponds
2925  to the `i-th` operand of Sort.
2926
2927For example, if there are three operands `operand0 = [3, 1]`,
2928`operand1 = [42, 50]`, `operand2 = [-3.0, 1.1]`, and the comparator compares
2929only the values of `operand0` with less-than, then the output of the sort is the
2930tuple `([1, 3], [50, 42], [1.1, -3.0])`.
2931
2932If `is_stable` is set to true, the sort is guaranteed to be stable, that is, if
2933there are elements which are considered to be equal by the comparator, the
2934relative order of the equal values is preserved. By default, `is_stable` is set
2935to false.
2936
2937## Transpose
2938
2939See also the `tf.reshape` operation.
2940
2941<b>`Transpose(operand)`</b>
2942
2943Arguments     | Type                | Semantics
2944------------- | ------------------- | ------------------------------
2945`operand`     | `XlaOp`             | The operand to transpose.
2946`permutation` | `ArraySlice<int64>` | How to permute the dimensions.
2947
2948
2949Permutes the operand dimensions with the given permutation, so
2950`∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]`.
2951
2952This is the same as Reshape(operand, permutation,
2953                            Permute(permutation, operand.shape.dimensions)).
2954
2955## TriangularSolve
2956
2957See also
2958[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2959
2960Solves systems of linear equations with lower or upper triangular coefficient
2961matrices by forward- or back-substitution. Broadcasting along leading
2962dimensions, this routine solves one of the matrix systems `op(a) * x =
2963b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is
2964either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`.
2965
2966<b> `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` </b>
2967
2968| Arguments       | Type        | Semantics                                    |
2969| --------------- | ----------- | -------------------------------------------- |
2970| `a`             | `XlaOp`     | a rank > 2 array of a complex or             |
2971:                 :             : floating-point type with shape `[..., M,     :
2972:                 :             : M]`.                                         :
2973| `b`             | `XlaOp`     | a rank > 2 array of the same type with shape |
2974:                 :             : `[..., M, K]` if `left_side` is true, `[..., :
2975:                 :             : K, M]` otherwise.                            :
2976| `left_side`     | `bool`      | indicates whether to solve a system of the   |
2977:                 :             : form `op(a) * x = b` (`true`) or `x *        :
2978:                 :             : op(a) = b` (`false`).                        :
2979| `lower`         | `bool`      | whether to use the upper or lower triangle   |
2980:                 :             : of `a`.                                      :
2981| `unit_diagonal` | `bool`      | if `true`, the diagonal elements of `a` are  |
2982:                 :             : assumed to be `1` and not accessed.          :
2983| `transpose_a`   | `Transpose` | whether to use `a` as is, transpose it or    |
2984:                 :             : take its conjugate transpose.                :
2985
2986Input data is read only from the lower/upper triangle of `a`, depending on the
2987value of `lower`. Values from the other triangle are ignored. Output data is
2988returned in the same triangle; the values in the other triangle are
2989implementation-defined and may be anything.
2990
2991If the rank of `a` and `b` are greater than 2, they are treated as batches of
2992matrices, where all except the minor 2 dimensions are batch dimensions. `a` and
2993`b` must have equal batch dimensions.
2994
2995## Tuple
2996
2997See also
2998[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2999
3000A tuple containing a variable number of data handles, each of which has its own
3001shape.
3002
3003This is analogous to `std::tuple` in C++. Conceptually:
3004
3005```
3006let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3007let s: s32 = 5;
3008let t: (f32[10], s32) = tuple(v, s);
3009```
3010
3011Tuples can be deconstructed (accessed) via the [`GetTupleElement`]
3012(#gettupleelement) operation.
3013
3014## While
3015
3016See also
3017[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
3018
3019<b> `While(condition, body, init)` </b>
3020
3021| Arguments   | Type             | Semantics                                |
3022| ----------- | ---------------- | ---------------------------------------- |
3023| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which |
3024:             :                  : defines the termination condition of the :
3025:             :                  : loop.                                    :
3026| `body`      | `XlaComputation` | XlaComputation of type `T -> T` which    |
3027:             :                  : defines the body of the loop.            :
3028| `init`      | `T`              | Initial value for the parameter of       |
3029:             :                  : `condition` and `body`.                  :
3030
3031Sequentially executes the `body` until the `condition` fails. This is similar to
3032a typical while loop in many other languages except for the differences and
3033restrictions listed below.
3034
3035*   A `While` node returns a value of type `T`, which is the result from the
3036    last execution of the `body`.
3037*   The shape of the type `T` is statically determined and must be the same
3038    across all iterations.
3039
3040The T parameters of the computations are initialized with the `init` value in
3041the first iteration and are automatically updated to the new result from `body`
3042in each subsequent iteration.
3043
3044One main use case of the `While` node is to implement the repeated execution of
3045training in neural networks. Simplified pseudocode is shown below with a graph
3046that represents the computation. The code can be found in
3047[`while_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/xla/tests/while_test.cc).
3048The type `T` in this example is a `Tuple` consisting of an `int32` for the
3049iteration count and a `vector[10]` for the accumulator. For 1000 iterations, the
3050loop keeps adding a constant vector to the accumulator.
3051
3052```
3053// Pseudocode for the computation.
3054init = {0, zero_vector[10]} // Tuple of int32 and float[10].
3055result = init;
3056while (result(0) < 1000) {
3057  iteration = result(0) + 1;
3058  new_vector = result(1) + constant_vector[10];
3059  result = {iteration, new_vector};
3060}
3061```
3062
3063<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
3064  <img style="width:100%" src="./images/ops_while.png">
3065</div>
3066