xref: /aosp_15_r20/external/tensorflow/tensorflow/c/ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/ops.h"
17 
18 #include "tensorflow/c/tf_status_helper.h"
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 
24 using ::tensorflow::DataType;
25 using ::tensorflow::OpDef;
26 using ::tensorflow::OpDefBuilder;
27 using ::tensorflow::OpDeprecation;
28 using ::tensorflow::OpShapeInferenceFn;
29 using ::tensorflow::Set_TF_Status_from_Status;
30 using ::tensorflow::Status;
31 using ::tensorflow::shape_inference::DimensionHandle;
32 using ::tensorflow::shape_inference::InferenceContext;
33 using ::tensorflow::shape_inference::ShapeHandle;
34 
TF_NewOpDefinitionBuilder(const char * op_name)35 TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
36   auto* result = new OpDefBuilder(op_name);
37   return reinterpret_cast<TF_OpDefinitionBuilder*>(result);
38 }
39 
TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder * builder)40 void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
41   delete reinterpret_cast<OpDefBuilder*>(builder);
42 }
43 
TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder * builder,const char * input_spec)44 void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
45                                     const char* input_spec) {
46   reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec);
47 }
48 
TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder * builder,const char * output_spec)49 void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
50                                      const char* output_spec) {
51   reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec);
52 }
53 
54 #define DEFINE_BUILDER_BOOL_SETTER(func_name)                             \
55   void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
56                                          bool arg_name) {                 \
57     reinterpret_cast<OpDefBuilder*>(builder)->func_name();                \
58   }
59 
60 DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative)
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)61 DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)
62 DEFINE_BUILDER_BOOL_SETTER(SetIsStateful)
63 DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput)
64 
65 void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder,
66                                    const char* attr_spec) {
67   reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec);
68 }
69 
TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder * builder,int version,const char * explanation)70 void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
71                                       int version, const char* explanation) {
72   reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation);
73 }
74 
TF_RegisterOpDefinition(TF_OpDefinitionBuilder * builder,TF_Status * status)75 void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
76                              TF_Status* status) {
77   auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
78   TF_SetStatus(status, TF_OK, "");
79   ::tensorflow::OpRegistry::Global()->Register(
80       [cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
81         Status result = cc_builder->Finalize(op_reg_data);
82         delete cc_builder;
83         return result;
84       });
85 }
86 
TF_OpDefinitionBuilderSetShapeInferenceFunction(TF_OpDefinitionBuilder * builder,void (* shape_inference_func)(TF_ShapeInferenceContext * ctx,TF_Status * status))87 void TF_OpDefinitionBuilderSetShapeInferenceFunction(
88     TF_OpDefinitionBuilder* builder,
89     void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
90                                  TF_Status* status)) {
91   auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
92   cc_builder->SetShapeFn(
93       [shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
94         TF_Status* c_status = TF_NewStatus();
95         auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
96         shape_inference_func(c_ctx, c_status);
97         tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
98         TF_DeleteStatus(c_status);
99         return result;
100       });
101 }
102 
TF_NewShapeHandle()103 TF_ShapeHandle* TF_NewShapeHandle() {
104   return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
105 }
106 
TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext * ctx)107 TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
108   auto* handle = new ShapeHandle;
109   *handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
110   return reinterpret_cast<TF_ShapeHandle*>(handle);
111 }
112 
TF_ShapeInferenceContextVectorFromSize(TF_ShapeInferenceContext * ctx,size_t size)113 TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
114     TF_ShapeInferenceContext* ctx, size_t size) {
115   auto* handle = new ShapeHandle;
116   *handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
117   return reinterpret_cast<TF_ShapeHandle*>(handle);
118 }
119 
TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * first,TF_ShapeHandle * second,TF_ShapeHandle * result,TF_Status * status)120 void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
121                                                TF_ShapeHandle* first,
122                                                TF_ShapeHandle* second,
123                                                TF_ShapeHandle* result,
124                                                TF_Status* status) {
125   auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
126   Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
127                                  *reinterpret_cast<ShapeHandle*>(second),
128                                  reinterpret_cast<ShapeHandle*>(result));
129   Set_TF_Status_from_Status(status, s);
130 }
131 
TF_NewDimensionHandle()132 TF_DimensionHandle* TF_NewDimensionHandle() {
133   return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
134 }
135 
TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext * ctx)136 int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
137   auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
138   return cc_ctx->num_inputs();
139 }
140 
TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext * ctx,int i,TF_ShapeHandle * handle,TF_Status * status)141 void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
142                                       TF_ShapeHandle* handle,
143                                       TF_Status* status) {
144   TF_SetStatus(status, TF_OK, "");
145   auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
146   if (i < 0 || i >= cc_ctx->num_inputs()) {
147     TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
148   }
149   if (TF_GetCode(status) == TF_OK) {
150     auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
151     *cc_result = cc_ctx->input(i);
152   }
153 }
154 
TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * handle)155 int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
156                                       TF_ShapeHandle* handle) {
157   auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
158   return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
159 }
160 
TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext * ctx,int i,TF_ShapeHandle * handle,TF_Status * status)161 void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
162                                        TF_ShapeHandle* handle,
163                                        TF_Status* status) {
164   TF_SetStatus(status, TF_OK, "");
165   auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
166   if (i < 0 || i >= cc_ctx->num_outputs()) {
167     TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
168   }
169   if (TF_GetCode(status) == TF_OK) {
170     cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
171   }
172 }
173 
TF_DeleteShapeHandle(TF_ShapeHandle * handle)174 void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
175   if (handle == nullptr) {
176     return;
177   }
178 
179   delete reinterpret_cast<ShapeHandle*>(handle);
180 }
181 
TF_DeleteDimensionHandle(TF_DimensionHandle * handle)182 void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
183   if (handle == nullptr) {
184     return;
185   }
186 
187   delete reinterpret_cast<DimensionHandle*>(handle);
188 }
189 
190 #define DEFINE_TF_GETATTR(func, c_type, cc_type)                         \
191   void TF_ShapeInferenceContext_GetAttr##func(                           \
192       TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
193       TF_Status* status) {                                               \
194     TF_SetStatus(status, TF_OK, "");                                     \
195     cc_type v;                                                           \
196     auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);             \
197     Status s = cc_ctx->GetAttr(attr_name, &v);                           \
198     Set_TF_Status_from_Status(status, s);                                \
199     if (s.ok()) {                                                        \
200       *val = static_cast<c_type>(v);                                     \
201     }                                                                    \
202   }
203 
DEFINE_TF_GETATTR(Type,TF_DataType,tensorflow::DataType)204 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
205 
206 #define DEFINE_RANK_FUNC(func_name)                                        \
207   void TF_ShapeInferenceContext##func_name(                                \
208       TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
209       TF_ShapeHandle* result, TF_Status* status) {                         \
210     auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);               \
211     auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle);              \
212     auto* cc_result = reinterpret_cast<ShapeHandle*>(result);              \
213     Status s = cc_ctx->func_name(*cc_handle, rank, cc_result);             \
214     Set_TF_Status_from_Status(status, s);                                  \
215   }
216 
217 DEFINE_RANK_FUNC(WithRank)
218 DEFINE_RANK_FUNC(WithRankAtLeast)
219 DEFINE_RANK_FUNC(WithRankAtMost)
220 
221 int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
222                                      TF_ShapeHandle* handle) {
223   return reinterpret_cast<InferenceContext*>(ctx)->Rank(
224       *reinterpret_cast<ShapeHandle*>(handle));
225 }
226 
TF_ShapeInferenceContextDim(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * shape_handle,int64_t i,TF_DimensionHandle * result)227 void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
228                                  TF_ShapeHandle* shape_handle, int64_t i,
229                                  TF_DimensionHandle* result) {
230   int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
231   auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
232 
233   if (i < -rank || i >= rank) {
234     *cc_result = DimensionHandle();
235     return;
236   }
237 
238   auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
239   auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
240   *cc_result = cc_ctx->Dim(*cc_shape_handle, i);
241 }
242 
TF_DimensionHandleValueKnown(TF_DimensionHandle * dim_handle)243 int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
244   return InferenceContext::ValueKnown(
245       *reinterpret_cast<DimensionHandle*>(dim_handle));
246 }
247 
TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext * ctx,TF_Status * status)248 void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
249                                              TF_Status* status) {
250   Status s = ::tensorflow::shape_inference::UnknownShape(
251       reinterpret_cast<InferenceContext*>(ctx));
252   Set_TF_Status_from_Status(status, s);
253 }
254 
TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * shape_handle,int64_t start,int64_t end,TF_ShapeHandle * result,TF_Status * status)255 void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
256                                       TF_ShapeHandle* shape_handle,
257                                       int64_t start, int64_t end,
258                                       TF_ShapeHandle* result,
259                                       TF_Status* status) {
260   TF_SetStatus(status, TF_OK, "");
261   auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
262   auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
263   Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
264                               start, end, cc_result);
265   Set_TF_Status_from_Status(status, s);
266 }
267 
TF_DimensionHandleValue(TF_DimensionHandle * dim_handle)268 int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
269   return InferenceContext::Value(
270       *reinterpret_cast<DimensionHandle*>(dim_handle));
271 }
272