1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h"
16
17 #include "tensorflow/c/experimental/ops/gen/common/view_util.h"
18 #include "tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h"
19 #include "tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tensorflow {
24 namespace generator {
25 namespace cpp {
26
OpView(OpSpec op)27 OpView::OpView(OpSpec op)
28 : op_(op),
29 input_args_(op_.Inputs().begin(), op_.Inputs().end()),
30 output_args_(op_.Outputs().begin(), op_.Outputs().end()),
31 argument_attrs_(op_.Attributes().begin(), op_.Attributes().end()) {
32 // Initialize function arguments
33 all_arguments_.push_back(OpArgumentView("AbstractContext*", "ctx"));
34 for (const auto& arg : op_.Inputs()) {
35 all_arguments_.push_back(OpArgumentView(arg));
36 }
37 for (const auto& arg : op_.Outputs()) {
38 all_arguments_.push_back(OpArgumentView(arg));
39 }
40 for (const auto& attr : op.Attributes()) {
41 all_arguments_.push_back(OpArgumentView(attr));
42 }
43 all_arguments_.push_back(OpArgumentView("const char*", "name", "nullptr"));
44 all_arguments_.push_back(
45 OpArgumentView("const char*", "raw_device_name", "nullptr"));
46 }
47
Inputs() const48 const std::vector<ArgView>& OpView::Inputs() const { return input_args_; }
49
Outputs() const50 const std::vector<ArgView>& OpView::Outputs() const { return output_args_; }
51
Attributes() const52 const std::vector<AttrView>& OpView::Attributes() const {
53 return argument_attrs_;
54 }
55
AllArguments() const56 const std::vector<OpArgumentView>& OpView::AllArguments() const {
57 return all_arguments_;
58 }
59
NumInputs() const60 int OpView::NumInputs() const { return input_args_.size(); }
61
NumOutputs() const62 int OpView::NumOutputs() const { return output_args_.size(); }
63
OnlyInput() const64 ArgView OpView::OnlyInput() const {
65 CHECK_EQ(input_args_.size(), 1); // Crash OK
66 return input_args_.front();
67 }
68
OnlyOutput() const69 ArgView OpView::OnlyOutput() const {
70 CHECK_EQ(output_args_.size(), 1); // Crash OK
71 return output_args_.front();
72 }
73
FunctionName() const74 string OpView::FunctionName() const { return op_.name(); }
75
OpNameString() const76 string OpView::OpNameString() const { return Quoted(op_.name()); }
77
VariableName() const78 string OpView::VariableName() const { return "op_ptr"; }
79
Description() const80 std::vector<string> OpView::Description() const {
81 return str_util::Split(op_.description(), "\n");
82 }
83
Summary() const84 string OpView::Summary() const { return op_.summary(); }
85
86 // Context
IsListOp() const87 bool OpView::IsListOp() const {
88 return NumOutputs() == 1 && OnlyOutput().IsList();
89 }
90
91 } // namespace cpp
92 } // namespace generator
93 } // namespace tensorflow
94