xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/cse/tests/cse.mlir (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1// RUN: tfg-transforms-opt -pass-pipeline='tfg.func(tfg-cse)' %s | FileCheck %s
2
3// CHECK-LABEL: tfg.func @test_simple_cse
4// CHECK-SAME: %[[A:.*]]: tensor
5tfg.func @test_simple_cse(%a: tensor<i32> {tfg.name = "a0"})
6    -> (tensor<i32> {tfg.name = "b0"},
7        tensor<i32> {tfg.name = "b1"}) {
8  // CHECK: %[[ADD1:.*]], %{{.*}} = Add(%[[A]], %[[A]]) name("add1")
9  %Add1, %ctl1 = Add(%a, %a) name("add1") : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
10  // CHECK-NOT: Add(%[[A]], %[[A]]) name("add0")
11  %Add0, %ctl0 = Add(%a, %a) name("add0") : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
12  // CHECK: return(%[[ADD1]], %[[ADD1]])
13  return(%Add0, %Add1) : tensor<i32>, tensor<i32>
14}
15
16// CHECK-LABEL: tfg.func @test_cse_across_regions
17// CHECK-SAME: %[[A:.*]]: tensor
18// CHECK-NEXT: %[[COND:.*]]: tensor
19tfg.func @test_cse_across_regions(%a: tensor<i32> {tfg.name = "a0"},
20                                  %cond: tensor<i1> {tfg.name = "cond"})
21    -> (tensor<i32> {tfg.name = "b0"},
22        tensor<i32> {tfg.name = "b1"}) {
23  // CHECK: %[[ADD0:.*]], %{{.*}} = Add(%[[A]], %[[A]]) name("add0")
24  %Add0, %ctl0 = Add(%a, %a) name("add0") : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
25  // CHECK: %[[IF:.*]], %{{.*}} = StatelessIfRegion
26  %If, %ctl = StatelessIfRegion %cond then {
27    // CHECK-NOT: Add(%[[A]], %[[A]]) name("add1")
28    %Add1, %ctl1 = Add(%a, %a) name("add1") : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
29    // CHECK: yield(%[[ADD0]])
30    yield(%Add1) : tensor<i32>
31  } else {
32    // CHECK-NOT: Add(%[[A]], %[[A]]) name("add2")
33    %Add2, %ctl2 = Add(%a, %a) name("add1") : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
34    // CHECK: yield(%[[ADD0]])
35    yield(%Add2) : tensor<i32>
36  } {_mlir_name = "if"} : (tensor<i1>) -> (tensor<i32>)
37  // CHECK: return(%[[IF]], %[[ADD0]])
38  return(%If, %Add0) : tensor<i32>, tensor<i32>
39}
40
41// CHECK-LABEL: tfg.func @test_cse_control_tokens
42// CHECK-SAME: %[[A:.*]]: tensor
43tfg.func @test_cse_control_tokens(%a: tensor<i32> {tfg.name = "a0"})
44    -> (tensor<i32> {tfg.name = "b0"}) {
45  // CHECK: %[[ADD1:.*]], %[[CTL1:.*]] = Add(%[[A]], %[[A]]) name("add1")
46  %Add1, %ctl1 = Add(%a, %a) name("add1") : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
47  // CHECK-NOT: Add(%[[A]], %[[A]]) name("add0")
48  %Add0, %ctl0 = Add(%a, %a) name("add0") : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
49  // CHECK: %[[CTL:.*]] = NoOp [%[[CTL1]], %[[CTL1]]]
50  %ctl = NoOp [%ctl1, %ctl0]
51  // CHECK: return(%[[ADD1]]) [%[[CTL]] {
52  return(%Add1) [%ctl {tfg.name = "noop"}] : tensor<i32>
53}