1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h"
16 
17 #include "tensorflow/c/experimental/ops/gen/common/view_util.h"
18 #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h"
19 #include "tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h"
20 #include "tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h"
21 #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h"
22 
23 namespace tensorflow {
24 namespace generator {
25 namespace cpp {
26 
OpImplementationRenderer(RendererContext context,OpView op)27 OpImplementationRenderer::OpImplementationRenderer(RendererContext context,
28                                                    OpView op)
29     : Renderer(context), op_(op) {}
30 
Render()31 void OpImplementationRenderer::Render() {
32   RenderInitialization();
33 
34   if (op_.IsListOp()) {
35     RenderExecutionListOp();
36   } else if (op_.NumOutputs() == 0) {
37     RenderExecutionZeroOutputs();
38   } else if (op_.NumOutputs() == 1) {
39     RenderExecutionSingleOutput();
40   } else {
41     RenderExecutionMultipleOutputs();
42   }
43 }
44 
RenderInitialization()45 void OpImplementationRenderer::RenderInitialization() {
46   // Create Op variable and initialize it
47   Statement("AbstractOperationPtr $0(ctx->CreateOperation())",
48             op_.VariableName());
49   TFStatement(Call(op_.VariableName(), "Reset",
50                    {op_.OpNameString(), "raw_device_name"}));
51   TFStatement(Call("MaybeSetOpName", {op_.VariableName() + ".get()", "name"}));
52   // Set each input
53   for (const ArgView& ar : op_.Inputs()) {
54     TFStatement(Call(op_.VariableName(), ar.SetterMethod(), ar.SetterArgs()));
55   }
56   // Set each attribute
57   for (const AttrView& ar : op_.Attributes()) {
58     TFStatement(Call(op_.VariableName(), ar.SetterMethod(), ar.SetterArgs()));
59   }
60 }
61 
RenderExecutionListOp()62 void OpImplementationRenderer::RenderExecutionListOp() {
63   ArgView output_arg = op_.OnlyOutput();
64   Statement("int num_retvals = $0.size()", output_arg.VariableName());
65   Statement("return " + Call(op_.VariableName(), "Execute",
66                              {output_arg.VariableName(), "&num_retvals"}));
67 }
68 
RenderExecutionSingleOutput()69 void OpImplementationRenderer::RenderExecutionSingleOutput() {
70   ArgView output_arg = op_.OnlyOutput();
71   Statement("int num_retvals = 1");
72   Statement("return $0->Execute(absl::MakeSpan($1, 1), &num_retvals)",
73             op_.VariableName(), output_arg.VariableName());
74 }
75 
RenderExecutionMultipleOutputs()76 void OpImplementationRenderer::RenderExecutionMultipleOutputs() {
77   Statement("int num_retvals = $0", op_.NumOutputs());
78   Statement("AbstractTensorHandle* temp_outputs[$0]", op_.NumOutputs());
79   Statement("Status status = $0->Execute(temp_outputs, &num_retvals)",
80             op_.VariableName());
81 
82   for (const ArgView& arg : op_.Outputs()) {
83     Statement("*$0 = temp_outputs[$1]", arg.VariableName(), arg.Position());
84   }
85 
86   Statement("return status");
87 }
88 
RenderExecutionZeroOutputs()89 void OpImplementationRenderer::RenderExecutionZeroOutputs() {
90   Statement("int num_retvals = 0");
91   Statement("std::vector<AbstractTensorHandle*> dummy_outputs");
92   Statement("return $0->Execute(absl::MakeSpan(dummy_outputs), &num_retvals)",
93             op_.VariableName());
94 }
95 
96 }  // namespace cpp
97 }  // namespace generator
98 }  // namespace tensorflow
99