1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
18 
19 // Operations calling functions are becoming ubiquitous in TF 2.0.
20 // Examples include PartitionedCallOp, functional If/While, and Dataset ops.
21 // Such operations might require deep inspection - looking at the body of the
22 // called function - to place them and surrounding ops correctly.
23 
24 // This file contains some utilities for placer to correctly place such ops
25 // including:
26 // - PlacerInspectionRequiredOpChecker: A simple class with a single
27 // IsPlacerInspectionRequired method.
28 // - IsolatePlacerInspectionRequiredOps: This function adds Identity ops for
29 // each input/output of ops requiring placer inspection. It greatly simplifies
30 // the implementation of placing such ops.
31 
32 #include <vector>
33 
34 #include "absl/types/optional.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/status.h"
38 
39 namespace tensorflow {
40 
41 // PlacerInspectionRequiredOpChecker allows one to check if Placer needs to
42 // look deeply into the op to place ops consuming the outputs correctly.
43 //
44 // It is a class instead of a standalone method because checking whether
45 // a function returns a resource takes non-trivial time and we cache the
46 // results.
47 class PlacerInspectionRequiredOpChecker {
48  public:
49   // Constructs a PlacerInspectionRequiredOpChecker for nodes of `graph`.
50   // The functions referenced by nodes in `graph` will be looked up in
51   // `flib_def`
52   PlacerInspectionRequiredOpChecker(const Graph* graph,
53                                     const FunctionLibraryDefinition* flib_def);
54 
55   // If `node` is considered a deep op, sets `*is_deep` to true and returns
56   // Status::OK(). If an error occurs, returns that error, and the value of
57   // `*is_deep` is undefined.
58   // Currently, an op is considered deep, if it is a calling a function
59   // returning a resource. This definition is driven by Placer's need to
60   // look inside the op.
61   // REQUIRES: `node` is part of `graph` passed into constructor.
62   Status IsPlacerInspectionRequired(const Node& node, bool* is_deep);
63 
64  private:
65   const Graph& graph_;
66   const FunctionLibraryDefinition& flib_def_;
67   // Indexed by the node id.
68   // If cache_[node_id] is empty, the deepness of the node with id `node_id` has
69   // not been computed yet. Else, it contains the value already computed.
70   std::vector<absl::optional<bool>> cache_;
71 };
72 
73 // Extracts `fdef` and `func` from `flib_def` for the function identified
74 // in "f" attribute of `node`.
75 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
76                               const Node& node, const FunctionDef** fdef,
77                               NameAttrList* func);
78 
79 // The "call" stack of functions.
80 // Useful for better error messages as well as for detecting recursion.
81 // Stores references to graph nodes. These references must outlive this.
82 class FunctionStack {
83  public:
84   explicit FunctionStack(const string& function_name);
85 
86   // `node_in_current_function` must outlive this.
87   FunctionStack Push(const Node* node_in_current_function,
88                      const string& new_current_function) const;
89 
90   // Returns true iff this stack already includes `function_name`.
91   bool HasFunction(const string& function_name) const;
92 
current_function_name()93   const string& current_function_name() const { return current_function_name_; }
94 
95   // Format's this suitable for error interpolation that retrieves
96   // Python files and line numbers.
97   string FormatForError() const;
98 
99  private:
100   struct Frame {
FrameFrame101     Frame(const string& function, const Node* node)
102         : function_name(function), node(node) {}
103 
104     string function_name;
105     const Node* node;
106   };
107 
108   // The function at the top of the stack. In other words, the function
109   // that is currently being inspected for placement.
110   string current_function_name_;
111 
112   // The stack of frames that got the placement to the current_function_name_.
113   // frames_[0].function_name is the top function that Placer was constructed
114   // with. frames_[0].function_name can be empty if placer was constructed with
115   // a nameless graph, not a function.  frames_[0].node_name is a name of a node
116   // in frames_[0].function_name that required deep inspection (e.g. a
117   // PartitionedCallOp). The function that this node invoked is
118   // frames_[1].function_name, if frames_.size() > 1.  Else, the function that
119   // this node invoked is current_function_name_.
120   std::vector<Frame> frames_;
121 };
122 
123 // Adds Identities for each input and output of function-calling ops in `graph`
124 //
125 // For example, the following graph calling a function on inputs `a` and `b`
126 // and producing output `y` will be rewritten to include identities on all
127 // edges:
128 //
129 //      a             b
130 //      |             |
131 //      v             v
132 //    f (PartitionedCallOp)
133 //         |
134 //         v
135 //         y
136 //
137 // is transformed to
138 //
139 //      a             b
140 //      |             |
141 //  a_f (Identity)   b_f (Identity)
142 //      |             |
143 //      v             v
144 //    f (PartitionedCallOp)
145 //         |
146 //      f_y (Identity)
147 //         |
148 //         v
149 //         y
150 //
151 Status IsolatePlacerInspectionRequiredOps(
152     const FunctionLibraryDefinition& flib_def, Graph* graph);
153 
154 }  // namespace tensorflow
155 
156 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
157