1// RUN: tfg-transforms-opt --tfg-eliminate-passthrough-iter-args %s | FileCheck %s
2
3// CHECK-LABEL: @test_uncapture_all
4// CHECK: %[[INDEX:.*]]: tensor
5// CHECK-NEXT: %[[A0:.*]]: tensor
6// CHECK-NEXT: %[[A1:.*]]: tensor
7// CHECK-NEXT: %[[A2:.*]]: tensor
8// CHECK-NEXT: %[[A3:.*]]: tensor
9tfg.func @test_uncapture_all(%index: tensor<i32> {tfg.name = "index"},
10                             %a0: tensor<i8> {tfg.name = "a0"},
11                             %a1: tensor<i16> {tfg.name = "a1"},
12                             %a2: tensor<i32> {tfg.name = "a2"},
13                             %a3: tensor<i64> {tfg.name = "a3"})
14    -> (tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>) {
15  // CHECK: %{{.*}} = ForRegion from %[[INDEX]]
16  // CHECK: ^bb0(%{{.*}}: tensor<i32>, %{{.*}}: !tf_type.control):
17  // CHECK:   %[[USE:.*]]:2, {{.*}} = Use(%[[A0]], %[[A1]], %[[A2]], %[[A3]])
18  // CHECK:   yield
19  // CHECK: _some_attr
20  // CHECK: return(%[[A0]], %[[A1]], %[[A2]], %[[A3]])
21  %For:4, %ctl = ForRegion(%a0, %a1, %a2, %a3) from %index to %index by %index {
22  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i8>, %arg2: tensor<i16>, %arg3: tensor<i32>, %arg4: tensor<i64>,
23       %ctl0: !tf_type.control, %ctl1: !tf_type.control, %ctl2: !tf_type.control, %ctl3: !tf_type.control, %ctl4: !tf_type.control):
24    %Use:2, %ctl_0 = Use(%arg1, %arg2, %arg3, %arg4) : (tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>) -> (tensor<i8>, tensor<i32>)
25    yield(%arg1, %arg2, %arg3, %arg4) : tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>
26  } {_some_attr}
27  : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>)
28  -> (tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>)
29  return(%For#0, %For#1, %For#2, %For#3) : tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>
30}
31
32// CHECK-LABEL: @test_uncapture_some
33// CHECK: %[[INDEX:.*]]: tensor
34// CHECK-NEXT: %[[A0:.*]]: tensor
35// CHECK-NEXT: %[[A1:.*]]: tensor
36// CHECK-NEXT: %[[A2:.*]]: tensor
37// CHECK-NEXT: %[[A3:.*]]: tensor
38tfg.func @test_uncapture_some(%index: tensor<i32> {tfg.name = "index"},
39                              %a0: tensor<i8> {tfg.name = "a0"},
40                              %a1: tensor<i16> {tfg.name = "a1"},
41                              %a2: tensor<i32> {tfg.name = "a2"},
42                              %a3: tensor<i64> {tfg.name = "a3"})
43    -> (tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>) {
44  // CHECK: %[[FOR:.*]]:2, %{{.*}} = ForRegion(%[[A0]], %[[A2]]) from %[[INDEX]]
45  // CHECK: ^bb0(%{{.*}}: tensor<i32>, %[[ARG0:.*]]: tensor<i8>, %[[ARG1:.*]]: tensor<i32>
46  // CHECK:   %[[USE:.*]]:2, {{.*}} = Use(%[[ARG0]], %[[A1]], %[[ARG1]], %[[A3]])
47  // CHECK:   yield(%[[USE]]#0, %[[USE]]#1)
48  // CHECK: return(%[[FOR]]#0, %[[A1]], %[[FOR]]#1, %[[A3]])
49  %For:4, %ctl = ForRegion(%a0, %a1, %a2, %a3) from %index to %index by %index {
50  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i8>, %arg2: tensor<i16>, %arg3: tensor<i32>, %arg4: tensor<i64>,
51       %ctl0: !tf_type.control, %ctl1: !tf_type.control, %ctl2: !tf_type.control, %ctl3: !tf_type.control, %ctl4: !tf_type.control):
52    %Use:2, %ctl_0 = Use(%arg1, %arg2, %arg3, %arg4) : (tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>) -> (tensor<i8>, tensor<i32>)
53    yield(%Use#0, %arg2, %Use#1, %arg4) : tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>
54  } : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>)
55  -> (tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>)
56  return(%For#0, %For#1, %For#2, %For#3) : tensor<i8>, tensor<i16>, tensor<i32>, tensor<i64>
57}
58
59// CHECK-LABEL: @test_uncapture_while
60// CHECK: %[[A0:.*]]: tensor
61// CHECK: %[[A1:.*]]: tensor
62tfg.func @test_uncapture_while(%a0: tensor<i8> {tfg.name = "a0"},
63                               %a1: tensor<i16> {tfg.name = "a1"})
64    -> (tensor<i8>, tensor<i16>) {
65  // CHECK: %[[WHILE:.*]], {{.*}} = WhileRegion(%[[A1]])
66  // CHECK: ^bb0(%[[ARG0:.*]]: tensor<i16>
67  // CHECK:  %[[COND:.*]], %{{.*}} = Cond(%[[A0]], %[[ARG0]])
68  // CHECK:  %[[CTL:.*]] = NoOp [%[[A0]].ctl]
69  // CHECK:  condition %[[COND]] : tensor<i1> (%[[ARG0]])
70  // CHECK: ^bb0(%[[ARG0:.*]]: tensor<i16>, %[[CTL0:.*]]: !tf_type.control
71  // CHECK:   %[[THING:.*]], %{{.*}} = Thing(%[[A0]], %[[ARG0]]) [%[[CTL0]]]
72  // CHECK:   yield(%[[THING]])
73  // CHECK: cond_region_attrs = #tfg.region_attrs<{tf._a} [{}] [{}]>
74  %While:2, %ctl = WhileRegion(%a0, %a1) {
75  ^bb0(%arg0: tensor<i8>, %arg1: tensor<i16>, %ctl0: !tf_type.control, %ctl1: !tf_type.control):
76    %Cond, %ctl_Cond = Cond(%arg0, %arg1) : (tensor<i8>, tensor<i16>) -> (tensor<i1>)
77    %ctl_NoOp = NoOp [%ctl0] : () -> ()
78    condition %Cond : tensor<i1> (%arg0, %arg1) [%ctl_NoOp] : tensor<i8>, tensor<i16>
79  } do {
80  ^bb0(%arg0: tensor<i8>, %arg1: tensor<i16>, %ctl0: !tf_type.control, %ctl1: !tf_type.control):
81    %Thing, %ctl_Thing = Thing(%arg0, %arg1) [%ctl1] : (tensor<i8>, tensor<i16>) -> (tensor<i16>)
82    yield(%arg0, %Thing) : tensor<i8>, tensor<i16>
83  } {parallel_iterations = 10 : i64,
84     cond_region_attrs = #tfg.region_attrs<{tf._a} [{}, {}] [{}]>}
85  : (tensor<i8>, tensor<i16>) -> (tensor<i8>, tensor<i16>)
86  // CHECK: return(%[[A0]], %[[WHILE]])
87  return(%While#0, %While#1) : tensor<i8>, tensor<i16>
88}
89