1# TOSA Lowerings 2 3## Introduction 4 5### Overview 6 7This document provides pseudo-code lowerings from TensorFlow and TensorFlow Lite 8MLIR Dialects (https://www.tensorflow.org/mlir/dialects) to the TOSA Dialect 9(https://mlir.llvm.org/docs/Dialects/TOSA/). 10 11The documentation is a work-in-progress: sections with missing legalizations are 12in the process of being written. 13 14## Syntax 15 16The pseudo-code syntax used in this document is described below. 17 18### Primitive Datatypes 19 20int8: signed 8-bit integer uint8: unsigned 8-bit integer int16: signed 16-bit 21integer int32: signed 32-bit integer int64: signed 32-bit integer uint32: 22unsigned 32-bit integer float32: IEEE-754 32-bit floating point format float64: 23IEEE-754 64-bit floating point format bool: boolean 24 25### Value 26 27In pseudo-code, symbol starting with "%" indicates it’s a value. A value is 28evaluated by an operator at run time, and operator can consume and can only 29consume a list of values as operands. Note value’s tensor type is determined at 30compile time. Only the evaluation happens at run time One can easily construct a 31data flow subgraph by looking at the producer/consumer. 32 33### Tensor Type 34 35Tensor type is an attribute determined by legalization at compile time, 36describing the shape and element data type. It’s noted as tensor<shape, 37dtype>, or shorthanded as tensor<%t.type> 38 39### Operator Prototype 40 41In pseudocode an TOSA operator is prototyped as following format. 42 43%<output\_value> = tosa.<OPERATOR>(%<input\_value>) 44{<attribute = …} 45 46### Value Attributes 47 48For the purposes of brevity and clarity in this document, the pseudocode allows 49the following notation on value attribute. 50 51Shorthand | Description 52----------------- | --------------------------------------------------- 53`%t.shape` | Shape vector for the tensor 54`%t.shape[i]` | Size of dimension i for the tensor 55`%t.rank` | Rank of the tensor 56`%t.dtype` | Datatype of the tensor 57`%t.scale` | Quantized scaling parameter (float64) 58`%t.zp` | Quantized zero-point (int64) 59`%t.signed` | Boolean indicating the type is signed 60`%t.num_bits` | Number of bits in the datatype 61`%t.num_elements` | Number of elements in the tensor 62`%t.type` | Tuple of `tensor<%t.shape, %t.dtype>` 63`%t.size` | For tensor lists: the number of tensors in the list 64 65### Tensor Dimension Shorthand 66 67Where the TOSA Specification allows the use of named dimensions, the following 68names may be used. 69 70Name | Description 71---- | -------------------- 72`N` | Batch dimension 73`H` | Height dimension 74`W` | Width dimension 75`C` | Channel dimension 76`M` | Depthwise multiplier 77 78Each of these may be prefixed with `I` for the input dimension or `O` for the 79output dimension or `K` for kernel dimensions. 80 81## Common Legalization Functions 82 83The following pseudocode helper functions are used to cannonicalize arguments 84from different frameworks to the TOSA dialect. 85 86### .as_constant(): Matched as Constant 87 88Wherever %tensor.as_constant() is specified, a constant vector will be created 89to hold the value in the %tensor at compile time. This only succeeds if %tensor 90is fed by a constant type operator. If constant matching fails, the lowering 91will fail and be terminated. 92 93## Common Legalization Functions 94 95The following pseudo-code helper functions are used to cannonicalize arguments 96from different frameworks to the TOSA dialect. 97 98### get_padding_values_from_explicit_pad_attr() 99 100``` 101vector<int64> get_padding_values_from_explict_pad_attr(vector<int64> explicit_pad, 102 tensorflow::TensorFormat data_format_tf) 103{ 104 int64 pad_before, pad_after 105 vector<int64> computed_paddings 106 107 for (int32 i = 0; i < 2; i++) { 108 int64 dim = GetTensorSpatialDimIndex(4, data_format_tf, i) 109 pad_before = explicit_pad[dim * 2] 110 pad_after = explicit_pad[dim * 2 + 1] 111 computed_paddings.push_back(pad_before) 112 computed_paddings.push_back(pad_after) 113 } 114 115 return computed_paddings 116} 117``` 118 119### get_padding_values_from_pad_type() 120 121Calculate explicit padding array based on pad type 122 123``` 124vector<int64> get_padding_values_from_pad_type(tensorflow::Padding padding, tensorflow::TensorFormat data_format, 125 uint32 first_filter_spatial_dim, type input_type, type filter_type 126 vector strides, vector dilations) 127{ 128 assert(padding != tensorflow::Padding::EXPLICIT); 129 130 vector<int64> computed_padding; 131 132 // Padding over H and W dimensions 133 for (int32 i = 0; i < 2; i++) { 134 int32 ifm_dim = get_tensor_spatial_dim_index(4, data_format, i); 135 136 int32 filter_dim = first_filter_spatial_dim + i; 137 138 int32 dim_dilation = dilations[ifm_dim]; 139 int32 dim_stride = strides[ifm_dim]; 140 141 int64 op_size, pad_before_tf, pad_after_tf; 142 143 tensorflow::GetWindowedOutputSizeVerboseV2(input_type.shape[ifm_dim], filter_type.shape[filter_dim], 144 dim_dilation, dim_stride, padding, 145 // Outputs 146 &op_size, &pad_before_tf, &pad_after_tf); 147 computed_paddings.push_back(pad_before_tf); 148 computed_paddings.push_back(pad_after_tf); 149 } 150 151 return computed_paddings; 152} 153``` 154 155### positive_axis() 156 157``` 158// Cannonicalize scalar axis attributes to a scalar positive axis attribute 159int32 positive_axis(int32 axis, int32 rank) 160{ 161 if (axis < 0) 162 axis += rank; 163 164 return axis; 165} 166``` 167 168### compute_scale_32() 169 170``` 171void compute_scale_32(float64 scale, int32& multiplier, int32& shift) 172{ 173 /* Generates mantissa and shift values where mantissa is in [-1.0,-0.5] or 174 [0.5, 1.0] such that 175 multiplier = mantissa*2^shift */ 176 177 const float64 mantissa = std::frexp(scale, &shift); 178 auto shifted_m = std::round(mantissa * (int64(1) << 31)); 179 180 assert(shifted_m <= (int64(1) << 31)); // can't be greater that 1.0 181 if (shifted_m == (int64(1) << 31)) { 182 shifted_m /= 2; 183 shift++; 184 } 185 // TOSA expect right shift to be positive, and embed (1 << 31) into right 186 // shift bits 187 shift = (-shift) + 31; 188 189 assert(shifted_m <= std::numeric_limits<int32>::max()); 190 191 multiplier = static_cast<int32>(shifted_m); 192 193} 194``` 195 196### lower_batch_to_space_nd_op() 197 198``` 199Value lower_batch_to_space_nd_op(Value %input, Value %block_shape, Value %crops, shape_t output_shape) 200{ 201 202 vector <size_t> block_shape(%block_shape.rank) 203 vector std::pair<size_t, size_t> crops_arr 204 205 size_t remaining_shape_rank = %input.rank - %block.rank - 1 206 size_t crops_dim = %crops.shape[0] 207 208 for (int32 i = 0; i < crops_dim; i++) { 209 crops[i] = std::make_pair(%crops.as_constant()[i * crops_dim + 0], 210 %crops.as_constant()[i * crops_dim + 1]) 211 } 212 213 // Step 1: Reshape input to 214 // [block_shape[0], 215 // ... 216 // [block_shape[M-1], 217 // [batch / prod(block_shape)] 218 // [input_shape[1], 219 // ... 220 // [input_shape[N-1] 221 222 vector <size_t> a1_shape(%block.rank + %input.rank) 223 224 for (int32 i = 0; i < %block.rank; i++) { 225 a1_shape[i] = %block.shape[i] 226 } 227 228 a1_shape[%block.rank] = %input.shape.[0] / %block.num_elements 229 230 for (int32 i = 1; i < %input.rank; i++) { 231 a1_shape[i + %block.rank] = %input.shape[i] 232 } 233 234 // Step 2. Permute to shape: 235 // [ batch / prod(block_shape) ], 236 // [ input_shape[1] ], [ block_shape[0] ] 237 // ... 238 // [ input_shape[M] ], [ block_shape[M-1] 239 // + remaining_input_shapes input_shape[M+1 .. N-1] 240 vector <size_t> a2_perm(%block.rank + %input.rank) 241 242 a2_perm[0] = %block.rank 243 for (int32 i = 0; i < %block.rank; i++) { 244 a2_perm[1 + i * 2 + 0] = %block.rank + 1 + i 245 a2_perm[1 + i * 2 + 1] = i 246 } 247 248 // Step 3. Reshape to 249 // [ batch / prod(block_shape) ], 250 // [input_shape[1] * block_shape[0] ], 251 // .. 252 // [input_shape[M * block_shape[M-1], 253 // + remaining input shapes [input_shape[M+1.. N-1]] 254 vector <size_t> a3_shape(%input.rank) 255 256 %a3_shape[0] = %input.shape[0] / %block.num_elements 257 for (int32 i = 0; i < %block.rank; i++) { 258 a3_shape[i + 1] = %input.shape[i + 1] * %block.shape[i] 259 } 260 261 for (int32 i = 0; remaining_block_shape; i++) { 262 a3_shape[1 + %block.rank + 1] = %input.shape[%block.rank + 1 + i] 263 } 264 265 // Step 4 Crop the start/end dimensions using slice 266 vector <size_t> a4_begin(%input.rank), a4_size(%input.rank) 267 268 for (int32 i = 0; i < %input.rank; i++) { 269 if (i == 0 || i > crop_dims) { 270 a4_begin[i] = 0 271 a4_size[i] = output_shape[i] 272 } else { 273 a4_begin[i] = %crops[i-1].first 274 a4_size[i] = crops[i - 1].first - crops[i - 1].second 275 } 276 } 277 278 %a1_reshape = tosa.RESHAPE(%input) {new_shape=a1_shape} 279 %a2_transpose = tosa.TRANSPOSE(%a1_reshape) {perms=a2_perm} 280 %a3_reshape = tosa.RESHAPE(%a2_transpose) {new_shape=a3_shape} 281 %output = tosa.SLICE(%a3_reshape) {begin=a4_begin, size=a4_size} 282 283 return %output 284} 285``` 286 287### lower_concatv2_op() 288 289``` 290Value lower_concatv2_op(Type output_type, Value %values, int32 axis) 291{ 292 int32 tosa_axis = positive_axis(axis) 293 294 assert(%values.size >= 2) 295 296 // Convert scalar inputs to a tensor 297 if (%values:0.size == 0) { 298 for (int32 i = 0; i < %values.size; i++) { 299 %values:i = tosa.RESHAPE(%values:i) {new_shape=1} 300 } 301 } 302 303 for (int32 i=0; i < %values.size(); i++) { 304 %val = %values:i 305 if (%val.zp != output_type.zp || %val.scale != output_type.scale) { 306 float64 rescale_scale = %val.scale / output_type.scale 307 %values:i = tosa.RESCALE(%val) {scale=rescale_scale, input_zp=%values:0.zp, output_zp=output_type.zp} 308 } 309 } 310 311 %concat_op = tosa.CONCAT(%values:0, %values:1) {axis=tosa_axis} 312 313 for (int32 i = 2; i < %values.size; i++) { 314 %concat_op = tosa.CONCAT(%concat_op, %values:i) {axis=tosa_axis} 315 } 316 317 return %concat_op 318} 319``` 320 321### lower_depth_to_space_op() 322 323``` 324Value lower_depth_to_space_op(Value %input, size_t block_size[], Format_t data_format) 325{ 326 assert(data_format == 'NHWC') 327 328 vector <size_t> a2_shape = {%input.shape[0], 329 %input.shape[1], 330 %input.shape[2], 331 block_size[0], 332 block_size[1], 333 %input.shape[3] / (block_size[0] * block_size[1])} 334 335 vector <size_t> a4_shape = {%input.shape[0], 336 %input.shape[1] * block_size[0], 337 %input.shape[2] * block_size[1], 338 %input.shape[3] / (block_size[0] * block_size[1])} 339 340 %a2_reshape = tosa.RESHAPE(%input) {new_shape=a2_shape} 341 %a3_transpose = tosa.TRANSPOSE(%a2_reshape) {perms={0, 1, 3, 2, 4, 5}} 342 %output = tosa.RESHAPE(%a3_transpose) {new_shape=a4_shape} 343 344 return %output 345} 346``` 347 348### lower_elu_op() 349 350``` 351Value lower_elu_op(Value %value) 352{ 353 // elu(x) = x < 0 ? (exp(x) - 1) : x 354 // Create constants for 0/1 and reshape to match the rank 355 // of %value 356 %one_const = tosa.CONST() {value={1}} 357 %zero_const = tosa.CONST() {value={0}} 358 359 vector bcast_shape 360 for (int32 i = 0; i < %value.rank; i++) { 361 bcast_shape.push_back(1) 362 } 363 364 %one_reshape = tosa.RESHAPE(%one_const) {new_shape=bcast_shape} 365 %zero_reshape = tosa.RESHAPE(%zero_const) {new_shape=bcast_shape} 366 367 %exp_in = tosa.EXP(%value) 368 %sub = tosa.SUB(%exp_in, %one_reshape) 369 %ge = tosa.GREATER_EQUAL(%value, %zero_reshape) 370 %output = tosa.SELECT(%ge, %value, %sub) 371 return %output 372} 373``` 374 375### lower_expand_dims() 376 377``` 378Value lower_expand_dims(Value %input, int32 axis) 379{ 380 vector<size_t> reshape_dims 381 382 if (axis < 0 || axis >= %input.rank) { 383 // Insert at the end of the tensor 384 axis += %input.rank 385 for (int32 i = 0; i < input.rank; i++) { 386 reshape_dims.push_back(%input.shape[i]) 387 } 388 } else { 389 for (int32 i= 0 ; i < %input.rank; i++) { 390 if (i == axis) { 391 reshape_dims.push_back(1) 392 } 393 reshape_dims.push_back(%input.shape[i]) 394 } 395 } 396 397 %output = tosa.RESHAPE(%input) {new_shape=reshape_dims} 398 return %output 399} 400``` 401 402### lower_fake_quant_op() 403 404``` 405Value lower_fake_quant_op(Value %inputs, type output_type, float64 min, float64 max, 406 int64 num_bits, bool narrow_range) 407{ 408 assert(num_bits == 8 || num_bits == 16) 409 410 int64 qmax = (1L << (num_bits - 1)) - 1; 411 int64 qmin = -(1L << (num_bits - 1)) 412 413 if (narrow_range) { 414 qmin = qmin + 1 415 } 416 417 float64 scale = (max - min) / float64(qmax - qmin) 418 419 int64 zeropoint = (int64)std::round((-min) / scale + float64(qmin)) 420 421 %quantized = lower_quantize_op(%inputs.type, %inputs, 1.0 / scale, zeropoint) 422 423 %dequantized = lower_dequantize_op(output_type, %quantized_op, scale, zeropoint) 424 425 return %dequantized 426} 427``` 428 429### lower_floor_div() 430 431``` 432Value lower_floor_div(Value %lhs, Value %rhs) 433{ 434 %recip = tosa.RECIPROCAL(%rhs) 435 %mul = tosa.MUL(%lhs, %recip) 436 %output = tosa.FLOOR(%mul) 437 438 return %output 439} 440``` 441 442### lower_floor_mod() 443 444``` 445Value lower_floor_mod(Value %lhs, Value %rhs) 446{ 447 %recip = tosa.RECIPROCAL(%rhs) 448 %mul = tosa.MUL(%lhs, %recip) 449 %floor = tosa.FLOOR(%mul) 450 %output = tosa.SUB(%mul, %floor) 451 return %output 452} 453``` 454 455### lower_quantize_op() 456 457``` 458Value lower_quantize_op(Type output_type, Value %input, float64 scale, int64 zeropoint) 459{ 460 %const_scale = tosa.CONST() {value={scale}} 461 %const_zp = tosa.CONST() {value={zeropoint}} 462 %op1_mul_in_scale = tosa.MUL(%input, %const_scale) 463 %op2_add_op1_zp = tosa.ADD(%op1_mul_in_scale, %const_zp) 464 %op3_cast_op2 = tosa.CAST(%op2_add_op1_zp) // f32->%output.dtype 465} 466``` 467 468### lower_dequantize_op() 469 470``` 471Value lower_dequantize_op(Value %input, float64 scale, int64 zeropoint) 472{ 473 %const_scale = tosa.CONST() {value={scale}} 474 %const_zp = tosa.CONST() {value={(float64)zeropoint}} 475 %op1_cast_in = tosa.CAST(%input) // %input.dtype->f32 476 %op2_sub_op1_zp = tosa.SUB(%op1_cast_in, %const_zp) 477 %op3_mul_op2_scale = tosa.MUL(%op2_sub_op1_zp, %const_scale) 478} 479``` 480 481### lower_log_softmax_op() 482 483``` 484Value lower_log_softmax_op(Value %logits) 485{ 486 %op1 = tosa.EXP(%logits) 487 %op2 = tosa.REDUCE_SUM(%op1) {axis=(%logits.rank-1)} 488 %op3 = tosa.RECIPROCAL(%op2) 489 %op4 = tosa.MUL(%op1, %op3) 490 %op5 = tosa.LOG(%op4) 491 492 return %op5 493} 494``` 495 496### lower_pack_op() 497 498``` 499Value lower_pack_op(Value %input[], size_t axis) 500{ 501 size_t concat_axis = positive_axis(axis) 502 503 size_t input_tensor_rank = %input[0].rank 504 505 // Convert any rank 0 to rank 1 with reshape 506 if (input_tensor_rank == 0) { 507 for (int32 i = 0; i < %input.size; i++) { 508 %input[i] = tosa.RESHAPE(%input[i], {1}) 509 } 510 } 511 512 vector<size_t> output_shape 513 for (int32 i = 0; i < input_tensor_rank; i++) { 514 output_shape.push_back(%input[0].shape[i] 515 } 516 517 output_shape[concat_axis] = output_shape[concat_axis] * %input.size 518 519 // First pair of tensors 520 %concat = tosa.CONCAT(%input[0], %input[1]) {axis=concat_axis} 521 522 // Remaining tensors 523 for (int32 i = 2; i < %input.size; i++) { 524 %concat = tosa.CONCAT(%concat, %input[i]) {axis=concat_axis} 525 } 526 527 if (input_tensor_rank == 0) { 528 // No reshape needed for rank 0, already done 529 %output = %concat 530 } else 531 532 %reshape = tosa.RESHAPE(%concat) {new_shape=output_shape} 533 534 if (concat_axis == input_tensor_rank) { 535 // Output shape is [A, B, C, .. n] in this case, 536 // need to reshape to [N, A, B, C, ..] with perm [1, 2, 3, .. 0] 537 concat_axis = 0 538 539 vector <size_t> perms 540 for (int32 i = 0; i < %input[0].rank; i++) 541 perms.push_back(i + 1) 542 perms.push_back(0) 543 544 %output = tosa.TRANSPOSE(%reshape) {perms=perms} 545 } else { 546 %output = %reshape 547 } 548 549 return %output 550} 551``` 552 553### lower_reduce_op() 554 555``` 556Value lower_reduce_op<tosa_op_t OP>(Value %input, shape_t output_shape, Value %axes, bool keep_dims, float64 input_scale=1.0f, int32 input_zp=0, float64 output_scale=1.0f, int32 output_zp=0) 557{ 558 559 vector axes_vec = %axes.as_constant(); 560 561 // Special case of no axes means no transformation 562 if (axes_vec.size() == 0) { 563 return tosa.IDENTITY(%input) 564 } 565 566 bool is_quantized = isa<QuantizedType>(%input.dtype) ? true : false 567 568 shape_t shape = %input.shape; 569 %output = %input; 570 571 if (is_quantized) { 572 %output = tosa.RESCALE(%output) {scale=input_scale, input_zp=input_zp, output_zp=0} 573 } 574 575 for (int32 i = 0; i < axes_vec.size(); i++) { 576 int32 axis = positive_axis(axes_vec[i], %input.rank); 577 578 shape[axis] = 1; 579 %output = tosa.OP(%output) {axis=axis} 580 } 581 582 if (!keep_dims) { 583 %output = tosa.RESHAPE(%output) {new_shape=output_shape} 584 } 585 586 if (is_quantized) { 587 %output = tosa.RESCALE(%output) {scale=output_scale, input_zp=0, output_zp=output_zp} 588 } 589 590 return %output; 591} 592``` 593 594### lower_resize_op() 595 596``` 597Value lower_resize_op(Value %images, Value %size, shape output_shape, dtype output_dtype, mode_t mode) 598{ 599 int32 input_height = %input.shape[1] 600 int32 input_width = %input.shape[2] 601 int32 output_height = %output.shape[1] 602 int32 output_width = %output.shape[2] 603 604 float64 in_center_h = static_cast<float64>(input_height - 1) / 2.0 605 float64 in_center_w = static_cast<float64>(input_width - 1) / 2.0 606 float64 out_center_h = static_cast<float64>(output_height - 1) / 2.0 607 float64 out_center_w = static_cast<float64>(output_width - 1) / 2.0 608 609 float64 fp_stride_y, fp_stride_x 610 if (align_corner && output_height > 1) 611 fp_stride_y = static_cast<float64>(input_height - 1) / static_cast<float64>(output_height - 1) 612 else 613 fp_stride_y = static_cast<float64>(input_height) / static_cast<float64>(output_height) 614 if (align_corner && output_width > 1) 615 fp_stride_x = static_cast<float64>(input_width - 1) / static_cast<float64>(output_width - 1) 616 else 617 fp_stride_x = static_cast<float64>(input_width) / static_cast<float64>(output_width) 618 619 float64 fp_offset_y = fp_offset_y = 0.0f 620 if (half_pixel_centers) { 621 fp_offset_y = fp_stride_y * 0.5f - 0.5f 622 fp_offset_x = fp_stride_x * 0.5f - 0.5f 623 } 624 625 if (dtype == float) 626 %op1_resize_in = tosa.RESIZE(%input) {stride={fp_stride_y, fp_stride_x}, offset={fp_offset_y, fp_offset_x}, shift=0, resize_mode=mode} 627 else { 628 int32 shift = 10 629 float64 unit = static_cast<float64>(1 << shift) 630 int32 stride_y = fp_stride_y * unit 631 int32 stride_x = fp_stride_x * unit 632 int32 offset_y = fp_offset_y * unit 633 int32 offset_x = fp_offset_x * unit 634 635 %op1_resize_in = tosa.RESIZE(%input) {stride={stride_y, stride_x}, offset={offset_y, offset_x}, shift=shift, resize_mode=mode} 636 637 if (mode == "BILINEAR") { 638 %const_zero = tosa.CONST() {value={0}} 639 %const_twenty = tosa.CONST() {value={20}} 640 %op2_ge_op1 = tosa.GREATER_EQUAL(%op1_resize_in, %const_zero) 641 %op3_abs_op1 = tosa.ABS(%op1_resize_in) 642 %op4_rshift_op3 = tosa.ARITHMETIC_RIGHT_SHIFT(%op3_abs_op1, %const_twenty) 643 %op5_negate_op4 = tosa.NEGATE(%op4_rshift_op3) 644 %op6_select_op2_op4_op5 = tosa.SELECT(%op2_ge_op1, %op4_rshift_op3, %op5_negate_op4) 645 %op7_cast_op6 = tosa.CAST(%op6_select_op2_op4_op5) // i32/i48->%output.dtype 646 } 647 } 648} 649``` 650 651### lower_reversev2_op() 652 653``` 654Value lower_reverse_v2_op(Value %tensor, Value %axis) 655{ 656 Value %output = %tensor 657 658 if (%axis.num_elements == 0) { 659 %output = tosa.IDENTITY(%tensor) 660 } else { 661 for (int32 i = 0; i < %axis.shape[0]; i++) { 662 size_t axis_val = positive_axis(%axis.as_constant()[i]) 663 %output = tosa.REVERSE(%output) {axis=%axis_val} 664 } 665 } 666 667 return %output 668} 669``` 670 671### lower_round_op() 672 673``` 674Value lower_round_op(Value %x) 675{ 676 %half = tosa.CONST() {value={0.5}} 677 %add = tosa.ADD(%x, %half) 678 %output = tosa.FLOOR(%add) 679 680 return %output 681} 682``` 683 684### lower_selectv2_op() 685 686``` 687Value lower_selectv2_op(Value %condition, Value %t, Value %e, shape output_shape) 688{ 689 // Reshape condition so that ranks match to support 690 // broadcasting (if necessary) 691 692 if (%condition.rank != output_shape.size) { 693 vector <size_t> cond_shape = %condition.shape 694 for (int32 i = 0; i < (output_shape.size - %condition.rank); i++) { 695 cond_shape.push_front(1) 696 } 697 698 %condition = tosa.RESHAPE(%condition) {new_shape=cond_shape} 699 } 700 701 %output = tosa.SELECT(%condition, %t, %e) 702 703 return %output 704} 705``` 706 707### lower_shape_op() 708 709``` 710Value lower_shape_op(Value %input) 711{ 712 vector <size_t> input_shape = %input.shape 713 714 %shape = tosa.CONST() {value={input_shape}} 715 return %shape 716} 717``` 718 719### lower_space_to_batch_nd_op() 720 721``` 722Value lower_space_to_batch_nd_op(Value %input, Value %block_shape, Value %padding) 723{ 724 725 size_t block_rank = %block.shape[0] 726 size_t remaining_shape_rank = %input.rank - block_rank - 1; 727 728 // Step 1. Pad based on paddings operand (flattened representation of [input.rank][2]-shaped array) 729 vector <size_t> a1_padding 730 a1_padding[0] = 0 731 a1_padding[1] = 0 732 733 for (int32 i = 0; i < %padding.shape[0]; i++) { 734 a1_padding[i + 2] = %padding.as_constant()[i] 735 } 736 737 %a1_pad = tosa.PAD(%input) {padding=a1_padding} 738 739 // Step 2. Reshape to 740 // [batch + padded_shape[1] / block_shape[0], block_shape[0], ... 741 // padded_shape[M] / block_shape[M-1], block_shape[M-1]] + 742 // remaining_shape 743 744 vector <size_t> a2_shape(1 + block_rank * 2 + remaining_shape_rank) 745 a2_shape[0] = %input.shape[0] 746 for (int32 i = 0; i < block_rank; i++) { 747 a2_shape[1 + i * 2 + 0] = %a1_pad.shape[1 + i] / block_shape.as_constant()[i] 748 a2_shape[1 + i * 2 + 1] = block_shape.as_constant()[i] 749 } 750 751 for (int32 i = 0; i < remaining_shape_rank; i++) { 752 a2_shape[1 + block_rank * 2 + i] = %input.shape[1 + block_rank + i] 753 } 754 755 %a2_reshape = tosa.RESHAPE(%a1_pad) {new_shape=a2_shape} 756 757 // Step 3 transpose to 758 // block-shape + 759 // [batch] + 760 // [padded_shape[1] / block_shape[0], 761 // ... 762 // [padded_shape[M] / block_shape[M-1]] + 763 // remaining_shape 764 vector <size_t> a3_perm(%a2_reshape.rank) 765 size_t block_num_elems = 1 766 767 for (int32 i = 0; i < block_rank; i++) { 768 a3_perm[i] = 1 + 2 * i + 1 769 a3_perm[block_rank + 1 + i] = 2 * i + 1 770 block_num_elems *= %block.as_constant()[i] 771 } 772 773 a3_perm[block_rank] = 0 774 for (int32 i = (1 + block_rank * 2); i < %a2_reshape.rank; i++) { 775 a3_perm[i] = i 776 } 777 778 %a3_reshape = tosa.RESHAPE(%a2_reshape) {perm=a3_perm} 779 780 // Step 4. Reshape transposed tensor to 781 // [ batch * prod(block_shape)] + 782 // [ padded_shape[1] / block_shape[0], 783 // ..., 784 // padded_shape[M] / block_shape[M-1]] + 785 // remaining_shape 786 787 vector <size_t> a4_shape(%input.rank) 788 a4_shape[0] = batch_size * block_num_elements 789 790 for (int32 i = 0; i < block_rank; i++) { 791 a4_shape[i + 1] = %a1_pad.shape[i + 1] / %block.as_constant()[i] 792 } 793 794 for (int32 i = 0; i < remaining_block_shape; i++) { 795 a4_shape[1 + block_rank + i] = %input.shape[1 + block_rank + i] 796 } 797 798 %output = tosa.RESHAPE(%a3_reshape) {new_shape=a4_shape} 799 800 return %output 801} 802``` 803 804### lower_space_to_depth_op() 805 806``` 807Value lower_space_to_depth_op(Value %input, size_t block_size[], Format_t data_format) 808{ 809 assert(data_format == 'NHWC') 810 811 vector <size_t> a2_shape = {%input.shape[0], 812 %input.shape[1] / block_size[0], 813 %block_size[0], 814 %input_shape[2] / block_size[1], 815 %block_size[1], 816 %input_shape[3]} 817 %a2_reshape = tosa.RESHAPE(%input) {new_shape=a2_shape} 818 %a3_transpose = tosa.TRANSPOSE(%a2_reshape) {perm={0, 1, 3, 2, 4, 5}} 819 820 vector <size_t> a4_shape = {%input.shape[0], 821 %input_shape[1] / block_size[0], 822 %input_shape[2] / block_size[1], 823 %input_shape[3] * block_size[0] * block_size[1]} 824 %output = tosa.RESHAPE(%a3_transpose) {new_shape=%a4_shape} 825 return %output 826} 827``` 828 829### lower_split_op() 830 831``` 832Value lower_split_op(Value %value, size_t axis, size_t num_split) 833{ 834 Value %output[] 835 836 size_t slice_size = %value.shape[axis] / num_split 837 838 for (int32 i = 0; i < num_split; i++) { 839 vector <size_t> begin_vals, size_vals 840 841 for (int32 j = 0; j < %value.rank; j++) { 842 if (j == axis) { 843 begin_vals.push_back(slice_size * i) 844 size_vals.push_back(slice_size) 845 } else { 846 begin_vals.push_back(0) 847 size_vals.push_bac(%value.shape[j]) 848 } 849 850 %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} 851 } 852 853 } 854 855 %output_list = tosa.IDENTITYN(%output) 856 return %output_list 857} 858``` 859 860### lower_splitv_op() 861 862``` 863Value lower_splitv_op(Value %value, vector <size_t> size_split, size_t axis) 864{ 865 Value %output[] 866 867 size_t curr_split_start = 0 868 869 for (int32 i = 0; i < size_split.size(); i++) { 870 vector <size_t> begin_vals, size_vals 871 872 for (int32 j = 0; j < %value.rank; j++) { 873 if (j == axis) { 874 begin_vals.push_back(curr_split_start) 875 size_vals.push_back(size_split[i]) 876 } else { 877 begin_vals.push_back(0) 878 size_vals.push_back(input.shape[j]) 879 } 880 } 881 882 %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} 883 884 curr_split_start += size_split[i] 885 } 886 887 %output_list = tosa.IDENTITYN(%output) 888 return %output_list 889} 890``` 891 892### lower_squeeze_op() 893 894``` 895Value lower_squeeze_op(Value %input, vector<size_t> squeeze_dims) 896{ 897 vector <size_t> reshape_dims 898 899 if (squeeze_dims.size() == 0) { 900 // Remove all 1-dims 901 for (int32 i = 0; i < %input.rank; i++) { 902 if (%input.shape[i] != 1) { 903 reshape_dims.push_back(%input_shape[i]) 904 } 905 } 906 } else { 907 // Remove the specified dimensions 908 for (int32 i = 0; i < %input.rank; i++) { 909 if (!squeeze_dims.find(i) || %input.shape[i] != -1) { 910 reshape_dims.push_back(%input_shape[i]) 911 } 912 } 913 } 914 915 %output = tosa.RESHAPE(%input) {new_shape=reshape_dims} 916 917 return %output 918} 919``` 920 921### lower_strided_slice_op() 922 923``` 924Value lower_strided_slice_op(Value %input, Value %begin_val, Value %end_val, Value %strides_val, 925 size_t begin_mask, size_t end_mask, size_t ellipsis_mask, 926 size_t new_axis_mask, size_t shrink_axis_mask) 927{ 928 // Note: does not implement ellipsis_mask or reverse stride at this time 929 assert(ellipsis_mask == 0) 930 931 vector <size_t> begin(%begin_val.as_constant()), end(%end_val.as_constant()), strides(%strides_val.as_constant()) 932 vector <size_t> a1_start, a1_size, a2_shape, a3_start, a3_size, a4_shape 933 934 for (int32 i = 0; i < %input.rank; i++) { 935 if (begin_mask & (1 << i)) { 936 begin[i] = 0 937 } 938 939 if (end_mask & (1 << i)) { 940 end[i] = %input.shape[i] 941 } 942 943 // Wrap around index if begin and end are negative 944 if (begin[i] < 0) { 945 begin[i] += %input.shape[i] 946 } 947 948 if (end[i] < 0) { 949 end[i] += %input.shape[i] 950 } 951 952 a1_start[i] = begin[i] 953 a1_size[i] = end[i] - begin[i] 954 955 a2_shape[i*2 + 0] = a1_size[i] / strides[i] 956 a2_shape[i*2 + 1] = strides[i] 957 958 a3_start[i*2 + 0] = 0 959 a3_start[i*2 + 1] = 0 960 961 if (shrink_axis_mask & (1 << i)) { 962 a3_size[i*2 + 0] = 1 963 } else { 964 a3_size[i*2 + 0] = a1_size[i] / strides[i] 965 } 966 a3_size[i*2 + 1] = 1 967 968 if (!(shrink_axis_mask & (1 << i))) { 969 if (new_axis_mask & (1 << i)) { 970 a4_shape.push_back(1) 971 a4_shape.push_back((a1_size[i] / strides[i])) 972 } 973 } 974 975 // Step 1: Slice the input array 976 %a1_slice = tosa.SLICE(%input) {start=a1_start, size=a1_size} 977 978 // Step 2: Reshape the sliced array: 2x as many dimensions as %input 979 %a2_reshape = tosa.RESHAPE(%a1_slice) {new_shape=a2_shape} 980 981 // Step 3: Take a slice of the [0] index along each of the strided dimensions (even dimensions) 982 %a3_slice = tosa.SLICE(%a2_reshape) {start=a3_start, size=a3_size} 983 984 // Step 4: Reshape the now-strided tensor back down to the desired number of dimensions 985 %output = tosa.RESHAPE(%a3_slice) {new_shape=a4_shape} 986 987 return %output 988} 989``` 990 991### lower_unpack_op() 992 993``` 994Value lower_unpack_op(Value %value, size_t axis, uint64_t num) 995{ 996 axis = positive_axis(axis) 997 998 Value %output_arr[] 999 1000 // Step 1: transpose 'axis' to left-most dimension, if necessary 1001 Value %transposed_value 1002 1003 if (axis != 0) { 1004 vector <size_t> perms 1005 1006 perms.push_back(axis) 1007 for (int32 i = 0; i < %input.rank; i++) { 1008 if (i != axis) 1009 perms.push_back(i) 1010 } 1011 1012 %transposed_value = tosa.TRANSPOSE(%value) {perms=perms} 1013 1014 } else { 1015 %transposed_value = %value 1016 } 1017 1018 // Step 2: Slice [N, A, B, C] into [N] [A, B, C] 1019 for (int32 i = 0; i < %transposed_value.rank; i++) { 1020 vector <size_t> begin_vals, size_vals, shape_vals 1021 1022 begin_vals.push_back(i) 1023 size_vals.push_back(1) 1024 1025 for (int32 j = 1; j < %transposed_value.rank; j++) { 1026 begin_vals.push_back(0) 1027 size_vals.push_back(transposed_value.shape[j]) 1028 shape_vals.push_back(transposed_value.shape[j]) 1029 } 1030 1031 %slice = %tosa.SLICE(%transposed_value) {begin=begin_vals, size=size_vals} 1032 %output_arr[i] = %tosa.RESHAPE(%slice) {new_shape=shape_vals} {begin=begin_vals, size=size_vals} 1033 } 1034 1035 // Combine array of sliced tensors into a list of tensors 1036 %output = tosa.IDENTITYN(%output_arr) 1037 return %output 1038} 1039``` 1040 1041### get_transpose_conv2d_padding_values_from_pad_type() 1042 1043``` 1044vector<int64> get_transpose_conv2d_padding_values_from_pad_type(tensorflow::Padding padding, tensorflow::TensorFormat data_format, 1045 uint32 first_filter_spatial_dim, type input_type, type filter_type 1046 vector strides, vector dilations) 1047{ 1048 int64 pad_before, pad_after; 1049 vector<int64> computed_padding 1050 1051 for (int32 i = 0; i < 2; i++) { 1052 int64 ifm_dim = GetTensorSpatialDimIndex(4, data_format, i); 1053 int64 ofm_dim = GetTensorSpatialDimIndex(4, data_format, i); 1054 int64 filter_dim = first_filter_spatial_dim + 1 1055 1056 int64 ifm_size = input_shape[ifm_dim] 1057 int64 ofm_size = output_dims[ofm_dim] 1058 int64 filter_size = filter.shape[filter_dim] 1059 int64 dim_dilation = dilations[i] 1060 int64 dim_stride = strides[i] 1061 int32 effective_filter_size = (filter_size - 1) * dim_dilation + 1 1062 int32 total_padding = ((ifm_size - 1) * dim_stride + effective_filter_size - ofm_size) 1063 total_padding = total_padding > 0 ? total_padding : 0 1064 1065 pad_before = total_padding / 2 1066 pad_after = total_padding - pad_before 1067 1068 computed_padding.push_back(pad_before) 1069 } 1070 1071 return computed_padding 1072} 1073``` 1074 1075### lower_fused_activation() 1076 1077``` 1078Value lower_fused_activation(Value %input, string activation) 1079{ 1080 bool is_quantized = isa<QuantizedType>(%input.dtype) ? true : false 1081 1082 if (is_quantized) { 1083 if (activation == "NONE") { 1084 return %input 1085 } 1086 else if (activation == "RELU") { 1087 int32 quantized_0 = %input.zp 1088 int32 quantized_max = %input.storage_max 1089 return tosa.CLAMP(%input) {min_int=quantized_0, max_int=quantized_max} 1090 } 1091 else if (activation == "RELU6") { 1092 int32 quantized_0 = %input.zp 1093 int32 quantized_6 = %input.zp + (6.0 / %input.scale) 1094 return tosa.CLAMP(%input) {min_int=quantized_0, max_int=quantized_6} 1095 } 1096 else if (activation == "RELU_N1_TO_1") { 1097 int32 quantized_n1 = %input.zp + (-1.0 / %input.scale) 1098 int32 quantized_1 = %input.zp + (1.0 / %input.scale) 1099 return tosa.CLAMP(%input) {min_int=quantized_n1, max_int=quantized_1} 1100 } 1101 } 1102 else { 1103 if (activation == "NONE") { 1104 return %input 1105 } 1106 else if (activation == "RELU") { 1107 return tosa.RELUN(%input) {max_fp=numeric_limit<float32>::max()} 1108 } 1109 else if (activation == "RELU6") { 1110 return tosa.RELUN(%input) {max_fp=6.0} 1111 } 1112 else if (activation == "RELU_N1_TO_1") { 1113 return tosa.CLAMP(%input) {min_fp=-1.0, max_fp=1.0} 1114 } 1115 else if (activation == "TANH") { 1116 return tosa.TANH(%input) 1117 } 1118 } 1119} 1120``` 1121 1122### get_table_const_tensor() 1123 1124``` 1125Value get_table_const_tensor(function func) 1126{ 1127 array<int16, 513> table_array 1128 for (int32 i = -256; i <= 256; i++) { 1129 table_array[i] = func(i) 1130 } 1131 1132 return tosa.CONST() {value=table_array} 1133} 1134``` 1135 1136### lower_gather_op() 1137 1138``` 1139Value lower_gather_op(Value %params, Value %indices, int32 batch_dims, int32 axis) 1140{ 1141 assert batch_dims <= %indices.rank 1142 assert axis >= batch_dims 1143 1144 int32 N = W = K = C = 1 1145 1146 for (int32 i = 0; i < batch_dims; i++) N *= %params.shape[i] 1147 for (int32 i = batch_dims; i < %indices.rank; i++) W *= %indices.shape[i] 1148 K = %params.shape[axis] 1149 for (int32 i = batch_dims; i < axis; i++) C *= %params.shape[i] 1150 for (int32 i = (axis + 1); i < %params.rank; i++) C *= %params.shape[i] 1151 1152 vector<int32> params_idx_batch, params_idx_left, params_idx_indices, params_idx_right 1153 for (int32 i = 0; i < %params.rank; i++) { 1154 if (i < batch_dims && i < axis) 1155 params_idx_batch.push_back(i) 1156 else if (i < axis) 1157 params_idx_left.push_back(i) 1158 else if (i < (axis + 1)) 1159 params_idx_indices.push_back(i) 1160 else 1161 params_idx_right.push_back(i) 1162 } 1163 1164 vector<int32> params_perm = {params_idx_batch, params_idx_left, params_idx_indices, params_idx_right} 1165 vector<int32> result_perm 1166 for (int32 i = 0; i < batch_dims; i++) 1167 result_perm.push_back(i) 1168 for (int32 i = 0; i < params_idx_left.size(); i++) 1169 result_perm.push_back(params_idx_left[i]) 1170 for (int32 i = batch_dims; i < %indices.rank; i++) 1171 result_perm.push_back(i) 1172 for (int32 i = 0; i < params_idx_right.size(); i++) 1173 result_perm.push_back(params_idx_right[i]) 1174 1175 %const_params_perm = tosa.CONST() {value=params_perm} 1176 %const_result_perm = tosa.CONST() {value=result_perm} 1177 1178 %op1_transpose_params = tosa.TRANSPOSE(%params, %const_params_perm) 1179 %op2_reshape_op1 = tosa.RESHAPE(%op1_transpose_params) {shape={N,K,C}} 1180 %op3_reshape_indices = tosa.RESHAPE(%indices) {shape={N,W}} 1181 %op4_gather_op2_op3 = tosa.GATHER(%op2_reshape_op1, %op3_reshape_indices) 1182 %op5_reshape_op4 = tosa.RESHAPE(%op4_gather_op2_op3) {shape={N,W,C}} 1183 %op6_transpose_op5 = tosa.TRANSPOSE(%op5_reshape_op4, %const_result_perm) 1184} 1185``` 1186 1187### lower_gather_nd_op() 1188 1189``` 1190Value lower_gather_nd_op(Value %params, Value %indices) 1191{ 1192 int32 N = W = K = C = ND = 1 1193 1194 ND = %indices.shape[%indices.rank - 1] 1195 1196 assert ND < %params.rank 1197 1198 for (int32 i = 0; i < (%indices.rank - 1); i++) W *= %indices.shape[i] 1199 for (int32 i = 0; i < ND; i++) K = %params.shape[i] 1200 for (int32 i = ND; i < %params.rank; i++) C *= %params.shape[i] 1201 1202 vector<int32> flatten_coeff_vec 1203 for (int32 i = 0; i < ND; i++) flatten_coeff_vec.push_back(i) 1204 flatten_coeff_vec.push_back(1) 1205 1206 %const_flatten_coeff = tosa.CONST() {value=flatten_coeff_vec} 1207 %op1_reshape_params = tosa.RESHAPE(%params) {shape={N,K,C}} 1208 %op2_reshape_indices = tosa.RESHAPE(%indices) {shape={W,ND}} 1209 %op3_mul_op2_flatten_coeff = tosa.MUL(%op2_reshape_indices, %const_flatten_coeff) 1210 %op4_rsum_op3 = tosa.REDUCE_SUM(%op3_mul_op2_flatten_coeff) {axis=1} 1211 %op5_reshape_op4 = tosa.RESHAPE(%op4_rsum_op3) {shape={N,W}} 1212 %op6_gather_op1_op5 = tosa.GATHER(%op1_reshape_params, %op5_reshape_op4) 1213 %op7_reshape_op6 = tosa.RESHAPE(%op6_gather_op1_op5) {shape={N,W,C}} 1214} 1215``` 1216 1217### lower_one_hot_op() 1218 1219``` 1220Value lower_one_hot_op(Value %indices, Value %depth, Value %on_value, Value %off_value, int32 axis) 1221{ 1222 int32 N = W = C = 1 1223 int32 K = %depth.as_constant() 1224 int32 left_dim = right_dim = 1 1225 for(int32 i : %indices.rank) { 1226 int32 dim = %indices.shape[i] 1227 N *= dim 1228 if (i >= axis) 1229 right_dim *= dim 1230 else 1231 left_dim *= dim 1232 } 1233 1234 %perm_const = tosa.CONST() {value={0, 2, 1}} 1235 %op1_reshape_on_value = tosa.RESHAPE(%on_value) {shape={1, 1, 1}} 1236 %op2_tile_op1 = tosa.TILE(%op1_reshape_on_value) {multiples={N, W, C}} 1237 %op3_reshape_off_value = tosa.RESHAPE(%off_value) {shape={1, 1, 1}} 1238 %op4_tile_op1 = tosa.TILE(%op3_reshape_off_value) {multiples={N, K, C}} 1239 %op5_reshape_indices = tosa.RESHAPE(%indices) {shape={N, W}} 1240 %op6_scatter_op4_op5_op2 = tosa.SCATTER(%op4_tile_op1, %op5_reshape_indices, %op2_tile_op1) 1241 %op7_reshape_op6 = tosa.RESHAPE(%op6_scatter_op4_op5_op2) {shape={left_dim, right_dim, K}} 1242 %op8_transpose_op7 = tosa.TRANSPOSE(%op7_reshape_op6, %perm_const) 1243 %op9_reshape_op8 = tosa.RESHAPE(%op8_transpose_op7) {shape=%output.shape} 1244} 1245 1246 1247## MLIR Passes Management 1248 1249Legalization is built on multiple MLIR passes. 1250 1251| MLIR Pass Name | Input Dialect | Output Dialect | Description | 1252| ------------------------- | ------------- | -------------- | --------------- | 1253| legalize_tf | TensorFlow | TOSA | Legalize | 1254: : : : TensorFlow : 1255: : : : dialect to TOSA : 1256: : : : dialect : 1257| fuse_tf_bias | TensorFlow | TOSA | Mapping | 1258: : : : tf.BiasAdd + : 1259: : : : tf.Conv2D to : 1260: : : : tosa.CONV2D : 1261| legalize_tfl | TensorFlow | TOSA | Legalize | 1262: : Lite : : TensorFlow Lite : 1263: : : : dialect to TOSA : 1264: : : : dialect : 1265| convert_tfl_uint8 | TensorFlow | TensorFlow | Convert | 1266: : Lite : Lite : quantized uint8 : 1267: : : : graph to int8 : 1268: : : : graph : 1269 1270TF to TOSA legalization could be summarized by following pseudocode: 1271 1272``` 1273 1274void legalize_tf_to_tosa(mlir::Module module) { mlir::PassManager pm 1275 1276``` 1277// other MLIR passes to optimize TF 1278 1279pm.addPass(fuse_tf_bias) 1280pm.addPass(legalize_tf) 1281 1282// other MLIR passes to optimize TOSA 1283``` 1284 1285} ``` 1286 1287TFLite to TOSA legalization could be summarized by following pseudocode: 1288 1289``` 1290void legalize_tfl_to_tosa(mlir::Module module) 1291{ 1292 mlir::PassManager pm 1293 1294 // other MLIR passes to optimize TFLite 1295 1296 pm.addPass(convert_tfl_uint8) 1297 pm.addPass(legalize_tfl) 1298 1299 // other MLIR passes to optimize TOSA 1300} 1301``` 1302 1303Each of the passes is described in more detail in the subsequent chapters. 1304 1305## TensorFlow MLIR Dialect Legalization (legalize_tf) 1306 1307### tf.Abs 1308 1309This operator is trivially lowered to tosa.ABS 1310 1311### tf.AddN 1312 1313**TensorFlow Dialect** 1314 1315``` 1316%output = tf.AddN(%inputs) 1317``` 1318 1319**TOSA Lowering** 1320 1321``` 1322%output = tosa.ADD(%inputs:0, %inputs:1) 1323for (int32 i = 2; i < %inputs.size; i++) { 1324 %output = tosa.ADD(%inputs:i, %output) 1325} 1326``` 1327 1328### tf.Add 1329 1330Element-wise addition. 1331 1332**TensorFlow Dialect** 1333 1334``` 1335%output = tf.Add(%x, %y) 1336``` 1337 1338**TOSA Lowering** This operator is trivially lowered to tosa.ADD. 1339 1340### tf.Addv2 1341 1342Element-wise addition. 1343 1344**TensorFlow Dialect** 1345 1346``` 1347%output = tf.Addv2(%x, %y) 1348``` 1349 1350**TOSA Lowering** This operator is trivially lowered to tosa.ADD. 1351 1352### tf.All 1353 1354Computes the "logical and" of elements across dimensions of a tensor. 1355 1356**TensorFlow Dialect** 1357 1358``` 1359%output = tf.all(%input, %reduction_indices) {keep_dims} 1360``` 1361 1362**TOSA Lowering** 1363 1364``` 1365%output = lower_reduce_op<tosa.REDUCE_ALL>(%input, %output.shape, %reduction_indices, keep_dims) 1366``` 1367 1368### tf.Any 1369 1370Computes the "logical or" of elements across dimensions of a tensor. 1371 1372**TensorFlow Dialect** 1373 1374``` 1375%output = tf.any(%input, %reduction_indices) {keep_dims} 1376``` 1377 1378**TOSA Lowering** 1379 1380``` 1381%output = lower_reduce_op<tosa.REDUCE_ANY>(%input, %output.shape, %reduction_indices, keep_dims) 1382``` 1383 1384### tf.ArgMax 1385 1386Returns the index with the largest value across the given axis of the input 1387tensor. 1388 1389**TensorFlow Dialect** 1390 1391``` 1392%output = tf.ArgMax(%input, %dimension) 1393``` 1394 1395**TOSA Lowering** 1396 1397``` 1398int64 axis = positive_axis(%dimension) 1399%output = tosa.ARGMAX(%input) {axis=axis} 1400``` 1401 1402### tf.ArgMin 1403 1404Returns the index with the smallest value across the given axis of the input 1405tensor. 1406 1407**TensorFlow Dialect** 1408 1409``` 1410%output = tf.ArgMin(%input, %dimension) 1411``` 1412 1413**TOSA Lowering** 1414 1415No TOSA lowering defined. 1416 1417### tf.Assert 1418 1419Asserts that the given condition is true. 1420 1421**TensorFlow Dialect** 1422 1423``` 1424%output = tf.Assert(%condition, %summarize) 1425``` 1426 1427**TOSA Lowering** 1428 1429No TOSA lowering defined. 1430 1431### tf.AssignAddVariableOp 1432 1433Adds a value to the current value of a variable. 1434 1435**TensorFlow Dialect** 1436 1437``` 1438%output = tf.AssignAddVariableOp(%resource, %value, %dtype) 1439``` 1440 1441**TOSA Lowering** 1442 1443No TOSA lowering defined. 1444 1445### tf.AssignSubVariableOp 1446 1447Subtracts a value to the current value of a variable. 1448 1449**TensorFlow Dialect** 1450 1451``` 1452%output = tf.AssignSubVariableOp(%resource, %value, %dtype) 1453``` 1454 1455**TOSA Lowering** 1456 1457No TOSA lowering defined. 1458 1459### tf.AssignVariableOp 1460 1461Assigns a new value to a variable. 1462 1463**TensorFlow Dialect** 1464 1465``` 1466%output = tf.AssignVariableOp(%resource, %value, %dtype) 1467``` 1468 1469**TOSA Lowering** 1470 1471No TOSA lowering defined. 1472 1473### tf.AvgPool 1474 1475Performs average pooling on the input. 1476 1477**TensorFlow Dialect** 1478 1479``` 1480%output = tf.AvgPool(%value) {ksize, strides, padding, data_format} 1481``` 1482 1483**TOSA Lowering** 1484 1485``` 1486assert(data_format == "NHWC") 1487 1488tosa_padding = 1489 get_padding_values_from_pad_type(%input, ksize, padding, data_format, 1490 FORMAT_OHWI, strides, {1, 1, 1, 1}) 1491%output = tosa.AVG_POOL2D(%value) {ksize=ksize, strides=strides, padding=tosa_padding} 1492``` 1493 1494### tf.BatchMatMul 1495 1496Multiplies slices of two tensors in batches. 1497 1498**TensorFlow Dialect** 1499 1500``` 1501%output = tf.BatchMatMul(%x, %y, %adj_x, %adj_y) 1502``` 1503 1504**TOSA Lowering** 1505 1506No TOSA lowering defined. 1507 1508### tf.BatchMatMulV2 1509 1510Multiplies slices of two tensors in batches. 1511 1512**TensorFlow Dialect** 1513 1514``` 1515%output = tf.BatchMatMulV2(%x, %y, %adj_x, %adj_y) 1516``` 1517 1518**TOSA Lowering** 1519 1520No TOSA lowering defined. 1521 1522### tf.BatchNormWithGlobalNormalization 1523 1524✗ Deprecated operator. 1525 1526### tf.BatchToSpaceND 1527 1528BatchToSpaceND for N-D tensors of type T. 1529 1530**TensorFlow Dialect** 1531 1532``` 1533%output = tf.BatchToSpaceND(%input, %block_shape, %crops) 1534``` 1535 1536**TOSA Lowering** 1537 1538``` 1539%output = lower_batch_to_space_nd_op(%input, %block_shape, %crops, output.shape) 1540``` 1541 1542### tf.BiasAddGrad 1543 1544Training profile: TOSA lowering not yet defined. 1545 1546### tf.BiasAdd 1547 1548Add bias to value. 1549 1550**TensorFlow Dialect** 1551 1552``` 1553%output = tf.BiasAdd(%bias, %value) {data_format} 1554``` 1555 1556**TOSA Lowering** 1557 1558``` 1559assert(data_format == 'NHWC') 1560%output = tosa.ADD(%value, %bias) 1561``` 1562 1563### tf.BitCast 1564 1565Bitcasts a tensor from one type to another without copying data. 1566 1567**TensorFlow Dialect** 1568 1569``` 1570%output = tf.BitCast(%input, %dtype) 1571``` 1572 1573**TOSA Lowering** 1574 1575No TOSA lowering defined. 1576 1577### tf.BitwiseAnd 1578 1579This operator is trivially lowered to tosa.BITWISE_AND. 1580 1581### tf.BitwiseOr 1582 1583This operator is trivially lowered to tosa.BITWISE_OR. 1584 1585### tf.BroadcastGradientArgs 1586 1587Training profile: TOSA lowering not yet defined. 1588 1589### tf.BroadcastTo 1590 1591No TOSA lowering defined. 1592 1593### tf.Cast 1594 1595This operator is trivially lowered to tosa.CAST. 1596 1597### tf.Ceil 1598 1599This operator is trivially lowered to tosa.CEIL. 1600 1601### tf.CheckNumerics 1602 1603No TOSA lowering defined. 1604 1605### tf.ComplexAbs 1606 1607No TOSA lowering defined. 1608 1609### tf.Complex 1610 1611No TOSA lowering defined. 1612 1613### tf.ConcatOffset 1614 1615No TOSA lowering defined. Training profile: TOSA lowering not yet defined. 1616 1617### tf.Concat 1618 1619No TOSA lowering defined. 1620 1621### tf.ConcatV2 1622 1623Concatenates tensors along one dimension. 1624 1625**TensorFlow Dialect** 1626 1627``` 1628%output = tf.ConcatV2(%values, %axis) 1629``` 1630 1631**TOSA Lowering** 1632 1633``` 1634%output = lower_concatv2_op(%values, %axis) 1635``` 1636 1637### tf.Conj 1638 1639No TOSA lowering defined. 1640 1641### tf.Const 1642 1643This operator is trivially lowered to tosa.CONST. 1644 1645### tf.Conv2DBackpropFilter 1646 1647No TOSA lowering defined. 1648 1649### tf.Conv2DBackpropInput 1650 1651Computes the gradients of convolution with respect to the input. 1652 1653**TensorFlow Dialect** 1654 1655``` 1656%output = tf.Conv2DBackpropInput(%input_sizes, %filter, %out_backprop) {strides, use_cudnn_on_gpu, padding, explicit_paddings, data_format, dilations} 1657``` 1658 1659**TOSA Lowering** 1660 1661``` 1662// Transpose filter from HWIO to OHWI 1663%tosa_filter = tosa.TRANSPOSE(%filter) {perms={2, 0, 1, 3}} 1664 1665vector output_shape 1666 1667for (int32 i = 0; i < input_sizes.size(); i++) { 1668 output_shape.push_back(input_size[i]) 1669} 1670 1671if (%padding == "EXPLICIT") { 1672 tosa_padding = 1673 get_padding_values_from_explicit_pad_attr(explict_padding, data_format) 1674} else { 1675 tosa_padding = 1676 get_transpose_conv2d_padding_values_from_pad_type(%input_sizes, %filter, output_shape, padding, data_format, FORMAT_HWIO, strides, dilations) 1677} 1678 1679// Create a zero bias tensor 1680%zero_bias = tosa.CONST() {value={0}} 1681%output = tosa.TRANSPOSE_CONV2D(%out_backprop) {weight=%tosa_filter, bias=%zero_bias, outpad=tosa_pading, stride=strides, dilation==dilations, out_shape=out_shape} 1682``` 1683 1684### tf.Conv2D 1685 1686Computes a 2-D convolution given 4-D input and filter tensors. 1687 1688**TensorFlow Dialect** 1689 1690``` 1691%output = tf.Conv2D(%input, %filter) {strides, padding, explicit_paddings, data_format, dilations} 1692``` 1693 1694**TOSA Lowering** 1695 1696``` 1697assert(data_format == "NHWC") 1698 1699// Transpose filter from HWIO to OHWI 1700%filter_tranpose = tosa.TRANSPOSE(%filter {perms={3, 0, 1, 2}} 1701 1702if (padding == "EXPLICIT") { 1703 tosa_padding = 1704 get_padding_values_from_explicit_pad_attr(explict_padding, data_format) 1705} else { 1706 %tosa_padding = 1707 get_padding_values_from_pad_type(%input, %filter.shape, padding, data_format, 1708 FORMAT_HWIO, strides, dilations) 1709} 1710 1711// Create a zero bias tensor 1712%zero_bias = tosa.CONST() {value={0}} 1713 1714%output = tosa.CONV2D(%input, %filter_transpose, %zero_bias) {padding=tosa_padding, stride=strides, dilation=dilations} 1715``` 1716 1717### tf.Conv3D 1718 1719TOSA lowering to tosa.CONV3D to be defined. 1720 1721### tf.Cos 1722 1723No TOSA lowering defined. 1724 1725### tf.CrossReplicaSum 1726 1727No TOSA lowering defined. 1728 1729### tf.DepthToSpace 1730 1731DepthToSpace for tensors of type T. 1732 1733**TensorFlow Dialect** 1734 1735``` 1736%output = tf.DepthToSpace(%input) {block_size, data_format} 1737``` 1738 1739**TOSA Lowering** 1740 1741``` 1742%output = lower_depth_to_space_op(%input, block_size, data_format) 1743``` 1744 1745### tf.DepthwiseConv2dNative 1746 1747Computes a 2-D depthwise convolution given 4-D input and filter tensors. 1748 1749**TensorFlow Dialect** 1750 1751``` 1752%output = tf.DepthwiseConv2dNative(%input, %filter) {strides, padding, data_format, dilations} 1753``` 1754 1755**TOSA Lowering** 1756 1757``` 1758if (padding == "EXPLICIT") { 1759 tosa_padding = 1760 get_padding_values_from_explicit_pad_attr(explict_padding, data_format) 1761} else { 1762 tosa_padding = 1763 get_padding_values_from_pad_type(%input, %filter.shape, padding, data_format, 1764 FORMAT_HWIO, strides, dilations) 1765} 1766 1767bias_dim = %filter.shape[2] * %filter.shape[3] 1768 1769// Create a zero-bias tensor 1770%zero_bias = tosa.CONST() {value={0} * bias_dim} 1771 1772%output = tosa.DEPTHWISE_CONV2D(%input, %filter, %zero_bias) {stride=strides, dilation=dilations, padding=padding} 1773``` 1774 1775### tf.DivNoNan 1776 1777No TOSA lowering defined. 1778 1779### tf.Div 1780 1781No TOSA lowering defined. 1782 1783### tf.DynamicStitch 1784 1785No TOSA lowering defined. 1786 1787### tf.Einsum 1788 1789No TOSA lowering defined. 1790 1791### tf.Elu 1792 1793Computes exponential linear: exp(features) - 1 if <0, features otherwise 1794 1795**TensorFlow Dialect** 1796 1797``` 1798%output = tf.Elu(%features) 1799``` 1800 1801**TOSA Lowering** 1802 1803``` 1804%output = lower_elu_op(%features) 1805``` 1806 1807### tf.EmptyTensorList 1808 1809No TOSA lowering defined. 1810 1811### tf.Equal 1812 1813Returns the truth value of (x == y) element-wise with broadcasting. 1814 1815**TensorFlow Dialect** 1816 1817``` 1818%output = tf.Equal(%x, %y) 1819``` 1820 1821**TOSA Lowering** This operator is trivially lowered to tosa.EQUAL. 1822 1823### tf.Exp 1824 1825This operator is trivially lowered to tosa.EXP. 1826 1827### tf.ExpandDims 1828 1829Inserts a dimension of 1 into a tensor’s shape 1830 1831**TensorFlow Dialect** 1832 1833``` 1834%output = tf.ExpandDims(%input, %axis) 1835``` 1836 1837**TOSA Lowering** 1838 1839``` 1840%output = lower_expand_dims(%input, %axis.to_constant()) 1841``` 1842 1843### tf.FakeQuantWithMinMaxArgs 1844 1845Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. 1846 1847**TensorFlow Dialect** 1848 1849``` 1850%output = tf.FakeQuantWithMinMaxArgs(%inputs) {min, max, num_bits, narrow_range} 1851``` 1852 1853**TOSA Lowering** 1854 1855``` 1856%output = lower_fake_quant_op(%inputs, %min, %max, %num_bits, %narrow_range) 1857``` 1858 1859### tf.FakeQuantWithMinMaxVars 1860 1861Fake-quantize the 'inputs' tensor of type float via global flats scalars min. 1862 1863**TensorFlow Dialect** 1864 1865``` 1866%output = tf.FakeQuantWithMinMaxVars(%inputs, %min, %max) {num_bits, narrow_range} 1867``` 1868 1869**TOSA Lowering** 1870 1871``` 1872%output = lower_fake_quant_op(%inputs, %output.type, %min.to_constant(), %max.to_constant(), num_bits, narrow_range) 1873``` 1874 1875### tf.FakeQuantWithMinMaxVarsPerChannel 1876 1877Fake-quantize the 'inputs' tensor of type float and one of the shapes \[d\]. 1878 1879**TensorFlow Dialect** 1880 1881``` 1882%output = tf.FakeQuantWithMinMaxVarsPerChannel(%inputs, %min, %max) {num_bits, narrow_range} 1883``` 1884 1885No TOSA lowering defined. 1886 1887### tf.Fill 1888 1889Creates a tensor filled with a scalar value 1890 1891**TensorFlow Dialect** 1892 1893``` 1894%output = tf.Fill(%dims, %value) 1895``` 1896 1897**TOSA Lowering** 1898 1899``` 1900int64 total_size = 1 1901 1902for (int32 i = 0; i < %dims.shape[0]; i++) { 1903 total_size *= %dims[i] 1904} 1905 1906vector<%value.dtype> fill_arr(total_size, %value) 1907 1908%output = tosa.CONST() {value={fill_arr}} 1909``` 1910 1911### tf.FloorDiv 1912 1913Returns x // y element-wise. 1914 1915**TensorFlow Dialect** 1916 1917``` 1918%output = tf.FloorDiv(%x, %y) 1919``` 1920 1921**TOSA Lowering** 1922 1923``` 1924%output = lower_floor_div(%lhs, %rhs) 1925``` 1926 1927### tf.FloorMod 1928 1929Returns element-wise remainder of division when x < 0 xor x < y is true. 1930 1931**TensorFlow Dialect** 1932 1933``` 1934%output = tf.FloorMod(%x, %y) 1935``` 1936 1937**TOSA Lowering** 1938 1939``` 1940%output = lower_floor_mod(%lhs, %rhs) 1941``` 1942 1943### tf.Floor 1944 1945This operator is trivially lowered to tosa.FLOOR. 1946 1947### tf.FusedBatchNormGrad 1948 1949Training profile: TOSA lowering not yet defined. 1950 1951### tf.FusedBatchNormGradV2 1952 1953Training profile: TOSA lowering not yet defined. 1954 1955### tf.FusedBatchNormGradV3 1956 1957Training profile: TOSA lowering not yet defined. 1958 1959### tf.FusedBatchNorm 1960 1961Batch normalization. 1962 1963**TensorFlow Dialect** 1964 1965``` 1966%output = tf.FusedBatchNorm(%x, %scale, %offset, %mean, %variance) {epsilon, data_format, is_training} 1967 1968 1969assert(data_format == 'NHWC') 1970assert(is_training == false) 1971 1972%epsilon_const = tosa.CONST() {value={epsilon}} 1973 1974%op1 = tosa.SUB(%x, %bmean) 1975%op2 = tosa.ADD(%variance, %epsilon_const) 1976%op3 = tosa.RSQRT(%op2) 1977%op4 = tosa.MUL(%op1, %op3) 1978%op5 = tosa.MUL(%op4, %scale) 1979%output = tosa.ADD(%op5, %offset) 1980``` 1981 1982### tf.FusedBatchNormV3 1983 1984Batch normalization. 1985 1986**TensorFlow Dialect** 1987 1988``` 1989%output = tf.FusedBatchNormV3(%x, %scale, %offset, %mean, %variance) {epsilon, data_format, is_training} 1990``` 1991 1992**TOSA Lowering** 1993 1994``` 1995assert(data_format == 'NHWC') 1996assert(is_training == false) 1997 1998%epsilon_const = tosa.CONST() {value={epsilon}} 1999 2000%op1 = tosa.SUB(%x, %bmean) 2001%op2 = tosa.ADD(%variance, %epsilon_const) 2002%op3 = tosa.RSQRT(%op2) 2003%op4 = tosa.MUL(%mean, %op3) 2004%op5 = tosa.MUL(%op4, %scale) 2005%output = tosa.ADD(%op5, %offset) 2006``` 2007 2008### tf.GatherNd 2009 2010Gather slices from params into a Tensor with shape specified by indices. 2011 2012**TensorFlow Dialect** 2013 2014``` 2015%output = tf.GatherNd(%params, %indices) 2016``` 2017 2018**TOSA Lowering** 2019 2020``` 2021%output = lower_gather_nd_op(%params, %indices) 2022``` 2023 2024### tf.Gather 2025 2026Gathers slices from params according to indices. 2027 2028**TensorFlow Dialect** 2029 2030``` 2031%output = tf.Gather(%params, %indices) 2032``` 2033 2034**TOSA Lowering** 2035 2036``` 2037%output = lower_gather_op(%params, %indices, 0, 0) 2038``` 2039 2040### tf.GatherV2 2041 2042Gathers slices from params axis according to indices. 2043 2044**TensorFlow Dialect** 2045 2046``` 2047%output = tf.GatherV2(%params, %indices, %axis) {batch_dims} 2048``` 2049 2050**TOSA Lowering** 2051 2052``` 2053%output = lower_gather_op(%params, %indices, batch_dims, %axis.to_constant()) 2054``` 2055 2056### tf.GreaterEqual 2057 2058Returns the truth value of (x >= y) element-wise with broadcasting. 2059 2060**TensorFlow Dialect** 2061 2062``` 2063%output = tf.GreaterEqual(%x, %y) 2064``` 2065 2066**TOSA Lowering** This operator is trivially lowered to tosa.GREATER_EQUAL. 2067 2068### tf.Greater 2069 2070RetruReturns the truth value of (x > y) element-wise with broadcasting. 2071 2072**TensorFlow Dialect** 2073 2074``` 2075%output = tf.Greater(%x, %y) 2076``` 2077 2078**TOSA Lowering** This operator is trivially lowered to tosa.GREATER. 2079 2080### tf.HashTableV2 2081 2082No TOSA lowering defined. 2083 2084### tf.IdentityN 2085 2086Returns a list of tensors with the same shapes and contents as the input. 2087 2088**TensorFlow Dialect** 2089 2090``` 2091%output = tf.IdentityN(%input) 2092``` 2093 2094**TOSA Lowering** 2095 2096``` 2097%output = tosa.IDENTITYN(%input) 2098``` 2099 2100### tf.Identity 2101 2102Returns a tensor with the same shape and contents as the input. 2103 2104**TensorFlow Dialect** 2105 2106``` 2107%output = tf.Identity(%input) 2108``` 2109 2110**TOSA Lowering** 2111 2112``` 2113%output = tosa.IDENTITY(%input) 2114``` 2115 2116### tf.If 2117 2118No TOSA lowering defined. 2119 2120### tf.Imag 2121 2122No TOSA lowering defined. 2123 2124### tf.InfeedDequeueTuple 2125 2126No TOSA lowering defined. 2127 2128### tf.Invert 2129 2130This operator is trivially lowered to tosa.BITWISE_NOT. 2131 2132### tf.InvertPermutation 2133 2134No TOSA lowering defined. 2135 2136### tf.IsFinite 2137 2138No TOSA lowering defined. 2139 2140### tf.IteratorGetNext 2141 2142No TOSA lowering defined. 2143 2144### tf.L2Loss 2145 2146Training profile: TOSA lowering not yet defined. 2147 2148### tf.LRN 2149 2150No TOSA lowering defined. 2151 2152### tf.LeakyRelu 2153 2154Computes rectified linear: max(features, features \* alpha). 2155 2156**TensorFlow Dialect** 2157 2158``` 2159%output = tf.LeakyRelu(%features) {alpha} 2160``` 2161 2162**TOSA Lowering** 2163 2164``` 2165%alpha_tensor = tosa.CONST() {value={alpha}} 2166%features_alpha = tosa.MUL(%features, %alpha_tensor) 2167%greater = tosa.GREATER(%features, %features_alpha) 2168%output = tosa.SELECT(%greater, %features, %features_alpha) 2169``` 2170 2171### tf.LeftShift 2172 2173Computes the bitwise left-shift of x by y bits, element-wise. 2174 2175**TensorFlow Dialect** 2176 2177``` 2178%output = tf.LeftShift(%x, %y) 2179``` 2180 2181**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_LEFT_SHIFT. 2182 2183### tf.LegacyCall 2184 2185No TOSA lowering defined. 2186 2187### tf.LessEqual 2188 2189Returns the truth value of (x ⇐ y) element-wise with broadcasting. 2190 2191**TensorFlow Dialect** 2192 2193``` 2194%output = tf.LessEqual(%x, %y) 2195``` 2196 2197**TOSA Lowering** 2198 2199``` 2200%output_greater = tosa.GREATER(%x, %y) 2201%output = tosa.LOGICAL_NOT(%output_greater) 2202``` 2203 2204### tf.Less 2205 2206Returns the truth value of (x < y) element-wise with broadcasting. 2207 2208**TensorFlow Dialect** 2209 2210``` 2211%output = tf.LessEqual(%x, %y) 2212``` 2213 2214**TOSA Lowering** 2215 2216``` 2217%output_greater_equal = tosa.GREATER_EQUAL(%x, %y) 2218%output = tosa.LOGICAL_NOT(%output_greater_equal) 2219``` 2220 2221### tf.LiNSpace 2222 2223No TOSA lowering defined. 2224 2225### tf.Log1p 2226 2227No TOSA lowering defined. 2228 2229### tf.Log 2230 2231This operator is trivially lowered to tosa.LOG. 2232 2233### tf.LogSoftmax 2234 2235Computes log softmax activations. 2236 2237**TensorFlow Dialect** 2238 2239``` 2240%output = tf.LogSoftmax(%logits) 2241``` 2242 2243**TOSA Lowering** 2244 2245``` 2246%output = lower_log_softmax_op(%logits) 2247``` 2248 2249### tf.LogicalAnd 2250 2251Returns the truth value of x AND y, element-wise. 2252 2253**TensorFlow Dialect** 2254 2255``` 2256%output = tf.LogicalAnd(%x, %y) 2257``` 2258 2259**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_AND. 2260 2261### tf.LogicalNot 2262 2263This operator is trivially lowered to tosa.LOGICAL_NOT. 2264 2265### tf.LogicalOr 2266 2267Returns the truth value of x OR y, element-wise. 2268 2269**TensorFlow Dialect** 2270 2271``` 2272%output = tf.LogicalOr(%x, %y) 2273``` 2274 2275**TOSA Lowering** This operator is trivially lowered to tosa.LOGICAL_OR. 2276 2277### tf.LookupTableFindV2 2278 2279No TOSA lowering defined. 2280 2281### tf.LookupTableInputV2 2282 2283No TOSA lowering defined. 2284 2285### tf.LookupTableSizeV2 2286 2287No TOSA lowering defined. 2288 2289### tf.MatMul 2290 2291Multiply the matrix a by the matrix b 2292 2293**TensorFlow Dialect** 2294 2295``` 2296%output = tf.MatMul(%a, %b) 2297``` 2298 2299**TOSA Lowering** 2300 2301``` 2302%output = tosa.MATMUL(%a, %b) 2303``` 2304 2305### tf.MatrixDiag 2306 2307No TOSA lowering defined. 2308 2309### tf.MatrixDiagV2 2310 2311No TOSA lowering defined. 2312 2313### tf.MatrixDiagV3 2314 2315No TOSA lowering defined. 2316 2317### tf.MatrixSetDiag 2318 2319No TOSA lowering defined. 2320 2321### tf.MatrixSetDiagV2 2322 2323No TOSA lowering defined. 2324 2325### tf.MatrixSetDiagV3 2326 2327No TOSA lowering defined. 2328 2329### tf.Max 2330 2331Computes the maximum of elements across dimensions of a tensor. 2332 2333**TensorFlow Dialect** 2334 2335``` 2336%output = tf.Max(%input, %reduction_indices) {keep_dims} 2337``` 2338 2339**TOSA Lowering** 2340 2341``` 2342%output = lower_reduce_op<tosa.REDUCE_MAX>(%input, %output.shape, %reduction_indices, keep_dims) 2343``` 2344 2345### tf.MaxPoolGrad 2346 2347Training profile: TOSA lowering not yet defined. 2348 2349### tf.MaxPool 2350 2351Performs max pooling on the input. 2352 2353**TensorFlow Dialect** 2354 2355``` 2356%output = tf.MaxPool(%input) {ksize, strides, padding, data_format} 2357``` 2358 2359**TOSA Lowering** 2360 2361``` 2362assert(data_format == "NHWC") 2363 2364tosa_padding = 2365 get_padding_values_from_pad_type(%input, ksize, padding, data_format, 2366 FORMAT_OHWI, strides, {1, 1, 1, 1}) 2367%output = tosa.MAX_POOL2D(%value) {ksize=ksize, strides=strides, padding=tosa_padding} 2368``` 2369 2370### tf.Maximum 2371 2372This operator is trivially lowered to tosa.MAXIMUM. 2373 2374### tf.Mean 2375 2376Computes the mean of elements across dimensions of a tensor. 2377 2378**TensorFlow Dialect** 2379 2380``` 2381%output = tf.Mean(%input, %reduction_indices) {keep_dims} 2382``` 2383 2384**TOSA Lowering** 2385 2386``` 2387int32 num_elements_on_axis = 1 2388for (int32 axis : %reduction_indices) { 2389 num_elements_on_axis *= %input.shape[axis] 2390} 2391float32 div_scale = 1.0 / num_elements_on_axis 2392 2393%cst_div_scale = tosa.CONST() {value={div_scale}} 2394%op1_rsum_in = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %reduction_indices, keep_dims) 2395%op2_mul_op1 = tosa.MUL(%op1_rsum_in, %cst_div_scale) 2396``` 2397 2398### tf.Min 2399 2400Computes the minimum of elements across dimensions of a tensor. 2401 2402**TensorFlow Dialect** 2403 2404``` 2405%output = tf.Min(%input, %reduction_indices) {keep_dims} 2406``` 2407 2408**TOSA Lowering** 2409 2410``` 2411%output = lower_reduce_op<tosa.REDUCE_MIN>(%input, %output.shape, %reduction_indices, keep_dims) 2412``` 2413 2414### tf.Minimum 2415 2416This operator is trivially lowered to tosa.MAXIMUM. 2417 2418### tf.MirrorPad 2419 2420No TOSA lowering defined. 2421 2422### tf.MlirPassthroughOp 2423 2424No TOSA lowering defined. 2425 2426### tf.MulNoNan 2427 2428No TOSA lowering defined. 2429 2430### tf.Mul 2431 2432Returns the product of x and y, element-wise. 2433 2434**TensorFlow Dialect** 2435 2436``` 2437%output = tf.Mul(%x, %y) 2438``` 2439 2440**TOSA Lowering** This operator is trivially lowered to tosa.MUL. 2441 2442### tf.Neg 2443 2444This operator is trivially lowered to tosa.NEGATE. 2445 2446### tf.NoOp 2447 2448No TOSA lowering defined. 2449 2450### tf.NonMaxSuppressionV4 2451 2452No TOSA lowering defined. 2453 2454### tf.NonMaxSuppressionV5 2455 2456No TOSA lowering defined. 2457 2458### tf.NotEqual 2459 2460Returns the truth value of (x != y) element-wise with broadcasting. 2461 2462**TensorFlow Dialect** 2463 2464``` 2465%output = tf.NotEqual(%x, %y) 2466``` 2467 2468**TOSA Lowering** 2469 2470``` 2471%equal = tosa.EQUAL(%x, %y) 2472%output = tosa.NOT(%equal) 2473``` 2474 2475### tf.OneHot 2476 2477OneHot operator. 2478 2479**TensorFlow Lite Dialect** 2480 2481``` 2482%output = tf.OneHot(%indices, %depth, %on_value, %off_value) {axis} 2483``` 2484 2485**TOSA Lowering** 2486 2487``` 2488%output = lower_one_hot_op(%indices, %depth, %on_value, %off_value, axis) 2489``` 2490 2491### tf.OutputEnqueueTuple 2492 2493No TOSA lowering defined. 2494 2495### tf.Pack 2496 2497Packs a list of N rank-R tensors into one rank-(R+1) tensor. 2498 2499**TensorFlow Dialect** 2500 2501``` 2502%output = tf.Pack(%values) {axis} 2503``` 2504 2505**TOSA Lowering** 2506 2507``` 2508%output = lower_pack_op(%values, axis) 2509``` 2510 2511### tf.Pad 2512 2513This operator is trivially lowered to tosa.PAD. 2514 2515### tf.PadV2 2516 2517No TOSA lowering defined. 2518 2519### tf.ParseExampleV2 2520 2521No TOSA lowering defined. 2522 2523### tf.PartitionedCall 2524 2525No TOSA lowering defined. 2526 2527### tf.Placeholder 2528 2529Not seen in practice. No lowering needed. 2530 2531### tf.PlaceholderWithDefault 2532 2533Not seen in practice. No lowering needed. 2534 2535### tf.Pow 2536 2537This operator is trivially lowered to tosa.POW. 2538 2539### tf.PreventGradient 2540 2541Training profile: TOSA lowering not yet defined. 2542 2543### tf.Prod 2544 2545Computes the product of elements across dimensions of a tensor. 2546 2547**TensorFlow Dialect** 2548 2549``` 2550%output = tf.Prod(%input, %reduction_indices) {keep_dims} 2551``` 2552 2553**TOSA Lowering** 2554 2555``` 2556%output = lower_reduce_op<tosa.REDUCE_PRODUCT>(%input, %output.shape, %reduction_indices, keep_dims) 2557``` 2558 2559### tf.QuantizeAndDequantize 2560 2561No TOSA lowering defined. 2562 2563### tf.QuantizeAndDequantizeV2 2564 2565No TOSA lowering defined. 2566 2567### tf.QuantizeAndDequantizeV3 2568 2569No TOSA lowering defined. 2570 2571### tf.RFFT 2572 2573No TOSA lowering defined. 2574 2575### tf.RandomShuffle 2576 2577No TOSA lowering defined. 2578 2579### tf.RandomStandardNormal 2580 2581No TOSA lowering defined. 2582 2583### tf.RandomUniform 2584 2585No TOSA lowering defined. 2586 2587### tf.Range 2588 2589No TOSA lowering defined. 2590 2591### tf.Rank 2592 2593Returns the rank of the tensor. 2594 2595**TensorFlow Dialect** 2596 2597``` 2598%output = tf.Rank(%input) 2599``` 2600 2601**TOSA Lowering** 2602 2603``` 2604%output = tosa.CONST() {value={%input.rank}} 2605``` 2606 2607### tf.ReadVariableOp 2608 2609No TOSA lowering defined. 2610 2611### tf.RealDiv 2612 2613Returns x / y element-wise for real types. 2614 2615**TensorFlow Dialect** 2616 2617``` 2618%output = tf.RealDiv(%x, %y) 2619``` 2620 2621**TOSA Lowering** 2622 2623``` 2624%recip = tosa.RECIPROCAL(%y) 2625%output = tosa.MUL(%x, %recip) 2626``` 2627 2628### tf.Real 2629 2630No TOSA lowering defined. 2631 2632### tf.Reciprocal 2633 2634This operator is trivially lowered to tosa.RECIPROCAL. 2635 2636### tf.Relu6 2637 2638Computes rectified linear 6: min(max(features, 0), 6). 2639 2640**TensorFlow Dialect** 2641 2642``` 2643%output = tf.Relu6(%features) 2644``` 2645 2646**TOSA Lowering** 2647 2648``` 2649%output = tosa.RELUN(%features) {max_val=6} 2650``` 2651 2652### tf.ReluGrad 2653 2654Training profile: TOSA lowering not yet defined. 2655 2656### tf.Relu 2657 2658Computes rectified linear 6: max(features, 0) 2659 2660**TensorFlow Dialect** 2661 2662``` 2663%output = tf.Relu(%features) 2664``` 2665 2666**TOSA Lowering** 2667 2668``` 2669%output = tosa.RELUN(%features) {max_val=0} 2670``` 2671 2672### tf.Reshape 2673 2674Reshapes a tensor. 2675 2676**TensorFlow Dialect** 2677 2678``` 2679%output = tf.Reshape(%tensor, %shape) 2680``` 2681 2682**TOSA Lowering** 2683 2684``` 2685%output = tosa.RESHAPE(%tensor) {new_shape=%shape.as_constant} 2686``` 2687 2688### tf.ResizeBilinear 2689 2690Resizes images to size using bilinear interpolation. 2691 2692**TensorFlow Dialect** 2693 2694``` 2695%output = tf.ResizeBilinear(%images, %size) {align_corners, half_pixel_centers} 2696``` 2697 2698inferred from output shape. **TOSA Lowering** 2699 2700``` 2701%output = lower_resize_op(%images, %size, float, "BILINEAR") 2702``` 2703 2704### tf.ResizeNearestNeighbor 2705 2706Resizes images to size using nearest neighbor interpolation. 2707 2708**TensorFlow Dialect** 2709 2710``` 2711%output = tf.ResizeNearestNeighbor(%images, %size) {align_corners, half_pixel_centers} 2712``` 2713 2714inferred from output shape. **TOSA Lowering** 2715 2716``` 2717%output = lower_resize_op(%images, %size, %output, float, "NEAREST_NEIGHBOR") 2718``` 2719 2720### tf.ResourceApplyAdam 2721 2722Training profile: TOSA lowering not yet defined. 2723 2724### tf.ResourceApplyGradientDescent 2725 2726Training profile: TOSA lowering not yet defined. 2727 2728### tf.ResourceApplyKerasMomentum 2729 2730Training profile: TOSA lowering not yet defined. 2731 2732### tf.ResourceGather 2733 2734Training profile: TOSA lowering not yet defined. 2735 2736### tf.ResourceScatterUpdate 2737 2738Training profile: TOSA lowering not yet defined. 2739 2740### tf.ReverseSequence 2741 2742No TOSA lowering defined. 2743 2744### tf.ReverseV2 2745 2746Reverses specific dimensions of a tensor. 2747 2748**TensorFlow Dialect** 2749 2750``` 2751%output = tf.ReverseV2(%tensor, %axis) 2752``` 2753 2754**TOSA Lowering** 2755 2756``` 2757%output = lower_reversev2_op(%tensor, %axis) 2758``` 2759 2760### tf.RightShift 2761 2762Computes the bitwise left-shift of x by y bits, element-wise. 2763 2764**TensorFlow Dialect** 2765 2766``` 2767%output = tf.LeftShift(%x, %y) 2768``` 2769 2770**TOSA Lowering** 2771 2772``` 2773if (is_unsigned(%x.dtype)) { 2774 %output = tosa.LOGICAL_RIGHT_SHIFT(%x, %y) 2775} else { 2776 %output = tosa.ARITHMETIC_RIGHT_SHIFT(%x, %y) 2777} 2778``` 2779 2780### tf.Round 2781 2782Rounds the values of a tensor to the nearest integer, element-wise. 2783 2784**TensorFlow Dialect** 2785 2786``` 2787%output = tf.Round(%x) 2788``` 2789 2790**TOSA Lowering** 2791 2792``` 2793%output = lower_round_op(%x) 2794``` 2795 2796### tf.RsqrtGrad 2797 2798Training profile: TOSA lowering not yet defined. 2799 2800### tf.Rsqrt 2801 2802This operator is trivially lowered to tosa.RSQRT. 2803 2804### tf.SegmentMax 2805 2806No TOSA lowering defined. 2807 2808### tf.SegmentMean 2809 2810No TOSA lowering defined. 2811 2812### tf.SegmentMin 2813 2814No TOSA lowering defined. 2815 2816### tf.SegmentProd 2817 2818No TOSA lowering defined. 2819 2820### tf.SegmentSum 2821 2822No TOSA lowering defined. 2823 2824### tf.Select 2825 2826No TOSA lowering defined. 2827 2828### tf.SelectV2 2829 2830Selects elements from t or e depending on condition. 2831 2832**TensorFlow Dialect** 2833 2834``` 2835%output = tf.SelectV2(%condition, %t, %e) 2836``` 2837 2838**TOSA Lowering** 2839 2840``` 2841%output = lower_selectv2_op(%condition, %t, %e, %output.shape) 2842``` 2843 2844### tf.ShapeN 2845 2846No TOSA lowering defined. 2847 2848### tf.Shape 2849 2850Returns the shape of a tensor. 2851 2852**TensorFlow Dialect** 2853 2854``` 2855%output = tf.Shape(%input) 2856``` 2857 2858**TOSA Lowering** 2859 2860``` 2861%output = lower_shape_op(%input) 2862``` 2863 2864### tf.Sigmoid 2865 2866This operator is trivially lowered to tosa.SIGMOID. 2867 2868### tf.Sign 2869 2870No TOSA lowering defined. 2871 2872### tf.Sin 2873 2874No TOSA lowering defined. 2875 2876### tf.Size 2877 2878No TOSA lowering defined. 2879 2880### tf.Slice 2881 2882Returns a slice from input. 2883 2884**TensorFlow Dialect** 2885 2886``` 2887%output = tf.Slice(%input, %begin, %size) 2888``` 2889 2890**TOSA Lowering** 2891 2892``` 2893vector <size_t> output_size 2894try { 2895 output_size = %size.as_constant() 2896} except(ConversionFailed) { 2897 output_size = %output.shape 2898} 2899 2900%output = tosa.SLICE(%input) {start=begin, size=output_size} 2901``` 2902 2903### tf.Snapshot 2904 2905No TOSA lowering defined. 2906 2907### tf.SoftmaxCrossEntropyWithLogits 2908 2909Training profile: TOSA lowering not yet defined. 2910 2911### tf.Softmax 2912 2913Computes softmax activations 2914 2915**TensorFlow Dialect** 2916 2917``` 2918%output = tf.Softmax(%logits) 2919``` 2920 2921**TOSA Lowering** 2922 2923``` 2924%op1 = tosa.EXP(%logits) 2925%op2 = tosa.REDUCE_SUM(op1) {reduce_axis=(%logits.rank - 1)} 2926%op3 = tosa.RECIPROCAL(%op2) 2927%output = tosa.MUL(%op1, %op3) 2928``` 2929 2930### tf.Softplus 2931 2932No TOSA lowering defined. 2933 2934### tf.SpaceToBatchND 2935 2936SpaceToBatch for N-D tensors of type T. 2937 2938**TensorFlow Dialect** 2939 2940``` 2941%output = tf.SpaceToBatchND(%input, %block_shape, %paddings) 2942``` 2943 2944**TOSA Lowering** 2945 2946``` 2947%output = lower_space_to_batch_nd_op(%input, %block_shape, %paddings) 2948``` 2949 2950### tf.SpaceToDepth 2951 2952SpaceToDepth for tensors of type T. 2953 2954**TensorFlow Dialect** 2955 2956``` 2957%output = tf.SpaceToDepth(%input) {block_size, data_format} 2958``` 2959 2960**TOSA Lowering** 2961 2962``` 2963%output = lower_space_to_depth_op(%input, block_size, data_format) 2964``` 2965 2966### tf.SparseMatMul 2967 2968No TOSA lowering defined. 2969 2970### tf.SparseSoftmaxCrossEntropyWithLogits 2971 2972No TOSA lowering defined. 2973 2974### tf.SparseToDense 2975 2976No TOSA lowering defined. 2977 2978### tf.Split 2979 2980Splits a tensor into num_split tensors along one dimension 2981 2982**TensorFlow Dialect** 2983 2984``` 2985%output = tf.Split(%split_dim, %value) {num_split} 2986``` 2987 2988**TOSA Lowering** 2989 2990``` 2991%output = lower_split_op(%value, %split_dim.as_constant(), num_split) 2992``` 2993 2994### tf.SplitV 2995 2996Splits a tensor into num_split tensors along one dimension 2997 2998**TensorFlow Dialect** 2999 3000``` 3001%output = tf.SplitV(%value, %size_splits, %split_dim) {num_split} 3002``` 3003 3004**TOSA Lowering** 3005 3006``` 3007%output = lower_splitv_op(%value, %size_splits.as_constant(), %split_dim.as_constant()) 3008``` 3009 3010### tf.Sqrt 3011 3012No TOSA lowering defined. 3013 3014### tf.Square 3015 3016Computes the square of x, element-wise. 3017 3018**TensorFlow Dialect** 3019 3020``` 3021%output = tf.Square(%x) 3022``` 3023 3024**TOSA Lowering** 3025 3026``` 3027%output = tosa.MUL(%x, %x) 3028``` 3029 3030### tf.SquareDifference 3031 3032Computes (x-y)\*(x-y) element-wise 3033 3034**TensorFlow Dialect** 3035 3036``` 3037%output = tf.SquareDifference(%x, %y) 3038``` 3039 3040**TOSA Lowering** 3041 3042``` 3043%diff = tosa.SUB(%x, %y) 3044%output = tosa.MUL(%diff, %diff) 3045``` 3046 3047### tf.Squeeze 3048 3049Removes dimensions of size 1 from the shape of a tensor. 3050 3051**TensorFlow Dialect** 3052 3053``` 3054%output = tf.Squeeze(%input) {squeeze_dims} 3055``` 3056 3057**TOSA Lowering** 3058 3059``` 3060%output = lower_squeeze_op(%input, squeeze_dims) 3061``` 3062 3063### tf.StatefulPartitionedCall 3064 3065No TOSA lowering defined. 3066 3067### tf.StopGradient 3068 3069Training profile: TOSA lowering not yet defined. 3070 3071### tf.StridedSliceGrad 3072 3073Training profile: TOSA lowering not yet defined. 3074 3075### tf.StridedSlice 3076 3077Return a strided slice from input. 3078 3079**TensorFlow Dialect** 3080 3081``` 3082%output = tf.StridedSlice(%input, %begin, %end, %strides) {begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask} 3083``` 3084 3085**TOSA Lowering** 3086 3087``` 3088%output = lower_strided_slice_op(%input, %begin, %end, %strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) 3089``` 3090 3091### tf.Sub 3092 3093This operator is trivially lowered to tosa.SUB. 3094 3095### tf.Sum 3096 3097Computes the sum of elements across dimensions of a tensor. 3098 3099**TensorFlow Dialect** 3100 3101``` 3102%output = tf.Sum(%input, %reduction_indices) {keep_dims} 3103``` 3104 3105**TOSA Lowering** 3106 3107``` 3108%output = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %reduction_indices, keep_dims) 3109``` 3110 3111### tf.TPUCompilationResult 3112 3113No TOSA lowering defined. 3114 3115### tf.TPUCopyWithLayout 3116 3117No TOSA lowering defined. 3118 3119### tf.TPUExecuteAndUpdateVariables 3120 3121No TOSA lowering defined. 3122 3123### tf.TPUExecute 3124 3125No TOSA lowering defined. 3126 3127### tf.TPUGetLayout 3128 3129No TOSA lowering defined. 3130 3131### tf.TPUReplicateMetadata 3132 3133No TOSA lowering defined. 3134 3135### tf.TPUReplicatedInput 3136 3137No TOSA lowering defined. 3138 3139### tf.TPUReplicatedOutput 3140 3141No TOSA lowering defined. 3142 3143### tf.TPUReshardVariables 3144 3145No TOSA lowering defined. 3146 3147### tf.TanhGrad 3148 3149Training profile: TOSA lowering not yet defined. 3150 3151### tf.Tanh 3152 3153This operator is trivially lowered to tosa.TANH. 3154 3155### tf.TensorListFromTensor 3156 3157No TOSA lowering defined. 3158 3159### tf.TensorListGetItem 3160 3161No TOSA lowering defined. 3162 3163### tf.TensorListLength 3164 3165No TOSA lowering defined. 3166 3167### tf.TensorListPushBack 3168 3169No TOSA lowering defined. 3170 3171### tf.TensorListReserve 3172 3173No TOSA lowering defined. 3174 3175### tf.TensorListResize 3176 3177No TOSA lowering defined. 3178 3179### tf.TensorListSetItem 3180 3181No TOSA lowering defined. 3182 3183### tf.TensorListStack 3184 3185No TOSA lowering defined. 3186 3187### tf.TensorScatterUpdate 3188 3189No TOSA lowering defined. 3190 3191### tf.Tile 3192 3193Constructs a tensor by tiling a given tensor. 3194 3195**TensorFlow Dialect** 3196 3197``` 3198%output = tf.Tile(%input, %multiples) 3199``` 3200 3201**TOSA Lowering** 3202 3203``` 3204%output = tosa.TILE(%input) {multiples=%multiples.as_constant()} 3205``` 3206 3207### tf.ToBool 3208 3209No TOSA lowering defined. 3210 3211### tf.TopKV2 3212 3213No TOSA lowering defined. 3214 3215### tf.Transpose 3216 3217Shuffle dimensions of x according to a permutation. 3218 3219**TensorFlow Dialect** 3220 3221``` 3222%output = tf.Transpose(%x, %perm) 3223``` 3224 3225**TOSA Lowering** 3226 3227``` 3228%output = tosa.TRANSPOSE(%x) {perm=%perm.as_constant()} 3229``` 3230 3231### tf.TruncateDiv 3232 3233No TOSA lowering defined. 3234 3235### tf.Unique 3236 3237No TOSA lowering defined. 3238 3239### tf.Unpack 3240 3241Unpacks a given dimension of a rank-R tensor into num rank-(R-1) tensors. 3242 3243**TensorFlow Dialect** 3244 3245``` 3246%output = tf.Unpack(%value) {axis, num} 3247``` 3248 3249**TOSA Lowering** 3250 3251``` 3252%output = lower_unpack_op(%value, axis, num) 3253``` 3254 3255### tf.UnsortedSegmentMax 3256 3257No TOSA lowering defined. 3258 3259### tf.UnsortedSegmentMin 3260 3261No TOSA lowering defined. === tf.UnsortedSegmentProd 3262 3263No TOSA lowering defined. === tf.UnsortedSegmentSum 3264 3265No TOSA lowering defined. 3266 3267### tf.VarHandle 3268 3269No TOSA lowering defined. 3270 3271### tf.VariableShape 3272 3273No TOSA lowering defined. 3274 3275### tf.Where 3276 3277No TOSA lowering defined. 3278 3279### tf.While 3280 3281No TOSA lowering defined. 3282 3283### tf.Xdivy 3284 3285No TOSA lowering defined. 3286 3287### tf.XlaDynamicUpdateSlice 3288 3289No TOSA lowering defined. 3290 3291### tf.XlaSharding 3292 3293No TOSA lowering defined. 3294 3295### tf.ZerosLike 3296 3297Returns a tensor of zeros with the same shape and type as x. 3298 3299**TensorFlow Dialect** 3300 3301``` 3302%output = tf.ZerosLike(%x) 3303``` 3304 3305**TOSA Lowering** 3306 3307``` 3308%output = tosa.CONST() {value={0} * %x.num_elements} 3309``` 3310 3311## TensorFlow Lite MLIR Dialect Legalization (legalize_tfl) 3312 3313### tfl.abs 3314 3315This operator is trivially lowered to tosa.ABS 3316 3317### tfl.add_n 3318 3319add_n operator. 3320 3321**TensorFlow Lite Dialect** 3322 3323``` 3324%sum = tfl.add_n(%inputs) 3325``` 3326 3327**TOSA Lowering** 3328 3329``` 3330%output = tosa.ADD(%inputs:0, %inputs:1) 3331for (int32 i = 2 i < %inputs.size i++) { 3332 %output = tosa.ADD(%inputs:i, %output) 3333} 3334``` 3335 3336### tfl.add 3337 3338Element-wise addition operation. 3339 3340**TensorFlow Lite Dialect** 3341 3342``` 3343%output = tfl.add(%lhs, %rhs) 3344``` 3345 3346**TOSA Lowering** 3347 3348If input/output tensors are all non-quantized typed, 3349 3350Legalization: 3351 3352``` 3353%result = tosa.ADD(%lhs, %rhs) 3354``` 3355 3356If input/output tensors are all quantized typed, 3357 3358Prepare: 3359 3360``` 3361float64 max_scale_2x = 2.0 * max(%lhs.scale, %rhs.scale) 3362float64 lhs_scale = float64(1 << input_shift) * %lhs.scale / max_scale_2x 3363float64 rhs_scale = float64(1 << input_shift) * %rhs.scale / max_scale_2x 3364float64 output_scale = max_scale_2x / (%output.scale * float64(1 << input_shift)) 3365 3366``` 3367 3368Legalization: 3369 3370``` 3371%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=lhs_scale, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32 3372%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=rhs_scale, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32 3373%op3_add_op1_op2 = tosa.ADD(%op1_rescale_lhs, %op2_rescale_rhs) 3374%op4_rescale_op3 = tosa.RESCALE(%op3_add_op1_op2) {scale=output_scale} // i32->%output.dtype 3375``` 3376 3377### tfl.arg_max 3378 3379ArgMax operator. 3380 3381**TensorFlow Lite Dialect** 3382 3383``` 3384%output = tfl.arg_max(%input, %dim) 3385``` 3386 3387**TOSA Lowering** 3388 3389``` 3390%result = tosa.ARGMAX(%input) {axis=positive_axis(%dim_const.as_constant(), %input.rank)} 3391``` 3392 3393### tfl.arg_min 3394 3395No TOSA lowering defined. 3396 3397### tfl.average_pool_2d 3398 3399Average_pool_2d operator. 3400 3401**TensorFlow Lite Dialect** 3402 3403``` 3404%output = tfl.average_pool_2d(%input) {filter_height, filter_width, padding, stride_h, stride_w, fused_activation_function} 3405``` 3406 3407**TOSA Lowering** 3408 3409Prepare: 3410 3411``` 3412tosa_padding = 3413 get_padding_values_from_pad_type(padding, NHWC, 1, 3414 %input.type, tensor<{filter_height, filter_width}, tosa.int32>, 3415 {1, stride_h, stride_w, 1}, {1, 1, 1, 1}) 3416``` 3417 3418If input/output tensors are all non-quantized typed, 3419 3420Legalization: 3421 3422``` 3423%avgpool2d = tosa.AVG_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding} 3424if(fused_activation != NONE) { 3425 %result = convert_fused_activation(%avgpool2d, fused_activation) 3426} 3427else { 3428 %result = %avgpool2d 3429} 3430``` 3431 3432If input/output tensors are all quantized typed, 3433 3434Legalization: 3435 3436``` 3437%avgpool2d = tosa.AVG_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding, quantization_info={input_zp=%input.zp, output_zp=%output.zp}} 3438if(fused_activation != NONE) { 3439 %result = convert_fused_activation(%avgpool2d, fused_activation) 3440} 3441else { 3442 %result = %avgpool2d 3443} 3444``` 3445 3446### tfl.basic_lstm 3447 3448No TOSA lowering defined. 3449 3450### tfl.batch_to_space_nd 3451 3452BatchToSpaceNd operator. 3453 3454**TensorFlow Lite Dialect** 3455 3456``` 3457%output = tfl.batch_to_space_nd(%input, %block_shape, %indices) 3458``` 3459 3460**TOSA Lowering** 3461 3462``` 3463%result = convert_batch_to_space_nd_op(%input, %block_shape, %indices) 3464``` 3465 3466### tfl.cast 3467 3468This operator is trivially lowered to tosa.CAST 3469 3470### tfl.ceil 3471 3472Ceil operator. 3473 3474**TensorFlow Lite Dialect** 3475 3476``` 3477%y = tfl.ceil(%x) 3478``` 3479 3480**TOSA Lowering** 3481 3482If input/output tensors are all non-quantized typed, 3483 3484``` 3485%result = tosa.CEIL(%x) 3486``` 3487 3488### tfl.concatenation 3489 3490Concatenation operator. 3491 3492**TensorFlow Lite Dialect** 3493 3494``` 3495%output = tfl.concatenation(%values) {axis} 3496``` 3497 3498**TOSA Lowering** 3499 3500``` 3501%result = lower_concatv2_op(%values, axis) 3502``` 3503 3504### tfl.pseudo_const 3505 3506This operator is trivially lowered to tosa.CONST 3507 3508### tfl.conv_2d 3509 3510Convolution operator. 3511 3512**TensorFlow Lite Dialect** 3513 3514``` 3515%output = tfl.conv_2d(%input, %filter, %bias) {dilation_h_factor, dilation_w_factor, fused_activation_function, padding, stride_h, stride_w} 3516``` 3517 3518**TOSA Lowering** 3519 3520If input/output tensors are all non-quantized typed, 3521 3522Prepare: 3523 3524``` 3525tosa_padding = 3526 get_padding_values_from_pad_type(padding, NHWC, 1, 3527 %input.type, %filter.type, 3528 {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1}) 3529``` 3530 3531Legalization: 3532 3533``` 3534%conv2d = tosa.CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}} 3535if(fused_activation != NONE) { 3536 %result = convert_fused_activation(%conv2d, fused_activation_function) 3537} 3538else { 3539 %result = %conv2d 3540} 3541``` 3542 3543If input/output tensors are all quantized typed, 3544 3545Prepare: 3546 3547``` 3548float64 output_rescale_scale = (%input.scale * %filter.scale) / %output.scale 3549 3550tosa_padding = 3551 get_padding_values_from_pad_type(padding, NHWC, 1, 3552 %input.type, %filter.type, 3553 {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1}) 3554``` 3555 3556Legalization: 3557 3558``` 3559%conv2d = tosa.CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}, quantization_info={input_zp=%input.zp, weight_zp=%filter.zp}} 3560%rescale = tosa.RESCALE(%conv2d) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %conv2d.dtype->%output.dtype 3561if(fused_activation != NONE) { 3562 %result = convert_fused_activation(%rescale, fused_activation_function) 3563} 3564else { 3565 %result = %rescale 3566} 3567``` 3568 3569### tfl.convolution_2d_transpose_bias 3570 3571No TOSA lowering defined. 3572 3573### tfl.cos 3574 3575No TOSA lowering defined. 3576 3577### tfl.densify 3578 3579No TOSA lowering defined. 3580 3581### tfl.depth_to_space 3582 3583DepthToSpace operator. 3584 3585**TensorFlow Dialect** 3586 3587``` 3588%output = tfl.depth_to_space(%input) {block_size} 3589``` 3590 3591**TOSA Lowering** 3592 3593``` 3594%output = lower_depth_to_space_op(%input, block_size, "NHWC") 3595``` 3596 3597### tfl.depthwise_conv_2d 3598 3599Depthwise-separable convolution operator. 3600 3601**TensorFlow Lite Dialect** 3602 3603``` 3604%output = tfl.depthwise_conv_2d(%input, %filter, %bias) {dilation_h_factor, dilation_w_factor, fused_activation_function, padding, stride_h, stride_w, depth_multiplier} 3605``` 3606 3607**TOSA Lowering** 3608 3609If input/output tensors are all non-quantized typed, 3610 3611Prepare: 3612 3613``` 3614tosa_padding = 3615 get_padding_values_from_pad_type(padding, NHWC, 1, 3616 %input.type, %filter.type, 3617 {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1}) 3618``` 3619 3620Legalization: 3621 3622``` 3623%depthwise_conv2d = tosa.DEPTHWISE_CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}} 3624if(fused_activation != NONE) { 3625 %result = convert_fused_activation(%depthwise_conv2d, fused_activation_function) 3626} 3627else { 3628 %result = %depthwise_conv2d 3629} 3630``` 3631 3632If input/output tensors are all quantized typed, 3633 3634Prepare: 3635 3636``` 3637float64 output_rescale_scale = (%input.scale * %filter.scale) / %output.scale 3638 3639tosa_padding = 3640 get_padding_values_from_pad_type(padding, NHWC, 1, 3641 %input.type, %filter.type, 3642 {1, stride_h, stride_w, 1}, {1, dilation_h_factor, dilation_w_factor, 1}) 3643``` 3644 3645Legalization: 3646 3647``` 3648%depthwise_conv2d = tosa.DEPTHWISE_CONV2D(%input, %filter, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={dilation_h_factor, dilation_w_factor}, quantization_info={input_zp=%input.zp, weight_zp=%filter.zp}} 3649%rescale = tosa.RESCALE(%conv2d) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %depthwise_conv2d.dtype->%output.dtype 3650if(fused_activation != NONE) { 3651 %result = convert_fused_activation(%rescale, fused_activation_function) 3652} 3653else { 3654 %result = %rescale 3655} 3656``` 3657 3658### tfl.dequantize 3659 3660Dequantize operator. 3661 3662**TensorFlow Lite Dialect** 3663 3664``` 3665%output = tfl.dequantize(%input) 3666``` 3667 3668**TOSA Lowering** 3669 3670``` 3671%result = lower_dequantize_op(%input, %input.scale, %input.zp) 3672``` 3673 3674### tfl.div 3675 3676Division operator. 3677 3678**TensorFlow Lite Dialect** 3679 3680``` 3681%output = tfl.div(%lhs, %rhs) 3682``` 3683 3684**TOSA Lowering** 3685 3686If input/output tensors are all non-quantized typed, 3687 3688``` 3689%rcp = tosa.RECIPROCAL(%rhs) 3690%mul = tosa.MUL(%lhs, %rcp) 3691``` 3692 3693### tfl.elu 3694 3695Exponential Linear Unit operator. 3696 3697**TensorFlow Lite Dialect** 3698 3699``` 3700%y = tfl.elu(%x) 3701``` 3702 3703**TOSA Lowering** 3704 3705If input/output tensors are all non-quantized typed, 3706 3707``` 3708%rcp = lower_elu_op(%x) 3709``` 3710 3711### tfl.embedding_lookup 3712 3713Embedding lookup operator. 3714 3715**TensorFlow Lite Dialect** 3716 3717``` 3718%output = tfl.embedding_lookup(%lookup, %value) 3719``` 3720 3721### tfl.equal 3722 3723This operator is trivially lowered to tosa.EQUAL 3724 3725### tfl.exp 3726 3727Natural exponentiation operator. 3728 3729**TensorFlow Lite Dialect** 3730 3731``` 3732%y = tfl.exp(%x) 3733``` 3734 3735**TOSA Lowering** 3736 3737If input/output tensors are all non-quantized typed, 3738 3739``` 3740%result = tosa.EXP(%x) 3741``` 3742 3743### tfl.expand_dims 3744 3745Inserts a dimension of 1 into a tensor’s shape. 3746 3747**TensorFlow Lite Dialect** 3748 3749``` 3750%output = tfl.expand_dims(%input, %dim) 3751``` 3752 3753**TOSA Lowering** 3754 3755``` 3756%result = lower_expand_dims(%input, %dim.as_constant()) 3757``` 3758 3759### tfl.external_const 3760 3761No TOSA lowering defined. 3762 3763### tfl.fake_quant 3764 3765FakeQuant operator 3766 3767**TensorFlow Lite Dialect** 3768 3769``` 3770%output = tfl.fake_quant(%input) {min, max, num_bits, narrow_range} 3771``` 3772 3773**TOSA Lowering** 3774 3775``` 3776%result = convert_fake_quant_op(%input, min, max, num_bits, narrow_range) 3777``` 3778 3779### tfl.fill 3780 3781Fill the tensor with given value. 3782 3783**TensorFlow Lite Dialect** 3784 3785``` 3786%res = tfl.fill(%dims, %value) 3787``` 3788 3789**TOSA Lowering** 3790 3791Prepare: 3792 3793``` 3794total_size = 1 3795dim_vec = %dim.as_constant() 3796for(int32 i = 0 i < dim_vec.size() i++) { 3797 total_size *= dim_vec[i] 3798} 3799filled_val = %value.as_constant()[0] 3800output_type = tensor<dim_vec, filled_val.dtype> 3801``` 3802 3803Legalization: 3804 3805``` 3806%result = tosa.CONST() {value={filled_val} * total_size} 3807``` 3808 3809### tfl.floor_div 3810 3811Floor div operator. 3812 3813**TensorFlow Lite Dialect** 3814 3815``` 3816%output = tfl.floor_div(%lhs, %rhs) 3817``` 3818 3819**TOSA Lowering** 3820 3821If input/output tensors are all non-quantized typed, 3822 3823``` 3824%recip = tosa.RECIPROCAL(%rhs) 3825%mul = tosa.MUL(%lhs, %recip) 3826%result = tosa.FLOOR(%mul) 3827``` 3828 3829### tfl.floor_mod 3830 3831Division remainder. 3832 3833**TensorFlow Lite Dialect** 3834 3835``` 3836%output = tfl.floor_mod(%lhs, %rhs) 3837``` 3838 3839**TOSA Lowering** 3840 3841If input/output tensors are all non-quantized typed, 3842 3843``` 3844%recip = tosa.RECIPROCAL(%rhs) 3845%mul = tosa.MUL(%lhs, %recip) 3846%floor = tosa.FLOOR(%mul) 3847%result = tosa.SUB(%mul, %floor) 3848``` 3849 3850### tfl.floor 3851 3852This operator is trivially lowered to tosa.FLOOR 3853 3854### tfl.fully_connected 3855 3856Fully connected op. 3857 3858**TensorFlow Lite Dialect** 3859 3860``` 3861%output = tfl.fully_connected(%input, %filter, %bias) {fused_activation_function} 3862``` 3863 3864**TOSA Lowering** 3865 3866If input/output tensors are all non-quantized typed, 3867 3868Prepare: 3869 3870``` 3871// input[N, IC] x filter[OC, IC] + bias[OC] -> output[N, OC] 3872auto input_reshape_shape = {%input.num_elements / %filter.shape[1], %filter.shape[1]} 3873``` 3874 3875Legalization: 3876 3877``` 3878if(!(%bias)) { 3879 %bias_val = tosa.CONST() {value={0} * %filter.shape[3]} 3880} 3881else { 3882 %bias_val = %bias 3883} 3884if(%input.rank != 2) { 3885 %input_val = tosa.RESHAPE(%input) {shape=input_reshape_shape} 3886} 3887else { 3888 %input_val = %input 3889} 3890%fc = tosa.FULLY_CONNECTED(%input_val, %filter, %bias_val) 3891if(fused_activation != NONE) { 3892 %result = convert_fused_activation(%fc, fused_activation_function) 3893} 3894else { 3895 %result = %fc 3896} 3897``` 3898 3899If input/output tensors are all quantized typed, 3900 3901Prepare: 3902 3903``` 3904auto input_reshape_shape = {%input.num_elements / %filter.shape[1], %filter.shape[1]} 3905float64 output_rescale_scale = (%input.scale * %filter.scale) / %output.scale 3906``` 3907 3908Legalization: 3909 3910``` 3911if(!(%bias)) { 3912 %bias_val = tosa.CONST() {value={0} * %filter.shape[3]} 3913} 3914else { 3915 %bias_val = %bias 3916} 3917if(%input.rank != 2) { 3918 %input_val = tosa.RESHAPE(%input) {shape=input_reshape_shape} 3919} 3920else { 3921 %input_val = %input 3922} 3923%fc = tosa.FULLY_CONNECTED(%input_val, %filter, %bias_val) 3924%rescale = tosa.RESCALE(%fc) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %fc.dtype->%output.dtype 3925if(fused_activation != NONE) { 3926 %result = convert_fused_activation(%rescale, fused_activation_function) 3927} 3928else { 3929 %result = %rescale 3930} 3931``` 3932 3933### tfl.gather_nd 3934 3935Gather_nd operator. 3936 3937**TensorFlow Dialect** 3938 3939``` 3940%output = tfl.gather_nd(%params, %indices) 3941``` 3942 3943**TOSA Lowering** 3944 3945``` 3946%output = lower_gather_nd_op(%params, %indices) 3947``` 3948 3949### tfl.gather 3950 3951Gather operator. 3952 3953**TensorFlow Dialect** 3954 3955``` 3956%output = tfl.gather(%params, %indices) {axis} 3957``` 3958 3959**TOSA Lowering** 3960 3961``` 3962%output = lower_gather_op(%params, %indices, 0, axis) 3963``` 3964 3965### tfl.greater_equal 3966 3967This operator is trivially lowered to tosa.GREATER_EQUAL 3968 3969### tfl.greater 3970 3971This operator is trivially lowered to tosa.GREATER 3972 3973### tfl.hard_swish 3974 3975Hardswish activation function. 3976 3977**TensorFlow Lite Dialect** 3978 3979``` 3980%output = tfl.hard_swish(%input) 3981``` 3982 3983**TOSA Lowering** 3984 3985If input/output tensors are all non-quantized typed, 3986 3987``` 3988%const_3 = tosa.CONST() {value={3.0}} 3989%const_rcp6 = tosa.CONST() {value={1.0 / 6.0}} 3990%op1_add_in_3 = tosa.ADD(%input, %const_3) 3991%op2_relun_op1 = tosa.RELUN(%op1_add_in_3) {max=6.0} 3992%op3_mul_in_op2 = tosa.MUL(%input, %op2_relun_op1) 3993%op4_mul_op3_rcp6 = tosa.MUL(%op3, %const_rcp6) 3994``` 3995 3996If input/output tensors are all quantized typed, 3997 3998Prepare: 3999 4000``` 4001float64 input_sample_grain = 1.0 / 64.0 4002auto hardswish_func = [input_sample_grain](int32 x) -> int32 { 4003 float64 v = (float64)x * input_sample_grain 4004 float64 w = v + 3.0 4005 w = (w < 0.0) ? 0.0 : ((w > 6.0) ? 6.0 : w) 4006 v = (v * w) / 6.0 4007 return std::lround(32768.0 * v) 4008} 4009float64 input_rescale_scale = (%input.scale * 128.0) / input_sample_grain 4010float64 output_rescale_scale = 1.0 / (128.0 * 32768.0 * %output.scale) 4011int32 quantized_3 = (int32)(std::ceil(3.0 / %input.scale)) + %input.zp 4012``` 4013 4014Legalization: 4015 4016``` 4017%table_const = get_table_const_tensor(hardswish_func) 4018%const_3 = tosa.CONST() {value={quantized_3}} 4019%op1_rescale_in = tosa.RESCALE(%input) {scale=input_rescale_scale, input_zp=%input.zp, output_zp=0} // %input.dtype->i16 4020%op2_table_op1 = tosa.TABLE(%op1_rescale_in, %table_const) 4021%op3_rescale_op2 = tosa.RESCALE(%op2_table_op1) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // i32->%output.dtype 4022%op4_rescale_in = tosa.RESCALE(%input {scale=1.0, input_zp=0, output_zp=0} // %input.dtype->i32 4023%op5_ge_op4 = tosa.GREATER_EQUAL(%op4_rescale_in, %const_3) 4024%op6_select_op5_in_op3 = tosa.SELECT(%op5_ge_op4, %input, %op3_rescale_op2) 4025``` 4026 4027### tfl.l2_normalization 4028 4029No TOSA lowering defined. 4030 4031### tfl.lstm 4032 4033No TOSA lowering defined. 4034 4035### tfl.leaky_relu 4036 4037Leaky Relu Operator. 4038 4039**TensorFlow Lite Dialect** 4040 4041``` 4042%output = tfl.leaky_relu(%input) {alpha} 4043``` 4044 4045**TOSA Lowering** 4046 4047If input/output tensors are all non-quantized typed, 4048 4049Legalization: 4050 4051``` 4052%const_0 = tosa.CONST() {value={0.0}} 4053%const_alpha = tosa.CONST() {value={alpha}} 4054%op1_mul_in_alpha = tosa.MUL(%input, %const_alpha) 4055%op2_ge_in_0 = tosa.GREATER_EQUAL(%input, %const_0) 4056%op3_select_op2_in_op1 = tosa.SELECT(%op2_ge_in_0, %input, $op1_mul_in_alpha) 4057``` 4058 4059If input/output tensors are all quantized typed, 4060 4061Prepare: 4062 4063``` 4064float32 scaled_alpha = (%input.scale * alpha) / %output.scale 4065float32 scaled_identity = %input.scale / %output.scale 4066``` 4067 4068Legalization: 4069 4070``` 4071%const_0 = tosa.CONST() {value={0}} 4072%op1_rescale_in = tosa.RESCALE(%input) {scale=1.0, input_zp=%input.zp} // %input.dtype->i32 4073%op2_ge_in_0 = tosa.GREATER_EQUAL(%input, %const_0) 4074%op3_rescale_in_alpha = tosa.RESCALE(%input) {scale=scaled_alpha, input_zp=%input.zp, output_zp=%output_zp} // %input.dtype->%output.dtype 4075%op4_rescale_in_identity = tosa.RESCALE(%input) {scale=scaled_identity, input_zp=%input.zp, output_zp=%output_zp} // %input.dtype->%output.dtype 4076%op5_select_op2_op3_op4 = tosa.SELECT(%op2_ge_in_0, %op4_rescale_in_identity, %op3_rescale_in_alpha) 4077``` 4078 4079### tfl.less_equal 4080 4081Less_equal operator. 4082 4083**TensorFlow Lite Dialect** 4084 4085``` 4086%output = tfl.less_equal(%lhs, %rhs) 4087``` 4088 4089**TOSA Lowering** 4090 4091If input/output tensors are all non-quantized typed, 4092 4093Legalization: 4094 4095``` 4096%op1_greater_lhs_rhs = tosa.GREATER(%lhs, %rhs) 4097%op2_not_op1 = tosa.LOGICAL_NOT(%op1_greater_lhs_rhs) 4098``` 4099 4100If input/output tensors are all quantized typed, 4101 4102Legalization: 4103 4104``` 4105assert (%lhs.scale == %rhs.scale) && (%lhs.zp == %rhs.zp) 4106 4107%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32 4108%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32 4109%op3_greater_op1_op2 = tosa.GREATER(%op1_rescale_lhs, %op2_rescale_rhs) 4110%op4_not_op3 = tosa.LOGICAL_NOT(%op3_greater_op1_op2) 4111``` 4112 4113### tfl.less 4114 4115Less operator. 4116 4117**TensorFlow Lite Dialect** 4118 4119``` 4120%output = tfl.less(%lhs, %rhs) 4121``` 4122 4123**TOSA Lowering** 4124 4125If input/output tensors are all non-quantized typed, 4126 4127Legalization: 4128 4129``` 4130%op1_ge_lhs_rhs = tosa.GREATER_EQUAL(%lhs, %rhs) 4131%op2_not_op1 = tosa.LOGICAL_NOT(%op1_ge_lhs_rhs) 4132``` 4133 4134If input/output tensors are all quantized typed, 4135 4136Legalization: 4137 4138``` 4139assert (%lhs.scale == %rhs.scale) && (%lhs.zp == %rhs.zp) 4140 4141%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32 4142%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32 4143%op3_ge_op1_op2 = tosa.GREATER_EQUAL(%op1_rescale_lhs, %op2_rescale_rhs) 4144%op4_not_op3 = tosa.LOGICAL_NOT(%op3_ge_op1_op2) 4145``` 4146 4147### tfl.local_response_normalization 4148 4149No TOSA lowering defined. 4150 4151### tfl.log 4152 4153No TOSA lowering defined. 4154 4155### tfl.log_softmax 4156 4157Log softmax operator. 4158 4159**TensorFlow Lite Dialect** 4160 4161``` 4162%output = tfl.log_softmax(%input) 4163``` 4164 4165**TOSA Lowering** 4166 4167If input/output tensors are all non-quantized typed, 4168 4169Legalization: 4170 4171``` 4172%output = lower_log_softmax_op(%logits) 4173``` 4174 4175No TOSA lowering defined if input/output tensors are all quantized typed. 4176 4177### tfl.logical_and 4178 4179This operator is trivially lowered to tosa.LOGICAL_AND 4180 4181### tfl.logical_not 4182 4183This operator is trivially lowered to tosa.LOGICAL_NOT 4184 4185### tfl.logical_or 4186 4187This operator is trivially lowered to tosa.LOGICAL_OR 4188 4189### tfl.logistic 4190 4191Logistic operator. 4192 4193**TensorFlow Lite Dialect** 4194 4195``` 4196%y = tfl.logistic(%x) 4197``` 4198 4199**TOSA Lowering** 4200 4201If input/output tensors are all non-quantized typed, 4202 4203Legalization: 4204 4205``` 4206%op1_sigmoid_in = tosa.SIGMOID(%x) 4207``` 4208 4209If input/output tensors are all quantized typed, 4210 4211Prepare: 4212 4213``` 4214float64 input_sample_grain = 1.0 / 16.0 4215auto sigmoid_func = [input_sample_grain](int32 x) -> int32 { 4216 float64 v = static_cast<float64>(x) * input_sample_grain 4217 v = 1.0 / (1.0 + std::exp(-v)) 4218 return std::lround(32768.0 * v) 4219} 4220 4221float32 input_rescale_scale = (%x.scale * 128.0) / input_sample_grain 4222float32 output_rescale_scale = 1.0 / (%y.scale * 32768.0 * 128.0); 4223``` 4224 4225Legalization: 4226 4227``` 4228%table_const = get_table_const_tensor(sigmoid_func) 4229%op1_rescale_in = tosa.RESCALE(%x) {scale=input_rescale_scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i16 4230%op2_table_op1 = tosa.TABLE(%op1_rescale_in, %table_const) 4231%op3_rescale_op2 = tosa.RESCALE(%op2_table_op1) {scale=output_rescale_scale, input_zp=0, output_zp=%y.zp} // %int32->%y.dtype 4232``` 4233 4234### tfl.matrix_diag 4235 4236No TOSA lowering defined. 4237 4238### tfl.matrix_set_diag 4239 4240No TOSA lowering defined. 4241 4242### tfl.max_pool_2d 4243 4244Max Pool 2d op. 4245 4246**TensorFlow Lite Dialect** 4247 4248``` 4249%output = tfl.max_pool_2d(%input) {filter_height, filter_width, padding, stride_h, stride_w, fused_activation_function} 4250``` 4251 4252**TOSA Lowering** 4253 4254Prepare: 4255 4256``` 4257tosa_padding = 4258 get_padding_values_from_pad_type(padding, NHWC, 1, 4259 %input.type, tensor<{filter_height, filter_width}, tosa.int32>, 4260 {1, stride_h, stride_w, 1}, {1, 1, 1, 1}) 4261``` 4262 4263If input/output tensors are all non-quantized typed, 4264 4265Legalization: 4266 4267``` 4268%maxpool2d = tosa.MAX_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding} 4269if(fused_activation != NONE) { 4270 %result = convert_fused_activation(%maxpool2d, fused_activation) 4271} 4272else { 4273 %result = %maxpool2d 4274} 4275``` 4276 4277If input/output tensors are all quantized typed, 4278 4279Legalization: 4280 4281``` 4282%maxpool2d = tosa.MAX_POOL2D(%input) {kernel={filter_height, filter_width}, stride={stride_h, stride_w}, padding=tosa_padding, quantization_info={input_zp=%input.zp, output_zp=%output.zp}} 4283if(fused_activation != NONE) { 4284 %result = convert_fused_activation(%maxpool2d, fused_activation) 4285} 4286else { 4287 %result = %maxpool2d 4288} 4289``` 4290 4291### tfl.max_pooling_with_argmax_2d 4292 4293No TOSA lowering defined. 4294 4295### tfl.max_unpooling_2d 4296 4297No TOSA lowering defined. 4298 4299### tfl.maximum 4300 4301This operator is trivially lowered to tosa.MAXIMUM 4302 4303### tfl.mean 4304 4305Mean operator. 4306 4307**TensorFlow Lite Dialect** 4308 4309``` 4310%output = tfl.mean(%input, %axis) {keep_dims} 4311``` 4312 4313**TOSA Lowering** 4314 4315Prepare: 4316 4317``` 4318int32 num_elements_on_axis = 1 4319for (int32 axis : %reduction_indices) { 4320 num_elements_on_axis *= %input.shape[axis] 4321} 4322float32 div_scale = 1.0 / num_elements_on_axis 4323``` 4324 4325If input/output tensors are all non-quantized typed, 4326 4327Legalization: 4328 4329``` 4330%cst_div_scale = tosa.CONST() {value={div_scale}} 4331%op1_rsum_in = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %axis, keep_dims) 4332%op2_mul_op1 = tosa.MUL(%op1_rsum_in, %cst_div_scale) 4333``` 4334 4335If input/output tensors are all quantized typed, 4336 4337Legalization: 4338 4339``` 4340%rsum = lower_reduce_op<tosa.REDUCE_SUM>(%op1_rescale_in, %output.shape, %reduction_indices, keep_dims, 1.0f, %input_zp, div_scale * %input.scale / %output.scale, %output.zp) 4341``` 4342 4343### tfl.minimum 4344 4345This operator is trivially lowered to tosa.MINIMUM 4346 4347### tfl.mirror_pad 4348 4349No TOSA lowering defined. 4350 4351### tfl.mul 4352 4353Mul operator. 4354 4355**TensorFlow Lite Dialect** 4356 4357``` 4358%output = tfl.mul(%lhs, %rhs) 4359``` 4360 4361**TOSA Lowering** 4362 4363If input/output tensors are all non-quantized typed, 4364 4365Legalization: 4366 4367``` 4368%op1_mul_in = tosa.MUL(%lhs, %rhs) 4369``` 4370 4371If input/output tensors are all quantized typed, 4372 4373Legalization: 4374 4375``` 4376%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0f, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32 4377%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0f, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32 4378%op3_mul_op1_op2 = tosa.MUL(%op1_rescale_lhs, %op2_rescale_rhs) 4379%op4_rescale_op3 = tosa.RESCALE(%op3_mul_op1_op2) {scale=%lhs.scale * %rhs.scale / %output.scale, input_zp=0, output_zp=%output.zp} // i32->%output.dtype 4380``` 4381 4382### tfl.neg 4383 4384This operator is trivially lowered to tosa.NEGATE 4385 4386### tfl.non_max_suppression_v4 4387 4388No TOSA lowering defined. 4389 4390### tfl.non_max_suppression_v5 4391 4392No TOSA lowering defined. 4393 4394### tfl.not_equal 4395 4396Not_equal operator. 4397 4398**TensorFlow Lite Dialect** 4399 4400``` 4401%output = tfl.not_equal(%lhs, %rhs) 4402``` 4403 4404**TOSA Lowering** 4405 4406If input/output tensors are all non-quantized typed, 4407 4408Legalization: 4409 4410``` 4411%op1_equal_lhs_rhs = tosa.EQUAL(%lhs, %rhs) 4412%op2_not_op1 = tosa.LOGICAL_NOT(%op1_equal_lhs_rhs) 4413``` 4414 4415If input/output tensors are all quantized typed, 4416 4417Legalization: 4418 4419``` 4420assert (%lhs.scale == %rhs.scale) && (%lhs.zp == %rhs.zp) 4421 4422%op1_rescale_lhs = tosa.RESCALE(%lhs) {scale=1.0f, input_zp=%lhs.zp, output_zp=0} // %lhs.dtype->i32 4423%op2_rescale_rhs = tosa.RESCALE(%rhs) {scale=1.0f, input_zp=%rhs.zp, output_zp=0} // %rhs.dtype->i32 4424%op3_equal_op1_op2 = tosa.EQUAL(%op1_rescale_lhs, %op2_rescale_rhs) 4425%op4_not_op3 = tosa.LOGICAL_NOT(%op3_equal_op1_op2) // i32->%output.dtype 4426``` 4427 4428### tfl.NumericVerify 4429 4430No TOSA lowering defined. 4431 4432### tfl.one_hot 4433 4434OneHot operator. 4435 4436**TensorFlow Lite Dialect** 4437 4438``` 4439%output = tfl.one_hot(%indices, %depth, %on_value, %off_value) {axis} 4440``` 4441 4442**TOSA Lowering** 4443 4444``` 4445%output = lower_one_hot_op(%indices, %depth, %on_value, %off_value, axis) 4446``` 4447 4448### tfl.prelu 4449 4450No TOSA lowering defined. 4451 4452### tfl.pack 4453 4454Packs a list of tensors along a dimension into one tensor. 4455 4456**TensorFlow Dialect** 4457 4458``` 4459%output = tf.pack(%values) {axis} 4460``` 4461 4462**TOSA Lowering** 4463 4464``` 4465%output = lower_pack_op(%values, axis) 4466``` 4467 4468### tfl.pad 4469 4470This operator is trivially lowered to tosa.PAD 4471 4472### tfl.padv2 4473 4474No TOSA lowering defined. 4475 4476### tfl.pow 4477 4478No TOSA lowering defined. 4479 4480### tfl.pseudo_qconst 4481 4482This operator is trivially lowered to tosa.CONST 4483 4484### tfl.quantize 4485 4486Quantize operator 4487 4488**TensorFlow Lite Dialect** 4489 4490``` 4491%output = tfl.quantize(%input) 4492``` 4493 4494**TOSA Lowering** 4495 4496Legalization: 4497 4498``` 4499if (isa<QuantizedType>(%input.dtype)) { 4500 %op1_rescale_in = tosa.RESCALE(%input) {scale=%input.scale / %output.scale, input_zp=%input.zp, output_zp=%output.zp} 4501} 4502else { 4503 %output = lower_quantize_op(%output.dtype, %input, %output.zp, %output.scale) 4504} 4505``` 4506 4507### tfl.range 4508 4509No TOSA lowering defined. 4510 4511### tfl.rank 4512 4513Rank operator 4514 4515**TensorFlow Lite Dialect** 4516 4517``` 4518%output = tfl.rank(%input) 4519``` 4520 4521**TOSA Lowering** 4522 4523Legalization: 4524 4525``` 4526%const = tosa.CONST() {value={%input.rank}} 4527``` 4528 4529### tfl.reduce_any 4530 4531Computes the "logical or" of elements across dimensions of a tensor. 4532 4533**TensorFlow Lite Dialect** 4534 4535``` 4536%output = tfl.reduce_any(%input, %reduction_indices) {keep_dims} 4537``` 4538 4539**TOSA Lowering** 4540 4541Legalization: 4542 4543``` 4544%op1_rsum_in = lower_reduce_op<tosa.REDUCE_ANY>(%input, %output.shape, %reduction_indices, keep_dims) 4545``` 4546 4547### tfl.reduce_max 4548 4549Max-reduction operator. 4550 4551**TensorFlow Lite Dialect** 4552 4553``` 4554%output = tfl.reduce_max(%input, %axes) {keep_dims} 4555``` 4556 4557**TOSA Lowering** 4558 4559Legalization: 4560 4561``` 4562%op1_rsum_in = lower_reduce_op<tosa.REDUCE_MAX>(%input, %output.shape, %reduction_indices, keep_dims) 4563``` 4564 4565### tfl.reduce_min 4566 4567Computes the min reduction along the specified axes. 4568 4569**TensorFlow Lite Dialect** 4570 4571``` 4572%output = tfl.reduce_min(%input, %axes) {keep_dims} 4573``` 4574 4575**TOSA Lowering** 4576 4577Legalization: 4578 4579``` 4580%op1_rsum_in = lower_reduce_op<tosa.REDUCE_MIN>(%input, %output.shape, %reduction_indices, keep_dims) 4581``` 4582 4583### tfl.reduce_prod 4584 4585Prod-reduction operator. 4586 4587**TensorFlow Lite Dialect** 4588 4589``` 4590%output = tfl.reduce_prod(%input, %axes) {keep_dims} 4591``` 4592 4593**TOSA Lowering** 4594 4595If input/output tensors are all float typed, 4596 4597Legalization: 4598 4599``` 4600%op1_rsum_in = lower_reduce_op<tosa.REDUCE_PROD>(%input, %output.shape, %reduction_indices, keep_dims) 4601``` 4602 4603### tfl.relu_n1_to_1 4604 4605No TOSA lowering defined. 4606 4607### tfl.relu6 4608 4609Relu6 operator. 4610 4611**TensorFlow Lite Dialect** 4612 4613``` 4614%y = tfl.relu6(%x) 4615``` 4616 4617**TOSA Lowering** 4618 4619If input/output tensors are all non-quantized typed, 4620 4621Legalization: 4622 4623``` 4624%op1_relun_in = tosa.RELUN(%input) {max_int=0, max_fp=6.0} 4625``` 4626 4627If input/output tensors are all quantized typed, 4628 4629Legalization: 4630 4631``` 4632%op1_rescale_in = tosa.RESCALE(%lhs) {scale=%x.scale / %y.scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i32 4633%op2_relun_op1 = tosa.RELUN(%op1_rescale_in) {max_int=(6.0 / %y.scale), max_fp=0.0} 4634%op3_rescale_op2 = tosa.RESCALE(%op2_relun_op1) {scale=1.0, input_zp=0, output_zp=%y.zp // i32->%y.dtype 4635``` 4636 4637### tfl.relu 4638 4639Relu operator. 4640 4641**TensorFlow Lite Dialect** 4642 4643``` 4644%y = tfl.relu(%x) 4645``` 4646 4647**TOSA Lowering** 4648 4649If input/output tensors are all non-quantized typed, 4650 4651Legalization: 4652 4653``` 4654%op1_relun_in = tosa.RELUN(%input) {max_int=0, max_fp=std::numeric_limits<float>::max()} 4655``` 4656 4657If input/output tensors are all quantized typed, 4658 4659Legalization: 4660 4661``` 4662%op1_rescale_in = tosa.RESCALE(%lhs) {scale=%x.scale / %y.scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i32 4663%op2_relun_op1 = tosa.RELUN(%op1_rescale_in) {max_int=std::numeric_limits<int32>::max(), max_fp=0.0} 4664%op3_rescale_op2 = tosa.RESCALE(%op2_relun_op1) {scale=1.0, input_zp=0, output_zp=%y.zp // i32->%y.dtype 4665``` 4666 4667### tfl.reshape 4668 4669This operator is trivially lowered to tosa.RESHAPE 4670 4671### tfl.resize_bilinear 4672 4673ResizeBilinear Op. 4674 4675**TensorFlow Lite Dialect** 4676 4677``` 4678%output = tfl.resize_bilinear(%input, %size) {aligned_corners, half_pixel_centers} 4679``` 4680 4681**TOSA Lowering** 4682 4683``` 4684%output = lower_resize_op(%input, %size, %input.dtype, "BILINEAR") 4685``` 4686 4687### tfl.resize_nearest_neighbor 4688 4689ResizeBilinear Op. 4690 4691**TensorFlow Lite Dialect** 4692 4693``` 4694%output = tfl.resize_bilinear(%input, %size) {aligned_corners, half_pixel_centers} 4695``` 4696 4697**TOSA Lowering** 4698 4699``` 4700%output = lower_resize_op(%input, %size, %input.dtype, "NEAREST_NEIGHBOR") 4701``` 4702 4703### tfl.reverse_sequence 4704 4705No TOSA lowering defined. 4706 4707### tfl.reverse_v2 4708 4709ReverseV2 Operator. 4710 4711**TensorFlow Lite Dialect** 4712 4713``` 4714%output = tfl.reverse_v2(%input, %axis) 4715``` 4716 4717**TOSA Lowering** 4718 4719``` 4720%output = lower_reversev2_op(%tensor, %axis) 4721``` 4722 4723### tfl.round 4724 4725Round operator. 4726 4727**TensorFlow Lite Dialect** 4728 4729``` 4730%output = tfl.round(%input) 4731``` 4732 4733**TOSA Lowering** 4734 4735``` 4736%const_half = tosa.CONST() {value={0.5}} 4737%op1_add_in_half = tosa.ADD(%input, %const_half) 4738%op2_floor_op1 = tosa.FLOOR(%op1_add_in_half) 4739``` 4740 4741### tfl.rsqrt 4742 4743No TOSA lowering defined. 4744 4745### tfl.svdf 4746 4747No TOSA lowering defined. 4748 4749### tfl.segment_sum 4750 4751No TOSA lowering defined. 4752 4753### tfl.select 4754 4755This operator is trivially lowered to tosa.SELECT 4756 4757### tfl.select_v2 4758 4759This operator is trivially lowered to tosa.SELECT 4760 4761### tfl.shape 4762 4763Shape operator 4764 4765**TensorFlow Lite Dialect** 4766 4767``` 4768%output = tfl.shape(%input) 4769``` 4770 4771**TOSA Lowering** 4772 4773Legalization: 4774 4775``` 4776%const = tosa.CONST() {value=%input.shape} 4777``` 4778 4779### tfl.sin 4780 4781No TOSA lowering defined. 4782 4783### tfl.slice 4784 4785This operator is trivially lowered to tosa.SLICE 4786 4787### tfl.softmax 4788 4789Softmax operator. 4790 4791**TensorFlow Lite Dialect** 4792 4793``` 4794%output = tfl.softmax(%input) 4795``` 4796 4797**TOSA Lowering** 4798 4799If input/output tensors are all non-quantized typed, 4800 4801Legalization: 4802 4803``` 4804%op1_exp_in = tosa.EXP(%input) 4805%op2_rsum_op1 = tosa.REDUCE_SUM(%op1_exp_in) {axis=(%input.rank-1)} 4806%op3_rcp_op2 = tosa.RECIPROCAL(%op2) 4807%op4_mul_op1_op3 = tosa.MUL(%op1, %op3) 4808``` 4809 4810If input/output tensors are all quantized typed, 4811 4812Prepare: 4813 4814``` 4815float64 exp_sample_grain = 1.0 / 16.0 4816auto exp_func = [exp_sample_grain](int32 x) -> int32 { 4817 double v = static_cast<float64>(x) * exp_sample_grain 4818 v = v < 0.0 ? std::exp(v) : 1.0 4819 return std::lround(32768.0 * v) 4820} 4821 4822float64 one_over_one_plus_x_sample_grain = 1.0 / 256.0 4823auto one_over_one_plus_x_func = [one_over_one_plus_x_sample_grain](int32 x) -> int32 { 4824 double v = static_cast<float64>(x) * one_over_one_plus_x_sample_grain 4825 v = v < 0.0 ? 1.0 : 1.0 / (1.0 + v) 4826 return std::lround(32768.0 * v) 4827} 4828 4829float64 op4_rescale_scale = (%input.scale * 128.0) / exp_sample_grain 4830float64 op19_rescale_scale = 1.0 / (%output.scale * 256.0) 4831``` 4832 4833Legalization: 4834 4835``` 4836%const_exp_table = get_table_const_tensor(exp_func) 4837%const_one_over_one_plus_x_table = get_table_const_tensor(one_over_one_plus_x_func) 4838%const_3 = tosa.CONST() {value={3}} 4839%const_34 = tosa.CONST() {value={12+20-8}} 4840%const_2_to_31 = tosa.CONST() {value={1<<31}} 4841%const_16 = tosa.CONST() {value={16}} 4842 4843%op1_rescale_in = tosa.RESCALE(%lhs) {scale=1.0f, input_zp=%x.zp, output_zp=0} // %x.dtype->i32 4844%op2_rmax_op1 = tosa.REDUCE_MAX(%op1_rescale_in) {axis=(%input.rank-1)} 4845%op3_sub_op1_op2 = tosa.SUB(%op1_rescale_in, %op2_relun_op1) 4846%op4_rescale_op3 = tosa.RESCALE(%op3_sub_op1_op2) {scale=op4_rescale_scale, input_zp=0, output_zp=0} // i32->i16 4847%op5_table_op4 = tosa.TABLE(%op4_rescale_op3, %const_exp_table) 4848%op6_rshift_op5_3 = tosa.ARITHMETIC_RIGHT_SHIFT(%op5_table_op4, %const_3) 4849%op7_rsum_op6 = tosa.REDUCE_SUM(%op6_rshift_op5_3) {axis=(%input.rank-1)} 4850%op8_clz_op7 = tosa.CLZ(%op7_rsum_op6) 4851%op9_sub_34_op8 = tosa.SUB(%const_34, %op8_clz_op7) 4852%op10_lshift_op7_op8 = tosa.LOGICAL_LEFT_SHIFT(%op7_rsum_op6, %op8_clz_op7) 4853%op11_sub_op10 = tosa.SUB(%op10_lshift_op7_op8, %const_2_to_31) 4854%op12_rshift_op11_16 = tosa.ARITHMETIC_RIGHT_SHIFT(%op11_sub_op10, %const_16) 4855%op13_cast_op12 = tosa.CAST(%op12_rshift_op11_16) // i32->i16 4856%op14_table_op13 = tosa.TABLE(%op13_cast_op12, %const_one_over_one_plus_x_table) 4857%op15_rescale_op14 = tosa.RESCALE(%op14_table_op13) {scale=1.0/128.0, input_zp=0, output_zp=0} // i32->i16 4858%op16_rescale_op5 = tosa.RESCALE(%op5_table_op4) {scale=1.0/128.0, input_zp=0, output_zp=0} // i32->i16 4859%op17_mul_op16_op15 = tosa.MUL(%op15_rescale_op14, %op16_rescale_op5) 4860%op18_rshift_op17_op9 = tosa.ARITHMETIC_RIGHT_SHIFT(%op17_mul_op16_op15, %op9_sub_34_op8) 4861%op19_rescale_op18 = tosa.RESCALE(%op18_rshift_op17_op9) {scale=op19_rescale_scale, input_zp=0, output_zp=%output.zp} 4862``` 4863 4864### tfl.space_to_batch_nd 4865 4866SpaceToBatchNd operator. 4867 4868**TensorFlow Dialect** 4869 4870``` 4871%output = tfl.space_to_batch_nd(%input, %block_shape, %paddings) 4872``` 4873 4874**TOSA Lowering** 4875 4876``` 4877%output = lower_space_to_batch_nd_op(%input, %block_shape, %paddings) 4878``` 4879 4880### tfl.space_to_depth 4881 4882SpaceToDepth operator. 4883 4884**TensorFlow Dialect** 4885 4886``` 4887%output = tfl.space_to_depth(%input) {block_size} 4888``` 4889 4890**TOSA Lowering** 4891 4892``` 4893%output = lower_space_to_depth_op(%input, block_size, "NHWC") 4894``` 4895 4896### tfl.pseudo_sparse_const 4897 4898No TOSA lowering defined. 4899 4900### tfl.pseudo_sparse_qconst 4901 4902No TOSA lowering defined. 4903 4904### tfl.sparse_to_dense 4905 4906No TOSA lowering defined. 4907 4908### tfl.split 4909 4910Splits a tensor into num_split tensors along one dimension. 4911 4912**TensorFlow Dialect** 4913 4914``` 4915%output = tfl.split(%split_dim, %value) {num_split} 4916``` 4917 4918**TOSA Lowering** 4919 4920``` 4921%output = lower_split_op(%value, %split_dim.as_constant(), num_split) 4922``` 4923 4924### tfl.split_v 4925 4926Splits a tensor into num_split tensors along one dimension. 4927 4928**TensorFlow Dialect** 4929 4930``` 4931%output = tfl.split_v(%value, %size_splits, %split_dim) {num_splits} 4932``` 4933 4934**TOSA Lowering** 4935 4936``` 4937%output = lower_splitv_op(%value, %size_splits.as_constant(), %split_dim.as_constant()) 4938``` 4939 4940### tfl.sqrt 4941 4942No TOSA lowering defined. 4943 4944### tfl.square 4945 4946Square operator. 4947 4948**TensorFlow Lite Dialect** 4949 4950``` 4951%y = tfl.square(%x) 4952``` 4953 4954**TOSA Lowering** 4955 4956If input/output tensors are all non-quantized typed, 4957 4958Legalization: 4959 4960``` 4961%op1_mul_in = tosa.MUL(%x, %x) 4962``` 4963 4964If input/output tensors are all quantized typed, 4965 4966Legalization: 4967 4968``` 4969%op1_rescale_x = tosa.RESCALE(%x) {scale=1.0f, input_zp=%x.zp, output_zp=0} // %x.dtype->i32 4970%op2_mul_op1_op1 = tosa.MUL(%op1_rescale_x, %op1_rescale_x) 4971%op3_rescale_op2 = tosa.RESCALE(%op2_mul_op1_op1) {scale=%(x.scale * %x.scale) / %output.scale, input_zp=0, output_zp=%y.zp} // i32->%y.dtype 4972``` 4973 4974### tfl.squared_difference 4975 4976Squared difference operator. 4977 4978**TensorFlow Lite Dialect** 4979 4980``` 4981%output = tfl.squared_difference(%lhs, %rhs) 4982``` 4983 4984**TOSA Lowering** 4985 4986Legalization: 4987 4988``` 4989%op1_sub_in = tosa.SUB(%lhs, %rhs) 4990%op2_mul_op1 = tosa.MUL(%op1_sub_in, %op1_sub_in) 4991``` 4992 4993### tfl.squeeze 4994 4995Removes dimensions of size 1 from the shape of a tensor. 4996 4997**TensorFlow Dialect** 4998 4999``` 5000%output = tfl.squeeze(%input) {squeeze_dims} 5001``` 5002 5003**TOSA Lowering** 5004 5005``` 5006%output = lower_squeeze_op(%input, squeeze_dims) 5007``` 5008 5009### tfl.strided_slice 5010 5011StridedSlice Op. 5012 5013**TensorFlow Dialect** 5014 5015``` 5016%output = tfl.strided_slice(%input, %begin, %end, %strides) {begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask} 5017``` 5018 5019**TOSA Lowering** 5020 5021``` 5022%output = lower_strided_slice_op(%input, %begin, %end, %strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) 5023``` 5024 5025### tfl.sub 5026 5027This operator is trivially lowered to tosa.SUB 5028 5029### tfl.sum 5030 5031Sum operator. 5032 5033**TensorFlow Lite Dialect** 5034 5035``` 5036%output = tfl.sum(%input, %axis) {keep_dims} 5037``` 5038 5039**TOSA Lowering** 5040 5041If input/output tensors are all non-quantized typed, 5042 5043Legalization: 5044 5045``` 5046%op1_rsum_in = lower_reduce_op<tosa.REDUCE_SUM>(%input, %output.shape, %axis, keep_dims) 5047``` 5048 5049If input/output tensors are all quantized typed, 5050 5051Legalization: 5052 5053``` 5054%rsum = lower_reduce_op<tosa.REDUCE_SUM>(%op1_rescale_in, %output.shape, %reduction_indices, keep_dims, 1.0f, %input_zp, (%input.scale / %output.scale), %output.zp) 5055``` 5056 5057### tfl.tanh 5058 5059Hyperbolic tangent operator. 5060 5061**TensorFlow Lite Dialect** 5062 5063``` 5064%y = tfl.tanh(%x) 5065``` 5066 5067**TOSA Lowering** 5068 5069If input/output tensors are all non-quantized typed, 5070 5071Legalization: 5072 5073``` 5074%op1_tanh_in = tosa.TANH(%x) 5075``` 5076 5077If input/output tensors are all quantized typed, 5078 5079Prepare: 5080 5081``` 5082float64 input_sample_grain = 1.0 / 32.0 5083auto tanh_func = [input_sample_grain](int32 x) -> int32 { 5084 float64 v = static_cast<float64>(x) * input_sample_grain 5085 v = std::exp(-2.0 * v) 5086 v = (1.0 - v) / (1.0 + v) 5087 return std::lround(32768.0 * v) 5088} 5089 5090float32 input_rescale_scale = (%x.scale * 128.0) / input_sample_grain 5091float32 output_rescale_scale = 1.0 / (%y.scale * 32768.0 * 128.0); 5092``` 5093 5094Legalization: 5095 5096``` 5097%table_const = get_table_const_tensor(tanh_func) 5098%op1_rescale_in = tosa.RESCALE(%x) {scale=input_rescale_scale, input_zp=%x.zp, output_zp=0} // %x.dtype->i16 5099%op2_table_op1 = tosa.TABLE(%op1_rescale_in, %table_const) 5100%op3_rescale_op2 = tosa.RESCALE(%op2_table_op1) {scale=output_rescale_scale, input_zp=0, output_zp=%y.zp} // %int32->%y.dtype 5101``` 5102 5103### tfl.tile 5104 5105This operator is trivially lowered to tosa.TILE 5106 5107### tfl.topk_v2 5108 5109No TOSA lowering defined. 5110 5111### tfl.transpose_conv 5112 5113Transpose convolution operator. 5114 5115**TensorFlow Lite Dialect** 5116 5117``` 5118%output = tfl.transpose_conv(%output_shape, %weights, %input) {padding, stride_h, stride_w} 5119``` 5120 5121**TOSA Lowering** 5122 5123Prepare: 5124 5125``` 5126tosa_padding = 5127 get_transpose_conv2d_padding_values_from_pad_type(%input.type, %weights.type, %output_shape, padding, "NHWC", FORMAT_HWIO, {stride_h, stride_w}, {1, 1}) 5128``` 5129 5130If input/output tensors are all non-quantized typed, 5131 5132Legalization: 5133 5134``` 5135%bias = tosa.CONST() {value={0.0} * %output.shape[3]} 5136%conv2d = tosa.TRANSPOSE_CONV2D(%input, %weight, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={1, 1}} 5137``` 5138 5139If input/output tensors are all quantized typed, 5140 5141Prepare: 5142 5143``` 5144float64 output_rescale_scale = (%input.scale * %weights.scale) / %output.scale 5145``` 5146 5147Legalization: 5148 5149``` 5150%bias = tosa.CONST() {value={0} * %output.shape[3]} 5151%conv2d = tosa.TRANSPOSE_CONV2D(%input, %weight, %bias) {padding=tosa_padding, stride={stride_h, stride_w}, dilation={1, 1}} 5152%rescale = tosa.RESCALE(%conv2d) {scale=output_rescale_scale, input_zp=0, output_zp=%output.zp} // %conv2d.dtype->%output.dtype 5153``` 5154 5155### tfl.transpose 5156 5157This operator is trivially lowered to tosa.TRANSPOSE 5158 5159### tfl.unidirectional_sequence_lstm 5160 5161No TOSA lowering defined. 5162 5163### tfl.unidirectional_sequence_rnn 5164 5165No TOSA lowering defined. 5166 5167### tfl.unique 5168 5169No TOSA lowering defined. 5170 5171### tfl.unpack 5172 5173Unpacks a tensor along a dimension into multiple tensors. 5174 5175**TensorFlow Dialect** 5176 5177``` 5178%output = tfl.unpack(%input) {num, axis} 5179``` 5180 5181**TOSA Lowering** 5182 5183``` 5184%output = lower_unpack_op(%input, axis, num) 5185``` 5186 5187### tfl.where 5188 5189No TOSA lowering defined. 5190 5191### tfl.while 5192 5193No TOSA lowering defined. 5194 5195### tfl.yield 5196 5197This operator is trivially lowered to tosa.YIELD 5198 5199### tfl.zeros_like 5200 5201ZerosLike operator. 5202 5203**TensorFlow Dialect** 5204 5205``` 5206%output = tfl.zeros_like(%input) 5207``` 5208 5209**TOSA Lowering** 5210 5211``` 5212%output = tosa.CONST() {value={0} * %input.num_elements} 5213``` 5214 5215## fuse_tf_bias 5216 5217Legalize (tf.Conv2D + tf.BiasAdd) to tosa.CONV2D. This is currently the only N:1 5218mapping in TOSA legalization. 5219 5220From: 5221 5222``` 5223%conv2d = tf.Conv2D(%input, %filter) {...} 5224%bias_add = tf.BiasAdd(%conv2d, %bias) 5225``` 5226 5227To: 5228 5229``` 5230%conv2d = tosa.CONV2D(%input, %filter, %bias) 5231``` 5232 5233## convert_tfl_uint8 5234 5235This pass does three things: 5236 52371. Convert const from quantized uint8 to quantized int8, with value within 5238 remapped as well. 52392. If input placeholders is quantized uint8 typed, insert "tosa.RESCALE() 5240 {scale=1.0, input_zp=input_zp, output_zp=input_zp-128} // qu8->qi8" in 5241 between 52423. If output tensor is quantized uint8 typed, insert "tosa.RESCALE() 5243 {scale=1.0, input_zp=output_zp+128, output_zp=output_zp} // qi8->qu8" in 5244 between 5245