xref: /aosp_15_r20/external/tensorflow/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h"
16 
17 #include <string>
18 
19 #include "absl/strings/substitute.h"
20 #include "tensorflow/c/experimental/ops/gen/common/case_format.h"
21 #include "tensorflow/c/experimental/ops/gen/common/view_util.h"
22 #include "tensorflow/core/framework/types.h"
23 
24 namespace tensorflow {
25 namespace generator {
26 namespace cpp {
27 
VariableName() const28 string AttrView::VariableName() const { return attr_.name(); }
29 
VariableType() const30 string AttrView::VariableType() const {
31   // Completely special cases (e.g. strings are different when lists)
32   if (attr_.full_type() == "string") {
33     return "const char*";
34   }
35   if (attr_.full_type() == "list(string)") {
36     return "absl::Span<string const>";
37   }
38 
39   // Normal path: translate base type to C++ ...
40   string c_base_type = attr_.base_type();
41   if (attr_.base_type() == "type") {
42     c_base_type = "DataType";
43   } else if (attr_.base_type() == "shape") {
44     c_base_type = "const PartialTensorShape";
45   }
46 
47   // ... and wrap in a Span<> if it's a list.
48   if (attr_.is_list()) {
49     return absl::Substitute("absl::Span<$0>", c_base_type);
50   } else {
51     return c_base_type;
52   }
53 
54   return attr_.full_type();
55 }
56 
AttrNameString() const57 string AttrView::AttrNameString() const { return Quoted(attr_.name()); }
58 
DefaultValue() const59 string AttrView::DefaultValue() const {
60   const AttrValue &attr_value = attr_.default_value();
61   switch (attr_value.value_case()) {
62     case AttrValue::VALUE_NOT_SET:
63       return "";
64     case AttrValue::kType:
65       return DataType_Name(attr_value.type());
66     case AttrValue::kS:
67       return "\"" + attr_value.s() + "\"";
68     case AttrValue::kI:
69       return std::to_string(attr_value.i());
70     case AttrValue::kF:
71       return std::to_string(attr_value.f());
72     case AttrValue::kB:
73       return attr_value.b() ? "true" : "false";
74     case AttrValue::kList:
75       if (attr_.full_type() == "list(string)" &&
76           attr_value.list().s_size() == 0) {
77         return "{}";
78       }
79       LOG(WARNING) << "Unimplemented: default value of list-typed attribute.";
80       return "/* UNIMPLEMENTED */";
81     case AttrValue::kShape:
82     case AttrValue::kTensor:
83     case AttrValue::kFunc:
84     case AttrValue::kPlaceholder:
85       LOG(ERROR) << "Unexpected non-primitive attribute value.";
86       return "/* ERROR */";
87   }
88 }
89 
VariableStrLen() const90 string AttrView::VariableStrLen() const {
91   return Call("strlen", {VariableName()});
92 }
93 
VariableSpanData() const94 string AttrView::VariableSpanData() const {
95   return Call(VariableName(), "data", {}, ".");
96 }
97 
VariableSpanLen() const98 string AttrView::VariableSpanLen() const {
99   return Call(VariableName(), "length", {}, ".");
100 }
101 
InputArg(bool with_default_value) const102 string AttrView::InputArg(bool with_default_value) const {
103   string default_value = DefaultValue();
104   if (!with_default_value || default_value.empty()) {
105     return absl::Substitute("$0 $1", VariableType(), attr_.name());
106   }
107   return absl::Substitute("$0 $1 = $2", VariableType(), attr_.name(),
108                           default_value);
109 }
110 
SetterMethod() const111 string AttrView::SetterMethod() const {
112   if (!attr_.is_list()) {
113     return absl::StrCat("SetAttr", toUpperCamel(attr_.full_type()));
114   } else {
115     return absl::StrCat("SetAttr", toUpperCamel(attr_.base_type()), "List");
116   }
117 }
118 
SetterArgs() const119 std::vector<string> AttrView::SetterArgs() const {
120   if (attr_.full_type() == "string") {
121     return {AttrNameString(), VariableName(), VariableStrLen()};
122   } else if (attr_.full_type() == "list(string)") {
123     return {AttrNameString(), VariableName()};  // accepts span directly
124   } else if (attr_.is_list()) {
125     return {AttrNameString(), VariableSpanData(), VariableSpanLen()};
126   } else {
127     return {AttrNameString(), VariableName()};
128   }
129 }
130 
131 }  // namespace cpp
132 }  // namespace generator
133 }  // namespace tensorflow
134