xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tosa/g3doc/legalization.md (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# TOSA Lowerings
2
3## Introduction
4
5### Overview
6
7This document provides pseudo-code lowerings from TensorFlow and TensorFlow Lite
8MLIR Dialects (https://www.tensorflow.org/mlir/dialects) to the TOSA Dialect
9(https://mlir.llvm.org/docs/Dialects/TOSA/).
10
11The documentation is a work-in-progress: sections with missing legalizations are
12in the process of being written.
13
14## Syntax
15
16The pseudo-code syntax used in this document is described below.
17
18### Primitive Datatypes
19
20int8: signed 8-bit integer uint8: unsigned 8-bit integer int16: signed 16-bit
21integer int32: signed 32-bit integer int64: signed 32-bit integer uint32:
22unsigned 32-bit integer float32: IEEE-754 32-bit floating point format float64:
23IEEE-754 64-bit floating point format bool: boolean
24
25### Value
26
27In pseudo-code, symbol starting with "%" indicates it’s a value. A value is
28evaluated by an operator at run time, and operator can consume and can only
29consume a list of values as operands. Note value’s tensor type is determined at
30compile time. Only the evaluation happens at run time One can easily construct a
31data flow subgraph by looking at the producer/consumer.
32
33### Tensor Type
34
35Tensor type is an attribute determined by legalization at compile time,
36describing the shape and element data type. It’s noted as tensor<shape,
37dtype>, or shorthanded as tensor<%t.type>
38
39### Operator Prototype
40
41In pseudocode an TOSA operator is prototyped as following format.
42
43%<output\_value> = tosa.<OPERATOR>(%<input\_value>)
44{<attribute = …​}
45
46### Value Attributes
47
48For the purposes of brevity and clarity in this document, the pseudocode allows
49the following notation on value attribute.
50
51Shorthand         | Description
52----------------- | ---------------------------------------------------
53`%t.shape`        | Shape vector for the tensor
54`%t.shape[i]`     | Size of dimension i for the tensor
55`%t.rank`         | Rank of the tensor
56`%t.dtype`        | Datatype of the tensor
57`%t.scale`        | Quantized scaling parameter (float64)
58`%t.zp`           | Quantized zero-point (int64)
59`%t.signed`       | Boolean indicating the type is signed
60`%t.num_bits`     | Number of bits in the datatype
61`%t.num_elements` | Number of elements in the tensor
62`%t.type`         | Tuple of `tensor<%t.shape, %t.dtype>`
63`%t.size`         | For tensor lists: the number of tensors in the list
64
65### Tensor Dimension Shorthand
66
67Where the TOSA Specification allows the use of named dimensions, the following
68names may be used.
69
70Name | Description
71---- | --------------------
72`N`  | Batch dimension
73`H`  | Height dimension
74`W`  | Width dimension
75`C`  | Channel dimension
76`M`  | Depthwise multiplier
77
78Each of these may be prefixed with `I` for the input dimension or `O` for the
79output dimension or `K` for kernel dimensions.
80
81## Common Legalization Functions
82
83The following pseudocode helper functions are used to cannonicalize arguments
84from different frameworks to the TOSA dialect.
85
86### .as_constant(): Matched as Constant
87
88Wherever %tensor.as_constant() is specified, a constant vector will be created
89to hold the value in the %tensor at compile time. This only succeeds if %tensor
90is fed by a constant type operator. If constant matching fails, the lowering
91will fail and be terminated.
92
93## Common Legalization Functions
94
95The following pseudo-code helper functions are used to cannonicalize arguments
96from different frameworks to the TOSA dialect.
97
98### get_padding_values_from_explicit_pad_attr()
99
100```
101vector<int64> get_padding_values_from_explict_pad_attr(vector<int64> explicit_pad,
102                                                         tensorflow::TensorFormat data_format_tf)
103{
104    int64 pad_before, pad_after
105    vector<int64> computed_paddings
106
107    for (int32 i = 0; i < 2; i++) {
108        int64 dim = GetTensorSpatialDimIndex(4, data_format_tf, i)
109        pad_before = explicit_pad[dim * 2]
110        pad_after  = explicit_pad[dim * 2 + 1]
111        computed_paddings.push_back(pad_before)
112        computed_paddings.push_back(pad_after)
113    }
114
115    return computed_paddings
116}
117```
118
119### get_padding_values_from_pad_type()
120
121Calculate explicit padding array based on pad type
122
123```
124vector<int64> get_padding_values_from_pad_type(tensorflow::Padding padding, tensorflow::TensorFormat data_format,
125                                        uint32 first_filter_spatial_dim, type input_type, type filter_type
126                                        vector strides, vector dilations)
127{
128    assert(padding != tensorflow::Padding::EXPLICIT);
129
130    vector<int64> computed_padding;
131
132    // Padding over H and W dimensions
133    for (int32 i = 0; i < 2; i++) {
134        int32 ifm_dim = get_tensor_spatial_dim_index(4, data_format, i);
135
136        int32 filter_dim = first_filter_spatial_dim + i;
137
138        int32 dim_dilation = dilations[ifm_dim];
139        int32 dim_stride   = strides[ifm_dim];
140
141        int64 op_size, pad_before_tf, pad_after_tf;
142
143        tensorflow::GetWindowedOutputSizeVerboseV2(input_type.shape[ifm_dim], filter_type.shape[filter_dim],
144                                                   dim_dilation, dim_stride, padding,
145                                                   // Outputs
146                                                   &op_size, &pad_before_tf, &pad_after_tf);
147        computed_paddings.push_back(pad_before_tf);
148        computed_paddings.push_back(pad_after_tf);
149    }
150
151    return computed_paddings;
152}
153```
154
155### positive_axis()
156
157```
158// Cannonicalize scalar axis attributes to a scalar positive axis attribute
159int32 positive_axis(int32 axis, int32 rank)
160{
161   if (axis < 0)
162       axis += rank;
163
164   return axis;
165}
166```
167
168### compute_scale_32()
169
170```
171void compute_scale_32(float64 scale, int32& multiplier, int32& shift)
172{
173    /* Generates mantissa and shift values where mantissa is in [-1.0,-0.5] or
174    [0.5, 1.0] such that
175    multiplier = mantissa*2^shift */
176
177    const float64 mantissa = std::frexp(scale, &shift);
178    auto shifted_m = std::round(mantissa * (int64(1) << 31));
179
180    assert(shifted_m <= (int64(1) << 31)); // can't be greater that 1.0
181    if (shifted_m == (int64(1) << 31)) {
182        shifted_m /= 2;
183        shift++;
184    }
185    // TOSA expect right shift to be positive, and embed (1 << 31) into right
186    // shift bits
187    shift = (-shift) + 31;
188
189    assert(shifted_m <= std::numeric_limits<int32>::max());
190
191    multiplier = static_cast<int32>(shifted_m);
192
193}
194```
195
196### lower_batch_to_space_nd_op()
197
198```
199Value lower_batch_to_space_nd_op(Value %input, Value %block_shape, Value %crops, shape_t output_shape)
200{
201
202    vector <size_t> block_shape(%block_shape.rank)
203    vector std::pair<size_t, size_t> crops_arr
204
205    size_t remaining_shape_rank = %input.rank - %block.rank - 1
206    size_t crops_dim = %crops.shape[0]
207
208    for (int32 i = 0; i < crops_dim; i++) {
209        crops[i] = std::make_pair(%crops.as_constant()[i * crops_dim + 0],
210                                  %crops.as_constant()[i * crops_dim + 1])
211    }
212
213    // Step 1: Reshape input to
214    // [block_shape[0],
215    // ...
216    // [block_shape[M-1],
217    // [batch / prod(block_shape)]
218    // [input_shape[1],
219    // ...
220    // [input_shape[N-1]
221
222    vector <size_t> a1_shape(%block.rank + %input.rank)
223
224    for (int32 i = 0; i < %block.rank; i++) {
225        a1_shape[i] = %block.shape[i]
226    }
227
228    a1_shape[%block.rank] = %input.shape.[0] / %block.num_elements
229
230    for (int32 i = 1; i < %input.rank; i++) {
231        a1_shape[i + %block.rank] = %input.shape[i]
232    }
233
234    // Step 2. Permute to shape:
235    // [ batch / prod(block_shape) ],
236    // [ input_shape[1] ], [ block_shape[0] ]
237    //  ...
238    // [ input_shape[M] ], [ block_shape[M-1]
239    // + remaining_input_shapes input_shape[M+1 .. N-1]
240    vector <size_t> a2_perm(%block.rank + %input.rank)
241
242    a2_perm[0] = %block.rank
243    for (int32 i = 0; i < %block.rank; i++) {
244        a2_perm[1 + i * 2 + 0] = %block.rank + 1 + i
245        a2_perm[1 + i * 2 + 1] = i
246    }
247
248    // Step 3. Reshape to
249    // [ batch / prod(block_shape) ],
250    // [input_shape[1] * block_shape[0] ],
251    //    ..
252    // [input_shape[M * block_shape[M-1],
253    // + remaining input shapes [input_shape[M+1.. N-1]]
254    vector <size_t> a3_shape(%input.rank)
255
256    %a3_shape[0] = %input.shape[0] / %block.num_elements
257    for (int32 i = 0; i < %block.rank; i++) {
258        a3_shape[i + 1] = %input.shape[i + 1] * %block.shape[i]
259    }
260
261    for (int32 i = 0; remaining_block_shape; i++) {
262        a3_shape[1 + %block.rank + 1] = %input.shape[%block.rank + 1 + i]
263    }
264
265    // Step 4 Crop the start/end dimensions using slice
266    vector <size_t> a4_begin(%input.rank), a4_size(%input.rank)
267
268    for (int32 i = 0; i < %input.rank; i++) {
269        if (i == 0 || i > crop_dims) {
270           a4_begin[i] = 0
271           a4_size[i] = output_shape[i]
272        } else {
273          a4_begin[i] = %crops[i-1].first
274          a4_size[i] = crops[i - 1].first - crops[i - 1].second
275        }
276    }
277
278    %a1_reshape = tosa.RESHAPE(%input) {new_shape=a1_shape}
279    %a2_transpose = tosa.TRANSPOSE(%a1_reshape) {perms=a2_perm}
280    %a3_reshape = tosa.RESHAPE(%a2_transpose) {new_shape=a3_shape}
281    %output = tosa.SLICE(%a3_reshape) {begin=a4_begin, size=a4_size}
282
283    return %output
284}
285```
286
287### lower_concatv2_op()
288
289```
290Value lower_concatv2_op(Type output_type, Value %values, int32 axis)
291{
292    int32 tosa_axis = positive_axis(axis)
293
294    assert(%values.size >= 2)
295
296    // Convert scalar inputs to a tensor
297    if (%values:0.size == 0) {
298       for (int32 i = 0; i < %values.size; i++) {
299          %values:i = tosa.RESHAPE(%values:i) {new_shape=1}
300       }
301    }
302
303    for (int32 i=0; i < %values.size(); i++) {
304        %val = %values:i
305        if (%val.zp != output_type.zp || %val.scale != output_type.scale) {
306            float64 rescale_scale = %val.scale / output_type.scale
307            %values:i = tosa.RESCALE(%val) {scale=rescale_scale, input_zp=%values:0.zp, output_zp=output_type.zp}
308        }
309    }
310
311    %concat_op = tosa.CONCAT(%values:0, %values:1) {axis=tosa_axis}
312
313    for (int32 i = 2; i < %values.size; i++) {
314        %concat_op = tosa.CONCAT(%concat_op, %values:i) {axis=tosa_axis}
315    }
316
317    return %concat_op
318}
319```
320
321### lower_depth_to_space_op()
322
323```
324Value lower_depth_to_space_op(Value %input, size_t block_size[], Format_t data_format)
325{
326    assert(data_format == 'NHWC')
327
328    vector <size_t> a2_shape = {%input.shape[0],
329                                %input.shape[1],
330                                %input.shape[2],
331                                block_size[0],
332                                block_size[1],
333                                %input.shape[3] / (block_size[0] * block_size[1])}
334
335    vector <size_t> a4_shape = {%input.shape[0],
336                                %input.shape[1] * block_size[0],
337                                %input.shape[2] * block_size[1],
338                                %input.shape[3] / (block_size[0] * block_size[1])}
339
340    %a2_reshape = tosa.RESHAPE(%input) {new_shape=a2_shape}
341    %a3_transpose = tosa.TRANSPOSE(%a2_reshape) {perms={0, 1, 3, 2, 4, 5}}
342    %output = tosa.RESHAPE(%a3_transpose) {new_shape=a4_shape}
343
344    return %output
345}
346```
347
348### lower_elu_op()
349
350```
351Value lower_elu_op(Value %value)
352{
353    // elu(x) = x < 0 ? (exp(x) - 1) : x
354    // Create constants for 0/1 and reshape to match the rank
355    // of %value
356    %one_const = tosa.CONST() {value={1}}
357    %zero_const = tosa.CONST() {value={0}}
358
359    vector bcast_shape
360    for (int32 i = 0; i < %value.rank; i++) {
361        bcast_shape.push_back(1)
362    }
363
364    %one_reshape = tosa.RESHAPE(%one_const) {new_shape=bcast_shape}
365    %zero_reshape = tosa.RESHAPE(%zero_const) {new_shape=bcast_shape}
366
367    %exp_in = tosa.EXP(%value)
368    %sub = tosa.SUB(%exp_in, %one_reshape)
369    %ge  = tosa.GREATER_EQUAL(%value, %zero_reshape)
370    %output = tosa.SELECT(%ge, %value, %sub)
371    return %output
372}
373```
374
375### lower_expand_dims()
376
377```
378Value lower_expand_dims(Value %input, int32 axis)
379{
380    vector<size_t> reshape_dims
381
382    if (axis < 0 || axis >= %input.rank) {
383        // Insert at the end of the tensor
384        axis += %input.rank
385        for (int32 i = 0; i < input.rank; i++) {
386           reshape_dims.push_back(%input.shape[i])
387        }
388    } else {
389        for (int32 i= 0 ; i < %input.rank; i++) {
390            if (i == axis) {
391                reshape_dims.push_back(1)
392            }
393            reshape_dims.push_back(%input.shape[i])
394        }
395    }
396
397    %output = tosa.RESHAPE(%input) {new_shape=reshape_dims}
398    return %output
399}
400```
401
402### lower_fake_quant_op()
403
404```
405Value lower_fake_quant_op(Value %inputs, type output_type, float64 min, float64 max,
406                            int64 num_bits, bool narrow_range)
407{
408    assert(num_bits == 8 || num_bits == 16)
409
410    int64 qmax = (1L << (num_bits - 1)) - 1;
411    int64 qmin = -(1L << (num_bits - 1))
412
413    if (narrow_range) {
414       qmin = qmin + 1
415    }
416
417    float64 scale = (max - min) / float64(qmax - qmin)
418
419    int64 zeropoint = (int64)std::round((-min) / scale + float64(qmin))
420
421    %quantized = lower_quantize_op(%inputs.type, %inputs, 1.0 / scale, zeropoint)
422
423    %dequantized = lower_dequantize_op(output_type, %quantized_op, scale, zeropoint)
424
425    return %dequantized
426}
427```
428
429### lower_floor_div()
430
431```
432Value lower_floor_div(Value %lhs, Value %rhs)
433{
434    %recip = tosa.RECIPROCAL(%rhs)
435    %mul = tosa.MUL(%lhs, %recip)
436    %output = tosa.FLOOR(%mul)
437
438    return %output
439}
440```
441
442### lower_floor_mod()
443
444```
445Value lower_floor_mod(Value %lhs, Value %rhs)
446{
447    %recip = tosa.RECIPROCAL(%rhs)
448    %mul = tosa.MUL(%lhs, %recip)
449    %floor = tosa.FLOOR(%mul)
450    %output = tosa.SUB(%mul, %floor)
451    return %output
452}
453```
454
455### lower_quantize_op()
456
457```
458Value lower_quantize_op(Type output_type, Value %input, float64 scale, int64 zeropoint)
459{
460    %const_scale = tosa.CONST() {value={scale}}
461    %const_zp = tosa.CONST() {value={zeropoint}}
462    %op1_mul_in_scale = tosa.MUL(%input, %const_scale)
463    %op2_add_op1_zp = tosa.ADD(%op1_mul_in_scale, %const_zp)
464    %op3_cast_op2 = tosa.CAST(%op2_add_op1_zp) // f32->%output.dtype
465}
466```
467
468### lower_dequantize_op()
469
470```
471Value lower_dequantize_op(Value %input, float64 scale, int64 zeropoint)
472{
473    %const_scale = tosa.CONST() {value={scale}}
474    %const_zp = tosa.CONST() {value={(float64)zeropoint}}
475    %op1_cast_in = tosa.CAST(%input) // %input.dtype->f32
476    %op2_sub_op1_zp = tosa.SUB(%op1_cast_in, %const_zp)
477    %op3_mul_op2_scale = tosa.MUL(%op2_sub_op1_zp, %const_scale)
478}
479```
480
481### lower_log_softmax_op()
482
483```
484Value lower_log_softmax_op(Value %logits)
485{
486    %op1 = tosa.EXP(%logits)
487    %op2 = tosa.REDUCE_SUM(%op1) {axis=(%logits.rank-1)}
488    %op3 = tosa.RECIPROCAL(%op2)
489    %op4 = tosa.MUL(%op1, %op3)
490    %op5 = tosa.LOG(%op4)
491
492    return %op5
493}
494```
495
496### lower_pack_op()
497
498```
499Value lower_pack_op(Value %input[], size_t axis)
500{
501    size_t concat_axis = positive_axis(axis)
502
503    size_t input_tensor_rank = %input[0].rank
504
505    // Convert any rank 0 to rank 1 with reshape
506    if (input_tensor_rank == 0) {
507       for (int32 i = 0; i < %input.size; i++) {
508           %input[i] = tosa.RESHAPE(%input[i], {1})
509       }
510   }
511
512   vector<size_t> output_shape
513   for (int32 i = 0; i < input_tensor_rank; i++) {
514       output_shape.push_back(%input[0].shape[i]
515   }
516
517   output_shape[concat_axis] = output_shape[concat_axis] * %input.size
518
519   // First pair of tensors
520   %concat = tosa.CONCAT(%input[0], %input[1]) {axis=concat_axis}
521
522   // Remaining tensors
523   for (int32 i = 2; i < %input.size; i++) {
524      %concat = tosa.CONCAT(%concat, %input[i]) {axis=concat_axis}
525   }
526
527   if (input_tensor_rank == 0) {
528      // No reshape needed for rank 0, already done
529      %output = %concat
530   } else
531
532      %reshape = tosa.RESHAPE(%concat) {new_shape=output_shape}
533
534      if (concat_axis == input_tensor_rank) {
535         // Output shape is [A, B, C, .. n] in this case,
536         // need to reshape to [N, A, B, C, ..] with perm [1, 2, 3, .. 0]
537         concat_axis = 0
538
539         vector <size_t> perms
540         for (int32 i = 0; i < %input[0].rank; i++)
541            perms.push_back(i + 1)
542         perms.push_back(0)
543
544         %output = tosa.TRANSPOSE(%reshape) {perms=perms}
545     } else {
546         %output = %reshape
547     }
548
549     return %output
550}
551```
552
553### lower_reduce_op()
554
555```
556Value lower_reduce_op<tosa_op_t OP>(Value %input, shape_t output_shape, Value %axes, bool keep_dims, float64 input_scale=1.0f, int32 input_zp=0, float64 output_scale=1.0f, int32 output_zp=0)
557{
558
559    vector axes_vec = %axes.as_constant();
560
561    // Special case of no axes means no transformation
562    if (axes_vec.size() == 0) {
563       return tosa.IDENTITY(%input)
564    }
565
566    bool is_quantized = isa<QuantizedType>(%input.dtype) ? true : false
567
568    shape_t shape = %input.shape;
569    %output = %input;
570
571    if (is_quantized) {
572        %output = tosa.RESCALE(%output) {scale=input_scale, input_zp=input_zp, output_zp=0}
573    }
574
575    for (int32 i = 0; i < axes_vec.size(); i++) {
576        int32 axis = positive_axis(axes_vec[i], %input.rank);
577
578        shape[axis] = 1;
579        %output = tosa.OP(%output) {axis=axis}
580    }
581
582    if (!keep_dims) {
583       %output = tosa.RESHAPE(%output) {new_shape=output_shape}
584    }
585
586    if (is_quantized) {
587        %output = tosa.RESCALE(%output) {scale=output_scale, input_zp=0, output_zp=output_zp}
588    }
589
590    return %output;
591}
592```
593
594### lower_resize_op()
595
596```
597Value lower_resize_op(Value %images, Value %size, shape output_shape, dtype output_dtype, mode_t mode)
598{
599    int32 input_height  = %input.shape[1]
600    int32 input_width   = %input.shape[2]
601    int32 output_height = %output.shape[1]
602    int32 output_width  = %output.shape[2]
603
604    float64 in_center_h  = static_cast<float64>(input_height - 1) / 2.0
605    float64 in_center_w  = static_cast<float64>(input_width - 1) / 2.0
606    float64 out_center_h = static_cast<float64>(output_height - 1) / 2.0
607    float64 out_center_w = static_cast<float64>(output_width - 1) / 2.0
608
609    float64 fp_stride_y, fp_stride_x
610    if (align_corner && output_height > 1)
611        fp_stride_y = static_cast<float64>(input_height - 1) / static_cast<float64>(output_height - 1)
612    else
613        fp_stride_y = static_cast<float64>(input_height) / static_cast<float64>(output_height)
614    if (align_corner && output_width > 1)
615        fp_stride_x = static_cast<float64>(input_width - 1) / static_cast<float64>(output_width - 1)
616    else
617        fp_stride_x = static_cast<float64>(input_width) / static_cast<float64>(output_width)
618
619    float64 fp_offset_y = fp_offset_y = 0.0f
620    if (half_pixel_centers) {
621        fp_offset_y = fp_stride_y * 0.5f - 0.5f
622        fp_offset_x = fp_stride_x * 0.5f - 0.5f
623    }
624
625    if (dtype == float)
626        %op1_resize_in = tosa.RESIZE(%input) {stride={fp_stride_y, fp_stride_x}, offset={fp_offset_y, fp_offset_x}, shift=0, resize_mode=mode}
627    else {
628        int32 shift = 10
629        float64 unit = static_cast<float64>(1 << shift)
630        int32 stride_y = fp_stride_y * unit
631        int32 stride_x = fp_stride_x * unit
632        int32 offset_y = fp_offset_y * unit
633        int32 offset_x = fp_offset_x * unit
634
635        %op1_resize_in = tosa.RESIZE(%input) {stride={stride_y, stride_x}, offset={offset_y, offset_x}, shift=shift, resize_mode=mode}
636
637        if (mode == "BILINEAR") {
638            %const_zero = tosa.CONST() {value={0}}
639            %const_twenty = tosa.CONST() {value={20}}
640            %op2_ge_op1 = tosa.GREATER_EQUAL(%op1_resize_in, %const_zero)
641            %op3_abs_op1 = tosa.ABS(%op1_resize_in)
642            %op4_rshift_op3 = tosa.ARITHMETIC_RIGHT_SHIFT(%op3_abs_op1, %const_twenty)
643            %op5_negate_op4 = tosa.NEGATE(%op4_rshift_op3)
644            %op6_select_op2_op4_op5 = tosa.SELECT(%op2_ge_op1, %op4_rshift_op3, %op5_negate_op4)
645            %op7_cast_op6 = tosa.CAST(%op6_select_op2_op4_op5) // i32/i48->%output.dtype
646        }
647    }
648}
649```
650
651### lower_reversev2_op()
652
653```
654Value lower_reverse_v2_op(Value %tensor, Value %axis)
655{
656    Value %output = %tensor
657
658    if (%axis.num_elements == 0) {
659       %output = tosa.IDENTITY(%tensor)
660    } else {
661        for (int32 i = 0; i < %axis.shape[0]; i++) {
662            size_t axis_val = positive_axis(%axis.as_constant()[i])
663            %output = tosa.REVERSE(%output) {axis=%axis_val}
664        }
665    }
666
667    return %output
668}
669```
670
671### lower_round_op()
672
673```
674Value lower_round_op(Value %x)
675{
676    %half = tosa.CONST() {value={0.5}}
677    %add = tosa.ADD(%x, %half)
678    %output = tosa.FLOOR(%add)
679
680    return %output
681}
682```
683
684### lower_selectv2_op()
685
686```
687Value lower_selectv2_op(Value %condition, Value %t, Value %e, shape output_shape)
688{
689    // Reshape condition so that ranks match to support
690    // broadcasting (if necessary)
691
692    if (%condition.rank != output_shape.size) {
693       vector <size_t> cond_shape = %condition.shape
694       for (int32 i = 0; i < (output_shape.size - %condition.rank); i++) {
695           cond_shape.push_front(1)
696       }
697
698       %condition = tosa.RESHAPE(%condition) {new_shape=cond_shape}
699    }
700
701    %output = tosa.SELECT(%condition, %t, %e)
702
703    return %output
704}
705```
706
707### lower_shape_op()
708
709```
710Value lower_shape_op(Value %input)
711{
712    vector <size_t> input_shape = %input.shape
713
714    %shape = tosa.CONST() {value={input_shape}}
715    return %shape
716}
717```
718
719### lower_space_to_batch_nd_op()
720
721```
722Value lower_space_to_batch_nd_op(Value %input, Value %block_shape, Value %padding)
723{
724
725    size_t block_rank = %block.shape[0]
726    size_t remaining_shape_rank = %input.rank - block_rank - 1;
727
728    // Step 1. Pad based on paddings operand (flattened representation of [input.rank][2]-shaped array)
729    vector <size_t> a1_padding
730    a1_padding[0] = 0
731    a1_padding[1] = 0
732
733    for (int32 i = 0; i < %padding.shape[0]; i++) {
734        a1_padding[i + 2] = %padding.as_constant()[i]
735    }
736
737    %a1_pad = tosa.PAD(%input) {padding=a1_padding}
738
739    // Step 2. Reshape to
740    // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
741    //    padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
742    //    remaining_shape
743
744    vector <size_t> a2_shape(1 + block_rank * 2 + remaining_shape_rank)
745    a2_shape[0] = %input.shape[0]
746    for (int32 i = 0; i < block_rank; i++) {
747        a2_shape[1 + i * 2 + 0] = %a1_pad.shape[1 + i] / block_shape.as_constant()[i]
748        a2_shape[1 + i * 2 + 1] = block_shape.as_constant()[i]
749    }
750
751    for (int32 i = 0; i < remaining_shape_rank; i++) {
752        a2_shape[1 + block_rank * 2 + i] = %input.shape[1 + block_rank + i]
753    }
754
755    %a2_reshape = tosa.RESHAPE(%a1_pad) {new_shape=a2_shape}
756
757    // Step 3 transpose to
758    //  block-shape +
759    //  [batch] +
760    //  [padded_shape[1] / block_shape[0],
761    // ...
762    //  [padded_shape[M] / block_shape[M-1]] +
763    //  remaining_shape
764    vector <size_t> a3_perm(%a2_reshape.rank)
765    size_t block_num_elems = 1
766
767    for (int32 i = 0; i < block_rank; i++) {
768        a3_perm[i] = 1 + 2 * i + 1
769        a3_perm[block_rank + 1 + i] = 2 * i + 1
770        block_num_elems *= %block.as_constant()[i]
771    }
772
773    a3_perm[block_rank] = 0
774    for (int32 i = (1 + block_rank * 2); i < %a2_reshape.rank; i++) {
775        a3_perm[i] = i
776    }
777
778    %a3_reshape = tosa.RESHAPE(%a2_reshape) {perm=a3_perm}
779
780    // Step 4. Reshape transposed tensor to
781    // [ batch * prod(block_shape)] +
782    // [ padded_shape[1] / block_shape[0],
783    //   ...,
784    // padded_shape[M] / block_shape[M-1]] +
785    // remaining_shape
786
787    vector <size_t> a4_shape(%input.rank)
788    a4_shape[0] = batch_size * block_num_elements
789
790    for (int32 i = 0; i < block_rank; i++) {
791        a4_shape[i + 1] = %a1_pad.shape[i + 1] / %block.as_constant()[i]
792    }
793
794    for (int32 i = 0; i < remaining_block_shape; i++) {
795        a4_shape[1 + block_rank + i] = %input.shape[1 + block_rank + i]
796    }
797
798    %output = tosa.RESHAPE(%a3_reshape) {new_shape=a4_shape}
799
800    return %output
801}
802```
803
804### lower_space_to_depth_op()
805
806```
807Value lower_space_to_depth_op(Value %input, size_t block_size[], Format_t data_format)
808{
809    assert(data_format == 'NHWC')
810
811    vector <size_t> a2_shape = {%input.shape[0],
812                                %input.shape[1] / block_size[0],
813                                %block_size[0],
814                                %input_shape[2] / block_size[1],
815                                %block_size[1],
816                                %input_shape[3]}
817    %a2_reshape = tosa.RESHAPE(%input) {new_shape=a2_shape}
818    %a3_transpose = tosa.TRANSPOSE(%a2_reshape) {perm={0, 1, 3, 2, 4, 5}}
819
820    vector <size_t> a4_shape = {%input.shape[0],
821                                %input_shape[1] / block_size[0],
822                                %input_shape[2] / block_size[1],
823                                %input_shape[3] * block_size[0] * block_size[1]}
824    %output = tosa.RESHAPE(%a3_transpose) {new_shape=%a4_shape}
825    return %output
826}
827```
828
829### lower_split_op()
830
831```
832Value lower_split_op(Value %value, size_t axis, size_t num_split)
833{
834    Value %output[]
835
836    size_t slice_size = %value.shape[axis] / num_split
837
838    for (int32 i = 0; i < num_split; i++) {
839        vector <size_t> begin_vals, size_vals
840
841        for (int32 j = 0; j < %value.rank; j++) {
842            if (j == axis) {
843               begin_vals.push_back(slice_size * i)
844               size_vals.push_back(slice_size)
845            } else {
846               begin_vals.push_back(0)
847               size_vals.push_bac(%value.shape[j])
848            }
849
850            %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals}
851        }
852
853    }
854
855    %output_list = tosa.IDENTITYN(%output)
856    return %output_list
857}
858```
859
860### lower_splitv_op()
861
862```
863Value lower_splitv_op(Value %value, vector <size_t> size_split, size_t axis)
864{
865   Value %output[]
866
867   size_t curr_split_start = 0
868
869   for (int32 i = 0; i < size_split.size(); i++) {
870       vector <size_t> begin_vals, size_vals
871
872       for (int32 j = 0; j < %value.rank; j++) {
873           if (j == axis) {
874              begin_vals.push_back(curr_split_start)
875              size_vals.push_back(size_split[i])
876           } else {
877              begin_vals.push_back(0)
878              size_vals.push_back(input.shape[j])
879           }
880       }
881
882       %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals}
883
884       curr_split_start += size_split[i]
885   }
886
887    %output_list = tosa.IDENTITYN(%output)
888    return %output_list
889}
890```
891
892### lower_squeeze_op()
893
894```
895Value lower_squeeze_op(Value %input, vector<size_t> squeeze_dims)
896{
897    vector <size_t> reshape_dims
898
899    if (squeeze_dims.size() == 0) {
900       // Remove all 1-dims
901       for (int32 i = 0; i < %input.rank; i++) {
902           if (%input.shape[i] != 1) {
903              reshape_dims.push_back(%input_shape[i])
904           }
905       }
906    } else {
907      // Remove the specified dimensions
908      for (int32 i = 0; i < %input.rank; i++) {
909          if (!squeeze_dims.find(i) || %input.shape[i] != -1) {
910              reshape_dims.push_back(%input_shape[i])
911          }
912      }
913    }
914
915    %output = tosa.RESHAPE(%input) {new_shape=reshape_dims}
916
917    return %output
918}
919```
920
921### lower_strided_slice_op()
922
923```
924Value lower_strided_slice_op(Value %input, Value %begin_val, Value %end_val, Value %strides_val,
925                               size_t begin_mask, size_t end_mask, size_t ellipsis_mask,
926                               size_t new_axis_mask, size_t shrink_axis_mask)
927{
928    // Note: does not implement ellipsis_mask or reverse stride at this time
929    assert(ellipsis_mask == 0)
930
931    vector <size_t> begin(%begin_val.as_constant()), end(%end_val.as_constant()), strides(%strides_val.as_constant())
932    vector <size_t> a1_start, a1_size, a2_shape, a3_start, a3_size, a4_shape
933
934    for (int32 i = 0; i < %input.rank; i++) {
935        if (begin_mask & (1 << i)) {
936           begin[i] = 0
937        }
938
939        if (end_mask & (1 << i)) {
940           end[i] = %input.shape[i]
941        }
942
943        // Wrap around index if begin and end are negative
944        if (begin[i] < 0) {
945           begin[i] += %input.shape[i]
946        }
947
948        if (end[i] < 0) {
949           end[i] += %input.shape[i]
950        }
951
952        a1_start[i] = begin[i]
953        a1_size[i] = end[i] - begin[i]
954
955        a2_shape[i*2 + 0] = a1_size[i] / strides[i]
956        a2_shape[i*2 + 1] = strides[i]
957
958        a3_start[i*2 + 0] = 0
959        a3_start[i*2 + 1] = 0
960
961        if (shrink_axis_mask & (1 << i)) {
962           a3_size[i*2 + 0] = 1
963        } else {
964           a3_size[i*2 + 0] = a1_size[i] / strides[i]
965        }
966        a3_size[i*2 + 1] = 1
967
968        if (!(shrink_axis_mask & (1 << i))) {
969           if (new_axis_mask & (1 << i)) {
970              a4_shape.push_back(1)
971           a4_shape.push_back((a1_size[i] / strides[i]))
972        }
973    }
974
975    // Step 1: Slice the input array
976    %a1_slice = tosa.SLICE(%input) {start=a1_start, size=a1_size}
977
978    // Step 2: Reshape the sliced array: 2x as many dimensions as %input
979    %a2_reshape = tosa.RESHAPE(%a1_slice) {new_shape=a2_shape}
980
981    // Step 3: Take a slice of the [0] index along each of the strided dimensions (even dimensions)
982    %a3_slice = tosa.SLICE(%a2_reshape) {start=a3_start, size=a3_size}
983
984    // Step 4: Reshape the now-strided tensor back down to the desired number of dimensions
985    %output = tosa.RESHAPE(%a3_slice) {new_shape=a4_shape}
986
987    return %output
988}
989```
990
991### lower_unpack_op()
992
993```
994Value lower_unpack_op(Value %value, size_t axis, uint64_t num)
995{
996    axis = positive_axis(axis)
997
998    Value %output_arr[]
999
1000    // Step 1: transpose 'axis' to left-most dimension, if necessary
1001    Value %transposed_value
1002
1003    if (axis != 0) {
1004       vector <size_t> perms
1005
1006       perms.push_back(axis)
1007       for (int32 i = 0; i < %input.rank; i++) {
1008           if (i != axis)
1009              perms.push_back(i)
1010       }
1011
1012       %transposed_value = tosa.TRANSPOSE(%value) {perms=perms}
1013
1014   } else {
1015      %transposed_value = %value
1016   }
1017
1018   // Step 2: Slice [N, A, B, C] into [N] [A, B, C]
1019   for (int32 i = 0; i < %transposed_value.rank; i++) {
1020       vector <size_t> begin_vals, size_vals, shape_vals
1021
1022       begin_vals.push_back(i)
1023       size_vals.push_back(1)
1024
1025       for (int32 j = 1; j < %transposed_value.rank; j++) {
1026           begin_vals.push_back(0)
1027           size_vals.push_back(transposed_value.shape[j])
1028           shape_vals.push_back(transposed_value.shape[j])
1029       }
1030
1031       %slice = %tosa.SLICE(%transposed_value) {begin=begin_vals, size=size_vals}
1032       %output_arr[i] = %tosa.RESHAPE(%slice) {new_shape=shape_vals} {begin=begin_vals, size=size_vals}
1033   }
1034
1035   // Combine array of sliced tensors into a list of tensors
1036   %output = tosa.IDENTITYN(%output_arr)
1037   return %output
1038}
1039```
1040
1041### get_transpose_conv2d_padding_values_from_pad_type()
1042
1043```
1044vector<int64> get_transpose_conv2d_padding_values_from_pad_type(tensorflow::Padding padding, tensorflow::TensorFormat data_format,
1045                                                         uint32 first_filter_spatial_dim, type input_type, type filter_type
1046                                                         vector strides, vector dilations)
1047{
1048    int64 pad_before, pad_after;
1049    vector<int64> computed_padding
1050
1051    for (int32 i = 0; i < 2; i++) {
1052        int64 ifm_dim = GetTensorSpatialDimIndex(4, data_format, i);
1053        int64 ofm_dim = GetTensorSpatialDimIndex(4, data_format, i);
1054        int64 filter_dim = first_filter_spatial_dim + 1
1055
1056        int64 ifm_size = input_shape[ifm_dim]
1057        int64 ofm_size = output_dims[ofm_dim]
1058        int64 filter_size = filter.shape[filter_dim]
1059        int64 dim_dilation = dilations[i]
1060        int64 dim_stride = strides[i]
1061        int32 effective_filter_size = (filter_size - 1) * dim_dilation + 1
1062        int32 total_padding = ((ifm_size - 1) * dim_stride + effective_filter_size - ofm_size)
1063        total_padding = total_padding > 0 ? total_padding : 0
1064
1065        pad_before = total_padding / 2
1066        pad_after = total_padding - pad_before
1067
1068        computed_padding.push_back(pad_before)
1069    }
1070
1071    return computed_padding
1072}
1073```
1074
1075### lower_fused_activation()
1076
1077```
1078Value lower_fused_activation(Value %input, string activation)
1079{
1080    bool is_quantized = isa<QuantizedType>(%input.dtype) ? true : false
1081
1082    if (is_quantized) {
1083        if (activation == "NONE") {
1084            return %input
1085        }
1086        else if (activation == "RELU") {
1087            int32 quantized_0 = %input.zp
1088            int32 quantized_max = %input.storage_max
1089            return tosa.CLAMP(%input) {min_int=quantized_0, max_int=quantized_max}
1090        }
1091        else if (activation == "RELU6") {
1092            int32 quantized_0 = %input.zp
1093            int32 quantized_6 = %input.zp + (6.0 / %input.scale)
1094            return tosa.CLAMP(%input) {min_int=quantized_0, max_int=quantized_6}
1095        }
1096        else if (activation == "RELU_N1_TO_1") {
1097            int32 quantized_n1 = %input.zp + (-1.0 / %input.scale)
1098            int32 quantized_1 = %input.zp + (1.0 / %input.scale)
1099            return tosa.CLAMP(%input) {min_int=quantized_n1, max_int=quantized_1}
1100        }
1101    }
1102    else {
1103        if (activation == "NONE") {
1104            return %input
1105        }
1106        else if (activation == "RELU") {
1107            return tosa.RELUN(%input) {max_fp=numeric_limit<float32>::max()}
1108        }
1109        else if (activation == "RELU6") {
1110            return tosa.RELUN(%input) {max_fp=6.0}
1111        }
1112        else if (activation == "RELU_N1_TO_1") {
1113            return tosa.CLAMP(%input) {min_fp=-1.0, max_fp=1.0}
1114        }
1115        else if (activation == "TANH") {
1116            return tosa.TANH(%input)
1117        }
1118    }
1119}
1120```
1121
1122### get_table_const_tensor()
1123
1124```
1125Value get_table_const_tensor(function func)
1126{
1127    array<int16, 513> table_array
1128    for (int32 i = -256; i <= 256; i++) {
1129        table_array[i] = func(i)
1130    }
1131
1132    return tosa.CONST() {value=table_array}
1133}
1134```
1135
1136### lower_gather_op()
1137
1138```
1139Value lower_gather_op(Value %params, Value %indices, int32 batch_dims, int32 axis)
1140{
1141    assert batch_dims <= %indices.rank
1142    assert axis >= batch_dims
1143
1144    int32 N = W = K = C = 1
1145
1146    for (int32 i = 0; i < batch_dims; i++) N *= %params.shape[i]
1147    for (int32 i = batch_dims; i < %indices.rank; i++) W *= %indices.shape[i]
1148    K = %params.shape[axis]
1149    for (int32 i = batch_dims; i < axis; i++) C *= %params.shape[i]
1150    for (int32 i = (axis + 1); i < %params.rank; i++) C *= %params.shape[i]
1151
1152    vector<int32> params_idx_batch, params_idx_left, params_idx_indices, params_idx_right
1153    for (int32 i = 0; i < %params.rank; i++) {
1154        if (i < batch_dims && i < axis)
1155            params_idx_batch.push_back(i)
1156        else if (i < axis)
1157            params_idx_left.push_back(i)
1158        else if (i < (axis + 1))
1159            params_idx_indices.push_back(i)
1160        else
1161            params_idx_right.push_back(i)
1162    }
1163
1164    vector<int32> params_perm = {params_idx_batch, params_idx_left, params_idx_indices, params_idx_right}
1165    vector<int32> result_perm
1166    for (int32 i = 0; i < batch_dims; i++)
1167        result_perm.push_back(i)
1168    for (int32 i = 0; i < params_idx_left.size(); i++)
1169        result_perm.push_back(params_idx_left[i])
1170    for (int32 i = batch_dims; i < %indices.rank; i++)
1171        result_perm.push_back(i)
1172    for (int32 i = 0; i < params_idx_right.size(); i++)
1173        result_perm.push_back(params_idx_right[i])
1174
1175    %const_params_perm = tosa.CONST() {value=params_perm}
1176    %const_result_perm = tosa.CONST() {value=result_perm}
1177
1178    %op1_transpose_params = tosa.TRANSPOSE(%params, %const_params_perm)
1179    %op2_reshape_op1 = tosa.RESHAPE(%op1_transpose_params) {shape={N,K,C}}
1180    %op3_reshape_indices = tosa.RESHAPE(%indices) {shape={N,W}}
1181    %op4_gather_op2_op3 = tosa.GATHER(%op2_reshape_op1, %op3_reshape_indices)
1182    %op5_reshape_op4 = tosa.RESHAPE(%op4_gather_op2_op3) {shape={N,W,C}}
1183    %op6_transpose_op5 = tosa.TRANSPOSE(%op5_reshape_op4, %const_result_perm)
1184}
1185```
1186
1187### lower_gather_nd_op()
1188
1189```
1190Value lower_gather_nd_op(Value %params, Value %indices)
1191{
1192    int32 N = W = K = C = ND = 1
1193
1194    ND = %indices.shape[%indices.rank - 1]
1195
1196    assert ND < %params.rank
1197
1198    for (int32 i = 0; i < (%indices.rank - 1); i++) W *= %indices.shape[i]
1199    for (int32 i = 0; i < ND; i++) K = %params.shape[i]
1200    for (int32 i = ND; i < %params.rank; i++) C *= %params.shape[i]
1201
1202    vector<int32> flatten_coeff_vec
1203    for (int32 i = 0; i < ND; i++) flatten_coeff_vec.push_back(i)
1204    flatten_coeff_vec.push_back(1)
1205
1206    %const_flatten_coeff = tosa.CONST() {value=flatten_coeff_vec}
1207    %op1_reshape_params = tosa.RESHAPE(%params) {shape={N,K,C}}
1208    %op2_reshape_indices = tosa.RESHAPE(%indices) {shape={W,ND}}
1209    %op3_mul_op2_flatten_coeff = tosa.MUL(%op2_reshape_indices, %const_flatten_coeff)
1210    %op4_rsum_op3 = tosa.REDUCE_SUM(%op3_mul_op2_flatten_coeff) {axis=1}
1211    %op5_reshape_op4 = tosa.RESHAPE(%op4_rsum_op3) {shape={N,W}}
1212    %op6_gather_op1_op5 = tosa.GATHER(%op1_reshape_params, %op5_reshape_op4)
1213    %op7_reshape_op6 = tosa.RESHAPE(%op6_gather_op1_op5) {shape={N,W,C}}
1214}
1215```
1216
1217### lower_one_hot_op()
1218
1219```
1220Value lower_one_hot_op(Value %indices, Value %depth, Value %on_value, Value %off_value, int32 axis)
1221{
1222    int32 N = W = C = 1
1223    int32 K = %depth.as_constant()
1224    int32 left_dim = right_dim = 1
1225    for(int32 i : %indices.rank) {
1226        int32 dim = %indices.shape[i]
1227        N *= dim
1228        if (i >= axis)
1229            right_dim *= dim
1230        else
1231            left_dim *= dim
1232    }
1233
1234    %perm_const = tosa.CONST() {value={0, 2, 1}}
1235    %op1_reshape_on_value = tosa.RESHAPE(%on_value) {shape={1, 1, 1}}
1236    %op2_tile_op1 = tosa.TILE(%op1_reshape_on_value) {multiples={N, W, C}}
1237    %op3_reshape_off_value = tosa.RESHAPE(%off_value) {shape={1, 1, 1}}
1238    %op4_tile_op1 = tosa.TILE(%op3_reshape_off_value) {multiples={N, K, C}}
1239    %op5_reshape_indices = tosa.RESHAPE(%indices) {shape={N, W}}
1240    %op6_scatter_op4_op5_op2 = tosa.SCATTER(%op4_tile_op1, %op5_reshape_indices, %op2_tile_op1)
1241    %op7_reshape_op6 = tosa.RESHAPE(%op6_scatter_op4_op5_op2) {shape={left_dim, right_dim, K}}
1242    %op8_transpose_op7 = tosa.TRANSPOSE(%op7_reshape_op6, %perm_const)
1243    %op9_reshape_op8 = tosa.RESHAPE(%op8_transpose_op7) {shape=%output.shape}
1244}
1245
1246
1247## MLIR Passes Management
1248
1249Legalization is built on multiple MLIR passes.
1250
1251| MLIR Pass Name            | Input Dialect | Output Dialect | Description     |
1252| ------------------------- | ------------- | -------------- | --------------- |
1253| legalize_tf               | TensorFlow    | TOSA           | Legalize        |
1254:                           :               :                : TensorFlow      :
1255:                           :               :                : dialect to TOSA :
1256:                           :               :                : dialect         :
1257| fuse_tf_bias              | TensorFlow    | TOSA           | Mapping         |
1258:                           :               :                : tf.BiasAdd +    :
1259:                           :               :                : tf.Conv2D to    :
1260:                           :               :                : tosa.CONV2D     :
1261| legalize_tfl              | TensorFlow    | TOSA           | Legalize        |
1262:                           : Lite          :                : TensorFlow Lite :
1263:                           :               :                : dialect to TOSA :
1264:                           :               :                : dialect         :
1265| convert_tfl_uint8         | TensorFlow    | TensorFlow     | Convert         |
1266:                           : Lite          : Lite           : quantized uint8 :
1267:                           :               :                : graph to int8   :
1268:                           :               :                : graph           :
1269
1270TF to TOSA legalization could be summarized by following pseudocode:
1271
1272```
1273
1274void legalize_tf_to_tosa(mlir::Module module) { mlir::PassManager pm
1275
1276```
1277// other MLIR passes to optimize TF
1278
1279pm.addPass(fuse_tf_bias)
1280pm.addPass(legalize_tf)
1281
1282// other MLIR passes to optimize TOSA
1283```
1284
1285} ```
1286
1287TFLite to TOSA legalization could be summarized by following pseudocode:
1288
1289```
1290void legalize_tfl_to_tosa(mlir::Module module)
1291{
1292    mlir::PassManager pm
1293
1294    // other MLIR passes to optimize TFLite
1295
1296    pm.addPass(convert_tfl_uint8)
1297    pm.addPass(legalize_tfl)
1298
1299    // other MLIR passes to optimize TOSA
1300}
1301```
1302
1303Each of the passes is described in more detail in the subsequent chapters.
1304
1305## TensorFlow MLIR Dialect Legalization (legalize_tf)
1306
1307### tf.Abs
1308
1309This operator is trivially lowered to tosa.ABS
1310
1311### tf.AddN
1312
1313**TensorFlow Dialect**
1314
1315```
1316%output = tf.AddN(%inputs)
1317```
1318
1319**TOSA Lowering**
1320
1321```
1322%output = tosa.ADD(%inputs:0, %inputs:1)
1323for (int32 i = 2; i < %inputs.size; i++) {
1324    %output = tosa.ADD(%inputs:i, %output)
1325}
1326```
1327
1328### tf.Add
1329
1330Element-wise addition.
1331
1332**TensorFlow Dialect**
1333
1334```
1335%output = tf.Add(%x, %y)
1336```
1337
1338**TOSA Lowering** This operator is trivially lowered to tosa.ADD.
1339
1340### tf.Addv2
1341
1342Element-wise addition.
1343
1344**TensorFlow Dialect**
1345
1346```
1347%output = tf.Addv2(%x, %y)
1348```
1349
1350**TOSA Lowering** This operator is trivially lowered to tosa.ADD.
1351
1352### tf.All
1353
1354Computes the "logical and" of elements across dimensions of a tensor.
1355
1356**TensorFlow Dialect**
1357
1358```
1359%output = tf.all(%input, %reduction_indices) {keep_dims}
1360```
1361
1362**TOSA Lowering**
1363
1364```
1365%output = lower_reduce_op<tosa.REDUCE_ALL>(%input, %output.shape, %reduction_indices, keep_dims)
1366```
1367
1368### tf.Any
1369
1370Computes the "logical or" of elements across dimensions of a tensor.
1371
1372**TensorFlow Dialect**
1373
1374```
1375%output = tf.any(%input, %reduction_indices) {keep_dims}
1376```
1377
1378**TOSA Lowering**
1379
1380```
1381%output = lower_reduce_op<tosa.REDUCE_ANY>(%input, %output.shape, %reduction_indices, keep_dims)
1382```
1383
1384### tf.ArgMax
1385
1386Returns the index with the largest value across the given axis of the input
1387tensor.
1388
1389**TensorFlow Dialect**
1390
1391```
1392%output = tf.ArgMax(%input, %dimension)
1393```
1394
1395**TOSA Lowering**
1396
1397```
1398int64 axis = positive_axis(%dimension)
1399%output = tosa.ARGMAX(%input) {axis=axis}
1400```
1401
1402### tf.ArgMin
1403
1404Returns the index with the smallest value across the given axis of the input
1405tensor.
1406
1407**TensorFlow Dialect**
1408
1409```
1410%output = tf.ArgMin(%input, %dimension)
1411```
1412
1413**TOSA Lowering**
1414
1415No TOSA lowering defined.
1416
1417### tf.Assert
1418
1419Asserts that the given condition is true.
1420
1421**TensorFlow Dialect**
1422
1423```
1424%output = tf.Assert(%condition, %summarize)
1425```
1426
1427**TOSA Lowering**
1428
1429No TOSA lowering defined.
1430
1431### tf.AssignAddVariableOp
1432
1433Adds a value to the current value of a variable.
1434
1435**TensorFlow Dialect**
1436
1437```
1438%output = tf.AssignAddVariableOp(%resource, %value, %dtype)
1439```
1440
1441**TOSA Lowering**
1442
1443No TOSA lowering defined.
1444
1445### tf.AssignSubVariableOp
1446
1447Subtracts a value to the current value of a variable.
1448
1449**TensorFlow Dialect**
1450
1451```
1452%output = tf.AssignSubVariableOp(%resource, %value, %dtype)
1453```
1454
1455**TOSA Lowering**
1456
1457No TOSA lowering defined.
1458
1459### tf.AssignVariableOp
1460
1461Assigns a new value to a variable.
1462
1463**TensorFlow Dialect**
1464
1465```
1466%output = tf.AssignVariableOp(%resource, %value, %dtype)
1467```
1468
1469**TOSA Lowering**
1470
1471No TOSA lowering defined.
1472
1473### tf.AvgPool
1474
1475Performs average pooling on the input.
1476
1477**TensorFlow Dialect**
1478
1479```
1480%output = tf.AvgPool(%value) {ksize, strides, padding, data_format}
1481```
1482
1483**TOSA Lowering**
1484
1485```
1486assert(data_format == "NHWC")
1487
1488tosa_padding =
1489     get_padding_values_from_pad_type(%input, ksize, padding, data_format,
1490                                      FORMAT_OHWI, strides, {1, 1, 1, 1})
1491%output = tosa.AVG_POOL2D(%value) {ksize=ksize, strides=strides, padding=tosa_padding}
1492```
1493
1494### tf.BatchMatMul
1495
1496Multiplies slices of two tensors in batches.
1497
1498**TensorFlow Dialect**
1499
1500```
1501%output = tf.BatchMatMul(%x, %y, %adj_x, %adj_y)
1502```
1503
1504**TOSA Lowering**
1505
1506No TOSA lowering defined.
1507
1508### tf.BatchMatMulV2
1509
1510Multiplies slices of two tensors in batches.
1511
1512**TensorFlow Dialect**
1513
1514```
1515%output = tf.BatchMatMulV2(%x, %y, %adj_x, %adj_y)
1516```
1517
1518**TOSA Lowering**
1519
1520No TOSA lowering defined.
1521
1522### tf.BatchNormWithGlobalNormalization
1523
1524✗ Deprecated operator.
1525
1526### tf.BatchToSpaceND
1527
1528BatchToSpaceND for N-D tensors of type T.
1529
1530**TensorFlow Dialect**
1531
1532```
1533%output = tf.BatchToSpaceND(%input, %block_shape, %crops)
1534```
1535
1536**TOSA Lowering**
1537
1538```
1539%output = lower_batch_to_space_nd_op(%input, %block_shape, %crops, output.shape)
1540```
1541
1542### tf.BiasAddGrad
1543
1544Training profile: TOSA lowering not yet defined.
1545
1546### tf.BiasAdd
1547
1548Add bias to value.
1549
1550**TensorFlow Dialect**
1551
1552```
1553%output = tf.BiasAdd(%bias, %value) {data_format}
1554```
1555
1556**TOSA Lowering**
1557
1558```
1559assert(data_format == 'NHWC')
1560%output = tosa.ADD(%value, %bias)
1561```
1562
1563### tf.BitCast
1564
1565Bitcasts a tensor from one type to another without copying data.
1566
1567**TensorFlow Dialect**
1568
1569```
1570%output = tf.BitCast(%input, %dtype)
1571```
1572
1573**TOSA Lowering**
1574
1575No TOSA lowering defined.
1576
1577### tf.BitwiseAnd
1578
1579This operator is trivially lowered to tosa.BITWISE_AND.
1580
1581### tf.BitwiseOr
1582
1583This operator is trivially lowered to tosa.BITWISE_OR.
1584
1585### tf.BroadcastGradientArgs
1586
1587Training profile: TOSA lowering not yet defined.
1588
1589### tf.BroadcastTo
1590
1591No TOSA lowering defined.
1592
1593### tf.Cast
1594
1595This operator is trivially lowered to tosa.CAST.
1596
1597### tf.Ceil
1598
1599This operator is trivially lowered to tosa.CEIL.
1600
1601### tf.CheckNumerics
1602
1603No TOSA lowering defined.
1604
1605### tf.ComplexAbs
1606
1607No TOSA lowering defined.
1608
1609### tf.Complex
1610
1611No TOSA lowering defined.
1612
1613### tf.ConcatOffset
1614
1615No TOSA lowering defined. Training profile: TOSA lowering not yet defined.
1616
1617### tf.Concat
1618
1619No TOSA lowering defined.
1620
1621### tf.ConcatV2
1622
1623Concatenates tensors along one dimension.
1624
1625**TensorFlow Dialect**
1626
1627```
1628%output = tf.ConcatV2(%values, %axis)
1629```
1630
1631**TOSA Lowering**
1632
1633```
1634%output = lower_concatv2_op(%values, %axis)
1635```
1636
1637### tf.Conj
1638
1639No TOSA lowering defined.
1640
1641### tf.Const
1642
1643This operator is trivially lowered to tosa.CONST.
1644
1645### tf.Conv2DBackpropFilter
1646
1647No TOSA lowering defined.
1648
1649### tf.Conv2DBackpropInput
1650
1651Computes the gradients of convolution with respect to the input.
1652
1653**TensorFlow Dialect**
1654
1655```
1656%output = tf.Conv2DBackpropInput(%input_sizes, %filter, %out_backprop) {strides, use_cudnn_on_gpu, padding, explicit_paddings, data_format, dilations}
1657```
1658
1659**TOSA Lowering**
1660
1661```
1662// Transpose filter from HWIO to OHWI
1663%tosa_filter = tosa.TRANSPOSE(%filter) {perms={2, 0, 1, 3}}
1664
1665vector output_shape
1666
1667for (int32 i = 0; i < input_sizes.size(); i++) {
1668   output_shape.push_back(input_size[i])
1669}
1670
1671if (%padding == "EXPLICIT") {
1672   tosa_padding =
1673       get_padding_values_from_explicit_pad_attr(explict_padding, data_format)
1674} else {
1675    tosa_padding =
1676        get_transpose_conv2d_padding_values_from_pad_type(%input_sizes, %filter, output_shape, padding, data_format, FORMAT_HWIO, strides, dilations)
1677}
1678
1679// Create a zero bias tensor
1680%zero_bias = tosa.CONST() {value={0}}
1681%output = tosa.TRANSPOSE_CONV2D(%out_backprop) {weight=%tosa_filter, bias=%zero_bias, outpad=tosa_pading, stride=strides, dilation==dilations, out_shape=out_shape}
1682```
1683
1684### tf.Conv2D
1685
1686Computes a 2-D convolution given 4-D input and filter tensors.
1687
1688**TensorFlow Dialect**
1689
1690```
1691%output = tf.Conv2D(%input, %filter) {strides, padding, explicit_paddings, data_format, dilations}
1692```
1693
1694**TOSA Lowering**
1695
1696```
1697assert(data_format == "NHWC")
1698
1699// Transpose filter from HWIO to OHWI
1700%filter_tranpose = tosa.TRANSPOSE(%filter {perms={3, 0, 1, 2}}
1701
1702if (padding == "EXPLICIT") {
1703   tosa_padding =
1704       get_padding_values_from_explicit_pad_attr(explict_padding, data_format)
1705} else {
1706    %tosa_padding =
1707        get_padding_values_from_pad_type(%input, %filter.shape, padding, data_format,
1708                                         FORMAT_HWIO, strides, dilations)
1709}
1710
1711// Create a zero bias tensor
1712%zero_bias = tosa.CONST() {value={0}}
1713
1714%output = tosa.CONV2D(%input, %filter_transpose, %zero_bias) {padding=tosa_padding, stride=strides, dilation=dilations}
1715```
1716
1717### tf.Conv3D
1718
1719TOSA lowering to tosa.CONV3D to be defined.
1720
1721### tf.Cos
1722
1723No TOSA lowering defined.
1724
1725### tf.CrossReplicaSum
1726
1727No TOSA lowering defined.
1728
1729### tf.DepthToSpace
1730
1731DepthToSpace for tensors of type T.
1732
1733**TensorFlow Dialect**
1734
1735```
1736%output = tf.DepthToSpace(%input) {block_size, data_format}
1737```
1738
1739**TOSA Lowering**
1740
1741```
1742%output = lower_depth_to_space_op(%input, block_size, data_format)
1743```
1744
1745### tf.DepthwiseConv2dNative
1746
1747Computes a 2-D depthwise convolution given 4-D input and filter tensors.
1748
1749**TensorFlow Dialect**
1750
1751```
1752%output = tf.DepthwiseConv2dNative(%input, %filter) {strides, padding, data_format, dilations}
1753```
1754
1755**TOSA Lowering**
1756
1757```
1758if (padding == "EXPLICIT") {
1759   tosa_padding =
1760       get_padding_values_from_explicit_pad_attr(explict_padding, data_format)
1761} else {
1762    tosa_padding =
1763        get_padding_values_from_pad_type(%input, %filter.shape, padding, data_format,
1764                                         FORMAT_HWIO, strides, dilations)
1765}
1766
1767bias_dim = %filter.shape[2] * %filter.shape[3]
1768
1769// Create a zero-bias tensor
1770%zero_bias = tosa.CONST() {value={0} * bias_dim}
1771
1772%output = tosa.DEPTHWISE_CONV2D(%input, %filter, %zero_bias) {stride=strides, dilation=dilations, padding=padding}
1773```
1774
1775### tf.DivNoNan
1776
1777No TOSA lowering defined.
1778
1779### tf.Div
1780
1781No TOSA lowering defined.
1782
1783### tf.DynamicStitch
1784
1785No TOSA lowering defined.
1786
1787### tf.Einsum
1788
1789No TOSA lowering defined.
1790
1791### tf.Elu
1792
1793Computes exponential linear: exp(features) - 1 if &lt;0, features otherwise
1794
1795**TensorFlow Dialect**
1796
1797```
1798%output = tf.Elu(%features)
1799```
1800
1801**TOSA Lowering**
1802
1803```
1804%output = lower_elu_op(%features)
1805```
1806
1807### tf.EmptyTensorList
1808
1809No TOSA lowering defined.
1810
1811### tf.Equal
1812
1813Returns the truth value of (x == y) element-wise with broadcasting.
1814
1815**TensorFlow Dialect**
1816
1817```
1818%output = tf.Equal(%x, %y)
1819```
1820
1821**TOSA Lowering** This operator is trivially lowered to tosa.EQUAL.
1822
1823### tf.Exp
1824
1825This operator is trivially lowered to tosa.EXP.
1826
1827### tf.ExpandDims
1828
1829Inserts a dimension of 1 into a tensor’s shape
1830
1831**TensorFlow Dialect**
1832
1833```
1834%output = tf.ExpandDims(%input, %axis)
1835```
1836
1837**TOSA Lowering**
1838
1839```
1840%output = lower_expand_dims(%input, %axis.to_constant())
1841```
1842
1843### tf.FakeQuantWithMinMaxArgs
1844
1845Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
1846
1847**TensorFlow Dialect**
1848
1849```
1850%output = tf.FakeQuantWithMinMaxArgs(%inputs) {min, max, num_bits, narrow_range}
1851```
1852
1853**TOSA Lowering**
1854
1855```
1856%output = lower_fake_quant_op(%inputs, %min, %max, %num_bits, %narrow_range)
1857```
1858
1859### tf.FakeQuantWithMinMaxVars
1860
1861Fake-quantize the 'inputs' tensor of type float via global flats scalars min.
1862
1863**TensorFlow Dialect**
1864
1865```
1866%output = tf.FakeQuantWithMinMaxVars(%inputs, %min, %max) {num_bits, narrow_range}
1867```
1868
1869**TOSA Lowering**
1870
1871```
1872%output = lower_fake_quant_op(%inputs, %output.type, %min.to_constant(), %max.to_constant(), num_bits, narrow_range)
1873```
1874
1875### tf.FakeQuantWithMinMaxVarsPerChannel
1876
1877Fake-quantize the 'inputs' tensor of type float and one of the shapes \[d\].
1878
1879**TensorFlow Dialect**
1880
1881```
1882%output = tf.FakeQuantWithMinMaxVarsPerChannel(%inputs, %min, %max) {num_bits, narrow_range}
1883```
1884
1885No TOSA lowering defined.
1886
1887### tf.Fill
1888
1889Creates a tensor filled with a scalar value
1890
1891**TensorFlow Dialect**
1892
1893```
1894%output = tf.Fill(%dims, %value)
1895```
1896
1897**TOSA Lowering**
1898
1899```
1900int64 total_size = 1
1901
1902for (int32 i = 0; i < %dims.shape[0]; i++) {
1903    total_size *= %dims[i]
1904}
1905
1906vector<%value.dtype> fill_arr(total_size, %value)
1907
1908%output = tosa.CONST() {value={fill_arr}}
1909```
1910
1911### tf.FloorDiv
1912
1913Returns x // y element-wise.
1914
1915**TensorFlow Dialect**
1916
1917```
1918%output = tf.FloorDiv(%x, %y)
1919```
1920
1921**TOSA Lowering**
1922
1923```
1924%output = lower_floor_div(%lhs, %rhs)
1925```
1926
1927### tf.FloorMod
1928
1929Returns element-wise remainder of division when x &lt; 0 xor x &lt; y is true.
1930
1931**TensorFlow Dialect**
1932
1933```
1934%output = tf.FloorMod(%x, %y)
1935```
1936
1937**TOSA Lowering**
1938
1939```
1940%output = lower_floor_mod(%lhs, %rhs)
1941```
1942
1943### tf.Floor
1944
1945This operator is trivially lowered to tosa.FLOOR.
1946
1947### tf.FusedBatchNormGrad
1948
1949Training profile: TOSA lowering not yet defined.
1950
1951### tf.FusedBatchNormGradV2
1952
1953Training profile: TOSA lowering not yet defined.
1954
1955### tf.FusedBatchNormGradV3
1956
1957Training profile: TOSA lowering not yet defined.
1958
1959### tf.FusedBatchNorm
1960
1961Batch normalization.
1962
1963**TensorFlow Dialect**
1964
1965```
1966%output = tf.FusedBatchNorm(%x, %scale, %offset, %mean, %variance) {epsilon, data_format, is_training}
1967
1968
1969assert(data_format == 'NHWC')
1970assert(is_training == false)
1971
1972%epsilon_const = tosa.CONST() {value={epsilon}}
1973
1974%op1 = tosa.SUB(%x, %bmean)
1975%op2 = tosa.ADD(%variance, %epsilon_const)
1976%op3 = tosa.RSQRT(%op2)
1977%op4 = tosa.MUL(%op1, %op3)
1978%op5 = tosa.MUL(%op4, %scale)
1979%output = tosa.ADD(%op5, %offset)
1980```
1981
1982### tf.FusedBatchNormV3
1983
1984Batch normalization.
1985
1986**TensorFlow Dialect**
1987
1988```
1989%output = tf.FusedBatchNormV3(%x, %scale, %offset, %mean, %variance) {epsilon, data_format, is_training}
1990```
1991
1992**TOSA Lowering**
1993
1994```
1995assert(data_format == 'NHWC')
1996assert(is_training == false)
1997
1998%epsilon_const = tosa.CONST() {value={epsilon}}
1999
2000%op1 = tosa.SUB(%x, %bmean)
2001%op2 = tosa.ADD(%variance, %epsilon_const)
2002%op3 = tosa.RSQRT(%op2)
2003%op4 = tosa.MUL(%mean, %op3)
2004%op5 = tosa.MUL(%op4, %scale)
2005%output = tosa.ADD(%op5, %offset)
2006```
2007
2008### tf.GatherNd
2009
2010Gather slices from params into a Tensor with shape specified by indices.
2011
2012**TensorFlow Dialect**
2013
2014```
2015%output = tf.GatherNd(%params, %indices)
2016```
2017
2018**TOSA Lowering**
2019
2020```
2021%output = lower_gather_nd_op(%params, %indices)
2022```
2023
2024### tf.Gather
2025
2026Gathers slices from params according to indices.
2027
2028**TensorFlow Dialect**
2029
2030```
2031%output = tf.Gather(%params, %indices)
2032```
2033
2034**TOSA Lowering**
2035
2036```
2037%output = lower_gather_op(%params, %indices, 0, 0)
2038```
2039
2040### tf.GatherV2
2041
2042Gathers slices from params axis according to indices.
2043
2044**TensorFlow Dialect**
2045
2046```
2047%output = tf.GatherV2(%params, %indices, %axis) {batch_dims}
2048```
2049
2050**TOSA Lowering**
2051
2052```
2053%output = lower_gather_op(%params, %indices, batch_dims, %axis.to_constant())
2054```
2055
2056### tf.GreaterEqual
2057
2058Returns the truth value of (x &gt;= y) element-wise with broadcasting.
2059
2060**TensorFlow Dialect**
2061
2062```
2063%output = tf.GreaterEqual(%x, %y)
2064```
2065
2066**TOSA Lowering** This operator is trivially lowered to tosa.GREATER_EQUAL.
2067
2068### tf.Greater
2069
2070RetruReturns the truth value of (x &gt; y) element-wise with broadcasting.
2071
2072**TensorFlow Dialect**
2073
2074```
2075%output = tf.Greater(%x, %y)
2076```
2077
2078**TOSA Lowering** This operator is trivially lowered to tosa.GREATER.
2079
2080### tf.HashTableV2
2081
2082No TOSA lowering defined.
2083
2084### tf.IdentityN
2085
2086Returns a list of tensors with the same shapes and contents as the input.
2087
2088**TensorFlow Dialect**
2089
2090```
2091%output = tf.IdentityN(%input)
2092```
2093
2094**TOSA Lowering**
2095
2096```
2097%output = tosa.IDENTITYN(%input)
2098```
2099
2100### tf.Identity
2101
2102Returns a tensor with the same shape and contents as the input.
2103
2104**TensorFlow Dialect**
2105
2106```
2107%output = tf.Identity(%input)
2108```
2109
2110**TOSA Lowering**
2111
2112```
2113%output = tosa.IDENTITY(%input)
2114```
2115
2116### tf.If
2117
2118No TOSA lowering defined.
2119
2120### tf.Imag
2121
2122No TOSA lowering defined.
2123
2124### tf.InfeedDequeueTuple
2125
2126No TOSA lowering defined.
2127
2128### tf.Invert
2129
2130This operator is trivially lowered to tosa.BITWISE_NOT.
2131
2132### tf.InvertPermutation
2133
2134No TOSA lowering defined.
2135
2136### tf.IsFinite
2137
2138No TOSA lowering defined.
2139
2140### tf.IteratorGetNext
2141
2142No TOSA lowering defined.
2143
2144### tf.L2Loss
2145
2146Training profile: TOSA lowering not yet defined.
2147
2148### tf.LRN
2149
2150No TOSA lowering defined.
2151
2152### tf.LeakyRelu
2153
2154Computes rectified linear: max(features, features \* alpha).
2155
2156**TensorFlow Dialect**
2157
2158```
2159%output = tf.LeakyRelu(%features) {alpha}
2160```
2161
2162**TOSA Lowering**
2163
2164```
2165%alpha_tensor = tosa.CONST() {value={alpha}}
2166%features_alpha = tosa.MUL(%features, %alpha_tensor)
2167%greater = tosa.GREATER(%features, %features_alpha)
2168%output = tosa.SELECT(%greater, %features, %features_alpha)
2169```
2170
2171### tf.LeftShift
2172
2173Computes the bitwise left-shift of x by y bits, element-wise.
2174
2175**TensorFlow Dialect**
2176
2177```
2178%output = tf.LeftShift(%x, %y)
2179```
2180
2181**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_LEFT_SHIFT.
2182
2183### tf.LegacyCall
2184
2185No TOSA lowering defined.
2186
2187### tf.LessEqual
2188
2189Returns the truth value of (x ⇐ y) element-wise with broadcasting.
2190
2191**TensorFlow Dialect**
2192
2193```
2194%output = tf.LessEqual(%x, %y)
2195```
2196
2197**TOSA Lowering**
2198
2199```
2200%output_greater = tosa.GREATER(%x, %y)
2201%output = tosa.LOGICAL_NOT(%output_greater)
2202```
2203
2204### tf.Less
2205
2206Returns the truth value of (x &lt; y) element-wise with broadcasting.
2207
2208**TensorFlow Dialect**
2209
2210```
2211%output = tf.LessEqual(%x, %y)
2212```
2213
2214**TOSA Lowering**
2215
2216```
2217%output_greater_equal = tosa.GREATER_EQUAL(%x, %y)
2218%output = tosa.LOGICAL_NOT(%output_greater_equal)
2219```
2220
2221### tf.LiNSpace
2222
2223No TOSA lowering defined.
2224
2225### tf.Log1p
2226
2227No TOSA lowering defined.
2228
2229### tf.Log
2230
2231This operator is trivially lowered to tosa.LOG.
2232
2233### tf.LogSoftmax
2234
2235Computes log softmax activations.
2236
2237**TensorFlow Dialect**
2238
2239```
2240%output = tf.LogSoftmax(%logits)
2241```
2242
2243**TOSA Lowering**
2244
2245```
2246%output = lower_log_softmax_op(%logits)
2247```
2248
2249### tf.LogicalAnd
2250
2251Returns the truth value of x AND y, element-wise.
2252
2253**TensorFlow Dialect**
2254
2255```
2256%output = tf.LogicalAnd(%x, %y)
2257```
2258
2259**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_AND.
2260
2261### tf.LogicalNot
2262
2263This operator is trivially lowered to tosa.LOGICAL_NOT.
2264
2265### tf.LogicalOr
2266
2267Returns the truth value of x OR y, element-wise.
2268
2269**TensorFlow Dialect**
2270
2271```
2272%output = tf.LogicalOr(%x, %y)
2273```
2274
2275**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_OR.
2276
2277### tf.LookupTableFindV2
2278
2279No TOSA lowering defined.
2280
2281### tf.LookupTableInputV2
2282
2283No TOSA lowering defined.
2284
2285### tf.LookupTableSizeV2
2286
2287No TOSA lowering defined.
2288
2289### tf.MatMul
2290
2291Multiply the matrix a by the matrix b
2292
2293**TensorFlow Dialect**
2294
2295```
2296%output = tf.MatMul(%a, %b)
2297```
2298
2299**TOSA Lowering**
2300
2301```
2302%output = tosa.MATMUL(%a, %b)
2303```
2304
2305### tf.MatrixDiag
2306
2307No TOSA lowering defined.
2308
2309### tf.MatrixDiagV2
2310
2311No TOSA lowering defined.
2312
2313### tf.MatrixDiagV3
2314
2315No TOSA lowering defined.
2316
2317### tf.MatrixSetDiag
2318
2319No TOSA lowering defined.
2320
2321### tf.MatrixSetDiagV2
2322
2323No TOSA lowering defined.
2324
2325### tf.MatrixSetDiagV3
2326
2327No TOSA lowering defined.
2328
2329### tf.Max
2330
2331Computes the maximum of elements across dimensions of a tensor.
2332
2333**TensorFlow Dialect**
2334
2335```
2336%output = tf.Max(%input, %reduction_indices) {keep_dims}
2337```
2338
2339**TOSA Lowering**
2340
2341```
2342%output = lower_reduce_op<tosa.REDUCE_MAX>(%input, %output.shape, %reduction_indices, keep_dims)
2343```
2344
2345### tf.MaxPoolGrad
2346
2347Training profile: TOSA lowering not yet defined.
2348
2349### tf.MaxPool
2350
2351Performs max pooling on the input.
2352
2353**TensorFlow Dialect**
2354
2355```
2356%output = tf.MaxPool(%input) {ksize, strides, padding, data_format}
2357```
2358
2359**TOSA Lowering**
2360
2361```
2362assert(data_format == "NHWC")
2363
2364tosa_padding =
2365     get_padding_values_from_pad_type(%input, ksize, padding, data_format,
2366                                      FORMAT_OHWI, strides, {1, 1, 1, 1})
2367%output = tosa.MAX_POOL2D(%value) {ksize=ksize, strides=strides, padding=tosa_padding}
2368```
2369
2370### tf.Maximum
2371
2372This operator is trivially lowered to tosa.MAXIMUM.
2373
2374### tf.Mean
2375
2376Computes the mean of elements across dimensions of a tensor.
2377
2378**TensorFlow Dialect**
2379
2380```
2381%output = tf.Mean(%input, %reduction_indices) {keep_dims}
2382```
2383
2384**TOSA Lowering**
2385
2386```
2387int32 num_elements_on_axis = 1
2388for (int32 axis : %reduction_indices) {
2389    num_elements_on_axis *= %input.shape[axis]
2390}
2391float32 div_scale = 1.0 / num_elements_on_axis
2392
2393%cst_div_scale = tosa.CONST() {value={div_scale}}
2394%op1_rsum_in = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %reduction_indices, keep_dims)
2395%op2_mul_op1 = tosa.MUL(%op1_rsum_in, %cst_div_scale)
2396```
2397
2398### tf.Min
2399
2400Computes the minimum of elements across dimensions of a tensor.
2401
2402**TensorFlow Dialect**
2403
2404```
2405%output = tf.Min(%input, %reduction_indices) {keep_dims}
2406```
2407
2408**TOSA Lowering**
2409
2410```
2411%output = lower_reduce_op<tosa.REDUCE_MIN>(%input, %output.shape, %reduction_indices, keep_dims)
2412```
2413
2414### tf.Minimum
2415
2416This operator is trivially lowered to tosa.MAXIMUM.
2417
2418### tf.MirrorPad
2419
2420No TOSA lowering defined.
2421
2422### tf.MlirPassthroughOp
2423
2424No TOSA lowering defined.
2425
2426### tf.MulNoNan
2427
2428No TOSA lowering defined.
2429
2430### tf.Mul
2431
2432Returns the product of x and y, element-wise.
2433
2434**TensorFlow Dialect**
2435
2436```
2437%output = tf.Mul(%x, %y)
2438```
2439
2440**TOSA Lowering** This operator is trivially lowered to tosa.MUL.
2441
2442### tf.Neg
2443
2444This operator is trivially lowered to tosa.NEGATE.
2445
2446### tf.NoOp
2447
2448No TOSA lowering defined.
2449
2450### tf.NonMaxSuppressionV4
2451
2452No TOSA lowering defined.
2453
2454### tf.NonMaxSuppressionV5
2455
2456No TOSA lowering defined.
2457
2458### tf.NotEqual
2459
2460Returns the truth value of (x != y) element-wise with broadcasting.
2461
2462**TensorFlow Dialect**
2463
2464```
2465%output = tf.NotEqual(%x, %y)
2466```
2467
2468**TOSA Lowering**
2469
2470```
2471%equal = tosa.EQUAL(%x, %y)
2472%output = tosa.NOT(%equal)
2473```
2474
2475### tf.OneHot
2476
2477OneHot operator.
2478
2479**TensorFlow Lite Dialect**
2480
2481```
2482%output = tf.OneHot(%indices, %depth, %on_value, %off_value) {axis}
2483```
2484
2485**TOSA Lowering**
2486
2487```
2488%output = lower_one_hot_op(%indices, %depth, %on_value, %off_value, axis)
2489```
2490
2491### tf.OutputEnqueueTuple
2492
2493No TOSA lowering defined.
2494
2495### tf.Pack
2496
2497Packs a list of N rank-R tensors into one rank-(R+1) tensor.
2498
2499**TensorFlow Dialect**
2500
2501```
2502%output = tf.Pack(%values) {axis}
2503```
2504
2505**TOSA Lowering**
2506
2507```
2508%output = lower_pack_op(%values, axis)
2509```
2510
2511### tf.Pad
2512
2513This operator is trivially lowered to tosa.PAD.
2514
2515### tf.PadV2
2516
2517No TOSA lowering defined.
2518
2519### tf.ParseExampleV2
2520
2521No TOSA lowering defined.
2522
2523### tf.PartitionedCall
2524
2525No TOSA lowering defined.
2526
2527### tf.Placeholder
2528
2529Not seen in practice. No lowering needed.
2530
2531### tf.PlaceholderWithDefault
2532
2533Not seen in practice. No lowering needed.
2534
2535### tf.Pow
2536
2537This operator is trivially lowered to tosa.POW.
2538
2539### tf.PreventGradient
2540
2541Training profile: TOSA lowering not yet defined.
2542
2543### tf.Prod
2544
2545Computes the product of elements across dimensions of a tensor.
2546
2547**TensorFlow Dialect**
2548
2549```
2550%output = tf.Prod(%input, %reduction_indices) {keep_dims}
2551```
2552
2553**TOSA Lowering**
2554
2555```
2556%output = lower_reduce_op<tosa.REDUCE_PRODUCT>(%input, %output.shape, %reduction_indices, keep_dims)
2557```
2558
2559### tf.QuantizeAndDequantize
2560
2561No TOSA lowering defined.
2562
2563### tf.QuantizeAndDequantizeV2
2564
2565No TOSA lowering defined.
2566
2567### tf.QuantizeAndDequantizeV3
2568
2569No TOSA lowering defined.
2570
2571### tf.RFFT
2572
2573No TOSA lowering defined.
2574
2575### tf.RandomShuffle
2576
2577No TOSA lowering defined.
2578
2579### tf.RandomStandardNormal
2580
2581No TOSA lowering defined.
2582
2583### tf.RandomUniform
2584
2585No TOSA lowering defined.
2586
2587### tf.Range
2588
2589No TOSA lowering defined.
2590
2591### tf.Rank
2592
2593Returns the rank of the tensor.
2594
2595**TensorFlow Dialect**
2596
2597```
2598%output = tf.Rank(%input)
2599```
2600
2601**TOSA Lowering**
2602
2603```
2604%output = tosa.CONST() {value={%input.rank}}
2605```
2606
2607### tf.ReadVariableOp
2608
2609No TOSA lowering defined.
2610
2611### tf.RealDiv
2612
2613Returns x / y element-wise for real types.
2614
2615**TensorFlow Dialect**
2616
2617```
2618%output = tf.RealDiv(%x, %y)
2619```
2620
2621**TOSA Lowering**
2622
2623```
2624%recip = tosa.RECIPROCAL(%y)
2625%output = tosa.MUL(%x, %recip)
2626```
2627
2628### tf.Real
2629
2630No TOSA lowering defined.
2631
2632### tf.Reciprocal
2633
2634This operator is trivially lowered to tosa.RECIPROCAL.
2635
2636### tf.Relu6
2637
2638Computes rectified linear 6: min(max(features, 0), 6).
2639
2640**TensorFlow Dialect**
2641
2642```
2643%output = tf.Relu6(%features)
2644```
2645
2646**TOSA Lowering**
2647
2648```
2649%output = tosa.RELUN(%features) {max_val=6}
2650```
2651
2652### tf.ReluGrad
2653
2654Training profile: TOSA lowering not yet defined.
2655
2656### tf.Relu
2657
2658Computes rectified linear 6: max(features, 0)
2659
2660**TensorFlow Dialect**
2661
2662```
2663%output = tf.Relu(%features)
2664```
2665
2666**TOSA Lowering**
2667
2668```
2669%output = tosa.RELUN(%features) {max_val=0}
2670```
2671
2672### tf.Reshape
2673
2674Reshapes a tensor.
2675
2676**TensorFlow Dialect**
2677
2678```
2679%output = tf.Reshape(%tensor, %shape)
2680```
2681
2682**TOSA Lowering**
2683
2684```
2685%output = tosa.RESHAPE(%tensor) {new_shape=%shape.as_constant}
2686```
2687
2688### tf.ResizeBilinear
2689
2690Resizes images to size using bilinear interpolation.
2691
2692**TensorFlow Dialect**
2693
2694```
2695%output = tf.ResizeBilinear(%images, %size) {align_corners, half_pixel_centers}
2696```
2697
2698inferred from output shape. **TOSA Lowering**
2699
2700```
2701%output = lower_resize_op(%images, %size, float, "BILINEAR")
2702```
2703
2704### tf.ResizeNearestNeighbor
2705
2706Resizes images to size using nearest neighbor interpolation.
2707
2708**TensorFlow Dialect**
2709
2710```
2711%output = tf.ResizeNearestNeighbor(%images, %size) {align_corners, half_pixel_centers}
2712```
2713
2714inferred from output shape. **TOSA Lowering**
2715
2716```
2717%output = lower_resize_op(%images, %size, %output, float, "NEAREST_NEIGHBOR")
2718```
2719
2720### tf.ResourceApplyAdam
2721
2722Training profile: TOSA lowering not yet defined.
2723
2724### tf.ResourceApplyGradientDescent
2725
2726Training profile: TOSA lowering not yet defined.
2727
2728### tf.ResourceApplyKerasMomentum
2729
2730Training profile: TOSA lowering not yet defined.
2731
2732### tf.ResourceGather
2733
2734Training profile: TOSA lowering not yet defined.
2735
2736### tf.ResourceScatterUpdate
2737
2738Training profile: TOSA lowering not yet defined.
2739
2740### tf.ReverseSequence
2741
2742No TOSA lowering defined.
2743
2744### tf.ReverseV2
2745
2746Reverses specific dimensions of a tensor.
2747
2748**TensorFlow Dialect**
2749
2750```
2751%output = tf.ReverseV2(%tensor, %axis)
2752```
2753
2754**TOSA Lowering**
2755
2756```
2757%output = lower_reversev2_op(%tensor, %axis)
2758```
2759
2760### tf.RightShift
2761
2762Computes the bitwise left-shift of x by y bits, element-wise.
2763
2764**TensorFlow Dialect**
2765
2766```
2767%output = tf.LeftShift(%x, %y)
2768```
2769
2770**TOSA Lowering**
2771
2772```
2773if (is_unsigned(%x.dtype)) {
2774  %output = tosa.LOGICAL_RIGHT_SHIFT(%x, %y)
2775} else {
2776  %output = tosa.ARITHMETIC_RIGHT_SHIFT(%x, %y)
2777}
2778```
2779
2780### tf.Round
2781
2782Rounds the values of a tensor to the nearest integer, element-wise.
2783
2784**TensorFlow Dialect**
2785
2786```
2787%output = tf.Round(%x)
2788```
2789
2790**TOSA Lowering**
2791
2792```
2793%output = lower_round_op(%x)
2794```
2795
2796### tf.RsqrtGrad
2797
2798Training profile: TOSA lowering not yet defined.
2799
2800### tf.Rsqrt
2801
2802This operator is trivially lowered to tosa.RSQRT.
2803
2804### tf.SegmentMax
2805
2806No TOSA lowering defined.
2807
2808### tf.SegmentMean
2809
2810No TOSA lowering defined.
2811
2812### tf.SegmentMin
2813
2814No TOSA lowering defined.
2815
2816### tf.SegmentProd
2817
2818No TOSA lowering defined.
2819
2820### tf.SegmentSum
2821
2822No TOSA lowering defined.
2823
2824### tf.Select
2825
2826No TOSA lowering defined.
2827
2828### tf.SelectV2
2829
2830Selects elements from t or e depending on condition.
2831
2832**TensorFlow Dialect**
2833
2834```
2835%output = tf.SelectV2(%condition, %t, %e)
2836```
2837
2838**TOSA Lowering**
2839
2840```
2841%output = lower_selectv2_op(%condition, %t, %e, %output.shape)
2842```
2843
2844### tf.ShapeN
2845
2846No TOSA lowering defined.
2847
2848### tf.Shape
2849
2850Returns the shape of a tensor.
2851
2852**TensorFlow Dialect**
2853
2854```
2855%output = tf.Shape(%input)
2856```
2857
2858**TOSA Lowering**
2859
2860```
2861%output = lower_shape_op(%input)
2862```
2863
2864### tf.Sigmoid
2865
2866This operator is trivially lowered to tosa.SIGMOID.
2867
2868### tf.Sign
2869
2870No TOSA lowering defined.
2871
2872### tf.Sin
2873
2874No TOSA lowering defined.
2875
2876### tf.Size
2877
2878No TOSA lowering defined.
2879
2880### tf.Slice
2881
2882Returns a slice from input.
2883
2884**TensorFlow Dialect**
2885
2886```
2887%output = tf.Slice(%input, %begin, %size)
2888```
2889
2890**TOSA Lowering**
2891
2892```
2893vector <size_t> output_size
2894try {
2895  output_size = %size.as_constant()
2896} except(ConversionFailed) {
2897  output_size = %output.shape
2898}
2899
2900%output = tosa.SLICE(%input) {start=begin, size=output_size}
2901```
2902
2903### tf.Snapshot
2904
2905No TOSA lowering defined.
2906
2907### tf.SoftmaxCrossEntropyWithLogits
2908
2909Training profile: TOSA lowering not yet defined.
2910
2911### tf.Softmax
2912
2913Computes softmax activations
2914
2915**TensorFlow Dialect**
2916
2917```
2918%output = tf.Softmax(%logits)
2919```
2920
2921**TOSA Lowering**
2922
2923```
2924%op1 = tosa.EXP(%logits)
2925%op2 = tosa.REDUCE_SUM(op1) {reduce_axis=(%logits.rank - 1)}
2926%op3 = tosa.RECIPROCAL(%op2)
2927%output = tosa.MUL(%op1, %op3)
2928```
2929
2930### tf.Softplus
2931
2932No TOSA lowering defined.
2933
2934### tf.SpaceToBatchND
2935
2936SpaceToBatch for N-D tensors of type T.
2937
2938**TensorFlow Dialect**
2939
2940```
2941%output = tf.SpaceToBatchND(%input, %block_shape, %paddings)
2942```
2943
2944**TOSA Lowering**
2945
2946```
2947%output = lower_space_to_batch_nd_op(%input, %block_shape, %paddings)
2948```
2949
2950### tf.SpaceToDepth
2951
2952SpaceToDepth for tensors of type T.
2953
2954**TensorFlow Dialect**
2955
2956```
2957%output = tf.SpaceToDepth(%input) {block_size, data_format}
2958```
2959
2960**TOSA Lowering**
2961
2962```
2963%output = lower_space_to_depth_op(%input, block_size, data_format)
2964```
2965
2966### tf.SparseMatMul
2967
2968No TOSA lowering defined.
2969
2970### tf.SparseSoftmaxCrossEntropyWithLogits
2971
2972No TOSA lowering defined.
2973
2974### tf.SparseToDense
2975
2976No TOSA lowering defined.
2977
2978### tf.Split
2979
2980Splits a tensor into num_split tensors along one dimension
2981
2982**TensorFlow Dialect**
2983
2984```
2985%output = tf.Split(%split_dim, %value) {num_split}
2986```
2987
2988**TOSA Lowering**
2989
2990```
2991%output = lower_split_op(%value, %split_dim.as_constant(), num_split)
2992```
2993
2994### tf.SplitV
2995
2996Splits a tensor into num_split tensors along one dimension
2997
2998**TensorFlow Dialect**
2999
3000```
3001%output = tf.SplitV(%value, %size_splits, %split_dim) {num_split}
3002```
3003
3004**TOSA Lowering**
3005
3006```
3007%output = lower_splitv_op(%value, %size_splits.as_constant(), %split_dim.as_constant())
3008```
3009
3010### tf.Sqrt
3011
3012No TOSA lowering defined.
3013
3014### tf.Square
3015
3016Computes the square of x, element-wise.
3017
3018**TensorFlow Dialect**
3019
3020```
3021%output = tf.Square(%x)
3022```
3023
3024**TOSA Lowering**
3025
3026```
3027%output = tosa.MUL(%x, %x)
3028```
3029
3030### tf.SquareDifference
3031
3032Computes (x-y)\*(x-y) element-wise
3033
3034**TensorFlow Dialect**
3035
3036```
3037%output = tf.SquareDifference(%x, %y)
3038```
3039
3040**TOSA Lowering**
3041
3042```
3043%diff = tosa.SUB(%x, %y)
3044%output = tosa.MUL(%diff, %diff)
3045```
3046
3047### tf.Squeeze
3048
3049Removes dimensions of size 1 from the shape of a tensor.
3050
3051**TensorFlow Dialect**
3052
3053```
3054%output = tf.Squeeze(%input) {squeeze_dims}
3055```
3056
3057**TOSA Lowering**
3058
3059```
3060%output = lower_squeeze_op(%input, squeeze_dims)
3061```
3062
3063### tf.StatefulPartitionedCall
3064
3065No TOSA lowering defined.
3066
3067### tf.StopGradient
3068
3069Training profile: TOSA lowering not yet defined.
3070
3071### tf.StridedSliceGrad
3072
3073Training profile: TOSA lowering not yet defined.
3074
3075### tf.StridedSlice
3076
3077Return a strided slice from input.
3078
3079**TensorFlow Dialect**
3080
3081```
3082%output = tf.StridedSlice(%input, %begin, %end, %strides) {begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask}
3083```
3084
3085**TOSA Lowering**
3086
3087```
3088%output = lower_strided_slice_op(%input, %begin, %end, %strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
3089```
3090
3091### tf.Sub
3092
3093This operator is trivially lowered to tosa.SUB.
3094
3095### tf.Sum
3096
3097Computes the sum of elements across dimensions of a tensor.
3098
3099**TensorFlow Dialect**
3100
3101```
3102%output = tf.Sum(%input, %reduction_indices) {keep_dims}
3103```
3104
3105**TOSA Lowering**
3106
3107```
3108%output = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %reduction_indices, keep_dims)
3109```
3110
3111### tf.TPUCompilationResult
3112
3113No TOSA lowering defined.
3114
3115### tf.TPUCopyWithLayout
3116
3117No TOSA lowering defined.
3118
3119### tf.TPUExecuteAndUpdateVariables
3120
3121No TOSA lowering defined.
3122
3123### tf.TPUExecute
3124
3125No TOSA lowering defined.
3126
3127### tf.TPUGetLayout
3128
3129No TOSA lowering defined.
3130
3131### tf.TPUReplicateMetadata
3132
3133No TOSA lowering defined.
3134
3135### tf.TPUReplicatedInput
3136
3137No TOSA lowering defined.
3138
3139### tf.TPUReplicatedOutput
3140
3141No TOSA lowering defined.
3142
3143### tf.TPUReshardVariables
3144
3145No TOSA lowering defined.
3146
3147### tf.TanhGrad
3148
3149Training profile: TOSA lowering not yet defined.
3150
3151### tf.Tanh
3152
3153This operator is trivially lowered to tosa.TANH.
3154
3155### tf.TensorListFromTensor
3156
3157No TOSA lowering defined.
3158
3159### tf.TensorListGetItem
3160
3161No TOSA lowering defined.
3162
3163### tf.TensorListLength
3164
3165No TOSA lowering defined.
3166
3167### tf.TensorListPushBack
3168
3169No TOSA lowering defined.
3170
3171### tf.TensorListReserve
3172
3173No TOSA lowering defined.
3174
3175### tf.TensorListResize
3176
3177No TOSA lowering defined.
3178
3179### tf.TensorListSetItem
3180
3181No TOSA lowering defined.
3182
3183### tf.TensorListStack
3184
3185No TOSA lowering defined.
3186
3187### tf.TensorScatterUpdate
3188
3189No TOSA lowering defined.
3190
3191### tf.Tile
3192
3193Constructs a tensor by tiling a given tensor.
3194
3195**TensorFlow Dialect**
3196
3197```
3198%output = tf.Tile(%input, %multiples)
3199```
3200
3201**TOSA Lowering**
3202
3203```
3204%output = tosa.TILE(%input) {multiples=%multiples.as_constant()}
3205```
3206
3207### tf.ToBool
3208
3209No TOSA lowering defined.
3210
3211### tf.TopKV2
3212
3213No TOSA lowering defined.
3214
3215### tf.Transpose
3216
3217Shuffle dimensions of x according to a permutation.
3218
3219**TensorFlow Dialect**
3220
3221```
3222%output = tf.Transpose(%x, %perm)
3223```
3224
3225**TOSA Lowering**
3226
3227```
3228%output = tosa.TRANSPOSE(%x) {perm=%perm.as_constant()}
3229```
3230
3231### tf.TruncateDiv
3232
3233No TOSA lowering defined.
3234
3235### tf.Unique
3236
3237No TOSA lowering defined.
3238
3239### tf.Unpack
3240
3241Unpacks a given dimension of a rank-R tensor into num rank-(R-1) tensors.
3242
3243**TensorFlow Dialect**
3244
3245```
3246%output = tf.Unpack(%value) {axis, num}
3247```
3248
3249**TOSA Lowering**
3250
3251```
3252%output = lower_unpack_op(%value, axis, num)
3253```
3254
3255### tf.UnsortedSegmentMax
3256
3257No TOSA lowering defined.
3258
3259### tf.UnsortedSegmentMin
3260
3261No TOSA lowering defined. === tf.UnsortedSegmentProd
3262
3263No TOSA lowering defined. === tf.UnsortedSegmentSum
3264
3265No TOSA lowering defined.
3266
3267### tf.VarHandle
3268
3269No TOSA lowering defined.
3270
3271### tf.VariableShape
3272
3273No TOSA lowering defined.
3274
3275### tf.Where
3276
3277No TOSA lowering defined.
3278
3279### tf.While
3280
3281No TOSA lowering defined.
3282
3283### tf.Xdivy
3284
3285No TOSA lowering defined.
3286
3287### tf.XlaDynamicUpdateSlice
3288
3289No TOSA lowering defined.
3290
3291### tf.XlaSharding
3292
3293No TOSA lowering defined.
3294
3295### tf.ZerosLike
3296
3297Returns a tensor of zeros with the same shape and type as x.
3298
3299**TensorFlow Dialect**
3300
3301```
3302%output = tf.ZerosLike(%x)
3303```
3304
3305**TOSA Lowering**
3306
3307```
3308%output = tosa.CONST() {value={0} * %x.num_elements}
3309```
3310
3311## TensorFlow Lite MLIR Dialect Legalization (legalize_tfl)
3312
3313### tfl.abs
3314
3315This operator is trivially lowered to tosa.ABS
3316
3317### tfl.add_n
3318
3319add_n operator.
3320
3321**TensorFlow Lite Dialect**
3322
3323```
3324%sum = tfl.add_n(%inputs)
3325```
3326
3327**TOSA Lowering**
3328
3329```
3330%output = tosa.ADD(%inputs:0, %inputs:1)
3331for (int32 i = 2 i < %inputs.size i++) {
3332    %output = tosa.ADD(%inputs:i, %output)
3333}
3334```
3335
3336### tfl.add
3337
3338Element-wise addition operation.
3339
3340**TensorFlow Lite Dialect**
3341
3342```
3343%output = tfl.add(%lhs, %rhs)
3344```
3345
3346**TOSA Lowering**
3347
3348If input/output tensors are all non-quantized typed,
3349
3350Legalization:
3351
3352```
3353%result = tosa.ADD(%lhs, %rhs)
3354```
3355
3356If input/output tensors are all quantized typed,
3357
3358Prepare:
3359
3360```
3361float64 max_scale_2x = 2.0 * max(%lhs.scale, %rhs.scale)
3362float64 lhs_scale = float64(1 << input_shift) * %lhs.scale / max_scale_2x
3363float64 rhs_scale = float64(1 << input_shift) * %rhs.scale / max_scale_2x
3364float64 output_scale = max_scale_2x / (%output.scale * float64(1 << input_shift))
3365
3366```
3367
3368Legalization:
3369
3370```
3371%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=lhs_scale, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32
3372%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=rhs_scale, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32
3373%op3_add_op1_op2 = tosa.ADD(%op1_rescale_lhs, %op2_rescale_rhs)
3374%op4_rescale_op3 = tosa.RESCALE(%op3_add_op1_op2) {scale=output_scale} // i32->%output.dtype
3375```
3376
3377### tfl.arg_max
3378
3379ArgMax operator.
3380
3381**TensorFlow Lite Dialect**
3382
3383```
3384%output = tfl.arg_max(%input, %dim)
3385```
3386
3387**TOSA Lowering**
3388
3389```
3390%result = tosa.ARGMAX(%input) {axis=positive_axis(%dim_const.as_constant(), %input.rank)}
3391```
3392
3393### tfl.arg_min
3394
3395No TOSA lowering defined.
3396
3397### tfl.average_pool_2d
3398
3399Average_pool_2d operator.
3400
3401**TensorFlow Lite Dialect**
3402
3403```
3404%output = tfl.average_pool_2d(%input) {filter_height, filter_width, padding, stride_h, stride_w, fused_activation_function}
3405```
3406
3407**TOSA Lowering**
3408
3409Prepare:
3410
3411```
3412tosa_padding =
3413     get_padding_values_from_pad_type(padding, NHWC, 1,
3414                                      %input.type, tensor<{filter_height, filter_width}, tosa.int32>,
3415                                      {1, stride_h, stride_w, 1}, {1, 1, 1, 1})
3416```
3417
3418If input/output tensors are all non-quantized typed,
3419
3420Legalization:
3421
3422```
3423%avgpool2d = tosa.AVG_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding}
3424if(fused_activation != NONE) {
3425    %result = convert_fused_activation(%avgpool2d, fused_activation)
3426}
3427else {
3428    %result = %avgpool2d
3429}
3430```
3431
3432If input/output tensors are all quantized typed,
3433
3434Legalization:
3435
3436```
3437%avgpool2d = tosa.AVG_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding, quantization_info={input_zp=%input.zp, output_zp=%output.zp}}
3438if(fused_activation != NONE) {
3439    %result = convert_fused_activation(%avgpool2d, fused_activation)
3440}
3441else {
3442    %result = %avgpool2d
3443}
3444```
3445
3446### tfl.basic_lstm
3447
3448No TOSA lowering defined.
3449
3450### tfl.batch_to_space_nd
3451
3452BatchToSpaceNd operator.
3453
3454**TensorFlow Lite Dialect**
3455
3456```
3457%output = tfl.batch_to_space_nd(%input, %block_shape, %indices)
3458```
3459
3460**TOSA Lowering**
3461
3462```
3463%result = convert_batch_to_space_nd_op(%input, %block_shape, %indices)
3464```
3465
3466### tfl.cast
3467
3468This operator is trivially lowered to tosa.CAST
3469
3470### tfl.ceil
3471
3472Ceil operator.
3473
3474**TensorFlow Lite Dialect**
3475
3476```
3477%y = tfl.ceil(%x)
3478```
3479
3480**TOSA Lowering**
3481
3482If input/output tensors are all non-quantized typed,
3483
3484```
3485%result = tosa.CEIL(%x)
3486```
3487
3488### tfl.concatenation
3489
3490Concatenation operator.
3491
3492**TensorFlow Lite Dialect**
3493
3494```
3495%output = tfl.concatenation(%values) {axis}
3496```
3497
3498**TOSA Lowering**
3499
3500```
3501%result = lower_concatv2_op(%values, axis)
3502```
3503
3504### tfl.pseudo_const
3505
3506This operator is trivially lowered to tosa.CONST
3507
3508### tfl.conv_2d
3509
3510Convolution operator.
3511
3512**TensorFlow Lite Dialect**
3513
3514```
3515%output = tfl.conv_2d(%input, %filter, %bias) {dilation_h_factor, dilation_w_factor, fused_activation_function, padding, stride_h, stride_w}
3516```
3517
3518**TOSA Lowering**
3519
3520If input/output tensors are all non-quantized typed,
3521
3522Prepare:
3523
3524```
3525tosa_padding =
3526     get_padding_values_from_pad_type(padding, NHWC, 1,
3527                                      %input.type, %filter.type,
3528                                      {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
3529```
3530
3531Legalization:
3532
3533```
3534%conv2d = tosa.CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}}
3535if(fused_activation != NONE) {
3536    %result = convert_fused_activation(%conv2d, fused_activation_function)
3537}
3538else {
3539    %result = %conv2d
3540}
3541```
3542
3543If input/output tensors are all quantized typed,
3544
3545Prepare:
3546
3547```
3548float64 output_rescale_scale = (%input.scale * %filter.scale) / %output.scale
3549
3550tosa_padding =
3551     get_padding_values_from_pad_type(padding, NHWC, 1,
3552                                      %input.type, %filter.type,
3553                                      {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
3554```
3555
3556Legalization:
3557
3558```
3559%conv2d = tosa.CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}, quantization_info={input_zp=%input.zp, weight_zp=%filter.zp}}
3560%rescale = tosa.RESCALE(%conv2d) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %conv2d.dtype->%output.dtype
3561if(fused_activation != NONE) {
3562    %result = convert_fused_activation(%rescale, fused_activation_function)
3563}
3564else {
3565    %result = %rescale
3566}
3567```
3568
3569### tfl.convolution_2d_transpose_bias
3570
3571No TOSA lowering defined.
3572
3573### tfl.cos
3574
3575No TOSA lowering defined.
3576
3577### tfl.densify
3578
3579No TOSA lowering defined.
3580
3581### tfl.depth_to_space
3582
3583DepthToSpace operator.
3584
3585**TensorFlow Dialect**
3586
3587```
3588%output = tfl.depth_to_space(%input) {block_size}
3589```
3590
3591**TOSA Lowering**
3592
3593```
3594%output = lower_depth_to_space_op(%input, block_size, "NHWC")
3595```
3596
3597### tfl.depthwise_conv_2d
3598
3599Depthwise-separable convolution operator.
3600
3601**TensorFlow Lite Dialect**
3602
3603```
3604%output = tfl.depthwise_conv_2d(%input, %filter, %bias) {dilation_h_factor, dilation_w_factor, fused_activation_function, padding, stride_h, stride_w, depth_multiplier}
3605```
3606
3607**TOSA Lowering**
3608
3609If input/output tensors are all non-quantized typed,
3610
3611Prepare:
3612
3613```
3614tosa_padding =
3615     get_padding_values_from_pad_type(padding, NHWC, 1,
3616                                      %input.type, %filter.type,
3617                                      {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
3618```
3619
3620Legalization:
3621
3622```
3623%depthwise_conv2d = tosa.DEPTHWISE_CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}}
3624if(fused_activation != NONE) {
3625    %result = convert_fused_activation(%depthwise_conv2d, fused_activation_function)
3626}
3627else {
3628    %result = %depthwise_conv2d
3629}
3630```
3631
3632If input/output tensors are all quantized typed,
3633
3634Prepare:
3635
3636```
3637float64 output_rescale_scale = (%input.scale * %filter.scale) / %output.scale
3638
3639tosa_padding =
3640     get_padding_values_from_pad_type(padding, NHWC, 1,
3641                                      %input.type, %filter.type,
3642                                      {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1})
3643```
3644
3645Legalization:
3646
3647```
3648%depthwise_conv2d = tosa.DEPTHWISE_CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}, quantization_info={input_zp=%input.zp, weight_zp=%filter.zp}}
3649%rescale = tosa.RESCALE(%conv2d) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %depthwise_conv2d.dtype->%output.dtype
3650if(fused_activation != NONE) {
3651    %result = convert_fused_activation(%rescale, fused_activation_function)
3652}
3653else {
3654    %result = %rescale
3655}
3656```
3657
3658### tfl.dequantize
3659
3660Dequantize operator.
3661
3662**TensorFlow Lite Dialect**
3663
3664```
3665%output = tfl.dequantize(%input)
3666```
3667
3668**TOSA Lowering**
3669
3670```
3671%result = lower_dequantize_op(%input, %input.scale, %input.zp)
3672```
3673
3674### tfl.div
3675
3676Division operator.
3677
3678**TensorFlow Lite Dialect**
3679
3680```
3681%output = tfl.div(%lhs, %rhs)
3682```
3683
3684**TOSA Lowering**
3685
3686If input/output tensors are all non-quantized typed,
3687
3688```
3689%rcp = tosa.RECIPROCAL(%rhs)
3690%mul = tosa.MUL(%lhs, %rcp)
3691```
3692
3693### tfl.elu
3694
3695Exponential Linear Unit operator.
3696
3697**TensorFlow Lite Dialect**
3698
3699```
3700%y = tfl.elu(%x)
3701```
3702
3703**TOSA Lowering**
3704
3705If input/output tensors are all non-quantized typed,
3706
3707```
3708%rcp = lower_elu_op(%x)
3709```
3710
3711### tfl.embedding_lookup
3712
3713Embedding lookup operator.
3714
3715**TensorFlow Lite Dialect**
3716
3717```
3718%output = tfl.embedding_lookup(%lookup, %value)
3719```
3720
3721### tfl.equal
3722
3723This operator is trivially lowered to tosa.EQUAL
3724
3725### tfl.exp
3726
3727Natural exponentiation operator.
3728
3729**TensorFlow Lite Dialect**
3730
3731```
3732%y = tfl.exp(%x)
3733```
3734
3735**TOSA Lowering**
3736
3737If input/output tensors are all non-quantized typed,
3738
3739```
3740%result = tosa.EXP(%x)
3741```
3742
3743### tfl.expand_dims
3744
3745Inserts a dimension of 1 into a tensor’s shape.
3746
3747**TensorFlow Lite Dialect**
3748
3749```
3750%output = tfl.expand_dims(%input, %dim)
3751```
3752
3753**TOSA Lowering**
3754
3755```
3756%result = lower_expand_dims(%input, %dim.as_constant())
3757```
3758
3759### tfl.external_const
3760
3761No TOSA lowering defined.
3762
3763### tfl.fake_quant
3764
3765FakeQuant operator
3766
3767**TensorFlow Lite Dialect**
3768
3769```
3770%output = tfl.fake_quant(%input) {min, max, num_bits, narrow_range}
3771```
3772
3773**TOSA Lowering**
3774
3775```
3776%result = convert_fake_quant_op(%input, min, max, num_bits, narrow_range)
3777```
3778
3779### tfl.fill
3780
3781Fill the tensor with given value.
3782
3783**TensorFlow Lite Dialect**
3784
3785```
3786%res = tfl.fill(%dims, %value)
3787```
3788
3789**TOSA Lowering**
3790
3791Prepare:
3792
3793```
3794total_size = 1
3795dim_vec = %dim.as_constant()
3796for(int32 i = 0 i < dim_vec.size() i++) {
3797    total_size *= dim_vec[i]
3798}
3799filled_val = %value.as_constant()[0]
3800output_type = tensor<dim_vec, filled_val.dtype>
3801```
3802
3803Legalization:
3804
3805```
3806%result = tosa.CONST() {value={filled_val} * total_size}
3807```
3808
3809### tfl.floor_div
3810
3811Floor div operator.
3812
3813**TensorFlow Lite Dialect**
3814
3815```
3816%output = tfl.floor_div(%lhs, %rhs)
3817```
3818
3819**TOSA Lowering**
3820
3821If input/output tensors are all non-quantized typed,
3822
3823```
3824%recip = tosa.RECIPROCAL(%rhs)
3825%mul = tosa.MUL(%lhs, %recip)
3826%result = tosa.FLOOR(%mul)
3827```
3828
3829### tfl.floor_mod
3830
3831Division remainder.
3832
3833**TensorFlow Lite Dialect**
3834
3835```
3836%output = tfl.floor_mod(%lhs, %rhs)
3837```
3838
3839**TOSA Lowering**
3840
3841If input/output tensors are all non-quantized typed,
3842
3843```
3844%recip = tosa.RECIPROCAL(%rhs)
3845%mul = tosa.MUL(%lhs, %recip)
3846%floor = tosa.FLOOR(%mul)
3847%result = tosa.SUB(%mul, %floor)
3848```
3849
3850### tfl.floor
3851
3852This operator is trivially lowered to tosa.FLOOR
3853
3854### tfl.fully_connected
3855
3856Fully connected op.
3857
3858**TensorFlow Lite Dialect**
3859
3860```
3861%output = tfl.fully_connected(%input, %filter, %bias) {fused_activation_function}
3862```
3863
3864**TOSA Lowering**
3865
3866If input/output tensors are all non-quantized typed,
3867
3868Prepare:
3869
3870```
3871// input[N, IC] x filter[OC, IC] + bias[OC] -> output[N, OC]
3872auto input_reshape_shape = {%input.num_elements / %filter.shape[1], %filter.shape[1]}
3873```
3874
3875Legalization:
3876
3877```
3878if(!(%bias)) {
3879    %bias_val = tosa.CONST() {value={0} * %filter.shape[3]}
3880}
3881else {
3882    %bias_val = %bias
3883}
3884if(%input.rank != 2) {
3885    %input_val = tosa.RESHAPE(%input) {shape=input_reshape_shape}
3886}
3887else {
3888    %input_val = %input
3889}
3890%fc = tosa.FULLY_CONNECTED(%input_val, %filter, %bias_val)
3891if(fused_activation != NONE) {
3892    %result = convert_fused_activation(%fc, fused_activation_function)
3893}
3894else {
3895    %result = %fc
3896}
3897```
3898
3899If input/output tensors are all quantized typed,
3900
3901Prepare:
3902
3903```
3904auto input_reshape_shape = {%input.num_elements / %filter.shape[1], %filter.shape[1]}
3905float64 output_rescale_scale = (%input.scale * %filter.scale) / %output.scale
3906```
3907
3908Legalization:
3909
3910```
3911if(!(%bias)) {
3912    %bias_val = tosa.CONST() {value={0} * %filter.shape[3]}
3913}
3914else {
3915    %bias_val = %bias
3916}
3917if(%input.rank != 2) {
3918    %input_val = tosa.RESHAPE(%input) {shape=input_reshape_shape}
3919}
3920else {
3921    %input_val = %input
3922}
3923%fc = tosa.FULLY_CONNECTED(%input_val, %filter, %bias_val)
3924%rescale = tosa.RESCALE(%fc) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %fc.dtype->%output.dtype
3925if(fused_activation != NONE) {
3926    %result = convert_fused_activation(%rescale, fused_activation_function)
3927}
3928else {
3929    %result = %rescale
3930}
3931```
3932
3933### tfl.gather_nd
3934
3935Gather_nd operator.
3936
3937**TensorFlow Dialect**
3938
3939```
3940%output = tfl.gather_nd(%params, %indices)
3941```
3942
3943**TOSA Lowering**
3944
3945```
3946%output = lower_gather_nd_op(%params, %indices)
3947```
3948
3949### tfl.gather
3950
3951Gather operator.
3952
3953**TensorFlow Dialect**
3954
3955```
3956%output = tfl.gather(%params, %indices) {axis}
3957```
3958
3959**TOSA Lowering**
3960
3961```
3962%output = lower_gather_op(%params, %indices, 0, axis)
3963```
3964
3965### tfl.greater_equal
3966
3967This operator is trivially lowered to tosa.GREATER_EQUAL
3968
3969### tfl.greater
3970
3971This operator is trivially lowered to tosa.GREATER
3972
3973### tfl.hard_swish
3974
3975Hardswish activation function.
3976
3977**TensorFlow Lite Dialect**
3978
3979```
3980%output = tfl.hard_swish(%input)
3981```
3982
3983**TOSA Lowering**
3984
3985If input/output tensors are all non-quantized typed,
3986
3987```
3988%const_3 = tosa.CONST() {value={3.0}}
3989%const_rcp6 = tosa.CONST() {value={1.0 / 6.0}}
3990%op1_add_in_3 = tosa.ADD(%input, %const_3)
3991%op2_relun_op1 = tosa.RELUN(%op1_add_in_3) {max=6.0}
3992%op3_mul_in_op2 = tosa.MUL(%input, %op2_relun_op1)
3993%op4_mul_op3_rcp6 = tosa.MUL(%op3, %const_rcp6)
3994```
3995
3996If input/output tensors are all quantized typed,
3997
3998Prepare:
3999
4000```
4001float64 input_sample_grain = 1.0 / 64.0
4002auto hardswish_func = [input_sample_grain](int32 x) -> int32 {
4003    float64 v = (float64)x * input_sample_grain
4004    float64 w = v + 3.0
4005    w = (w < 0.0) ? 0.0 : ((w > 6.0) ? 6.0 : w)
4006    v = (v * w) / 6.0
4007    return std::lround(32768.0 * v)
4008}
4009float64 input_rescale_scale = (%input.scale * 128.0) / input_sample_grain
4010float64 output_rescale_scale = 1.0 / (128.0 * 32768.0 * %output.scale)
4011int32 quantized_3 = (int32)(std::ceil(3.0 / %input.scale)) + %input.zp
4012```
4013
4014Legalization:
4015
4016```
4017%table_const = get_table_const_tensor(hardswish_func)
4018%const_3 = tosa.CONST() {value={quantized_3}}
4019%op1_rescale_in = tosa.RESCALE(%input) {scale=input_rescale_scale, input_zp=%input.zp, output_zp=0} // %input.dtype->i16
4020%op2_table_op1 = tosa.TABLE(%op1_rescale_in, %table_const)
4021%op3_rescale_op2 = tosa.RESCALE(%op2_table_op1) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // i32->%output.dtype
4022%op4_rescale_in = tosa.RESCALE(%input {scale=1.0, input_zp=0, output_zp=0} // %input.dtype->i32
4023%op5_ge_op4 = tosa.GREATER_EQUAL(%op4_rescale_in, %const_3)
4024%op6_select_op5_in_op3 = tosa.SELECT(%op5_ge_op4, %input, %op3_rescale_op2)
4025```
4026
4027### tfl.l2_normalization
4028
4029No TOSA lowering defined.
4030
4031### tfl.lstm
4032
4033No TOSA lowering defined.
4034
4035### tfl.leaky_relu
4036
4037Leaky Relu Operator.
4038
4039**TensorFlow Lite Dialect**
4040
4041```
4042%output = tfl.leaky_relu(%input) {alpha}
4043```
4044
4045**TOSA Lowering**
4046
4047If input/output tensors are all non-quantized typed,
4048
4049Legalization:
4050
4051```
4052%const_0 = tosa.CONST() {value={0.0}}
4053%const_alpha = tosa.CONST() {value={alpha}}
4054%op1_mul_in_alpha = tosa.MUL(%input, %const_alpha)
4055%op2_ge_in_0 = tosa.GREATER_EQUAL(%input, %const_0)
4056%op3_select_op2_in_op1 = tosa.SELECT(%op2_ge_in_0, %input, $op1_mul_in_alpha)
4057```
4058
4059If input/output tensors are all quantized typed,
4060
4061Prepare:
4062
4063```
4064float32 scaled_alpha = (%input.scale * alpha) / %output.scale
4065float32 scaled_identity = %input.scale / %output.scale
4066```
4067
4068Legalization:
4069
4070```
4071%const_0 = tosa.CONST() {value={0}}
4072%op1_rescale_in = tosa.RESCALE(%input) {scale=1.0, input_zp=%input.zp} // %input.dtype->i32
4073%op2_ge_in_0 = tosa.GREATER_EQUAL(%input, %const_0)
4074%op3_rescale_in_alpha = tosa.RESCALE(%input) {scale=scaled_alpha, input_zp=%input.zp, output_zp=%output_zp} // %input.dtype->%output.dtype
4075%op4_rescale_in_identity = tosa.RESCALE(%input) {scale=scaled_identity, input_zp=%input.zp, output_zp=%output_zp} // %input.dtype->%output.dtype
4076%op5_select_op2_op3_op4 = tosa.SELECT(%op2_ge_in_0, %op4_rescale_in_identity, %op3_rescale_in_alpha)
4077```
4078
4079### tfl.less_equal
4080
4081Less_equal operator.
4082
4083**TensorFlow Lite Dialect**
4084
4085```
4086%output = tfl.less_equal(%lhs, %rhs)
4087```
4088
4089**TOSA Lowering**
4090
4091If input/output tensors are all non-quantized typed,
4092
4093Legalization:
4094
4095```
4096%op1_greater_lhs_rhs = tosa.GREATER(%lhs, %rhs)
4097%op2_not_op1 = tosa.LOGICAL_NOT(%op1_greater_lhs_rhs)
4098```
4099
4100If input/output tensors are all quantized typed,
4101
4102Legalization:
4103
4104```
4105assert (%lhs.scale == %rhs.scale) && (%lhs.zp == %rhs.zp)
4106
4107%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32
4108%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32
4109%op3_greater_op1_op2 = tosa.GREATER(%op1_rescale_lhs, %op2_rescale_rhs)
4110%op4_not_op3 = tosa.LOGICAL_NOT(%op3_greater_op1_op2)
4111```
4112
4113### tfl.less
4114
4115Less operator.
4116
4117**TensorFlow Lite Dialect**
4118
4119```
4120%output = tfl.less(%lhs, %rhs)
4121```
4122
4123**TOSA Lowering**
4124
4125If input/output tensors are all non-quantized typed,
4126
4127Legalization:
4128
4129```
4130%op1_ge_lhs_rhs = tosa.GREATER_EQUAL(%lhs, %rhs)
4131%op2_not_op1 = tosa.LOGICAL_NOT(%op1_ge_lhs_rhs)
4132```
4133
4134If input/output tensors are all quantized typed,
4135
4136Legalization:
4137
4138```
4139assert (%lhs.scale == %rhs.scale) && (%lhs.zp == %rhs.zp)
4140
4141%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32
4142%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32
4143%op3_ge_op1_op2 = tosa.GREATER_EQUAL(%op1_rescale_lhs, %op2_rescale_rhs)
4144%op4_not_op3 = tosa.LOGICAL_NOT(%op3_ge_op1_op2)
4145```
4146
4147### tfl.local_response_normalization
4148
4149No TOSA lowering defined.
4150
4151### tfl.log
4152
4153No TOSA lowering defined.
4154
4155### tfl.log_softmax
4156
4157Log softmax operator.
4158
4159**TensorFlow Lite Dialect**
4160
4161```
4162%output = tfl.log_softmax(%input)
4163```
4164
4165**TOSA Lowering**
4166
4167If input/output tensors are all non-quantized typed,
4168
4169Legalization:
4170
4171```
4172%output = lower_log_softmax_op(%logits)
4173```
4174
4175No TOSA lowering defined if input/output tensors are all quantized typed.
4176
4177### tfl.logical_and
4178
4179This operator is trivially lowered to tosa.LOGICAL_AND
4180
4181### tfl.logical_not
4182
4183This operator is trivially lowered to tosa.LOGICAL_NOT
4184
4185### tfl.logical_or
4186
4187This operator is trivially lowered to tosa.LOGICAL_OR
4188
4189### tfl.logistic
4190
4191Logistic operator.
4192
4193**TensorFlow Lite Dialect**
4194
4195```
4196%y = tfl.logistic(%x)
4197```
4198
4199**TOSA Lowering**
4200
4201If input/output tensors are all non-quantized typed,
4202
4203Legalization:
4204
4205```
4206%op1_sigmoid_in = tosa.SIGMOID(%x)
4207```
4208
4209If input/output tensors are all quantized typed,
4210
4211Prepare:
4212
4213```
4214float64 input_sample_grain = 1.0 / 16.0
4215auto sigmoid_func = [input_sample_grain](int32 x) -> int32 {
4216  float64 v = static_cast<float64>(x) * input_sample_grain
4217  v = 1.0 / (1.0 + std::exp(-v))
4218  return std::lround(32768.0 * v)
4219}
4220
4221float32 input_rescale_scale = (%x.scale * 128.0) / input_sample_grain
4222float32 output_rescale_scale = 1.0 / (%y.scale * 32768.0 * 128.0);
4223```
4224
4225Legalization:
4226
4227```
4228%table_const = get_table_const_tensor(sigmoid_func)
4229%op1_rescale_in = tosa.RESCALE(%x) {scale=input_rescale_scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i16
4230%op2_table_op1 = tosa.TABLE(%op1_rescale_in, %table_const)
4231%op3_rescale_op2 = tosa.RESCALE(%op2_table_op1) {scale=output_rescale_scale, input_zp=0, output_zp=%y.zp} // %int32->%y.dtype
4232```
4233
4234### tfl.matrix_diag
4235
4236No TOSA lowering defined.
4237
4238### tfl.matrix_set_diag
4239
4240No TOSA lowering defined.
4241
4242### tfl.max_pool_2d
4243
4244Max Pool 2d op.
4245
4246**TensorFlow Lite Dialect**
4247
4248```
4249%output = tfl.max_pool_2d(%input) {filter_height, filter_width, padding, stride_h, stride_w, fused_activation_function}
4250```
4251
4252**TOSA Lowering**
4253
4254Prepare:
4255
4256```
4257tosa_padding =
4258     get_padding_values_from_pad_type(padding, NHWC, 1,
4259                                      %input.type, tensor<{filter_height, filter_width}, tosa.int32>,
4260                                      {1, stride_h, stride_w, 1}, {1, 1, 1, 1})
4261```
4262
4263If input/output tensors are all non-quantized typed,
4264
4265Legalization:
4266
4267```
4268%maxpool2d = tosa.MAX_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding}
4269if(fused_activation != NONE) {
4270    %result = convert_fused_activation(%maxpool2d, fused_activation)
4271}
4272else {
4273    %result = %maxpool2d
4274}
4275```
4276
4277If input/output tensors are all quantized typed,
4278
4279Legalization:
4280
4281```
4282%maxpool2d = tosa.MAX_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding, quantization_info={input_zp=%input.zp, output_zp=%output.zp}}
4283if(fused_activation != NONE) {
4284    %result = convert_fused_activation(%maxpool2d, fused_activation)
4285}
4286else {
4287    %result = %maxpool2d
4288}
4289```
4290
4291### tfl.max_pooling_with_argmax_2d
4292
4293No TOSA lowering defined.
4294
4295### tfl.max_unpooling_2d
4296
4297No TOSA lowering defined.
4298
4299### tfl.maximum
4300
4301This operator is trivially lowered to tosa.MAXIMUM
4302
4303### tfl.mean
4304
4305Mean operator.
4306
4307**TensorFlow Lite Dialect**
4308
4309```
4310%output = tfl.mean(%input, %axis) {keep_dims}
4311```
4312
4313**TOSA Lowering**
4314
4315Prepare:
4316
4317```
4318int32 num_elements_on_axis = 1
4319for (int32 axis : %reduction_indices) {
4320    num_elements_on_axis *= %input.shape[axis]
4321}
4322float32 div_scale = 1.0 / num_elements_on_axis
4323```
4324
4325If input/output tensors are all non-quantized typed,
4326
4327Legalization:
4328
4329```
4330%cst_div_scale = tosa.CONST() {value={div_scale}}
4331%op1_rsum_in = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %axis, keep_dims)
4332%op2_mul_op1 = tosa.MUL(%op1_rsum_in, %cst_div_scale)
4333```
4334
4335If input/output tensors are all quantized typed,
4336
4337Legalization:
4338
4339```
4340%rsum = lower_reduce_op<tosa.REDUCE_SUM>(%op1_rescale_in, %output.shape, %reduction_indices, keep_dims, 1.0f, %input_zp, div_scale * %input.scale / %output.scale, %output.zp)
4341```
4342
4343### tfl.minimum
4344
4345This operator is trivially lowered to tosa.MINIMUM
4346
4347### tfl.mirror_pad
4348
4349No TOSA lowering defined.
4350
4351### tfl.mul
4352
4353Mul operator.
4354
4355**TensorFlow Lite Dialect**
4356
4357```
4358%output = tfl.mul(%lhs, %rhs)
4359```
4360
4361**TOSA Lowering**
4362
4363If input/output tensors are all non-quantized typed,
4364
4365Legalization:
4366
4367```
4368%op1_mul_in = tosa.MUL(%lhs, %rhs)
4369```
4370
4371If input/output tensors are all quantized typed,
4372
4373Legalization:
4374
4375```
4376%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0f, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32
4377%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0f, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32
4378%op3_mul_op1_op2 = tosa.MUL(%op1_rescale_lhs, %op2_rescale_rhs)
4379%op4_rescale_op3 = tosa.RESCALE(%op3_mul_op1_op2) {scale=%lhs.scale * %rhs.scale / %output.scale, input_zp=0, output_zp=%output.zp} // i32->%output.dtype
4380```
4381
4382### tfl.neg
4383
4384This operator is trivially lowered to tosa.NEGATE
4385
4386### tfl.non_max_suppression_v4
4387
4388No TOSA lowering defined.
4389
4390### tfl.non_max_suppression_v5
4391
4392No TOSA lowering defined.
4393
4394### tfl.not_equal
4395
4396Not_equal operator.
4397
4398**TensorFlow Lite Dialect**
4399
4400```
4401%output = tfl.not_equal(%lhs, %rhs)
4402```
4403
4404**TOSA Lowering**
4405
4406If input/output tensors are all non-quantized typed,
4407
4408Legalization:
4409
4410```
4411%op1_equal_lhs_rhs = tosa.EQUAL(%lhs, %rhs)
4412%op2_not_op1 = tosa.LOGICAL_NOT(%op1_equal_lhs_rhs)
4413```
4414
4415If input/output tensors are all quantized typed,
4416
4417Legalization:
4418
4419```
4420assert (%lhs.scale == %rhs.scale) && (%lhs.zp == %rhs.zp)
4421
4422%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0f, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32
4423%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0f, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32
4424%op3_equal_op1_op2 = tosa.EQUAL(%op1_rescale_lhs, %op2_rescale_rhs)
4425%op4_not_op3 = tosa.LOGICAL_NOT(%op3_equal_op1_op2) // i32->%output.dtype
4426```
4427
4428### tfl.NumericVerify
4429
4430No TOSA lowering defined.
4431
4432### tfl.one_hot
4433
4434OneHot operator.
4435
4436**TensorFlow Lite Dialect**
4437
4438```
4439%output = tfl.one_hot(%indices, %depth, %on_value, %off_value) {axis}
4440```
4441
4442**TOSA Lowering**
4443
4444```
4445%output = lower_one_hot_op(%indices, %depth, %on_value, %off_value, axis)
4446```
4447
4448### tfl.prelu
4449
4450No TOSA lowering defined.
4451
4452### tfl.pack
4453
4454Packs a list of tensors along a dimension into one tensor.
4455
4456**TensorFlow Dialect**
4457
4458```
4459%output = tf.pack(%values) {axis}
4460```
4461
4462**TOSA Lowering**
4463
4464```
4465%output = lower_pack_op(%values, axis)
4466```
4467
4468### tfl.pad
4469
4470This operator is trivially lowered to tosa.PAD
4471
4472### tfl.padv2
4473
4474No TOSA lowering defined.
4475
4476### tfl.pow
4477
4478No TOSA lowering defined.
4479
4480### tfl.pseudo_qconst
4481
4482This operator is trivially lowered to tosa.CONST
4483
4484### tfl.quantize
4485
4486Quantize operator
4487
4488**TensorFlow Lite Dialect**
4489
4490```
4491%output = tfl.quantize(%input)
4492```
4493
4494**TOSA Lowering**
4495
4496Legalization:
4497
4498```
4499if (isa<QuantizedType>(%input.dtype)) {
4500    %op1_rescale_in = tosa.RESCALE(%input) {scale=%input.scale / %output.scale, input_zp=%input.zp, output_zp=%output.zp}
4501}
4502else {
4503    %output = lower_quantize_op(%output.dtype, %input, %output.zp, %output.scale)
4504}
4505```
4506
4507### tfl.range
4508
4509No TOSA lowering defined.
4510
4511### tfl.rank
4512
4513Rank operator
4514
4515**TensorFlow Lite Dialect**
4516
4517```
4518%output = tfl.rank(%input)
4519```
4520
4521**TOSA Lowering**
4522
4523Legalization:
4524
4525```
4526%const = tosa.CONST() {value={%input.rank}}
4527```
4528
4529### tfl.reduce_any
4530
4531Computes the "logical or" of elements across dimensions of a tensor.
4532
4533**TensorFlow Lite Dialect**
4534
4535```
4536%output = tfl.reduce_any(%input, %reduction_indices) {keep_dims}
4537```
4538
4539**TOSA Lowering**
4540
4541Legalization:
4542
4543```
4544%op1_rsum_in = lower_reduce_op<tosa.REDUCE_ANY>(%input, %output.shape, %reduction_indices, keep_dims)
4545```
4546
4547### tfl.reduce_max
4548
4549Max-reduction operator.
4550
4551**TensorFlow Lite Dialect**
4552
4553```
4554%output = tfl.reduce_max(%input, %axes) {keep_dims}
4555```
4556
4557**TOSA Lowering**
4558
4559Legalization:
4560
4561```
4562%op1_rsum_in = lower_reduce_op<tosa.REDUCE_MAX>(%input, %output.shape, %reduction_indices, keep_dims)
4563```
4564
4565### tfl.reduce_min
4566
4567Computes the min reduction along the specified axes.
4568
4569**TensorFlow Lite Dialect**
4570
4571```
4572%output = tfl.reduce_min(%input, %axes) {keep_dims}
4573```
4574
4575**TOSA Lowering**
4576
4577Legalization:
4578
4579```
4580%op1_rsum_in = lower_reduce_op<tosa.REDUCE_MIN>(%input, %output.shape, %reduction_indices, keep_dims)
4581```
4582
4583### tfl.reduce_prod
4584
4585Prod-reduction operator.
4586
4587**TensorFlow Lite Dialect**
4588
4589```
4590%output = tfl.reduce_prod(%input, %axes) {keep_dims}
4591```
4592
4593**TOSA Lowering**
4594
4595If input/output tensors are all float typed,
4596
4597Legalization:
4598
4599```
4600%op1_rsum_in = lower_reduce_op<tosa.REDUCE_PROD>(%input, %output.shape, %reduction_indices, keep_dims)
4601```
4602
4603### tfl.relu_n1_to_1
4604
4605No TOSA lowering defined.
4606
4607### tfl.relu6
4608
4609Relu6 operator.
4610
4611**TensorFlow Lite Dialect**
4612
4613```
4614%y = tfl.relu6(%x)
4615```
4616
4617**TOSA Lowering**
4618
4619If input/output tensors are all non-quantized typed,
4620
4621Legalization:
4622
4623```
4624%op1_relun_in = tosa.RELUN(%input) {max_int=0, max_fp=6.0}
4625```
4626
4627If input/output tensors are all quantized typed,
4628
4629Legalization:
4630
4631```
4632%op1_rescale_in = tosa.RESCALE(%lhs) {scale=%x.scale / %y.scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i32
4633%op2_relun_op1 = tosa.RELUN(%op1_rescale_in) {max_int=(6.0 / %y.scale), max_fp=0.0}
4634%op3_rescale_op2 = tosa.RESCALE(%op2_relun_op1) {scale=1.0, input_zp=0, output_zp=%y.zp // i32->%y.dtype
4635```
4636
4637### tfl.relu
4638
4639Relu operator.
4640
4641**TensorFlow Lite Dialect**
4642
4643```
4644%y = tfl.relu(%x)
4645```
4646
4647**TOSA Lowering**
4648
4649If input/output tensors are all non-quantized typed,
4650
4651Legalization:
4652
4653```
4654%op1_relun_in = tosa.RELUN(%input) {max_int=0, max_fp=std::numeric_limits<float>::max()}
4655```
4656
4657If input/output tensors are all quantized typed,
4658
4659Legalization:
4660
4661```
4662%op1_rescale_in = tosa.RESCALE(%lhs) {scale=%x.scale / %y.scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i32
4663%op2_relun_op1 = tosa.RELUN(%op1_rescale_in) {max_int=std::numeric_limits<int32>::max(), max_fp=0.0}
4664%op3_rescale_op2 = tosa.RESCALE(%op2_relun_op1) {scale=1.0, input_zp=0, output_zp=%y.zp // i32->%y.dtype
4665```
4666
4667### tfl.reshape
4668
4669This operator is trivially lowered to tosa.RESHAPE
4670
4671### tfl.resize_bilinear
4672
4673ResizeBilinear Op.
4674
4675**TensorFlow Lite Dialect**
4676
4677```
4678%output = tfl.resize_bilinear(%input, %size) {aligned_corners, half_pixel_centers}
4679```
4680
4681**TOSA Lowering**
4682
4683```
4684%output = lower_resize_op(%input, %size, %input.dtype, "BILINEAR")
4685```
4686
4687### tfl.resize_nearest_neighbor
4688
4689ResizeBilinear Op.
4690
4691**TensorFlow Lite Dialect**
4692
4693```
4694%output = tfl.resize_bilinear(%input, %size) {aligned_corners, half_pixel_centers}
4695```
4696
4697**TOSA Lowering**
4698
4699```
4700%output = lower_resize_op(%input, %size, %input.dtype, "NEAREST_NEIGHBOR")
4701```
4702
4703### tfl.reverse_sequence
4704
4705No TOSA lowering defined.
4706
4707### tfl.reverse_v2
4708
4709ReverseV2 Operator.
4710
4711**TensorFlow Lite Dialect**
4712
4713```
4714%output = tfl.reverse_v2(%input, %axis)
4715```
4716
4717**TOSA Lowering**
4718
4719```
4720%output = lower_reversev2_op(%tensor, %axis)
4721```
4722
4723### tfl.round
4724
4725Round operator.
4726
4727**TensorFlow Lite Dialect**
4728
4729```
4730%output = tfl.round(%input)
4731```
4732
4733**TOSA Lowering**
4734
4735```
4736%const_half = tosa.CONST() {value={0.5}}
4737%op1_add_in_half = tosa.ADD(%input, %const_half)
4738%op2_floor_op1 = tosa.FLOOR(%op1_add_in_half)
4739```
4740
4741### tfl.rsqrt
4742
4743No TOSA lowering defined.
4744
4745### tfl.svdf
4746
4747No TOSA lowering defined.
4748
4749### tfl.segment_sum
4750
4751No TOSA lowering defined.
4752
4753### tfl.select
4754
4755This operator is trivially lowered to tosa.SELECT
4756
4757### tfl.select_v2
4758
4759This operator is trivially lowered to tosa.SELECT
4760
4761### tfl.shape
4762
4763Shape operator
4764
4765**TensorFlow Lite Dialect**
4766
4767```
4768%output = tfl.shape(%input)
4769```
4770
4771**TOSA Lowering**
4772
4773Legalization:
4774
4775```
4776%const = tosa.CONST() {value=%input.shape}
4777```
4778
4779### tfl.sin
4780
4781No TOSA lowering defined.
4782
4783### tfl.slice
4784
4785This operator is trivially lowered to tosa.SLICE
4786
4787### tfl.softmax
4788
4789Softmax operator.
4790
4791**TensorFlow Lite Dialect**
4792
4793```
4794%output = tfl.softmax(%input)
4795```
4796
4797**TOSA Lowering**
4798
4799If input/output tensors are all non-quantized typed,
4800
4801Legalization:
4802
4803```
4804%op1_exp_in = tosa.EXP(%input)
4805%op2_rsum_op1 = tosa.REDUCE_SUM(%op1_exp_in) {axis=(%input.rank-1)}
4806%op3_rcp_op2 = tosa.RECIPROCAL(%op2)
4807%op4_mul_op1_op3 = tosa.MUL(%op1, %op3)
4808```
4809
4810If input/output tensors are all quantized typed,
4811
4812Prepare:
4813
4814```
4815float64 exp_sample_grain = 1.0 / 16.0
4816auto exp_func = [exp_sample_grain](int32 x) -> int32 {
4817  double v = static_cast<float64>(x) * exp_sample_grain
4818  v = v < 0.0 ? std::exp(v) : 1.0
4819  return std::lround(32768.0 * v)
4820}
4821
4822float64 one_over_one_plus_x_sample_grain = 1.0 / 256.0
4823auto one_over_one_plus_x_func = [one_over_one_plus_x_sample_grain](int32 x) -> int32 {
4824  double v = static_cast<float64>(x) * one_over_one_plus_x_sample_grain
4825  v = v < 0.0 ? 1.0 : 1.0 / (1.0 + v)
4826  return std::lround(32768.0 * v)
4827}
4828
4829float64 op4_rescale_scale = (%input.scale * 128.0) / exp_sample_grain
4830float64 op19_rescale_scale = 1.0 / (%output.scale * 256.0)
4831```
4832
4833Legalization:
4834
4835```
4836%const_exp_table = get_table_const_tensor(exp_func)
4837%const_one_over_one_plus_x_table = get_table_const_tensor(one_over_one_plus_x_func)
4838%const_3 = tosa.CONST() {value={3}}
4839%const_34 = tosa.CONST() {value={12+20-8}}
4840%const_2_to_31 = tosa.CONST() {value={1<<31}}
4841%const_16 = tosa.CONST() {value={16}}
4842
4843%op1_rescale_in = tosa.RESCALE(%lhs) {scale=1.0f, input_zp=%x.zp, output_zp=0} // %x.dtype->i32
4844%op2_rmax_op1 = tosa.REDUCE_MAX(%op1_rescale_in) {axis=(%input.rank-1)}
4845%op3_sub_op1_op2 = tosa.SUB(%op1_rescale_in, %op2_relun_op1)
4846%op4_rescale_op3 = tosa.RESCALE(%op3_sub_op1_op2) {scale=op4_rescale_scale, input_zp=0, output_zp=0} // i32->i16
4847%op5_table_op4 = tosa.TABLE(%op4_rescale_op3, %const_exp_table)
4848%op6_rshift_op5_3 = tosa.ARITHMETIC_RIGHT_SHIFT(%op5_table_op4, %const_3)
4849%op7_rsum_op6 = tosa.REDUCE_SUM(%op6_rshift_op5_3) {axis=(%input.rank-1)}
4850%op8_clz_op7 = tosa.CLZ(%op7_rsum_op6)
4851%op9_sub_34_op8 = tosa.SUB(%const_34, %op8_clz_op7)
4852%op10_lshift_op7_op8 = tosa.LOGICAL_LEFT_SHIFT(%op7_rsum_op6, %op8_clz_op7)
4853%op11_sub_op10 = tosa.SUB(%op10_lshift_op7_op8, %const_2_to_31)
4854%op12_rshift_op11_16 = tosa.ARITHMETIC_RIGHT_SHIFT(%op11_sub_op10, %const_16)
4855%op13_cast_op12 = tosa.CAST(%op12_rshift_op11_16) // i32->i16
4856%op14_table_op13 = tosa.TABLE(%op13_cast_op12, %const_one_over_one_plus_x_table)
4857%op15_rescale_op14 = tosa.RESCALE(%op14_table_op13) {scale=1.0/128.0, input_zp=0, output_zp=0} // i32->i16
4858%op16_rescale_op5 = tosa.RESCALE(%op5_table_op4) {scale=1.0/128.0, input_zp=0, output_zp=0} // i32->i16
4859%op17_mul_op16_op15 = tosa.MUL(%op15_rescale_op14, %op16_rescale_op5)
4860%op18_rshift_op17_op9 = tosa.ARITHMETIC_RIGHT_SHIFT(%op17_mul_op16_op15, %op9_sub_34_op8)
4861%op19_rescale_op18 = tosa.RESCALE(%op18_rshift_op17_op9) {scale=op19_rescale_scale, input_zp=0, output_zp=%output.zp}
4862```
4863
4864### tfl.space_to_batch_nd
4865
4866SpaceToBatchNd operator.
4867
4868**TensorFlow Dialect**
4869
4870```
4871%output = tfl.space_to_batch_nd(%input, %block_shape, %paddings)
4872```
4873
4874**TOSA Lowering**
4875
4876```
4877%output = lower_space_to_batch_nd_op(%input, %block_shape, %paddings)
4878```
4879
4880### tfl.space_to_depth
4881
4882SpaceToDepth operator.
4883
4884**TensorFlow Dialect**
4885
4886```
4887%output = tfl.space_to_depth(%input) {block_size}
4888```
4889
4890**TOSA Lowering**
4891
4892```
4893%output = lower_space_to_depth_op(%input, block_size, "NHWC")
4894```
4895
4896### tfl.pseudo_sparse_const
4897
4898No TOSA lowering defined.
4899
4900### tfl.pseudo_sparse_qconst
4901
4902No TOSA lowering defined.
4903
4904### tfl.sparse_to_dense
4905
4906No TOSA lowering defined.
4907
4908### tfl.split
4909
4910Splits a tensor into num_split tensors along one dimension.
4911
4912**TensorFlow Dialect**
4913
4914```
4915%output = tfl.split(%split_dim, %value) {num_split}
4916```
4917
4918**TOSA Lowering**
4919
4920```
4921%output = lower_split_op(%value, %split_dim.as_constant(), num_split)
4922```
4923
4924### tfl.split_v
4925
4926Splits a tensor into num_split tensors along one dimension.
4927
4928**TensorFlow Dialect**
4929
4930```
4931%output = tfl.split_v(%value, %size_splits, %split_dim) {num_splits}
4932```
4933
4934**TOSA Lowering**
4935
4936```
4937%output = lower_splitv_op(%value, %size_splits.as_constant(), %split_dim.as_constant())
4938```
4939
4940### tfl.sqrt
4941
4942No TOSA lowering defined.
4943
4944### tfl.square
4945
4946Square operator.
4947
4948**TensorFlow Lite Dialect**
4949
4950```
4951%y = tfl.square(%x)
4952```
4953
4954**TOSA Lowering**
4955
4956If input/output tensors are all non-quantized typed,
4957
4958Legalization:
4959
4960```
4961%op1_mul_in = tosa.MUL(%x, %x)
4962```
4963
4964If input/output tensors are all quantized typed,
4965
4966Legalization:
4967
4968```
4969%op1_rescale_x = tosa.RESCALE(%x) {scale=1.0f, input_zp=%x.zp, output_zp=0} // %x.dtype->i32
4970%op2_mul_op1_op1 = tosa.MUL(%op1_rescale_x, %op1_rescale_x)
4971%op3_rescale_op2 = tosa.RESCALE(%op2_mul_op1_op1) {scale=%(x.scale * %x.scale) / %output.scale, input_zp=0, output_zp=%y.zp} // i32->%y.dtype
4972```
4973
4974### tfl.squared_difference
4975
4976Squared difference operator.
4977
4978**TensorFlow Lite Dialect**
4979
4980```
4981%output = tfl.squared_difference(%lhs, %rhs)
4982```
4983
4984**TOSA Lowering**
4985
4986Legalization:
4987
4988```
4989%op1_sub_in = tosa.SUB(%lhs, %rhs)
4990%op2_mul_op1 = tosa.MUL(%op1_sub_in, %op1_sub_in)
4991```
4992
4993### tfl.squeeze
4994
4995Removes dimensions of size 1 from the shape of a tensor.
4996
4997**TensorFlow Dialect**
4998
4999```
5000%output = tfl.squeeze(%input) {squeeze_dims}
5001```
5002
5003**TOSA Lowering**
5004
5005```
5006%output = lower_squeeze_op(%input, squeeze_dims)
5007```
5008
5009### tfl.strided_slice
5010
5011StridedSlice Op.
5012
5013**TensorFlow Dialect**
5014
5015```
5016%output = tfl.strided_slice(%input, %begin, %end, %strides) {begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask}
5017```
5018
5019**TOSA Lowering**
5020
5021```
5022%output = lower_strided_slice_op(%input, %begin, %end, %strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
5023```
5024
5025### tfl.sub
5026
5027This operator is trivially lowered to tosa.SUB
5028
5029### tfl.sum
5030
5031Sum operator.
5032
5033**TensorFlow Lite Dialect**
5034
5035```
5036%output = tfl.sum(%input, %axis) {keep_dims}
5037```
5038
5039**TOSA Lowering**
5040
5041If input/output tensors are all non-quantized typed,
5042
5043Legalization:
5044
5045```
5046%op1_rsum_in = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %axis, keep_dims)
5047```
5048
5049If input/output tensors are all quantized typed,
5050
5051Legalization:
5052
5053```
5054%rsum = lower_reduce_op<tosa.REDUCE_SUM>(%op1_rescale_in, %output.shape, %reduction_indices, keep_dims, 1.0f, %input_zp, (%input.scale / %output.scale), %output.zp)
5055```
5056
5057### tfl.tanh
5058
5059Hyperbolic tangent operator.
5060
5061**TensorFlow Lite Dialect**
5062
5063```
5064%y = tfl.tanh(%x)
5065```
5066
5067**TOSA Lowering**
5068
5069If input/output tensors are all non-quantized typed,
5070
5071Legalization:
5072
5073```
5074%op1_tanh_in = tosa.TANH(%x)
5075```
5076
5077If input/output tensors are all quantized typed,
5078
5079Prepare:
5080
5081```
5082float64 input_sample_grain = 1.0 / 32.0
5083auto tanh_func = [input_sample_grain](int32 x) -> int32 {
5084  float64 v = static_cast<float64>(x) * input_sample_grain
5085  v = std::exp(-2.0 * v)
5086  v = (1.0 - v) / (1.0 + v)
5087  return std::lround(32768.0 * v)
5088}
5089
5090float32 input_rescale_scale = (%x.scale * 128.0) / input_sample_grain
5091float32 output_rescale_scale = 1.0 / (%y.scale * 32768.0 * 128.0);
5092```
5093
5094Legalization:
5095
5096```
5097%table_const = get_table_const_tensor(tanh_func)
5098%op1_rescale_in = tosa.RESCALE(%x) {scale=input_rescale_scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i16
5099%op2_table_op1 = tosa.TABLE(%op1_rescale_in, %table_const)
5100%op3_rescale_op2 = tosa.RESCALE(%op2_table_op1) {scale=output_rescale_scale, input_zp=0, output_zp=%y.zp} // %int32->%y.dtype
5101```
5102
5103### tfl.tile
5104
5105This operator is trivially lowered to tosa.TILE
5106
5107### tfl.topk_v2
5108
5109No TOSA lowering defined.
5110
5111### tfl.transpose_conv
5112
5113Transpose convolution operator.
5114
5115**TensorFlow Lite Dialect**
5116
5117```
5118%output = tfl.transpose_conv(%output_shape, %weights, %input) {padding, stride_h, stride_w}
5119```
5120
5121**TOSA Lowering**
5122
5123Prepare:
5124
5125```
5126tosa_padding =
5127    get_transpose_conv2d_padding_values_from_pad_type(%input.type, %weights.type, %output_shape, padding, "NHWC", FORMAT_HWIO, {stride_h, stride_w}, {1, 1})
5128```
5129
5130If input/output tensors are all non-quantized typed,
5131
5132Legalization:
5133
5134```
5135%bias = tosa.CONST() {value={0.0} * %output.shape[3]}
5136%conv2d = tosa.TRANSPOSE_CONV2D(%input, %weight, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={1, 1}}
5137```
5138
5139If input/output tensors are all quantized typed,
5140
5141Prepare:
5142
5143```
5144float64 output_rescale_scale = (%input.scale * %weights.scale) / %output.scale
5145```
5146
5147Legalization:
5148
5149```
5150%bias = tosa.CONST() {value={0} * %output.shape[3]}
5151%conv2d = tosa.TRANSPOSE_CONV2D(%input, %weight, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={1, 1}}
5152%rescale = tosa.RESCALE(%conv2d) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %conv2d.dtype->%output.dtype
5153```
5154
5155### tfl.transpose
5156
5157This operator is trivially lowered to tosa.TRANSPOSE
5158
5159### tfl.unidirectional_sequence_lstm
5160
5161No TOSA lowering defined.
5162
5163### tfl.unidirectional_sequence_rnn
5164
5165No TOSA lowering defined.
5166
5167### tfl.unique
5168
5169No TOSA lowering defined.
5170
5171### tfl.unpack
5172
5173Unpacks a tensor along a dimension into multiple tensors.
5174
5175**TensorFlow Dialect**
5176
5177```
5178%output = tfl.unpack(%input) {num, axis}
5179```
5180
5181**TOSA Lowering**
5182
5183```
5184%output = lower_unpack_op(%input, axis, num)
5185```
5186
5187### tfl.where
5188
5189No TOSA lowering defined.
5190
5191### tfl.while
5192
5193No TOSA lowering defined.
5194
5195### tfl.yield
5196
5197This operator is trivially lowered to tosa.YIELD
5198
5199### tfl.zeros_like
5200
5201ZerosLike operator.
5202
5203**TensorFlow Dialect**
5204
5205```
5206%output = tfl.zeros_like(%input)
5207```
5208
5209**TOSA Lowering**
5210
5211```
5212%output = tosa.CONST() {value={0} * %input.num_elements}
5213```
5214
5215## fuse_tf_bias
5216
5217Legalize (tf.Conv2D + tf.BiasAdd) to tosa.CONV2D. This is currently the only N:1
5218mapping in TOSA legalization.
5219
5220From:
5221
5222```
5223%conv2d = tf.Conv2D(%input, %filter) {...}
5224%bias_add = tf.BiasAdd(%conv2d, %bias)
5225```
5226
5227To:
5228
5229```
5230%conv2d = tosa.CONV2D(%input, %filter, %bias)
5231```
5232
5233## convert_tfl_uint8
5234
5235This pass does three things:
5236
52371.  Convert const from quantized uint8 to quantized int8, with value within
5238    remapped as well.
52392.  If input placeholders is quantized uint8 typed, insert "tosa.RESCALE()
5240    {scale=1.0, input_zp=input_zp, output_zp=input_zp-128} // qu8->qi8" in
5241    between
52423.  If output tensor is quantized uint8 typed, insert "tosa.RESCALE()
5243    {scale=1.0, input_zp=output_zp+128, output_zp=output_zp} // qi8->qu8" in
5244    between
5245