xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/math_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 
21 namespace tensorflow {
22 
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 REGISTER_OP("AddN")
28     .Input("inputs: N * T")
29     .Output("sum: T")
30     .Attr("N: int >= 1")
31     .Attr("T: {numbertype, variant}")
32     .SetIsCommutative()
33     .SetIsAggregate()
__anon247ac1530102(InferenceContext* c) 34     .SetShapeFn([](InferenceContext* c) {
35       ShapeHandle cur = c->input(c->num_inputs() - 1);
36       for (int i = c->num_inputs() - 2; i >= 0; --i) {
37         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
38                                         "From merging shape ", i,
39                                         " with other shapes.");
40       }
41       c->set_output(0, cur);
42 
43       DataType dtype;
44       TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
45 
46       if (dtype != DT_VARIANT) {
47         // Exit early if not DT_VARIANT.
48         return OkStatus();
49       } else {
50         // DT_VARIANT shape handle shape inference.  All sizes and dtypes must
51         // be the same; all shapes must be compatible via Merge.
52         std::vector<shape_inference::ShapeAndType> cur_shapes_and_types;
53         auto* shapes_and_types =
54             c->input_handle_shapes_and_types(c->num_inputs() - 1);
55         if (shapes_and_types) {
56           cur_shapes_and_types = *shapes_and_types;
57         }
58 
59         for (int i = c->num_inputs() - 2; i >= 0; --i) {
60           auto shapes_and_types_i = c->input_handle_shapes_and_types(i);
61           if (!shapes_and_types && shapes_and_types_i) {
62             // TODO(ebrevdo): Find cases where this happens and fix their shape
63             // inference.  If we are calling AddN on variant types, they should
64             // all have consistent shape_and_type info.
65             shapes_and_types = shapes_and_types_i;
66           } else if (shapes_and_types && shapes_and_types_i) {
67             if (shapes_and_types_i->size() != shapes_and_types->size()) {
68               return errors::InvalidArgument(
69                   "shapes_and_types[", i,
70                   "].size() == ", shapes_and_types_i->size(),
71                   " != shapes_and_types[0].size() == ",
72                   shapes_and_types->size());
73             }
74             for (int j = 0; j < shapes_and_types->size(); ++j) {
75               if (shapes_and_types->at(j).dtype !=
76                   shapes_and_types_i->at(j).dtype) {
77                 return errors::InvalidArgument(
78                     "shapes_and_types[", i, "][", j, "].dtype() == ",
79                     DataTypeString(shapes_and_types_i->at(j).dtype),
80                     " != shapes_and_types[0][", j, "].dtype == ",
81                     DataTypeString(shapes_and_types->at(j).dtype));
82               }
83               TF_RETURN_WITH_CONTEXT_IF_ERROR(
84                   c->Merge(shapes_and_types_i->at(j).shape,
85                            cur_shapes_and_types.at(j).shape,
86                            &cur_shapes_and_types.at(j).shape),
87                   "From merging shapes_and_types[", i, "][", j, "].shape with ",
88                   "shapes_and_types[0][", j, "].shape");
89             }
90           }
91         }
92         if (shapes_and_types) {
93           c->set_output_handle_shapes_and_types(0, cur_shapes_and_types);
94         }
95         return OkStatus();
96       }
97     });
98 
99 // --------------------------------------------------------------------------
100 
101 // Note that the following operator is just a placeholder and has no
102 // associated kernel. The code in accumulate_n_optimizer.cc replaces
103 // this placeholder with a graph of operators that do have kernels.
104 // The Python code that generates instances of this op is currently in
105 // contrib/framework/python/ops/accumulate_n_v2.py
106 REGISTER_OP("AccumulateNV2")
107     .Input("inputs: N * T")
108     .Output("sum: T")
109     .Attr("N: int >= 1")
110     .Attr("T: numbertype")
111     .Attr("shape: shape")
112     .SetIsCommutative()
113     .SetIsAggregate()
114     .SetShapeFn(shape_inference::ExplicitShape);
115 
116 // --------------------------------------------------------------------------
117 
118 REGISTER_OP("BatchMatMul")
119     .Input("x: T")
120     .Input("y: T")
121     .Output("output: T")
122     .Attr(
123         "T: {bfloat16, half, float, double, int32, int64, complex64, "
124         "complex128}")
125     .Attr("adj_x: bool = false")
126     .Attr("adj_y: bool = false")
127     .SetShapeFn(shape_inference::BatchMatMulShape);
128 
129 REGISTER_OP("BatchMatMulV2")
130     .Input("x: T")
131     .Input("y: T")
132     .Output("output: T")
133     .Attr(
134         "T: {bfloat16, half, float, double, int16, int32, int64, complex64, "
135         "complex128}")
136     .Attr("adj_x: bool = false")
137     .Attr("adj_y: bool = false")
138     .SetShapeFn(shape_inference::BatchMatMulV2Shape);
139 
140 REGISTER_OP("BatchMatMulV3")
141     .Input("x: Ta")
142     .Input("y: Tb")
143     .Output("output: Tout")
144     .Attr(
145         "Ta: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
146         "complex64, complex128}")
147     .Attr(
148         "Tb: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
149         "complex64, complex128}")
150     .Attr(
151         "Tout: {bfloat16, half, float, double, int16, int32, int64, complex64, "
152         "complex128}")
153     .Attr("adj_x: bool = false")
154     .Attr("adj_y: bool = false")
155     .SetShapeFn(shape_inference::BatchMatMulV2Shape);
156 
157 #ifdef INTEL_MKL
158 REGISTER_OP("_MklBatchMatMul")
159     .Input("x: T")
160     .Input("y: T")
161     .Output("output: T")
162     .Attr("T: {bfloat16, float}")
163     .Attr("adj_x: bool = false")
164     .Attr("adj_y: bool = false")
165     .SetShapeFn(shape_inference::BatchMatMulShape);
166 
167 REGISTER_OP("_MklBatchMatMulV2")
168     .Input("x: T")
169     .Input("y: T")
170     .Output("output: T")
171     .Attr("T: {bfloat16, float}")
172     .Attr("adj_x: bool = false")
173     .Attr("adj_y: bool = false")
174     .SetShapeFn(shape_inference::BatchMatMulV2Shape);
175 #endif  // INTEL_MKL
176 
177 // --------------------------------------------------------------------------
178 // Casting Ops
179 //
180 // NOTE: Only a smaller number of types are supported by
181 // Cast. The exact casting rule is TBD. The current
182 // implementation uses C++ static cast rules for numeric
183 // types, which may be changed in the future.
184 REGISTER_OP("Cast")
185     .Input("x: SrcT")
186     .Output("y: DstT")
187     .Attr("SrcT: type")
188     .Attr("DstT: type")
189     .Attr("Truncate: bool = false")
190     .SetTypeConstructor(full_type::NoOp())
191     .SetForwardTypeFn(full_type::KeepExisting())
192     .SetShapeFn(shape_inference::UnchangedShape);
193 
194 REGISTER_OP("_HostCast")
195     .Input("x: SrcT")
196     .Output("y: DstT")
197     .Attr("SrcT: type")
198     .Attr("DstT: type")
199     .Attr("Truncate: bool = false")
200     .SetTypeConstructor(full_type::NoOp())
201     .SetForwardTypeFn(full_type::KeepExisting())
202     .SetShapeFn(shape_inference::UnchangedShape)
203     .Doc(R"doc(
204 Cast x of type SrcT to y of DstT.
205 
206 _HostCast requires its input and produces its output in host memory.
207 )doc");
208 
209 // --------------------------------------------------------------------------
210 
211 REGISTER_OP("Abs")
212     .Input("x: T")
213     .Output("y: T")
214     .Attr("T: {bfloat16, half, float, double, int8, int16, int32, int64}")
215     .SetShapeFn(shape_inference::UnchangedShape);
216 
217 REGISTER_OP("ComplexAbs")
218     .Input("x: T")
219     .Output("y: Tout")
220     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
221     .Attr("Tout: {float, double} = DT_FLOAT")
222     .SetShapeFn(shape_inference::UnchangedShape);
223 
224 // Declares cwise unary operations signature: 't -> 't
225 #define UNARY()                                                            \
226   Input("x: T")                                                            \
227       .Output("y: T")                                                      \
228       .Attr(                                                               \
229           "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
230           "complex64, complex128}")                                        \
231       .SetShapeFn(shape_inference::UnchangedShape)
232 
233 #define UNARY_UNSIGNED()                                                   \
234   Input("x: T")                                                            \
235       .Output("y: T")                                                      \
236       .Attr(                                                               \
237           "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
238           "uint8, uint16, uint32, uint64, complex64, complex128}")         \
239       .SetShapeFn(shape_inference::UnchangedShape)
240 
241 #define UNARY_REAL()                              \
242   Input("x: T")                                   \
243       .Output("y: T")                             \
244       .Attr("T: {bfloat16, half, float, double}") \
245       .SetShapeFn(shape_inference::UnchangedShape)
246 
247 #define UNARY_COMPLEX()                                                  \
248   Input("x: T")                                                          \
249       .Output("y: T")                                                    \
250       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
251       .SetShapeFn(shape_inference::UnchangedShape)
252 
253 #define UNARY_GRADIENT_COMPLEX()                                         \
254   Input("y: T")                                                          \
255       .Input("dy: T")                                                    \
256       .Output("z: T")                                                    \
257       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
258       .SetShapeFn(shape_inference::UnchangedShape)
259 
260 REGISTER_OP("Neg").UNARY();
261 
262 REGISTER_OP("Inv").UNARY();
263 
264 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
265 
266 REGISTER_OP("Reciprocal").UNARY();
267 
268 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
269 
270 REGISTER_OP("Square").UNARY_UNSIGNED();
271 
272 REGISTER_OP("Sqrt").UNARY_COMPLEX();
273 
274 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
275 
276 REGISTER_OP("Rsqrt").UNARY_COMPLEX();
277 
278 REGISTER_OP("Round").UNARY();
279 
280 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
281 
282 REGISTER_OP("Exp").UNARY_COMPLEX();
283 
284 REGISTER_OP("Expm1").UNARY_COMPLEX();
285 
286 REGISTER_OP("Log").UNARY_COMPLEX();
287 
288 REGISTER_OP("Log1p").UNARY_COMPLEX();
289 
290 REGISTER_OP("Sinh").UNARY_COMPLEX();
291 
292 REGISTER_OP("Cosh").UNARY_COMPLEX();
293 
294 REGISTER_OP("Tanh").UNARY_COMPLEX();
295 
296 REGISTER_OP("Asinh").UNARY_COMPLEX();
297 
298 REGISTER_OP("Acosh").UNARY_COMPLEX();
299 
300 REGISTER_OP("Atanh").UNARY_COMPLEX();
301 
302 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
303 
304 REGISTER_OP("Lgamma").UNARY_REAL();
305 
306 REGISTER_OP("Digamma").UNARY_REAL();
307 
308 REGISTER_OP("Erf").UNARY_REAL();
309 REGISTER_OP("Erfinv").UNARY_REAL();
310 REGISTER_OP("Ndtri").UNARY_REAL();
311 REGISTER_OP("Erfc").UNARY_REAL();
312 
313 REGISTER_OP("Sigmoid").UNARY_COMPLEX();
314 
315 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
316 
317 REGISTER_OP("Sin").UNARY_COMPLEX();
318 
319 REGISTER_OP("Cos").UNARY_COMPLEX();
320 
321 REGISTER_OP("Tan").UNARY();
322 
323 REGISTER_OP("Asin").UNARY();
324 
325 REGISTER_OP("Acos").UNARY();
326 
327 REGISTER_OP("Atan").UNARY();
328 
329 REGISTER_OP("_UnaryOpsComposition")
330     .Input("x: T")
331     .Output("y: T")
332     .Attr("T: {float, half, double}")
333     .Attr("op_names: list(string)")
334     .SetShapeFn(shape_inference::UnchangedShape)
335     .Doc(R"doc(
336 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
337 expected to create these operators.
338 )doc");
339 
340 #undef UNARY
341 #undef UNARY_REAL
342 #undef UNARY_COMPLEX
343 
344 REGISTER_OP("IsNan")
345     .Input("x: T")
346     .Output("y: bool")
347     .Attr("T: {bfloat16, half, float, double}")
348     .SetShapeFn(shape_inference::UnchangedShape);
349 
350 REGISTER_OP("IsInf")
351     .Input("x: T")
352     .Output("y: bool")
353     .Attr("T: {bfloat16, half, float, double}")
354     .SetShapeFn(shape_inference::UnchangedShape);
355 
356 REGISTER_OP("IsFinite")
357     .Input("x: T")
358     .Output("y: bool")
359     .Attr("T: {bfloat16, half, float, double}")
360     .SetShapeFn(shape_inference::UnchangedShape);
361 
362 REGISTER_OP("Sign")
363     .Input("x: T")
364     .Output("y: T")
365     .Attr(
366         "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
367         "complex64, complex128}")
368     .SetShapeFn(shape_inference::UnchangedShape);
369 
370 REGISTER_OP("Floor")
371     .Input("x: T")
372     .Output("y: T")
373     .Attr("T: {bfloat16, half, float, double}")
374     .SetShapeFn(shape_inference::UnchangedShape);
375 
376 REGISTER_OP("Ceil")
377     .Input("x: T")
378     .Output("y: T")
379     .Attr("T: {bfloat16, half, float, double}")
380     .SetShapeFn(shape_inference::UnchangedShape);
381 
382 REGISTER_OP("Rint")
383     .Input("x: T")
384     .Output("y: T")
385     .Attr("T: {bfloat16, half, float, double}")
386     .SetShapeFn(shape_inference::UnchangedShape);
387 
388 // Declares cwise binary operations signature: 't, 't -> 't.
389 
390 #define BINARY_MORE()                                                          \
391   Input("x: T").Input("y: T").Output("z: T").Attr(                             \
392       "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
393       "uint32, uint64, int64, complex64, complex128}")
394 
395 #define BINARY_FEWER()                                               \
396   Input("x: T").Input("y: T").Output("z: T").Attr(                   \
397       "T: {bfloat16, half, float, double, int32, int64, complex64, " \
398       "complex128}")
399 
400 REGISTER_OP("Add")
401     .Input("x: T")
402     .Input("y: T")
403     .Output("z: T")
404     .Attr(
405         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
406         "complex64, complex128, string}")
407     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
408 
409 REGISTER_OP("AddV2")
410     .Input("x: T")
411     .Input("y: T")
412     .Output("z: T")
413     .Attr(
414         "T: {bfloat16, half, float, double, uint8, uint16, uint32, uint64, "
415         "int8, int16, int32, int64, complex64, complex128}")
416     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
417     .SetIsAggregate()
418     .SetIsCommutative();
419 
420 #ifdef INTEL_MKL
421 REGISTER_OP("_MklAdd")
422     .Input("x: T")
423     .Input("y: T")
424     .Input("mkl_x: uint8")
425     .Input("mkl_y: uint8")
426     .Output("z: T")
427     .Output("mkl_z: uint8")
428     .Attr(
429         "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
430         "complex128, string, bfloat16}")
431     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
432     .Doc(R"doc(
433 Returns `x` + `y` element-wise.
434 
435 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
436 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
437 )doc");
438 
439 REGISTER_OP("_MklAddV2")
440     .Input("x: T")
441     .Input("y: T")
442     .Input("mkl_x: uint8")
443     .Input("mkl_y: uint8")
444     .Output("z: T")
445     .Output("mkl_z: uint8")
446     .Attr(
447         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
448         "complex64, complex128}")
449     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
450     .SetIsAggregate()
451     .SetIsCommutative()
452     .Doc(R"doc(
453 Returns `x` + `y` element-wise.
454 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
455 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
456 )doc");
457 #endif  // INTEL_MKL
458 
459 REGISTER_OP("Sub")
460     .Input("x: T")
461     .Input("y: T")
462     .Output("z: T")
463     .Attr(
464         "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, "
465         "int64, complex64, complex128, uint32, uint64}")
466     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
467 
468 REGISTER_OP("_MklSub")
469     .BINARY_FEWER()
470     .Input("mkl_x: uint8")
471     .Input("mkl_y: uint8")
472     .Output("mkl_z: uint8")
473     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
474     .Doc(R"doc(
475 Returns x - y element-wise.
476 
477 *NOTE*: `Sub` supports broadcasting. More about broadcasting
478 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
479 )doc");
480 
481 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
482     shape_inference::BroadcastBinaryOpShapeFn);
483 
484 REGISTER_OP("MulNoNan")
485     .Input("x: T")
486     .Input("y: T")
487     .Output("z: T")
488     .Attr("T: {bfloat16, half, float, double, complex64, complex128}")
489     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
490 
491 // Note: This op is not commutative w.r.t. to all its inputs.
492 REGISTER_OP("_MklMul")
493     .BINARY_MORE()
494     .Input("mkl_x: uint8")
495     .Input("mkl_y: uint8")
496     .Output("mkl_z: uint8")
497     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
498     .Doc(R"doc(
499 Returns x * y element-wise.
500 
501 *NOTE*: `Mul` supports broadcasting. More about broadcasting
502 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
503 )doc");
504 
505 REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
506     shape_inference::BroadcastBinaryOpShapeFn);
507 
508 REGISTER_OP("DivNoNan")
509     .Input("x: T")
510     .Input("y: T")
511     .Output("z: T")
512     .Attr("T: {half, float, bfloat16, double, complex64, complex128}")
513     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
514 
515 REGISTER_OP("FloorDiv")
516     .BINARY_MORE()
517     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
518 
519 REGISTER_OP("TruncateDiv")
520     .BINARY_MORE()
521     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
522 
523 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
524     shape_inference::BroadcastBinaryOpShapeFn);
525 
526 // Note SquaredDifference implements conj(x - y)*(x - y).
527 REGISTER_OP("SquaredDifference")
528     .BINARY_FEWER()
529     .SetIsCommutative()
530     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
531 
532 // Note: This op is not commutative w.r.t. to all its inputs.
533 REGISTER_OP("_MklSquaredDifference")
534     .BINARY_FEWER()
535     .Input("mkl_x: uint8")
536     .Input("mkl_y: uint8")
537     .Output("mkl_z: uint8")
538     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
539     .Doc(R"doc(
540 Returns (x - y)(x - y) element-wise.
541 
542 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
543 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
544 )doc");
545 
546 REGISTER_OP("Xlogy")
547     .Input("x: T")
548     .Input("y: T")
549     .Output("z: T")
550     .Attr("T: {half, float, double, complex64, complex128}")
551     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
552 
553 REGISTER_OP("Xlog1py")
554     .Input("x: T")
555     .Input("y: T")
556     .Output("z: T")
557     .Attr("T: {half, float, double, complex64, complex128}")
558     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
559 
560 REGISTER_OP("Xdivy")
561     .Input("x: T")
562     .Input("y: T")
563     .Output("z: T")
564     .Attr("T: {half, float, double, complex64, complex128}")
565     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
566 
567 #undef BINARY_FEWER
568 #undef BINARY_MORE
569 
570 REGISTER_OP("Maximum")
571     .Input("x: T")
572     .Input("y: T")
573     .Output("z: T")
574     .Attr(
575         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
576         "int32, uint32, int64, uint64}")
577     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
578 
579 // Note: This op is not commutative w.r.t. to all its inputs.
580 REGISTER_OP("_MklMaximum")
581     .Input("x: T")
582     .Input("y: T")
583     .Input("mkl_x: uint8")
584     .Input("mkl_y: uint8")
585     .Output("z: T")
586     .Output("mkl_z: uint8")
587     .Attr("T: {half, float, double, int32, int64, bfloat16}")
588     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
589     .Doc(R"doc(
590 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
591 
592 *NOTE*: `Maximum` supports broadcasting. More about broadcasting
593 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
594 )doc");
595 
596 REGISTER_OP("Minimum")
597     .Input("x: T")
598     .Input("y: T")
599     .Output("z: T")
600     .Attr(
601         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
602         "int32, uint32, int64, uint64}")
603     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
604 
605 REGISTER_OP("Mod")
606     .Input("x: T")
607     .Input("y: T")
608     .Output("z: T")
609     .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
610     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
611 
612 REGISTER_OP("FloorMod")
613     .Input("x: T")
614     .Input("y: T")
615     .Output("z: T")
616     .Attr(
617         "T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64, "
618         "bfloat16, half, float, double}")
619     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
620 
621 REGISTER_OP("TruncateMod")
622     .Input("x: T")
623     .Input("y: T")
624     .Output("z: T")
625     .Attr("T: {int32, int64, bfloat16, half, float, double}")
626     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
627 
628 REGISTER_OP("Pow")
629     .Input("x: T")
630     .Input("y: T")
631     .Output("z: T")
632     .Attr(
633         "T: {bfloat16, float, half, double, int8, int16, int32, int64, "
634         "complex64, complex128}")
635     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
636 
637 REGISTER_OP("Igammac")
638     .Input("a: T")
639     .Input("x: T")
640     .Output("z: T")
641     .Attr("T: {bfloat16, half, float, double}")
642     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
643 
644 REGISTER_OP("Igamma")
645     .Input("a: T")
646     .Input("x: T")
647     .Output("z: T")
648     .Attr("T: {bfloat16, half, float, double}")
649     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
650 
651 REGISTER_OP("IgammaGradA")
652     .Input("a: T")
653     .Input("x: T")
654     .Output("z: T")
655     .Attr("T: {float, double}")
656     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
657 
658 REGISTER_OP("Zeta")
659     .Input("x: T")
660     .Input("q: T")
661     .Output("z: T")
662     .Attr("T: {float, double}")
663     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
664 
665 REGISTER_OP("Polygamma")
666     .Input("a: T")
667     .Input("x: T")
668     .Output("z: T")
669     .Attr("T: {float, double}")
670     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
671 
672 REGISTER_OP("Atan2")
673     .Input("y: T")
674     .Input("x: T")
675     .Output("z: T")
676     .Attr("T: {bfloat16, half, float, double}")
677     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
678 
679 REGISTER_OP("Betainc")
680     .Input("a: T")
681     .Input("b: T")
682     .Input("x: T")
683     .Output("z: T")
684     .Attr("T: {float, double}")
__anon247ac1530202(InferenceContext* c) 685     .SetShapeFn([](InferenceContext* c) {
686       const int num_inputs = 3;
687       ShapeHandle output = c->UnknownShape();
688       int num_scalars = 0;
689       ShapeHandle some_non_scalar;
690       for (int i = 0; i < num_inputs; ++i) {
691         ShapeHandle in = c->input(i);
692         if (!c->RankKnown(in)) {
693           some_non_scalar = in;
694           // An input with unknown rank could be either a scalar (to be
695           // broadcast) or some other shape.
696         } else if (c->Rank(in) == 0) {
697           // Input is a scalar, it will be broadcast to the output shape.
698           ++num_scalars;
699         } else {
700           TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
701           some_non_scalar = output;
702         }
703       }
704 
705       if (num_scalars == num_inputs - 1) {
706         // If all but one input is known to be a scalar, then output is the
707         // remaining input.
708         output = some_non_scalar;
709       } else if (num_scalars == num_inputs) {
710         // If all are scalars, output is scalar; pick the first one arbitrarily.
711         output = c->input(0);
712       }
713 
714       c->set_output(0, output);
715       return OkStatus();
716     });
717 
718 // --------------------------------------------------------------------------
719 
720 // Declares cwise binary comparison operations signature: 't, 't -> bool,
721 // where 't has a natural total order.
722 #define COMPARISON()             \
723   Input("x: T")                  \
724       .Input("y: T")             \
725       .Output("z: bool")         \
726       .Attr("T: realnumbertype") \
727       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
728 
729 REGISTER_OP("Less").COMPARISON();
730 
731 REGISTER_OP("LessEqual").COMPARISON();
732 
733 REGISTER_OP("Greater").COMPARISON();
734 
735 REGISTER_OP("GreaterEqual").COMPARISON();
736 
737 #undef COMPARISON
738 
739 // --------------------------------------------------------------------------
740 
741 #define EQUALITY_COMPARISON()                                      \
742   Input("x: T")                                                    \
743       .Input("y: T")                                               \
744       .Output("z: bool")                                           \
745       .SetIsCommutative()                                          \
746       .Attr("T: type")                                             \
747       .Attr("incompatible_shape_error: bool = true")               \
748       .SetShapeFn([](InferenceContext* c) {                        \
749         ShapeHandle x = c->input(0);                               \
750         ShapeHandle y = c->input(1);                               \
751         ShapeHandle output;                                        \
752         bool incompatible_shape_error;                             \
753         TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error",  \
754                                       &incompatible_shape_error)); \
755         TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(   \
756             c, x, y, incompatible_shape_error, &output));          \
757         c->set_output(0, output);                                  \
758         return OkStatus();                                         \
759       })
760 
761 REGISTER_OP("Equal").EQUALITY_COMPARISON();
762 
763 REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
764 
765 #undef EQUALITY_COMPARISON
766 
767 REGISTER_OP("ApproximateEqual")
768     .Input("x: T")
769     .Input("y: T")
770     .Output("z: bool")
771     .SetIsCommutative()
772     .Attr("T: numbertype")
773     .Attr("tolerance: float = 0.00001")
__anon247ac1530302(InferenceContext* c) 774     .SetShapeFn([](InferenceContext* c) {
775       // The inputs 'x' and 'y' must have the same shape.
776       ShapeHandle data_x = c->input(0);
777       ShapeHandle data_y = c->input(1);
778       TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
779       return shape_inference::UnchangedShape(c);
780     });
781 
782 // --------------------------------------------------------------------------
783 
784 REGISTER_OP("LogicalNot")
785     .Input("x: bool")
786     .Output("y: bool")
787     .SetShapeFn(shape_inference::UnchangedShape);
788 
789 #define BINARY_LOGICAL()  \
790   Input("x: bool")        \
791       .Input("y: bool")   \
792       .Output("z: bool")  \
793       .SetIsCommutative() \
794       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
795 
796 REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
797 
798 REGISTER_OP("LogicalOr").BINARY_LOGICAL();
799 
800 #undef BINARY_LOGICAL
801 
802 // --------------------------------------------------------------------------
803 
804 REGISTER_OP("Select")
805     .Input("condition: bool")
806     .Input("t: T")
807     .Input("e: T")
808     .Output("output: T")
809     .Attr("T: type")
__anon247ac1530402(InferenceContext* c) 810     .SetShapeFn([](InferenceContext* c) {
811       auto* handle_data_1 = c->input_handle_shapes_and_types(1);
812       auto* handle_data_2 = c->input_handle_shapes_and_types(2);
813       // Merge handle shape and dtype if applicable.
814       if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
815         const auto size = handle_data_1->size();
816         std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
817         if (size != handle_data_2->size()) {
818           return errors::InvalidArgument(
819               "Trying to merge handles pointing to different numbers of "
820               "tensors.");
821         }
822 
823         for (int i = 0; i < size; ++i) {
824           const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
825           const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
826           if (s1.dtype != s2.dtype) {
827             // TODO(apassos) resolve this in the manner of b/32476923
828             return errors::InvalidArgument(
829                 "Trying to merge handles pointing to different dtypes.");
830           }
831           merged_handle_data[i].dtype = s1.dtype;
832           TF_RETURN_IF_ERROR(
833               c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
834         }
835 
836         c->set_output_handle_shapes_and_types(0, merged_handle_data);
837       }
838 
839       // The inputs 'then' and 'else' must have the same shape.
840       ShapeHandle data = c->input(1);
841       ShapeHandle other = c->input(2);
842       TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
843 
844       // The input 'cond' must either have the same shape as 'then' and
845       // 'else', or be a vector if 'then' and 'else' are at least vectors.
846       ShapeHandle cond = c->input(0);
847 
848       if (!c->RankKnown(cond) || !c->RankKnown(data)) {
849         c->set_output(0, data);
850         return OkStatus();
851       }
852 
853       // rank of shape and data is known.
854 
855       const int32_t cond_rank = c->Rank(cond);
856       const int32_t data_rank = c->Rank(data);
857 
858       if (cond_rank == 0) {
859         // The rank of 'cond' is a scalar.
860         // t and e can have any shape.
861         c->set_output(0, data);
862         return OkStatus();
863       }
864 
865       if (cond_rank != 1) {
866         // If 'cond' is not a vector, and not a scalar,
867         // then shape must match 'then' and 'else'
868         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
869         c->set_output(0, data);
870         return OkStatus();
871       }
872 
873       if (data_rank == 0) {
874         // if 'then' and 'else' are scalar also the cond must be
875         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
876         c->set_output(0, data);
877         return OkStatus();
878       }
879 
880       if (cond_rank == 1) {
881         // if the cond is a vector and the 'then' is not a scalar,
882         // the first dimension of 'then' and 'else'
883         TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
884         c->set_output(0, data);
885         return OkStatus();
886       }
887 
888       c->set_output(0, data);
889 
890       return OkStatus();
891     });
892 
893 REGISTER_OP("SelectV2")
894     .Input("condition: bool")
895     .Input("t: T")
896     .Input("e: T")
897     .Output("output: T")
898     .Attr("T: type")
__anon247ac1530502(InferenceContext* c) 899     .SetShapeFn([](InferenceContext* c) {
900       auto* handle_data_1 = c->input_handle_shapes_and_types(1);
901       auto* handle_data_2 = c->input_handle_shapes_and_types(2);
902       // Merge handle shape and dtype if applicable.
903       if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
904         const auto size = handle_data_1->size();
905         std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
906         if (size != handle_data_2->size()) {
907           return errors::InvalidArgument(
908               "Trying to merge handles pointing to different numbers of "
909               "tensors.");
910         }
911 
912         for (int i = 0; i < size; ++i) {
913           const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
914           const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
915           if (s1.dtype != s2.dtype) {
916             // TODO(apassos) resolve this in the manner of b/32476923
917             return errors::InvalidArgument(
918                 "Trying to merge handles pointing to different dtypes.");
919           }
920           merged_handle_data[i].dtype = s1.dtype;
921           TF_RETURN_IF_ERROR(
922               c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
923         }
924 
925         c->set_output_handle_shapes_and_types(0, merged_handle_data);
926       }
927 
928       // The inputs 'cond', 'then', and 'else' must be broadcastable.
929       // TODO (yongtang): Consolidate 3-ary broadcast instead of
930       // multiple 2-ary broadcast.
931       ShapeHandle cond = c->input(0);
932       ShapeHandle then = c->input(1);
933       ShapeHandle else_ = c->input(2);
934       ShapeHandle other;
935       TF_RETURN_IF_ERROR(
936           BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
937       ShapeHandle output;
938       TF_RETURN_IF_ERROR(
939           BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
940       c->set_output(0, output);
941       return OkStatus();
942     });
943 
944 // --------------------------------------------------------------------------
945 
946 REGISTER_OP("MatMul")
947     .Input("a: T")
948     .Input("b: T")
949     .Output("product: T")
950     .Attr("transpose_a: bool = false")
951     .Attr("transpose_b: bool = false")
952     .Attr(
953         "T: {bfloat16, half, float, double, int32, int64, complex64, "
954         "complex128}")
955     .SetShapeFn(shape_inference::MatMulShape);
956 
957 #ifdef INTEL_MKL
958 REGISTER_OP("_MklMatMul")
959     .Input("a: T")
960     .Input("b: T")
961     .Output("product: T")
962     .Attr("transpose_a: bool = false")
963     .Attr("transpose_b: bool = false")
964     .Attr("T: {bfloat16, float}")
965     .SetShapeFn(shape_inference::MatMulShape);
966 #endif  // INTEL_MKL
967 
968 REGISTER_OP("SparseMatMul")
969     .Input("a: Ta")
970     .Input("b: Tb")
971     .Output("product: float")
972     .Attr("transpose_a: bool = false")
973     .Attr("transpose_b: bool = false")
974     .Attr("a_is_sparse: bool = false")
975     .Attr("b_is_sparse: bool = false")
976     .Attr("Ta: {float, bfloat16} = DT_FLOAT")
977     .Attr("Tb: {float, bfloat16} = DT_FLOAT")
978     .SetShapeFn(shape_inference::MatMulShape);
979 
980 REGISTER_OP("_FusedMatMul")
981     .Input("a: T")
982     .Input("b: T")
983     .Input("args: num_args * T")
984     .Output("product: T")
985     .Attr("transpose_a: bool = false")
986     .Attr("transpose_b: bool = false")
987     .Attr("T: {bfloat16, half, float}")
988     .Attr("num_args: int >= 0")
989     .Attr("fused_ops: list(string) = []")
990     // Attributes for the FusedBatchNorm ----------- //
991     .Attr("epsilon: float = 0.0001")
992     // Attributes for the LeakyRelu ---------------- //
993     .Attr("leakyrelu_alpha: float = 0.2")
994     // --------------------------------------------- //
995     .SetShapeFn(shape_inference::MatMulShape)
996     .Doc(R"doc(
997 Performs a MatMul followed by a specified series of operations.
998 
999 The inputs to the MatMul are specified by `a` and `b`. The series of operations
1000 that follows is specified by the `fused_ops` attribute, which is a list of TF op
1001 names specified as strings (e.g. "Relu"). They are performed in order, where the
1002 (first) input to each op is the output of the preceding op. The first input and
1003 the output of each fused_op must be of type T.
1004 
1005 Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A],
1006 where A is one of {"Elu","Relu","Relu6"}.
1007 
1008 * The first input to BiasAdd is the MatMul result, and the additional BiasAdd
1009 input is specified by `args`.
1010 * If there is an op A specified, the output of the BiasAdd is the input to op A,
1011 and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the
1012 _FusedConv2D output.
1013 
1014 *NOTE*: Do not invoke this operator directly in Python. Grappler is
1015 expected to create these operators.
1016 )doc");
1017 
1018 // --------------------------------------------------------------------------
1019 
1020 // For operations where the output is a reduction function along some
1021 // dimensions of the input.
1022 REGISTER_OP("Sum")
1023     .Input("input: T")
1024     .Input("reduction_indices: Tidx")
1025     .Output("output: T")
1026     .Attr("keep_dims: bool = false")
1027     .Attr("T: numbertype")
1028     .Attr("Tidx: {int32, int64} = DT_INT32")
1029     .SetShapeFn(shape_inference::ReductionShape);
1030 
1031 REGISTER_OP("EuclideanNorm")
1032     .Input("input: T")
1033     .Input("reduction_indices: Tidx")
1034     .Output("output: T")
1035     .Attr("keep_dims: bool = false")
1036     .Attr("T: numbertype")
1037     .Attr("Tidx: {int32, int64} = DT_INT32")
1038     .SetShapeFn(shape_inference::ReductionShape);
1039 
1040 REGISTER_OP("Mean")
1041     .Input("input: T")
1042     .Input("reduction_indices: Tidx")
1043     .Output("output: T")
1044     .Attr("keep_dims: bool = false")
1045     .Attr("T: numbertype")
1046     .Attr("Tidx: {int32, int64} = DT_INT32")
1047     .SetShapeFn(shape_inference::ReductionShape);
1048 
1049 REGISTER_OP("Prod")
1050     .Input("input: T")
1051     .Input("reduction_indices: Tidx")
1052     .Output("output: T")
1053     .Attr("keep_dims: bool = false")
1054     .Attr("T: numbertype")
1055     .Attr("Tidx: {int32, int64} = DT_INT32")
1056     .SetShapeFn(shape_inference::ReductionShape);
1057 
1058 REGISTER_OP("Min")
1059     .Input("input: T")
1060     .Input("reduction_indices: Tidx")
1061     .Output("output: T")
1062     .Attr("keep_dims: bool = false")
1063     .Attr("T: {realnumbertype, quantizedtype}")
1064     .Attr("Tidx: {int32, int64} = DT_INT32")
1065     .SetShapeFn(shape_inference::ReductionShape);
1066 
1067 REGISTER_OP("Max")
1068     .Input("input: T")
1069     .Input("reduction_indices: Tidx")
1070     .Output("output: T")
1071     .Attr("keep_dims: bool = false")
1072     .Attr("T: {realnumbertype, quantizedtype}")
1073     .Attr("Tidx: {int32, int64} = DT_INT32")
1074     .SetShapeFn(shape_inference::ReductionShape);
1075 
1076 namespace {
1077 
ArgOpShape(shape_inference::InferenceContext * c)1078 Status ArgOpShape(shape_inference::InferenceContext* c) {
1079   ShapeHandle dimension_shape;
1080   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
1081 
1082   ShapeHandle input_shape = c->input(0);
1083   if (!c->RankKnown(input_shape)) {
1084     return shape_inference::UnknownShape(c);
1085   }
1086 
1087   const int32_t input_rank = c->Rank(input_shape);
1088   if (input_rank <= 1) {
1089     // Reducing a scalar/vector must return a scalar.
1090     return shape_inference::ScalarShape(c);
1091   }
1092 
1093   const Tensor* dim_t = c->input_tensor(1);
1094   if (dim_t == nullptr) {
1095     // We don't know the value of the dimension, but we
1096     // know the rank of the input, so return the correct
1097     // rank with unknown dimensions.
1098     std::vector<DimensionHandle> dims(input_rank - 1);
1099     for (int i = 0; i < dims.size(); ++i) {
1100       dims[i] = c->UnknownDim();
1101     }
1102 
1103     c->set_output(0, c->MakeShape(dims));
1104     return OkStatus();
1105   }
1106 
1107   int64_t dimension_val;
1108   if (dim_t->dtype() == DT_INT32) {
1109     dimension_val = dim_t->scalar<int32>()();
1110   } else {
1111     dimension_val = dim_t->scalar<int64_t>()();
1112   }
1113 
1114   int64_t axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
1115   if (axis < 0 || axis >= input_rank) {
1116     return errors::InvalidArgument(
1117         "Dimension (", dimension_val, ") must be in the range [", -input_rank,
1118         ", ", input_rank, "), where ", input_rank,
1119         " is the number of dimensions in the input.");
1120   }
1121 
1122   // Return the input shape without the dimension being reduced.
1123   std::vector<DimensionHandle> dims;
1124   for (int i = 0; i < input_rank; ++i) {
1125     if (axis != i) {
1126       dims.emplace_back(c->Dim(input_shape, i));
1127     }
1128   }
1129   c->set_output(0, c->MakeShape(dims));
1130   return OkStatus();
1131 }
1132 
1133 }  // namespace
1134 
1135 REGISTER_OP("ArgMax")
1136     .Input("input: T")
1137     .Input("dimension: Tidx")
1138     .Output("output: output_type")
1139     .Attr("T: {numbertype, bool}")
1140     .Attr("Tidx: {int16, int32, int64} = DT_INT32")
1141     .Attr("output_type: {int16, uint16, int32, int64} = DT_INT64")
1142     .SetShapeFn(ArgOpShape);
1143 
1144 REGISTER_OP("ArgMin")
1145     .Input("input: T")
1146     .Input("dimension: Tidx")
1147     .Output("output: output_type")
1148     .Attr("T: {numbertype, bool}")
1149     .Attr("Tidx: {int32, int64} = DT_INT32")
1150     .Attr("output_type: {int32, int64} = DT_INT64")
1151     .SetShapeFn(ArgOpShape);
1152 
1153 namespace {
1154 
SegmentReductionShapeFn(InferenceContext * c)1155 Status SegmentReductionShapeFn(InferenceContext* c) {
1156   ShapeHandle data_shape;
1157   ShapeHandle segment_ids_shape;
1158   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1159   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
1160 
1161   ShapeHandle subshape;
1162   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1163 
1164   ShapeHandle out;
1165   TF_RETURN_IF_ERROR(
1166       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1167   c->set_output(0, out);
1168   return OkStatus();
1169 }
1170 
SparseSegmentReductionShapeFn(InferenceContext * c)1171 Status SparseSegmentReductionShapeFn(InferenceContext* c) {
1172   ShapeHandle data_shape;
1173   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1174 
1175   ShapeHandle indices_shape;
1176   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1177 
1178   ShapeHandle segment_ids_shape;
1179   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1180 
1181   // indices and segment_ids should merge cleanly.
1182   ShapeHandle unused;
1183   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1184 
1185   ShapeHandle subshape;
1186   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1187 
1188   ShapeHandle out;
1189   TF_RETURN_IF_ERROR(
1190       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1191   c->set_output(0, out);
1192   return OkStatus();
1193 }
1194 
SparseSegmentReductionGradShapeFn(InferenceContext * c)1195 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
1196   ShapeHandle data_shape;
1197   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1198 
1199   ShapeHandle indices_shape;
1200   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1201 
1202   // indices and segment_ids should merge cleanly.
1203   ShapeHandle unused;
1204   TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
1205 
1206   // output_dim0 should be a scalar
1207   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1208 
1209   ShapeHandle subshape;
1210   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1211 
1212   const Tensor* dim0 = c->input_tensor(3);
1213   ShapeHandle dim0_shape;
1214   if (dim0 == nullptr) {
1215     // We don't have the value at inference time, so the output
1216     // shape is unknown.
1217     dim0_shape = c->Vector(InferenceContext::kUnknownDim);
1218   } else {
1219     auto dim0_value = dim0->scalar<int32>()();
1220     if (dim0_value < 0) {
1221       return errors::InvalidArgument(
1222           "Cannot specify a negative value for output_dim0");
1223     }
1224     dim0_shape = c->Vector(dim0_value);
1225   }
1226 
1227   ShapeHandle out;
1228   TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
1229   c->set_output(0, out);
1230   return OkStatus();
1231 }
1232 
SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext * c)1233 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
1234   ShapeHandle data_shape;
1235   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1236 
1237   ShapeHandle indices_shape;
1238   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1239 
1240   ShapeHandle segment_ids_shape;
1241   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1242 
1243   ShapeHandle num_segments_shape;
1244   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
1245 
1246   // indices and segment_ids should merge cleanly.
1247   ShapeHandle unused;
1248   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1249 
1250   ShapeHandle subshape;
1251   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1252 
1253   ShapeHandle out;
1254   const Tensor* dim0 = c->input_tensor(3);
1255   if (dim0 == nullptr) {
1256     // We don't have the value at inference time, so the output
1257     // shape is unknown.
1258     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
1259                                       subshape, &out));
1260   } else {
1261     auto dim0_value = dim0->scalar<int32>()();
1262     if (dim0_value < 0) {
1263       return errors::InvalidArgument(
1264           "Cannot specify a negative value for num_segments");
1265     }
1266     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
1267   }
1268   c->set_output(0, out);
1269   return OkStatus();
1270 }
1271 }  // namespace
1272 
1273 REGISTER_OP("SegmentSum")
1274     .Input("data: T")
1275     .Input("segment_ids: Tindices")
1276     .Output("output: T")
1277     .Attr("T: numbertype")
1278     .Attr("Tindices: {int32,int64}")
1279     .SetShapeFn(SegmentReductionShapeFn);
1280 
1281 REGISTER_OP("SegmentMean")
1282     .Input("data: T")
1283     .Input("segment_ids: Tindices")
1284     .Output("output: T")
1285     .Attr("T: numbertype")
1286     .Attr("Tindices: {int32,int64}")
1287     .SetShapeFn(SegmentReductionShapeFn);
1288 
1289 REGISTER_OP("SegmentProd")
1290     .Input("data: T")
1291     .Input("segment_ids: Tindices")
1292     .Output("output: T")
1293     .Attr("T: numbertype")
1294     .Attr("Tindices: {int32,int64}")
1295     .SetShapeFn(SegmentReductionShapeFn);
1296 
1297 REGISTER_OP("SegmentMin")
1298     .Input("data: T")
1299     .Input("segment_ids: Tindices")
1300     .Output("output: T")
1301     .Attr("T: realnumbertype")
1302     .Attr("Tindices: {int32,int64}")
1303     .SetShapeFn(SegmentReductionShapeFn);
1304 
1305 REGISTER_OP("SegmentMax")
1306     .Input("data: T")
1307     .Input("segment_ids: Tindices")
1308     .Output("output: T")
1309     .Attr("T: realnumbertype")
1310     .Attr("Tindices: {int32,int64}")
1311     .SetShapeFn(SegmentReductionShapeFn);
1312 
1313 REGISTER_OP("UnsortedSegmentSum")
1314     .Input("data: T")
1315     .Input("segment_ids: Tindices")
1316     .Input("num_segments: Tnumsegments")
1317     .Output("output: T")
1318     .Attr("T: numbertype")
1319     .Attr("Tindices: {int32,int64}")
1320     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1321     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1322 
1323 REGISTER_OP("UnsortedSegmentMax")
1324     .Input("data: T")
1325     .Input("segment_ids: Tindices")
1326     .Input("num_segments: Tnumsegments")
1327     .Output("output: T")
1328     .Attr("T: realnumbertype")
1329     .Attr("Tindices: {int32,int64}")
1330     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1331     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1332 
1333 REGISTER_OP("UnsortedSegmentMin")
1334     .Input("data: T")
1335     .Input("segment_ids: Tindices")
1336     .Input("num_segments: Tnumsegments")
1337     .Output("output: T")
1338     .Attr("T: realnumbertype")
1339     .Attr("Tindices: {int32,int64}")
1340     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1341     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1342 
1343 REGISTER_OP("UnsortedSegmentProd")
1344     .Input("data: T")
1345     .Input("segment_ids: Tindices")
1346     .Input("num_segments: Tnumsegments")
1347     .Output("output: T")
1348     .Attr("T: numbertype")
1349     .Attr("Tindices: {int32,int64}")
1350     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1351     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1352 
1353 REGISTER_OP("SparseSegmentSum")
1354     .Input("data: T")
1355     .Input("indices: Tidx")
1356     .Input("segment_ids: Tsegmentids")
1357     .Output("output: T")
1358     .Attr("T: realnumbertype")
1359     .Attr("Tidx: {int32, int64} = DT_INT32")
1360     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1361     .SetShapeFn(SparseSegmentReductionShapeFn);
1362 
1363 REGISTER_OP("SparseSegmentSumWithNumSegments")
1364     .Input("data: T")
1365     .Input("indices: Tidx")
1366     .Input("segment_ids: Tsegmentids")
1367     .Input("num_segments: Tnumsegments")
1368     .Output("output: T")
1369     .Attr("T: realnumbertype")
1370     .Attr("Tidx: {int32, int64} = DT_INT32")
1371     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1372     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1373     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1374 
1375 REGISTER_OP("SparseSegmentSumGrad")
1376     .Input("grad: T")
1377     .Input("indices: Tidx")
1378     .Input("segment_ids: Tsegmentids")
1379     .Input("output_dim0: int32")
1380     .Output("output: T")
1381     .Attr("T: {bfloat16, half, float, double}")
1382     .Attr("Tidx: {int32, int64} = DT_INT32")
1383     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1384     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1385 
1386 REGISTER_OP("SparseSegmentMean")
1387     .Input("data: T")
1388     .Input("indices: Tidx")
1389     .Input("segment_ids: Tsegmentids")
1390     .Output("output: T")
1391     .Attr("T: {bfloat16, half, float, double}")
1392     .Attr("Tidx: {int32, int64} = DT_INT32")
1393     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1394     .SetShapeFn(SparseSegmentReductionShapeFn);
1395 
1396 REGISTER_OP("SparseSegmentMeanWithNumSegments")
1397     .Input("data: T")
1398     .Input("indices: Tidx")
1399     .Input("segment_ids: Tsegmentids")
1400     .Input("num_segments: Tnumsegments")
1401     .Output("output: T")
1402     .Attr("T: {bfloat16, half, float, double}")
1403     .Attr("Tidx: {int32, int64} = DT_INT32")
1404     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1405     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1406     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1407 
1408 REGISTER_OP("SparseSegmentMeanGrad")
1409     .Input("grad: T")
1410     .Input("indices: Tidx")
1411     .Input("segment_ids: Tsegmentids")
1412     .Input("output_dim0: int32")
1413     .Output("output: T")
1414     .Attr("T: {bfloat16, half, float, double}")
1415     .Attr("Tidx: {int32, int64} = DT_INT32")
1416     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1417     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1418 
1419 REGISTER_OP("SparseSegmentSqrtN")
1420     .Input("data: T")
1421     .Input("indices: Tidx")
1422     .Input("segment_ids: Tsegmentids")
1423     .Output("output: T")
1424     .Attr("T: {bfloat16, half, float, double}")
1425     .Attr("Tidx: {int32, int64} = DT_INT32")
1426     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1427     .SetShapeFn(SparseSegmentReductionShapeFn);
1428 
1429 REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
1430     .Input("data: T")
1431     .Input("indices: Tidx")
1432     .Input("segment_ids: Tsegmentids")
1433     .Input("num_segments: Tnumsegments")
1434     .Output("output: T")
1435     .Attr("T: {bfloat16, half, float, double}")
1436     .Attr("Tidx: {int32, int64} = DT_INT32")
1437     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1438     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1439     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1440 
1441 REGISTER_OP("SparseSegmentSqrtNGrad")
1442     .Input("grad: T")
1443     .Input("indices: Tidx")
1444     .Input("segment_ids: Tsegmentids")
1445     .Input("output_dim0: int32")
1446     .Output("output: T")
1447     .Attr("T: {bfloat16, half, float, double}")
1448     .Attr("Tidx: {int32, int64} = DT_INT32")
1449     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1450     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1451 
1452 REGISTER_OP("All")
1453     .Input("input: bool")
1454     .Input("reduction_indices: Tidx")
1455     .Output("output: bool")
1456     .Attr("keep_dims: bool = false")
1457     .Attr("Tidx: {int32, int64} = DT_INT32")
1458     .SetShapeFn(shape_inference::ReductionShape);
1459 
1460 REGISTER_OP("Any")
1461     .Input("input: bool")
1462     .Input("reduction_indices: Tidx")
1463     .Attr("keep_dims: bool = false")
1464     .Output("output: bool")
1465     .Attr("Tidx: {int32, int64} = DT_INT32")
1466     .SetShapeFn(shape_inference::ReductionShape);
1467 
1468 // --------------------------------------------------------------------------
1469 
1470 namespace {
1471 
1472 template <typename T>
RangeSize(const Tensor * start_t,const Tensor * limit_t,const Tensor * delta_t,InferenceContext * const c)1473 Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
1474                  const Tensor* delta_t, InferenceContext* const c) {
1475   T start = start_t->scalar<T>()();
1476   T limit = limit_t->scalar<T>()();
1477   T delta = delta_t->scalar<T>()();
1478   if (start > limit && delta > T(0)) {
1479     return errors::InvalidArgument(
1480         "Requires start <= limit when delta > 0: ", start, "/", limit);
1481   }
1482   if (start < limit && delta < T(0)) {
1483     return errors::InvalidArgument(
1484         "Requires start >= limit when delta < 0: ", start, "/", limit);
1485   }
1486   if (delta == T(0)) {
1487     return errors::InvalidArgument("Requires delta != 0");
1488   }
1489 
1490   int64_t size;
1491   if (std::is_integral<T>::value) {
1492     size = Eigen::divup(static_cast<int64_t>(Eigen::numext::abs(limit - start)),
1493                         static_cast<int64_t>(Eigen::numext::abs(delta)));
1494   } else {
1495     auto size_auto =
1496         Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta));
1497     if (size_auto > std::numeric_limits<int64_t>::max()) {
1498       return errors::InvalidArgument("Requires ((limit - start) / delta) <= ",
1499                                      std::numeric_limits<int64_t>::max());
1500     }
1501     size = static_cast<int64_t>(size_auto);
1502   }
1503 
1504   c->set_output(0, c->Vector(static_cast<int64_t>(size)));
1505   return OkStatus();
1506 }
1507 
1508 }  // namespace
1509 
1510 REGISTER_OP("Range")
1511     .Input("start: Tidx")
1512     .Input("limit: Tidx")
1513     .Input("delta: Tidx")
1514     .Output("output: Tidx")
1515     .Attr(
1516         "Tidx: "
1517         "{bfloat16, half, float, double, int8, int16, int32, int64, uint16, "
1518         "uint32} = "
1519         "DT_INT32")
__anon247ac1530902(InferenceContext* c) 1520     .SetShapeFn([](InferenceContext* c) {
1521       ShapeHandle unused;
1522       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1523                                       " for 'start'");
1524       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1525                                       " for 'limit'");
1526       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1527                                       " for 'delta'");
1528       const Tensor* start_t = c->input_tensor(0);
1529       const Tensor* limit_t = c->input_tensor(1);
1530       const Tensor* delta_t = c->input_tensor(2);
1531       DataType dtype;
1532       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1533       if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
1534         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1535         return OkStatus();
1536       }
1537       if (dtype == DT_INT32) {
1538         return RangeSize<int32>(start_t, limit_t, delta_t, c);
1539       } else if (dtype == DT_INT16) {
1540         return RangeSize<int16>(start_t, limit_t, delta_t, c);
1541       } else if (dtype == DT_INT8) {
1542         return RangeSize<int8>(start_t, limit_t, delta_t, c);
1543       } else if (dtype == DT_INT64) {
1544         return RangeSize<int64_t>(start_t, limit_t, delta_t, c);
1545       } else if (dtype == DT_UINT16) {
1546         return RangeSize<uint16>(start_t, limit_t, delta_t, c);
1547       } else if (dtype == DT_UINT32) {
1548         return RangeSize<uint32>(start_t, limit_t, delta_t, c);
1549       } else if (dtype == DT_FLOAT) {
1550         return RangeSize<float>(start_t, limit_t, delta_t, c);
1551       } else if (dtype == DT_DOUBLE) {
1552         return RangeSize<double>(start_t, limit_t, delta_t, c);
1553       } else if (dtype == DT_BFLOAT16) {
1554         return RangeSize<bfloat16>(start_t, limit_t, delta_t, c);
1555       } else {
1556         return errors::InvalidArgument("Unsupported dtype", dtype);
1557       }
1558       return OkStatus();
1559     });
1560 
1561 REGISTER_OP("LinSpace")
1562     .Input("start: T")
1563     .Input("stop: T")
1564     .Input("num: Tidx")
1565     .Output("output: T")
1566     .Attr("T: {bfloat16, half, float, double}")
1567     .Attr("Tidx: {int32, int64} = DT_INT32")
__anon247ac1530a02(InferenceContext* c) 1568     .SetShapeFn([](InferenceContext* c) {
1569       ShapeHandle unused;
1570       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1571                                       " for 'start'");
1572       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1573                                       " for 'stop'");
1574       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1575                                       " for 'num'");
1576       const Tensor* num_t = c->input_tensor(2);
1577       if (num_t == nullptr) {
1578         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1579         return OkStatus();
1580       }
1581 
1582       int64_t num;
1583       if (num_t->dtype() == DT_INT32) {
1584         num = num_t->scalar<int32>()();
1585       } else {
1586         num = num_t->scalar<int64_t>()();
1587       }
1588       if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
1589       c->set_output(0, c->Vector(num));
1590       return OkStatus();
1591     });
1592 
1593 REGISTER_OP("Complex")
1594     .Input("real: T")
1595     .Input("imag: T")
1596     .Output("out: Tout")
1597     .Attr("T: {float, double} = DT_FLOAT")
1598     .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
1599     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1600 
1601 REGISTER_OP("Real")
1602     .Input("input: T")
1603     .Output("output: Tout")
1604     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1605     .Attr("Tout: {float, double} = DT_FLOAT")
1606     .SetShapeFn(shape_inference::UnchangedShape);
1607 
1608 REGISTER_OP("Imag")
1609     .Input("input: T")
1610     .Output("output: Tout")
1611     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1612     .Attr("Tout: {float, double} = DT_FLOAT")
1613     .SetShapeFn(shape_inference::UnchangedShape);
1614 
1615 REGISTER_OP("Angle")
1616     .Input("input: T")
1617     .Output("output: Tout")
1618     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1619     .Attr("Tout: {float, double} = DT_FLOAT")
1620     .SetShapeFn(shape_inference::UnchangedShape);
1621 
1622 REGISTER_OP("Conj")
1623     .Input("input: T")
1624     .Output("output: T")
1625     .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
__anon247ac1530b02(InferenceContext* c) 1626     .SetShapeFn([](InferenceContext* c) {
1627       c->set_output(0, c->input(0));
1628       auto* handle_data = c->input_handle_shapes_and_types(0);
1629       if (handle_data != nullptr) {
1630         c->set_output_handle_shapes_and_types(0, *handle_data);
1631       }
1632       return OkStatus();
1633     });
1634 
1635 // --------------------------------------------------------------------------
1636 
1637 REGISTER_OP("Cross")
1638     .Input("a: T")
1639     .Input("b: T")
1640     .Output("product: T")
1641     .Attr("T: realnumbertype")
__anon247ac1530c02(InferenceContext* c) 1642     .SetShapeFn([](InferenceContext* c) {
1643       ShapeHandle a_shape;
1644       ShapeHandle b_shape;
1645       // * Input rank >= 1.
1646       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
1647       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
1648 
1649       // * Both inputs have the same shape.
1650       TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
1651 
1652       // * input_shape[-1] == 3.
1653       if (c->RankKnown(a_shape)) {
1654         int rank = c->Rank(a_shape);
1655         auto dim = c->Dim(a_shape, rank - 1);
1656         TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
1657       }
1658       c->set_output(0, a_shape);
1659       return OkStatus();
1660     });
1661 
1662 // --------------------------------------------------------------------------
1663 
1664 REGISTER_OP("HistogramFixedWidth")
1665     .Input("values: T")
1666     .Input("value_range: T")
1667     .Input("nbins: int32")
1668     .Output("out: dtype")
1669     .Attr("T: {int32, int64, float32, float64}")
1670     .Attr("dtype: {int32, int64} = DT_INT32")
__anon247ac1530d02(InferenceContext* c) 1671     .SetShapeFn([](InferenceContext* c) {
1672       // value_range should be a vector.
1673       ShapeHandle value_range_shape;
1674       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
1675       // value_range should have two elements.
1676       DimensionHandle unused;
1677       TF_RETURN_IF_ERROR(
1678           c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
1679       // nbins should be a scalar.
1680       ShapeHandle nbins_shape;
1681       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
1682 
1683       // If nbins is available, set the shape from nbins.
1684       const Tensor* nbins_input = c->input_tensor(2);
1685       if (nbins_input != nullptr) {
1686         int64_t nbins;
1687         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
1688         // nbins has to be positive.
1689         if (nbins <= 0) {
1690           return errors::InvalidArgument("Requires nbins > 0: ", nbins);
1691         }
1692         c->set_output(0, c->Vector(nbins));
1693       } else {
1694         c->set_output(0, c->UnknownShapeOfRank(1));
1695       }
1696       return OkStatus();
1697     });
1698 
1699 REGISTER_OP("Bincount")
1700     .Input("arr: int32")
1701     .Input("size: int32")
1702     .Input("weights: T")
1703     .Attr("T: {int32, int64, float32, float64}")
1704     .Output("bins: T")
__anon247ac1530e02(InferenceContext* c) 1705     .SetShapeFn([](InferenceContext* c) {
1706       ShapeHandle unused;
1707       // The input `size` must be a scalar.
1708       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1709 
1710       const Tensor* size_tensor = c->input_tensor(1);
1711       if (size_tensor == nullptr) {
1712         // Return unknown shape if size is not known.
1713         c->set_output(0, c->UnknownShapeOfRank(1));
1714         return OkStatus();
1715       }
1716 
1717       if (size_tensor->dims() != 0) {
1718         return errors::InvalidArgument("Shape must be rank 0 but is rank ",
1719                                        size_tensor->dims());
1720       }
1721 
1722       // Return `[size]` shape if size is known.
1723       int32_t size_val = size_tensor->scalar<int32>()();
1724       if (size_val < 0) {
1725         return errors::InvalidArgument("size (", size_val,
1726                                        ") must be non-negative");
1727       }
1728       c->set_output(0, c->MakeShape({size_val}));
1729       return OkStatus();
1730     });
1731 
1732 REGISTER_OP("DenseBincount")
1733     .Input("input: Tidx")
1734     .Input("size: Tidx")
1735     .Input("weights: T")
1736     .Attr("Tidx: {int32, int64}")
1737     .Attr("T: {int32, int64, float32, float64}")
1738     .Attr("binary_output: bool = false")
1739     .Output("output: T")
__anon247ac1530f02(InferenceContext* c) 1740     .SetShapeFn([](InferenceContext* c) {
1741       ShapeHandle unused;
1742       // The input `input` must be at most matrix.
1743       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 2, &unused));
1744       // The input `size` must be a scalar.
1745       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1746 
1747       const Tensor* size_tensor = c->input_tensor(1);
1748       if (size_tensor == nullptr) {
1749         // Return unknown shape if size is not known.
1750         c->set_output(0, c->UnknownShape());
1751         return OkStatus();
1752       }
1753       if (size_tensor->dims() != 0) {
1754         return errors::InvalidArgument("Shape must be rank 0 but is rank ",
1755                                        size_tensor->dims());
1756       }
1757 
1758       int64_t size_val;
1759       DataType dtype;
1760       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1761       if (dtype == DT_INT32) {
1762         size_val = static_cast<int64_t>(size_tensor->scalar<int32>()());
1763       } else if (dtype == DT_INT64) {
1764         size_val = size_tensor->scalar<int64_t>()();
1765       } else {
1766         return errors::InvalidArgument("size dtype must be int32 or int64");
1767       }
1768       // Return `[size]` shape if size is known.
1769       if (size_val < 0) {
1770         return errors::InvalidArgument("size (", size_val,
1771                                        ") must be non-negative");
1772       }
1773       if (c->Rank(c->input(0)) == 1) {
1774         c->set_output(0, c->MakeShape({size_val}));
1775       } else if (c->Rank(c->input(0)) == 2) {
1776         c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val}));
1777       }
1778       return OkStatus();
1779     });
1780 
1781 REGISTER_OP("SparseBincount")
1782     .Input("indices: int64")
1783     .Input("values: Tidx")
1784     .Input("dense_shape: int64")
1785     .Input("size: Tidx")
1786     .Input("weights: T")
1787     .Attr("Tidx: {int32, int64}")
1788     .Attr("T: {int32, int64, float32, float64}")
1789     .Attr("binary_output: bool = false")
1790     .Output("output: T")
__anon247ac1531002(InferenceContext* c) 1791     .SetShapeFn([](InferenceContext* c) {
1792       const Tensor* size_tensor = c->input_tensor(3);
1793       if (size_tensor == nullptr) {
1794         // Return unknown shape if size is not known.
1795         c->set_output(0, c->UnknownShape());
1796         return OkStatus();
1797       }
1798       if (size_tensor->dims() != 0) {
1799         return errors::InvalidArgument("Shape must be rank 0 but is rank ",
1800                                        size_tensor->dims());
1801       }
1802 
1803       int64_t size_val;
1804       DataType dtype;
1805       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1806       if (dtype == DT_INT32) {
1807         size_val = static_cast<int64_t>(size_tensor->scalar<int32>()());
1808       } else if (dtype == DT_INT64) {
1809         size_val = size_tensor->scalar<int64_t>()();
1810       } else {
1811         return errors::InvalidArgument("size dtype must be int32 or int64");
1812       }
1813       // Return `[size]` shape if size is known.
1814       if (size_val < 0) {
1815         return errors::InvalidArgument("size (", size_val,
1816                                        ") must be non-negative");
1817       }
1818 
1819       const Tensor* shape_tensor = c->input_tensor(2);
1820       if (shape_tensor == nullptr) {
1821         // Return unknown shape if size is not known.
1822         c->set_output(0, c->UnknownShape());
1823         return OkStatus();
1824       }
1825       if (shape_tensor->NumElements() == 1) {
1826         c->set_output(0, c->MakeShape({size_val}));
1827       } else if (shape_tensor->NumElements() == 2) {
1828         c->set_output(
1829             0, c->MakeShape({shape_tensor->flat<int64_t>()(0), size_val}));
1830       } else {
1831         return errors::InvalidArgument("Input must be less than rank 2");
1832       }
1833       return OkStatus();
1834     });
1835 
1836 REGISTER_OP("RaggedBincount")
1837     .Input("splits: int64")
1838     .Input("values: Tidx")
1839     .Input("size: Tidx")
1840     .Input("weights: T")
1841     .Attr("Tidx: {int32, int64}")
1842     .Attr("T: {int32, int64, float32, float64}")
1843     .Attr("binary_output: bool = false")
1844     .Output("output: T")
__anon247ac1531102(InferenceContext* c) 1845     .SetShapeFn([](InferenceContext* c) {
1846       c->set_output(0, c->UnknownShape());
1847       return OkStatus();
1848     });
1849 
1850 REGISTER_OP("Cumsum")
1851     .Input("x: T")
1852     .Input("axis: Tidx")
1853     .Attr("exclusive: bool = false")
1854     .Attr("reverse: bool = false")
1855     .Output("out: T")
1856     .Attr("T: numbertype")
1857     .Attr("Tidx: {int32, int64} = DT_INT32")
1858     .SetShapeFn(shape_inference::UnchangedShape);
1859 
1860 REGISTER_OP("Cumprod")
1861     .Input("x: T")
1862     .Input("axis: Tidx")
1863     .Attr("exclusive: bool = false")
1864     .Attr("reverse: bool = false")
1865     .Output("out: T")
1866     .Attr("T: numbertype")
1867     .Attr("Tidx: {int32, int64} = DT_INT32")
1868     .SetShapeFn(shape_inference::UnchangedShape);
1869 
1870 REGISTER_OP("CumulativeLogsumexp")
1871     .Input("x : T")
1872     .Input("axis: Tidx")
1873     .Attr("exclusive: bool = false")
1874     .Attr("reverse: bool = false")
1875     .Output("out: T")
1876     .Attr("T: {float16, float32, float64}")
1877     .Attr("Tidx: {int32, int64} = DT_INT32")
1878     .SetShapeFn(shape_inference::UnchangedShape);
1879 
1880 REGISTER_OP("QuantizedMatMul")
1881     .Input("a: T1")
1882     .Input("b: T2")
1883     .Input("min_a: float")
1884     .Input("max_a: float")
1885     .Input("min_b: float")
1886     .Input("max_b: float")
1887     .Output("out: Toutput")
1888     .Output("min_out: float")
1889     .Output("max_out: float")
1890     .Attr("T1: quantizedtype")
1891     .Attr("T2: quantizedtype")
1892     .Attr("Toutput: quantizedtype = DT_QINT32")
1893     .Attr("transpose_a: bool = false")
1894     .Attr("transpose_b: bool = false")
1895     .Attr("Tactivation: quantizedtype = DT_QUINT8")
__anon247ac1531202(InferenceContext* c) 1896     .SetShapeFn([](InferenceContext* c) {
1897       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1898       ShapeHandle unused;
1899       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1900       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1901       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1902       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1903 
1904       c->set_output(1, c->Scalar());
1905       c->set_output(2, c->Scalar());
1906       return OkStatus();
1907     });
1908 
1909 // Note: This op is not commutative w.r.t. to all its inputs.
1910 REGISTER_OP("QuantizedMul")
1911     .Input("x: T1")
1912     .Input("y: T2")
1913     .Input("min_x: float")
1914     .Input("max_x: float")
1915     .Input("min_y: float")
1916     .Input("max_y: float")
1917     .Output("z: Toutput")
1918     .Output("min_z: float")
1919     .Output("max_z: float")
1920     .Attr("T1: quantizedtype")
1921     .Attr("T2: quantizedtype")
1922     .Attr("Toutput: quantizedtype = DT_QINT32")
__anon247ac1531302(InferenceContext* c) 1923     .SetShapeFn([](InferenceContext* c) {
1924       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1925       c->set_output(1, c->Scalar());
1926       c->set_output(2, c->Scalar());
1927       return OkStatus();
1928     });
1929 
1930 // Note: This op is not commutative w.r.t. to all its inputs.
1931 REGISTER_OP("QuantizedAdd")
1932     .Input("x: T1")
1933     .Input("y: T2")
1934     .Input("min_x: float")
1935     .Input("max_x: float")
1936     .Input("min_y: float")
1937     .Input("max_y: float")
1938     .Output("z: Toutput")
1939     .Output("min_z: float")
1940     .Output("max_z: float")
1941     .Attr("T1: quantizedtype")
1942     .Attr("T2: quantizedtype")
1943     .Attr("Toutput: quantizedtype = DT_QINT32")
__anon247ac1531402(InferenceContext* c) 1944     .SetShapeFn([](InferenceContext* c) {
1945       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1946       // min_x, max_x, min_y, max_y should be scalar.
1947       ShapeHandle unused;
1948       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1949       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1950       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1951       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1952 
1953       c->set_output(1, c->Scalar());
1954       c->set_output(2, c->Scalar());
1955       return OkStatus();
1956     });
1957 
1958 REGISTER_OP("QuantizeDownAndShrinkRange")
1959     .Input("input: Tinput")
1960     .Input("input_min: float")
1961     .Input("input_max: float")
1962     .Output("output: out_type")
1963     .Output("output_min: float")
1964     .Output("output_max: float")
1965     .Attr("Tinput: quantizedtype")
1966     .Attr("out_type: quantizedtype")
__anon247ac1531502(InferenceContext* c) 1967     .SetShapeFn([](InferenceContext* c) {
1968       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1969       ShapeHandle unused;
1970       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1971       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1972       c->set_output(1, c->Scalar());
1973       c->set_output(2, c->Scalar());
1974       return OkStatus();
1975     });
1976 
1977 REGISTER_OP("Requantize")
1978     .Input("input: Tinput")
1979     .Input("input_min: float")
1980     .Input("input_max: float")
1981     .Input("requested_output_min: float")
1982     .Input("requested_output_max: float")
1983     .Output("output: out_type")
1984     .Output("output_min: float")
1985     .Output("output_max: float")
1986     .Attr("Tinput: quantizedtype")
1987     .Attr("out_type: quantizedtype")
__anon247ac1531602(InferenceContext* c) 1988     .SetShapeFn([](InferenceContext* c) {
1989       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1990       ShapeHandle unused;
1991       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1992       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1993       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1994       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1995       c->set_output(1, c->Scalar());
1996       c->set_output(2, c->Scalar());
1997       return OkStatus();
1998     });
1999 
2000 REGISTER_OP("RequantizationRange")
2001     .Input("input: Tinput")
2002     .Input("input_min: float")
2003     .Input("input_max: float")
2004     .Output("output_min: float")
2005     .Output("output_max: float")
2006     .Attr("Tinput: quantizedtype")
__anon247ac1531702(InferenceContext* c) 2007     .SetShapeFn([](InferenceContext* c) {
2008       ShapeHandle unused;
2009       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2010       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2011       c->set_output(0, c->Scalar());
2012       c->set_output(1, c->Scalar());
2013       return OkStatus();
2014     });
2015 
2016 // --------------------------------------------------------------------------
2017 
2018 REGISTER_OP("Bucketize")
2019     .Input("input: T")
2020     .Output("output: int32")
2021     .Attr("T: {int32, int64, float, double}")
2022     .Attr("boundaries: list(float)")
2023     .SetShapeFn(shape_inference::UnchangedShape);
2024 
2025 REGISTER_OP("ClipByValue")
2026     .Input("t: T")
2027     .Input("clip_value_min: T")
2028     .Input("clip_value_max: T")
2029     .Output("output: T")
2030     .Attr("T: numbertype")
2031     .SetShapeFn(shape_inference::UnchangedShape);
2032 
2033 #ifdef INTEL_MKL
2034 // Note: This op is not commutative w.r.t. to all its inputs.
2035 REGISTER_OP("_MklAddN")
2036     .Input("inputs: N * T")
2037     .Input("mkl_input: N * uint8")
2038     .Output("sum: T")
2039     .Output("mkl_sum: uint8")
2040     .Attr("N: int >= 1")
2041     .Attr("T: numbertype")
__anon247ac1531802(InferenceContext* c) 2042     .SetShapeFn([](InferenceContext* c) {
2043       ShapeHandle cur = c->input(c->num_inputs() - 1);
2044       for (int i = c->num_inputs() - 2; i >= 0; --i) {
2045         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
2046                                         "From merging shape ", i,
2047                                         " with other shapes.");
2048       }
2049       c->set_output(0, cur);
2050       return Status::OK();
2051     })
2052     .Doc(R"doc(
2053 Add two input tensors element wise using mkl kernel sum.
2054 inputs: Must all be the same size and shape.
2055 )doc");
2056 
2057 #endif  // INTEL_MKL
2058 
2059 REGISTER_OP("RequantizePerChannel")
2060     .Input("input: T")
2061     .Input("input_min: float")
2062     .Input("input_max: float")
2063     .Input("requested_output_min: float")
2064     .Input("requested_output_max: float")
2065     .Output("output: out_type")
2066     .Output("output_min: float")
2067     .Output("output_max: float")
2068     .Attr("T: quantizedtype = DT_QINT32")
2069     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon247ac1531902(InferenceContext* c) 2070     .SetShapeFn([](InferenceContext* c) {
2071       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2072       ShapeHandle unused;
2073       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2074       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2075       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2076       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2077       c->set_output(1, c->Scalar());
2078       c->set_output(2, c->Scalar());
2079       return OkStatus();
2080     });
2081 REGISTER_OP("RequantizationRangePerChannel")
2082     .Input("input: T")
2083     .Input("input_min: float")
2084     .Input("input_max: float")
2085     .Output("output_min: float")
2086     .Output("output_max: float")
2087     .Attr("T: quantizedtype = DT_QINT32")
2088     .Attr("clip_value_max: float")
__anon247ac1531a02(InferenceContext* c) 2089     .SetShapeFn([](InferenceContext* c) {
2090       ShapeHandle unused;
2091       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2092       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2093       c->set_output(0, c->Scalar());
2094       c->set_output(1, c->Scalar());
2095       return OkStatus();
2096     });
2097 
2098 REGISTER_OP("NextAfter")
2099     .Attr("T: {float64, float32} = DT_FLOAT")
2100     .Input("x1: T")
2101     .Input("x2: T")
2102     .Output("output: T")
2103     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
2104 
2105 REGISTER_OP("SobolSample")
2106     .Input("dim: int32")
2107     .Input("num_results: int32")
2108     .Input("skip: int32")
2109     .Attr("dtype: {float, double} = DT_FLOAT")
2110     .Output("samples: dtype")
__anon247ac1531b02(shape_inference::InferenceContext* c) 2111     .SetShapeFn([](shape_inference::InferenceContext* c) {
2112       ShapeHandle unused;
2113 
2114       // inputs must be scalars
2115       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
2116       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2117       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2118 
2119       const Tensor* dim_t = c->input_tensor(0);
2120       const Tensor* num_results_t = c->input_tensor(1);
2121 
2122       int32_t dim = dim_t == nullptr ? InferenceContext::kUnknownDim
2123                                      : dim_t->scalar<int32>()();
2124 
2125       int32_t num_results = num_results_t == nullptr
2126                                 ? InferenceContext::kUnknownDim
2127                                 : num_results_t->scalar<int32>()();
2128 
2129       c->set_output(0, c->Matrix(num_results, dim));
2130       return OkStatus();
2131     });
2132 
2133 }  // namespace tensorflow
2134