xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convert.mlir (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s
2
3// -----
4
5// CHECK-LABEL: func @same_type
6// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
7func.func @same_type(%arg: tensor<f32>) -> tensor<f32> {
8  %0 = mhlo.convert(%arg) : (tensor<f32>) -> tensor<f32>
9  // CHECK-NEXT: return [[ARG]]
10  func.return %0 : tensor<f32>
11}
12
13// -----
14
15// CHECK-LABEL: func @non_const_chained_convert_unused_parent
16// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
17func.func @non_const_chained_convert_unused_parent(%arg: tensor<f16>) -> tensor<f64> {
18  // CHECK-NEXT: [[RES:%.+]] = mhlo.convert([[ARG]]) : (tensor<f16>) -> tensor<f64>
19  %0 = mhlo.convert(%arg) : (tensor<f16>) -> tensor<f32>
20  %1 = mhlo.convert(%0) : (tensor<f32>) -> tensor<f64>
21  // CHECK-NEXT: return [[RES]]
22  func.return %1 : tensor<f64>
23}
24
25// -----
26
27// CHECK-LABEL: func @non_const_chained_convert_unused_parent_integer
28// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
29func.func @non_const_chained_convert_unused_parent_integer(%arg: tensor<ui16>) -> tensor<i64> {
30  // CHECK-NEXT: [[RES:%.+]] = mhlo.convert([[ARG]]) : (tensor<ui16>) -> tensor<i64>
31  %0 = mhlo.convert(%arg) : (tensor<ui16>) -> tensor<i32>
32  %1 = mhlo.convert(%0) : (tensor<i32>) -> tensor<i64>
33  // CHECK-NEXT: return [[RES]]
34  func.return %1 : tensor<i64>
35}
36
37// -----
38
39// CHECK-LABEL: func @not_convert_float_lower_width
40// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
41func.func @not_convert_float_lower_width(%arg: tensor<f32>) -> tensor<f32> {
42  // CHECK-NEXT: [[VAL0:%.+]] = mhlo.convert([[ARG]]) : (tensor<f32>) -> tensor<f16>
43  // CHECK-NEXT: [[VAL1:%.+]] = mhlo.convert([[VAL0]]) : (tensor<f16>) -> tensor<f32>
44  %0 = mhlo.convert(%arg) : (tensor<f32>) -> tensor<f16>
45  %1 = mhlo.convert(%0) : (tensor<f16>) -> tensor<f32>
46  // CHECK-NEXT: return [[VAL1]]
47  func.return %1 : tensor<f32>
48}
49
50// -----
51
52// CHECK-LABEL: func @non_const_chained_convert_becomes_noop
53// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
54func.func @non_const_chained_convert_becomes_noop(%arg: tensor<f32>) -> tensor<f32> {
55  %0 = mhlo.convert(%arg) : (tensor<f32>) -> tensor<f64>
56  %1 = mhlo.convert(%0) : (tensor<f64>) -> tensor<f32>
57  // CHECK-NEXT: return [[ARG]]
58  func.return %1 : tensor<f32>
59}
60
61// -----
62// CHECK-LABEL: func @non_const_chained_convert
63// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
64func.func @non_const_chained_convert(%arg: tensor<f16>) -> (tensor<f32>, tensor<f64>) {
65  // CHECK-NEXT: [[RES0:%.+]] = mhlo.convert([[ARG]]) : (tensor<f16>) -> tensor<f32>
66  // CHECK-NEXT: [[RES1:%.+]] = mhlo.convert([[ARG]]) : (tensor<f16>) -> tensor<f64>
67  %0 = mhlo.convert(%arg) : (tensor<f16>) -> tensor<f32>
68  %1 = mhlo.convert(%0) : (tensor<f32>) -> tensor<f64>
69  // CHECK-NEXT: return [[RES0]], [[RES1]]
70  func.return %0, %1 : tensor<f32>, tensor<f64>
71}
72
73// -----
74
75// CHECK-LABEL: func @int_widening
76// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
77func.func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
78  // CHECK-NEXT: [[RES:%.+]] = mhlo.convert([[ARG]]) : (tensor<i32>) -> tensor<i64>
79  %0 = mhlo.convert(%arg) : (tensor<i32>) -> tensor<i64>
80  // CHECK-NEXT: return [[RES]]
81  func.return %0 : tensor<i64>
82}
83
84// -----
85
86// CHECK-LABEL: func @int_narrowing
87// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
88func.func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
89  // CHECK-NEXT: [[RES:%.+]] = mhlo.convert([[ARG]]) : (tensor<i32>) -> tensor<i16>
90  %0 = mhlo.convert(%arg) : (tensor<i32>) -> tensor<i16>
91  // CHECK-NEXT: return [[RES]]
92  func.return %0 : tensor<i16>
93}
94
95// -----
96
97// CHECK-LABEL: func @float_int
98// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
99func.func @float_int(%arg: tensor<f32>) -> tensor<i32> {
100  // CHECK-NEXT: [[RES:%.+]] = mhlo.convert([[ARG]]) : (tensor<f32>) -> tensor<i32>
101  %0 = mhlo.convert(%arg) : (tensor<f32>) -> tensor<i32>
102  // CHECK-NEXT: return [[RES]]
103  func.return %0 : tensor<i32>
104}
105
106// -----
107
108// CHECK-LABEL: func @int_float
109// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
110func.func @int_float(%arg: tensor<i32>) -> tensor<f32> {
111  // CHECK-NEXT: [[RES:%.+]] = mhlo.convert([[ARG]]) : (tensor<i32>) -> tensor<f32>
112  %0 = mhlo.convert(%arg) : (tensor<i32>) -> tensor<f32>
113  // CHECK-NEXT: return [[RES]]
114  func.return %0 : tensor<f32>
115}
116
117// -----
118
119// CHECK-LABEL: func @high_rank_tensor
120// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
121func.func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
122  // CHECK-NEXT: [[RES:%.+]] = mhlo.convert([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
123  %0 = mhlo.convert(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
124  // CHECK-NEXT: return [[RES]]
125  func.return %0 : tensor<2x3xf32>
126}
127
128// -----
129
130
131// CHECK-LABEL: func @const_same_type
132func.func @const_same_type() -> tensor<i32> {
133  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
134  %cst = mhlo.constant dense<42> : tensor<i32>
135  %0 = mhlo.convert(%cst) : (tensor<i32>) -> tensor<i32>
136  // CHECK-NEXT: return [[CST]]
137  func.return %0 : tensor<i32>
138}
139
140// -----
141
142// CHECK-LABEL: func @const_float_int
143func.func @const_float_int() -> tensor<i32> {
144  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
145  %cst = mhlo.constant dense<42.0> : tensor<f32>
146  %0 = mhlo.convert(%cst) : (tensor<f32>) -> tensor<i32>
147  // CHECK-NEXT: return [[CST]]
148  func.return %0 : tensor<i32>
149}
150
151// -----
152
153// CHECK-LABEL: func @const_int_float
154func.func @const_int_float() -> tensor<f32> {
155  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<f32>
156  %cst = mhlo.constant dense<4> : tensor<i32>
157  %0 = mhlo.convert(%cst) : (tensor<i32>) -> tensor<f32>
158  // CHECK-NEXT: return [[CST]]
159  func.return %0 : tensor<f32>
160}
161
162// -----
163
164// CHECK-LABEL: func @const_negative_int_float
165func.func @const_negative_int_float() -> tensor<f32> {
166  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
167  %cst = mhlo.constant dense<-4> : tensor<i32>
168  %0 = mhlo.convert(%cst) : (tensor<i32>) -> tensor<f32>
169  // CHECK-NEXT: return [[CST]]
170  func.return %0 : tensor<f32>
171}
172
173// -----
174
175// CHECK-LABEL: func @const_int_bf16
176func.func @const_int_bf16() -> tensor<bf16> {
177  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
178  %cst = mhlo.constant dense<4> : tensor<i32>
179  %0 = mhlo.convert(%cst) : (tensor<i32>) -> tensor<bf16>
180  // CHECK-NEXT: return [[CST]]
181  func.return %0 : tensor<bf16>
182}
183
184// -----
185
186// CHECK-LABEL: func @const_bool_f32
187func.func @const_bool_f32() -> tensor<2xf32> {
188  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
189  %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
190  %0 = mhlo.convert(%cst) : (tensor<2xi1>) -> tensor<2xf32>
191  // CHECK-NEXT: return [[CST]]
192  func.return %0 : tensor<2xf32>
193}
194
195// -----
196
197// CHECK-LABEL: func @const_bf16_int16
198func.func @const_bf16_int16() -> tensor<i16> {
199  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
200  %cst = mhlo.constant dense<42.0> : tensor<bf16>
201  %0 = mhlo.convert(%cst) : (tensor<bf16>) -> tensor<i16>
202  // CHECK-NEXT: return [[CST]]
203  func.return %0 : tensor<i16>
204}
205
206// -----
207
208// CHECK-LABEL: func @const_int_narrowing
209func.func @const_int_narrowing() -> tensor<i32> {
210  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
211  %cst = mhlo.constant dense<42> : tensor<i64>
212  %0 = mhlo.convert(%cst) : (tensor<i64>) -> tensor<i32>
213  // CHECK-NEXT: return [[CST]]
214  func.return %0 : tensor<i32>
215}
216
217// -----
218
219// CHECK-LABEL: func @const_bool_widening
220func.func @const_bool_widening() -> tensor<i64> {
221  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
222  %cst = mhlo.constant dense<42> : tensor<i32>
223  %0 = mhlo.convert(%cst) : (tensor<i32>) -> tensor<i64>
224  // CHECK-NEXT: return [[CST]]
225  func.return %0 : tensor<i64>
226}
227
228// -----
229
230// CHECK-LABEL: func @const_int_widening
231func.func @const_int_widening() -> tensor<2xi32> {
232  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
233  %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
234  %0 = mhlo.convert(%cst) : (tensor<2xi1>) -> tensor<2xi32>
235  // CHECK-NEXT: return [[CST]]
236  func.return %0 : tensor<2xi32>
237}
238
239// -----
240
241// CHECK-LABEL: func @const_negative_int_widening
242func.func @const_negative_int_widening() -> tensor<i64> {
243  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
244  %cst = mhlo.constant dense<-42> : tensor<i32>
245  %0 = mhlo.convert(%cst) : (tensor<i32>) -> tensor<i64>
246  // CHECK-NEXT: return [[CST]]
247  func.return %0 : tensor<i64>
248}
249
250// -----
251
252// CHECK-LABEL: func @const_float_narrowing
253func.func @const_float_narrowing() -> tensor<f32> {
254  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
255  %cst = mhlo.constant dense<4.2> : tensor<f64>
256  %0 = mhlo.convert(%cst) : (tensor<f64>) -> tensor<f32>
257  // CHECK-NEXT: return [[CST]]
258  func.return %0 : tensor<f32>
259}
260
261// -----
262
263// CHECK-LABEL: func @const_f32_bf16
264func.func @const_f32_bf16() -> tensor<bf16> {
265  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
266  %cst = mhlo.constant dense<42.0> : tensor<f32>
267  %0 = mhlo.convert(%cst) : (tensor<f32>) -> tensor<bf16>
268  // CHECK-NEXT: return [[CST]]
269  func.return %0 : tensor<bf16>
270}
271
272// -----
273
274// CHECK-LABEL: func @const_bf16_f64
275func.func @const_bf16_f64() -> tensor<f64> {
276  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor<f64>
277  %cst = mhlo.constant dense<4.2> : tensor<bf16>
278  %0 = mhlo.convert(%cst) : (tensor<bf16>) -> tensor<f64>
279  // CHECK-NEXT: return [[CST]]
280  func.return %0 : tensor<f64>
281}
282
283// -----
284
285// CHECK-LABEL: func @const_bf16_int64
286func.func @const_bf16_int64() -> tensor<i64> {
287  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
288  %cst = mhlo.constant dense<42.0> : tensor<bf16>
289  %0 = mhlo.convert(%cst) : (tensor<bf16>) -> tensor<i64>
290  // CHECK-NEXT: return [[CST]]
291  func.return %0 : tensor<i64>
292}
293
294
295// -----
296
297// CHECK-LABEL: func @const_high_rank_tensor
298func.func @const_high_rank_tensor() -> tensor<2x3xi32> {
299  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
300  // CHECK-SAME:     [1, 2, 3], [4, 5, 6]
301  // CHECK-SAME: ]> : tensor<2x3xi32>
302  %cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
303  %0 = mhlo.convert(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
304  // CHECK-NEXT: return [[CST]]
305  func.return %0 : tensor<2x3xi32>
306}
307
308// -----
309
310// CHECK-LABEL: func @const_int_complex
311func.func @const_int_complex() -> tensor<2xcomplex<f32>> {
312  %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
313  // CHECK: mhlo.convert
314  %0 = mhlo.convert(%cst) : (tensor<2xi1>) -> tensor<2xcomplex<f32>>
315  func.return %0 : tensor<2xcomplex<f32>>
316}
317
318// -----
319
320// CHECK-LABEL: func @const_float_complex
321func.func @const_float_complex() -> tensor<2xcomplex<f64>> {
322  %cst = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
323  // CHECK: mhlo.convert
324  %0 = mhlo.convert(%cst) : (tensor<2xf32>) -> tensor<2xcomplex<f64>>
325  func.return %0 : tensor<2xcomplex<f64>>
326}
327
328
329// -----
330
331// CHECK-LABEL: func @const_complex_int
332func.func @const_complex_int() -> tensor<i32> {
333  %cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
334  // CHECK: mhlo.convert
335  %0 = mhlo.convert(%cst) : (tensor<complex<f32>>) -> tensor<i32>
336  func.return %0 : tensor<i32>
337}
338
339// -----
340
341// CHECK-LABEL: func @const_complex_float
342func.func @const_complex_float() -> tensor<f32> {
343  %cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
344  // CHECK: mhlo.convert
345  %0 = mhlo.convert(%cst) : (tensor<complex<f32>>) -> tensor<f32>
346  func.return %0 : tensor<f32>
347}
348
349// -----
350
351// CHECK-LABEL: func @const_complex_complex
352func.func @const_complex_complex() -> tensor<complex<f64>> {
353  %cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
354  // CHECK: mhlo.convert
355  %0 = mhlo.convert(%cst) : (tensor<complex<f32>>) -> tensor<complex<f64>>
356  func.return %0 : tensor<complex<f64>>
357}
358