xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/risc/risc_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RISC Operations."""
16
17
18from tensorflow.python.ops import gen_risc_ops
19
20
21def risc_abs(x, name='RISC_ABS'):
22  return gen_risc_ops.risc_abs(x, name=name)
23
24
25def risc_add(
26    input_lhs,
27    input_rhs,
28    name='RISC_ADD'):
29  return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name)
30
31
32def risc_binary_arithmetic(x, y, op_type, name='RISC_BinaryArithmetic'):
33  return gen_risc_ops.risc_binary_arithmetic(x, y, op_type=op_type, name=name)
34
35
36def risc_binary_comparison(x, y, op_type, name='RISC_BinaryComparison'):
37  return gen_risc_ops.risc_binary_comparison(x, y, op_type=op_type, name=name)
38
39
40def risc_bitcast(x, dtype, name='RISC_BITCAST'):
41  return gen_risc_ops.risc_bitcast(x, dtype, name=name)
42
43
44def risc_broadcast(x, shape, name='RISC_BROADCAST'):
45  return gen_risc_ops.risc_broadcast(x, shape, name=name)
46
47
48def risc_cast(x, dtype, name='RISC_CAST'):
49  return gen_risc_ops.risc_cast(x, dtype, name=name)
50
51
52def risc_ceil(x, name='RISC_CEIL'):
53  return gen_risc_ops.risc_ceil(x, name=name)
54
55
56def risc_cos(x, name='RISC_COS'):
57  return gen_risc_ops.risc_cos(x, name=name)
58
59
60def risc_cholesky(x, name='RISC_CHOLESKY'):
61  return gen_risc_ops.risc_cholesky(x, name=name)
62
63
64def risc_concat(x, axis, name='RISC_CONCAT'):
65  return gen_risc_ops.risc_concat(x, axis, name=name)
66
67
68def risc_condition(pred,
69                   input_true,
70                   input_false,
71                   func_true,
72                   func_false,
73                   name='RISC_CONDITION'):
74  return gen_risc_ops.risc_condition(
75      pred,
76      input_true,
77      input_false,
78      func_true=func_true,
79      func_false=func_false,
80      name=name)
81
82
83def risc_conv(x,
84              kernel,
85              strides,
86              data_format='NHWC',
87              dilations=None,
88              name='RISC_CONV'):
89  return gen_risc_ops.risc_conv(
90      x,
91      kernel,
92      strides,
93      data_format=data_format,
94      dilations=dilations,
95      name=name)
96
97
98def risc_div(input_lhs, input_rhs, name='RISC_DIV'):
99  return gen_risc_ops.risc_div(input_lhs, input_rhs, name=name)
100
101
102def risc_dot(input_lhs,
103             input_rhs,
104             transpose_a=False,
105             transpose_b=False,
106             name='RISC_DOT'):
107  return gen_risc_ops.risc_dot(
108      input_lhs,
109      input_rhs,
110      transpose_a=transpose_a,
111      transpose_b=transpose_b,
112      name=name)
113
114
115def risc_exp(x, name='RISC_EXP'):
116  return gen_risc_ops.risc_exp(x, name=name)
117
118
119def risc_fft(x, name='RISC_FFT'):
120  return gen_risc_ops.risc_fft(x, name=name)
121
122
123def risc_floor(x, name='RISC_FLOOR'):
124  return gen_risc_ops.risc_floor(x, name=name)
125
126
127def risc_gather(params,
128                indices,
129                validate_indices=None,
130                axis=None,
131                batch_dims=0,
132                name='RISC_GATHER'):
133  return gen_risc_ops.risc_gather(
134      params,
135      indices,
136      validate_indices=validate_indices,
137      name=name,
138      axis=axis,
139      batch_dims=batch_dims)
140
141
142def risc_imag(x, name='RISC_IMAG'):
143  return gen_risc_ops.risc_imag(x, name=name)
144
145
146def risc_is_finite(x, name='RISC_IS_FINITE'):
147  return gen_risc_ops.risc_is_finite(x, name=name)
148
149
150def risc_log(x, name='RISC_LOG'):
151  return gen_risc_ops.risc_log(x, name=name)
152
153
154def risc_logical_and(a, b, name='RISC_LOGICAL_AND'):
155  return gen_risc_ops.risc_logical_and(a, b, name=name)
156
157
158def risc_logical_not(a, b, name='RISC_LOGICAL_NOT'):
159  return gen_risc_ops.risc_logical_not(a, b, name=name)
160
161
162def risc_logical_or(a, b, name='RISC_LOGICAL_OR'):
163  return gen_risc_ops.risc_logical_or(a, b, name=name)
164
165
166def risc_max(input_lhs, input_rhs, name='RISC_MAX'):
167  return gen_risc_ops.risc_max(input_lhs, input_rhs, name=name)
168
169
170def risc_min(input_lhs, input_rhs, name='RISC_MIN'):
171  return gen_risc_ops.risc_min(input_lhs, input_rhs, name=name)
172
173
174def risc_mul(input_lhs, input_rhs, name='RISC_MUL'):
175  return gen_risc_ops.risc_mul(input_lhs, input_rhs, name=name)
176
177
178def risc_neg(x, name='RISC_NEG'):
179  return gen_risc_ops.risc_neg(x, name=name)
180
181
182def risc_pad(x, padding, constant_values, name='RISC_PAD'):
183  return gen_risc_ops.risc_pad(x, padding, constant_values, name=name)
184
185
186def risc_pool(x, ksize, strides, pooling_type='MAX', name='RISC_POOL'):
187  return gen_risc_ops.risc_pool(
188      x, ksize, strides, pooling_type=pooling_type, name=name)
189
190
191def risc_pow(input_lhs, input_rhs, name='RISC_POW'):
192  return gen_risc_ops.risc_pow(input_lhs, input_rhs, name=name)
193
194
195def risc_random_uniform(shape, seed, name='RISC_RANDOM_UNIFORM'):
196  return gen_risc_ops.risc_random_uniform(shape, seed, name=name)
197
198
199def risc_real(x, name='RISC_REAL'):
200  return gen_risc_ops.risc_real(x, name=name)
201
202
203def risc_reduce(x, axis, reduce_type, name='RISC_REDUCE'):
204  return gen_risc_ops.risc_reduce(x, axis, reduce_type=reduce_type, name=name)
205
206
207def risc_rem(x, name='RISC_REM'):
208  return gen_risc_ops.risc_rem(x, name=name)
209
210
211def risc_reshape(x, shape, name='RISC_RESHAPE'):
212  return gen_risc_ops.risc_reshape(x, shape, name=name)
213
214
215def risc_reverse(x, axis, name='RISC_REVERSE'):
216  return gen_risc_ops.risc_reverse(x, axis, name=name)
217
218
219def risc_scatter(indices, updates, shape, name='RISC_SCATTER'):
220  return gen_risc_ops.risc_scatter(indices, updates, shape, name=name)
221
222
223def risc_shape(x, name='RISC_SHAPE'):
224  return gen_risc_ops.risc_shape(x, name=name)
225
226
227def risc_sign(x, name='RISC_SIGN'):
228  return gen_risc_ops.risc_sign(x, name=name)
229
230
231def risc_slice(x, begin, size, name='RISC_SLICE'):
232  return gen_risc_ops.risc_slice(x, begin, size, name=name)
233
234
235def risc_sub(input_lhs, input_rhs, name='RISC_SUB'):
236  return gen_risc_ops.risc_sub(input_lhs, input_rhs, name=name)
237
238
239def risc_sort(x, axis, direction='ASCENDING', name='RISC_SORT'):
240  return gen_risc_ops.risc_sort(x, axis, direction=direction, name=name)
241
242
243def risc_squeeze(x, axis=None, name='RISC_SQUEEZE'):
244  return gen_risc_ops.risc_squeeze(x, axis, name=name)
245
246
247def risc_transpose(x, perm=None, name='RISC_TRANSPOSE'):
248  return gen_risc_ops.risc_transpose(x, perm, name=name)
249
250
251def risc_triangular_solve(matrix,
252                          rhs,
253                          lower=True,
254                          adjoint=False,
255                          name='RISC_TRIANGULAR_SOLVE'):
256  return gen_risc_ops.risc_triangular_solve(
257      matrix, rhs, lower=lower, adjoint=adjoint, name=name)
258
259
260def risc_unary(x, op_type='ABL', name='RISC_UNARY'):
261  return gen_risc_ops.risc_unary(x, op_type=op_type, name=name)
262
263
264def risc_while(cond,
265               body,
266               loop_vars,
267               shape_invariants=None,
268               parallel_iterations=10,
269               back_prop=True,
270               swap_memory=False,
271               maximum_iterations=None,
272               name='RISC_WHILE'):
273  return gen_risc_ops.risc_while(
274      cond=cond,
275      body=body,
276      loop_vars=loop_vars,
277      shape_invariants=shape_invariants,
278      parallel_iterations=parallel_iterations,
279      back_prop=back_prop,
280      swap_memory=swap_memory,
281      name=name,
282      maximum_iterations=maximum_iterations,
283      return_same_structure=True)
284