1// RUN: mlir-hlo-opt %s -split-input-file -canonicalize | FileCheck %s 2 3 4// CHECK-LABEL: func @loop_invariants 5module { 6 func.func @loop_invariants(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) { 7 // The first operand is used directly as an implicit capture in the return 8 // of the body, the third operand is loop carried: they both can be 9 // eliminated, ony the second operand is really a loop-carried value. 10 // CHECK: %[[WHILE:.*]] = mhlo.while 11 // CHECK-SAME: (%[[ITER_ARG:.*]] = %arg2) 12 %0:3 = "mhlo.while"(%arg0, %arg2, %arg3) ({ 13 ^bb0(%arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>): 14 // CHECK: mhlo.compare 15 // CHECK-SAME: %[[ITER_ARG]], %arg3 16 %1 = "mhlo.compare"(%arg5, %arg6) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1> 17 "mhlo.return"(%1) : (tensor<i1>) -> () 18 }, { 19 ^bb0(%arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>): 20 // CHECK: %[[ADD:.*]] = mhlo.add %[[ITER_ARG]], %arg0 21 %1 = mhlo.add %arg5, %arg4 : tensor<i32> 22 // This op is dead, its removal will enable the canonicalization of the while op. 23 %2 = "mhlo.tuple"(%arg4, %1, %arg6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>> 24 // CHECK: mhlo.return 25 // CHECK-SAME: %[[ADD]] 26 "mhlo.return"(%arg0, %1, %arg6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> () 27 }) : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) 28 // CHECK: return %arg0, %[[WHILE]], %arg3 29 func.return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32> 30 } 31} 32 33// ----- 34 35// CHECK-LABEL: func @dead_loop 36module { 37 func.func @dead_loop(%arg0: tensor<i32>) -> tensor<i32> { 38 // The following loop will always return its operand which is carried over 39 // from one iteration to the next as-is, that is: we assume that loops 40 // always terminate. 41 // CHECK-NOT: mhlo.while 42 %0 = "mhlo.while"(%arg0) ({ 43 ^bb0(%arg1: tensor<i32>): 44 %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1> 45 "mhlo.return"(%1) : (tensor<i1>) -> () 46 }, { 47 ^bb0(%arg1: tensor<i32>): 48 "mhlo.return"(%arg1) : (tensor<i32>) -> () 49 }) : (tensor<i32>) -> (tensor<i32>) 50 func.return %0 : tensor<i32> 51 } 52} 53 54// ----- 55 56// CHECK-LABEL: func @fold_constant_cond 57func.func @fold_constant_cond(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { 58// CHECK-NOT: while 59// CHECK: return %arg0, %arg 60 %0:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor<4xf32>, tensor<4xf32> 61 cond { 62 %cst = arith.constant dense<false> : tensor<i1> 63 "mhlo.return"(%cst) : (tensor<i1>) -> () 64 } do { 65 %1 = mhlo.add %iterArg, %iterArg_0 : tensor<4xf32> 66 "mhlo.return"(%1, %1) : (tensor<4xf32>, tensor<4xf32>) -> () 67 } 68 return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> 69}