xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/random_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 REGISTER_OP("RandomUniform")
27     .Input("shape: T")
28     .SetIsStateful()
29     .Output("output: dtype")
30     .Attr("seed: int = 0")
31     .Attr("seed2: int = 0")
32     .Attr("dtype: {half,bfloat16,float,double}")
33     .Attr("T: {int32, int64}")
34     .SetShapeFn(shape_inference::RandomShape);
35 
36 REGISTER_OP("RandomUniformInt")
37     .Input("shape: T")
38     .Input("minval: Tout")
39     .Input("maxval: Tout")
40     .SetIsStateful()
41     .Output("output: Tout")
42     .Attr("seed: int = 0")
43     .Attr("seed2: int = 0")
44     .Attr("Tout: {int32, int64}")
45     .Attr("T: {int32, int64}")
__anond8b80b8a0102(InferenceContext* c) 46     .SetShapeFn([](InferenceContext* c) {
47       ShapeHandle unused;
48       Status s = c->WithRank(c->input(1), 0, &unused);
49       if (!s.ok()) {
50         return errors::InvalidArgument(
51             "minval must be a scalar; got a tensor of shape ",
52             c->DebugString(c->input(1)));
53       }
54       s = c->WithRank(c->input(2), 0, &unused);
55       if (!s.ok()) {
56         return errors::InvalidArgument(
57             "maxval must be a scalar; got a tensor of shape ",
58             c->DebugString(c->input(2)));
59       }
60       return shape_inference::RandomShape(c);
61     });
62 
63 REGISTER_OP("RandomStandardNormal")
64     .Input("shape: T")
65     .SetIsStateful()
66     .Output("output: dtype")
67     .Attr("seed: int = 0")
68     .Attr("seed2: int = 0")
69     .Attr("dtype: {half,bfloat16,float,double}")
70     .Attr("T: {int32, int64}")
71     .SetShapeFn(shape_inference::RandomShape);
72 
73 REGISTER_OP("ParameterizedTruncatedNormal")
74     .Input("shape: T")
75     .Input("means: dtype")
76     .Input("stdevs: dtype")
77     .Input("minvals: dtype")
78     .Input("maxvals: dtype")
79     .SetIsStateful()
80     .Output("output: dtype")
81     .Attr("seed: int = 0")
82     .Attr("seed2: int = 0")
83     .Attr("dtype: {half,bfloat16,float,double}")
84     .Attr("T: {int32, int64}")
__anond8b80b8a0202(InferenceContext* c) 85     .SetShapeFn([](InferenceContext* c) {
86       ShapeHandle unused;
87       // Parameters must be 0-d or 1-d.
88       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused));
89       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
90       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused));
91       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
92       return shape_inference::RandomShape(c);
93     });
94 
95 REGISTER_OP("TruncatedNormal")
96     .Input("shape: T")
97     .SetIsStateful()
98     .Output("output: dtype")
99     .Attr("seed: int = 0")
100     .Attr("seed2: int = 0")
101     .Attr("dtype: {half,bfloat16,float,double}")
102     .Attr("T: {int32, int64}")
103     .SetShapeFn(shape_inference::RandomShape);
104 
105 REGISTER_OP("RandomShuffle")
106     .Input("value: T")
107     .SetIsStateful()
108     .Output("output: T")
109     .Attr("seed: int = 0")
110     .Attr("seed2: int = 0")
111     .Attr("T: type")
112     .SetShapeFn(shape_inference::UnchangedShape);
113 
114 REGISTER_OP("Multinomial")
115     .SetIsStateful()
116     .Input("logits: T")
117     .Input("num_samples: int32")
118     .Output("output: output_dtype")
119     .Attr("seed: int = 0")
120     .Attr("seed2: int = 0")
121     .Attr("T: realnumbertype")
122     .Attr("output_dtype: {int32, int64} = DT_INT64")
__anond8b80b8a0302(InferenceContext* c) 123     .SetShapeFn([](InferenceContext* c) {
124       ShapeHandle logits_shape;
125       ShapeHandle unused;
126       DimensionHandle num_samples;
127       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape));
128       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
129       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples));
130       c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples));
131       return OkStatus();
132     });
133 
134 REGISTER_OP("RandomGamma")
135     .SetIsStateful()
136     .Input("shape: S")
137     .Input("alpha: T")
138     .Output("output: T")
139     .Attr("seed: int = 0")
140     .Attr("seed2: int = 0")
141     .Attr("S: {int32, int64}")
142     .Attr("T: {half, float, double}")
__anond8b80b8a0402(InferenceContext* c) 143     .SetShapeFn([](InferenceContext* c) {
144       ShapeHandle out;
145       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
146       TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
147       c->set_output(0, out);
148       return OkStatus();
149     });
150 
151 REGISTER_OP("RandomGammaGrad")
152     .Input("alpha: T")
153     .Input("sample: T")
154     .Output("output: T")
155     .Attr("T: {float, double}")
156     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
157 
158 REGISTER_OP("RandomPoisson")
159     .SetIsStateful()
160     .Input("shape: S")
161     .Input("rate: dtype")
162     .Output("output: dtype")
163     .Attr("seed: int = 0")
164     .Attr("seed2: int = 0")
165     .Attr("S: {int32, int64}")
166     .Attr("dtype: {half, float, double}")
__anond8b80b8a0502(InferenceContext* c) 167     .SetShapeFn([](InferenceContext* c) {
168       ShapeHandle out;
169       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
170       TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
171       c->set_output(0, out);
172       return OkStatus();
173     })
174     .Deprecated(25, "Replaced by RandomPoissonV2");
175 
176 REGISTER_OP("RandomPoissonV2")
177     .SetIsStateful()
178     .Input("shape: S")
179     .Input("rate: R")
180     .Output("output: dtype")
181     .Attr("seed: int = 0")
182     .Attr("seed2: int = 0")
183     .Attr("S: {int32, int64}")
184     .Attr("R: {half, float, double, int32, int64} = DT_DOUBLE")
185     .Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
__anond8b80b8a0602(InferenceContext* c) 186     .SetShapeFn([](InferenceContext* c) {
187       ShapeHandle out;
188       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
189       TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
190       c->set_output(0, out);
191       return OkStatus();
192     });
193 
194 }  // namespace tensorflow
195