xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/while.mlir (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1// RUN: mlir-hlo-opt %s -split-input-file -canonicalize | FileCheck %s
2
3
4// CHECK-LABEL: func @loop_invariants
5module  {
6  func.func @loop_invariants(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) {
7    // The first operand is used directly as an implicit capture in the return
8    // of the body, the third operand is loop carried: they both can be
9    // eliminated, ony the second operand is really a loop-carried value.
10    // CHECK: %[[WHILE:.*]] = mhlo.while
11    // CHECK-SAME: (%[[ITER_ARG:.*]] = %arg2)
12    %0:3 = "mhlo.while"(%arg0, %arg2, %arg3) ({
13    ^bb0(%arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>):
14      // CHECK: mhlo.compare
15      // CHECK-SAME: %[[ITER_ARG]], %arg3
16      %1 = "mhlo.compare"(%arg5, %arg6) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
17      "mhlo.return"(%1) : (tensor<i1>) -> ()
18    },  {
19    ^bb0(%arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>):
20      // CHECK: %[[ADD:.*]] = mhlo.add %[[ITER_ARG]], %arg0
21      %1 = mhlo.add %arg5, %arg4 : tensor<i32>
22     // This op is dead, its removal will enable the canonicalization of the while op.
23      %2 = "mhlo.tuple"(%arg4, %1, %arg6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
24      // CHECK: mhlo.return
25      // CHECK-SAME: %[[ADD]]
26      "mhlo.return"(%arg0, %1, %arg6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
27    }) : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
28    // CHECK: return %arg0, %[[WHILE]], %arg3
29    func.return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32>
30  }
31}
32
33// -----
34
35// CHECK-LABEL: func @dead_loop
36module  {
37  func.func @dead_loop(%arg0: tensor<i32>) -> tensor<i32> {
38    // The following loop will always return its operand which is carried over
39    // from one iteration to the next as-is, that is: we assume that loops
40    // always terminate.
41    // CHECK-NOT: mhlo.while
42    %0 = "mhlo.while"(%arg0) ({
43    ^bb0(%arg1: tensor<i32>):
44      %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
45      "mhlo.return"(%1) : (tensor<i1>) -> ()
46    },  {
47    ^bb0(%arg1: tensor<i32>):
48      "mhlo.return"(%arg1) : (tensor<i32>) -> ()
49    }) : (tensor<i32>) -> (tensor<i32>)
50    func.return %0 : tensor<i32>
51  }
52}
53
54// -----
55
56// CHECK-LABEL: func @fold_constant_cond
57func.func @fold_constant_cond(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
58// CHECK-NOT: while
59// CHECK: return %arg0, %arg
60  %0:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor<4xf32>, tensor<4xf32>
61   cond {
62    %cst = arith.constant dense<false> : tensor<i1>
63    "mhlo.return"(%cst) : (tensor<i1>) -> ()
64  } do {
65    %1 = mhlo.add %iterArg, %iterArg_0 : tensor<4xf32>
66    "mhlo.return"(%1, %1) : (tensor<4xf32>, tensor<4xf32>) -> ()
67  }
68  return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
69}