xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/framework/ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_
17 #define TENSORFLOW_CC_FRAMEWORK_OPS_H_
18 
19 #include <type_traits>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 
27 namespace tensorflow {
28 
29 /// @defgroup core Core Tensorflow API
30 
31 class Output;
32 
33 /// @addtogroup core
34 /// @{
35 
36 /// Represents a node in the computation graph.
37 class Operation {
38  public:
Operation()39   Operation() : node_(nullptr) {}
40   explicit Operation(Node* n);
41 
num_inputs()42   int32 num_inputs() const { return node_->num_inputs(); }
input_type(int32_t o)43   DataType input_type(int32_t o) const { return node_->input_type(o); }
44   Output input(int32_t i) const;
45 
num_outputs()46   int32 num_outputs() const { return node_->num_outputs(); }
output_type(int32_t o)47   DataType output_type(int32_t o) const { return node_->output_type(o); }
48   Output output(int32_t i) const;
49 
node()50   Node* node() const { return node_; }
51 
52   uint64 hash(int32_t index) const;
53 
54   bool operator==(const Operation& other) const { return node_ == other.node_; }
55 
56  private:
57   typedef std::vector<std::pair<Node*, int32>> Inputs;
58   static Inputs GetInputs(Node* node);
59 
60   Inputs inputs_;
61   Node* node_;
62 };
63 
64 /// Represents a tensor value produced by an Operation.
65 class Output {
66  public:
67   Output() = default;
Output(Node * n)68   explicit Output(Node* n) : op_(n) {}
Output(Node * n,int32_t index)69   Output(Node* n, int32_t index) : op_(n), index_(index) {}
Output(const Operation & op,int32_t index)70   Output(const Operation& op, int32_t index) : op_(op), index_(index) {}
71 
op()72   Operation op() const { return op_; }
node()73   Node* node() const { return op().node(); }
index()74   int32 index() const { return index_; }
type()75   DataType type() const { return op_.output_type(index_); }
name()76   std::string name() const {
77     return strings::StrCat(node()->name(), ":", index());
78   }
79   bool operator==(const Output& other) const {
80     return op_ == other.op_ && index_ == other.index_;
81   }
82 
hash()83   uint64 hash() const { return op_.hash(index_); }
84 
85  private:
86   Operation op_ = Operation(nullptr);
87   int32 index_ = 0;
88 };
89 
90 /// Hash class that can be used for e.g. storing Outputs in an unordered_map
91 struct OutputHash {
operatorOutputHash92   std::size_t operator()(const Output& output) const {
93     return Hash64Combine(std::hash<Node*>()(output.node()),
94                          std::hash<int32>()(output.index()));
95   }
96 };
97 
98 /// Represents a tensor value that can be used as an operand to an Operation.
99 class Input {
100  public:
101   /// Initializer enables constructing an Input object from various kinds of C++
102   /// constants such as simple primitive constants and nested initializer lists
103   /// representing a multi-dimensional array. Initializer constructors are all
104   /// templates, so the aforementioned kinds of C++ constants can be used to
105   /// construct an Initializer. Initializer stores the value it got constructed
106   /// with in a Tensor object.
107   struct Initializer {
108     /// Construct from a scalar value of an arithmetic type or a type that can
109     /// be converted to a string (eg. a string literal).
110     template <typename T, typename = typename std::enable_if<
111                               std::is_arithmetic<T>::value ||
112                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer113     Initializer(const T& v) {  // NOLINT(runtime/explicit)
114       typedef typename RealType<T>::type RealT;
115       Tensor t(DataTypeToEnum<RealT>::v(), TensorShape());
116       t.flat<RealT>()(0) = RealT(v);
117       tensor = t;
118     }
119 
InitializerInitializer120     Initializer(const Tensor& t) : tensor(t) {}  // NOLINT(runtime/explicit)
121 
122     /// Construct from a scalar value and an explicit shape
123     template <typename T, typename = typename std::enable_if<
124                               std::is_arithmetic<T>::value ||
125                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer126     Initializer(const T& v, const TensorShape& shape) {
127       typedef typename RealType<T>::type RealT;
128       Tensor t(DataTypeToEnum<RealT>::v(), shape);
129       for (int64_t i = 0; i < t.NumElements(); ++i) {
130         t.flat<RealT>()(i) = RealT(v);
131       }
132       tensor = t;
133     }
134 
135     /// Construct from a initializer list of scalars (a one-dimensional tensor).
136     template <typename T, typename = typename std::enable_if<
137                               std::is_arithmetic<T>::value ||
138                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer139     Initializer(
140         const std::initializer_list<T>& v) {  // NOLINT(runtime/explicit)
141       typedef typename RealType<T>::type RealT;
142       Tensor t(DataTypeToEnum<RealT>::v(),
143                TensorShape{static_cast<int>(v.size())});
144       std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
145       tensor = t;
146     }
147 
148     /// Construct from a initializer list of scalars and an explicit shape.
149     template <typename T, typename = typename std::enable_if<
150                               std::is_arithmetic<T>::value ||
151                               std::is_convertible<T, std::string>::value>::type>
InitializerInitializer152     Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
153       typedef typename RealType<T>::type RealT;
154       Tensor t(DataTypeToEnum<RealT>::v(), shape);
155       if (t.NumElements() != static_cast<int64_t>(v.size())) {
156         status = errors::InvalidArgument(
157             "Cannot construct a tensor with ", t.NumElements(),
158             " from an initializer list with ", v.size(), " elements");
159         return;
160       }
161       std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
162       tensor = t;
163     }
164 
165     /// Construct a multi-dimensional tensor from a nested initializer
166     /// list. Note that C++ syntax allows nesting of arbitrarily typed
167     /// initializer lists, so such invalid initializers cannot be disallowed at
168     /// compile time. This function performs checks to make sure that the nested
169     /// initializer list is indeed a valid multi-dimensional tensor.
170     Initializer(const std::initializer_list<Initializer>& v);
171 
172     // START_SKIP_DOXYGEN
173     template <typename T, bool = std::is_convertible<T, std::string>::value>
174     struct RealType {
175       typedef tstring type;
176     };
177 
178     template <typename T>
179     struct RealType<T, false> {
180       typedef T type;
181     };
182     // END_SKIP_DOXYGEN
183 
184     TensorProto AsTensorProto() {
185       TensorProto tensor_proto;
186       if (tensor.NumElements() > 1) {
187         tensor.AsProtoTensorContent(&tensor_proto);
188       } else {
189         tensor.AsProtoField(&tensor_proto);
190       }
191       return tensor_proto;
192     }
193 
194     Status status;
195     Tensor tensor;
196   };
197 
198   /// All of Input's constructors are implicit. Input can be implicitly
199   /// constructed from the following objects :
200   /// * Output: This is so that the output of an Operation can be directly used
201   ///   as the input to a op wrapper, which takes Inputs.
202   /// * A scalar, or a multi-dimensional tensor specified as a recursive
203   ///   initializer list. This enables directly passing constants as
204   ///   inputs to op wrappers.
205   /// * A Tensor object.
206   Input(const Output& o) : output_(o) {}  // NOLINT(runtime/explicit)
207 
208   template <typename T, typename = typename std::enable_if<
209                             std::is_arithmetic<T>::value ||
210                             std::is_convertible<T, std::string>::value>::type>
211   Input(const T& v)  // NOLINT(runtime/explicit)
212       : Input(Initializer(v)) {}
213 
214   Input(const Initializer& init)  // NOLINT(runtime/explicit)
215       : status_(init.status),
216         tensor_(init.tensor) {}
217 
218   Input(const Tensor& t)  // NOLINT(runtime/explicit)
219       : status_(OkStatus()), tensor_(t) {}
220 
221   Input(const std::initializer_list<Initializer>&
222             init) {  // NOLINT(runtime/explicit)
223     for (const auto& i : init) {
224       if (!i.status.ok()) {
225         status_ = i.status;
226         return;
227       }
228     }
229     tensor_ = Initializer(init).tensor;
230   }
231 
232   /// Constructor specifying a node name, index and datatype. This should only
233   /// be used for specifying a backward edge, needed by control flow.
234   Input(const std::string& name, int32_t i, DataType dt)
235       : node_name_(name), index_(i), data_type_(dt) {}
236 
237   Node* node() const { return output_.node(); }
238   std::string node_name() const { return node_name_; }
239   int32 index() const { return node_name_.empty() ? output_.index() : index_; }
240   DataType data_type() const { return data_type_; }
241   Status status() const { return status_; }
242   const Tensor& tensor() const { return tensor_; }
243 
244  private:
245   Status status_;
246   Output output_ = Output(Operation(nullptr), 0);
247   Tensor tensor_;
248   const std::string node_name_ = "";
249   int32 index_ = 0;
250   DataType data_type_ = DT_INVALID;
251 };
252 
253 /// A type for representing the output of ops that produce more than one output,
254 /// or a list of tensors.
255 typedef std::vector<Output> OutputList;
256 
257 /// A type for representing the input to ops that require a list of tensors.
258 class InputList {
259  public:
260   /// Implicitly convert a list of outputs to a list of inputs. This is useful
261   /// to write code such as ops::Concat(ops::Split(x, 4)).
262   InputList(const OutputList& out) {  // NOLINT(runtime/explicit)
263     for (auto const& x : out) {
264       inputs_.push_back(x);
265     }
266   }
267 
268   InputList(
269       const std::initializer_list<Input>& inputs)  // NOLINT(runtime/explicit)
270       : inputs_(inputs.begin(), inputs.end()) {}
271 
272   InputList(const tensorflow::gtl::ArraySlice<Input>&
273                 inputs)  // NOLINT(runtime/explicit)
274       : inputs_(inputs.begin(), inputs.end()) {}
275 
276   InputList(
277       const std::initializer_list<Output>& out) {  // NOLINT(runtime/explicit)
278     for (auto const& x : out) {
279       inputs_.push_back(x);
280     }
281   }
282 
283   typename std::vector<Input>::iterator begin() { return inputs_.begin(); }
284   typename std::vector<Input>::iterator end() { return inputs_.end(); }
285   typename std::vector<Input>::const_iterator begin() const {
286     return inputs_.begin();
287   }
288   typename std::vector<Input>::const_iterator end() const {
289     return inputs_.end();
290   }
291 
292  private:
293   std::vector<Input> inputs_;
294 };
295 
296 /// @}
297 
298 }  // namespace tensorflow
299 
300 #endif  // TENSORFLOW_CC_FRAMEWORK_OPS_H_
301