1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ 18 19 // Operations calling functions are becoming ubiquitous in TF 2.0. 20 // Examples include PartitionedCallOp, functional If/While, and Dataset ops. 21 // Such operations might require deep inspection - looking at the body of the 22 // called function - to place them and surrounding ops correctly. 23 24 // This file contains some utilities for placer to correctly place such ops 25 // including: 26 // - PlacerInspectionRequiredOpChecker: A simple class with a single 27 // IsPlacerInspectionRequired method. 28 // - IsolatePlacerInspectionRequiredOps: This function adds Identity ops for 29 // each input/output of ops requiring placer inspection. It greatly simplifies 30 // the implementation of placing such ops. 31 32 #include <vector> 33 34 #include "absl/types/optional.h" 35 #include "tensorflow/core/framework/function.h" 36 #include "tensorflow/core/graph/graph.h" 37 #include "tensorflow/core/lib/core/status.h" 38 39 namespace tensorflow { 40 41 // PlacerInspectionRequiredOpChecker allows one to check if Placer needs to 42 // look deeply into the op to place ops consuming the outputs correctly. 43 // 44 // It is a class instead of a standalone method because checking whether 45 // a function returns a resource takes non-trivial time and we cache the 46 // results. 47 class PlacerInspectionRequiredOpChecker { 48 public: 49 // Constructs a PlacerInspectionRequiredOpChecker for nodes of `graph`. 50 // The functions referenced by nodes in `graph` will be looked up in 51 // `flib_def` 52 PlacerInspectionRequiredOpChecker(const Graph* graph, 53 const FunctionLibraryDefinition* flib_def); 54 55 // If `node` is considered a deep op, sets `*is_deep` to true and returns 56 // Status::OK(). If an error occurs, returns that error, and the value of 57 // `*is_deep` is undefined. 58 // Currently, an op is considered deep, if it is a calling a function 59 // returning a resource. This definition is driven by Placer's need to 60 // look inside the op. 61 // REQUIRES: `node` is part of `graph` passed into constructor. 62 Status IsPlacerInspectionRequired(const Node& node, bool* is_deep); 63 64 private: 65 const Graph& graph_; 66 const FunctionLibraryDefinition& flib_def_; 67 // Indexed by the node id. 68 // If cache_[node_id] is empty, the deepness of the node with id `node_id` has 69 // not been computed yet. Else, it contains the value already computed. 70 std::vector<absl::optional<bool>> cache_; 71 }; 72 73 // Extracts `fdef` and `func` from `flib_def` for the function identified 74 // in "f" attribute of `node`. 75 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def, 76 const Node& node, const FunctionDef** fdef, 77 NameAttrList* func); 78 79 // The "call" stack of functions. 80 // Useful for better error messages as well as for detecting recursion. 81 // Stores references to graph nodes. These references must outlive this. 82 class FunctionStack { 83 public: 84 explicit FunctionStack(const string& function_name); 85 86 // `node_in_current_function` must outlive this. 87 FunctionStack Push(const Node* node_in_current_function, 88 const string& new_current_function) const; 89 90 // Returns true iff this stack already includes `function_name`. 91 bool HasFunction(const string& function_name) const; 92 current_function_name()93 const string& current_function_name() const { return current_function_name_; } 94 95 // Format's this suitable for error interpolation that retrieves 96 // Python files and line numbers. 97 string FormatForError() const; 98 99 private: 100 struct Frame { FrameFrame101 Frame(const string& function, const Node* node) 102 : function_name(function), node(node) {} 103 104 string function_name; 105 const Node* node; 106 }; 107 108 // The function at the top of the stack. In other words, the function 109 // that is currently being inspected for placement. 110 string current_function_name_; 111 112 // The stack of frames that got the placement to the current_function_name_. 113 // frames_[0].function_name is the top function that Placer was constructed 114 // with. frames_[0].function_name can be empty if placer was constructed with 115 // a nameless graph, not a function. frames_[0].node_name is a name of a node 116 // in frames_[0].function_name that required deep inspection (e.g. a 117 // PartitionedCallOp). The function that this node invoked is 118 // frames_[1].function_name, if frames_.size() > 1. Else, the function that 119 // this node invoked is current_function_name_. 120 std::vector<Frame> frames_; 121 }; 122 123 // Adds Identities for each input and output of function-calling ops in `graph` 124 // 125 // For example, the following graph calling a function on inputs `a` and `b` 126 // and producing output `y` will be rewritten to include identities on all 127 // edges: 128 // 129 // a b 130 // | | 131 // v v 132 // f (PartitionedCallOp) 133 // | 134 // v 135 // y 136 // 137 // is transformed to 138 // 139 // a b 140 // | | 141 // a_f (Identity) b_f (Identity) 142 // | | 143 // v v 144 // f (PartitionedCallOp) 145 // | 146 // f_y (Identity) 147 // | 148 // v 149 // y 150 // 151 Status IsolatePlacerInspectionRequiredOps( 152 const FunctionLibraryDefinition& flib_def, Graph* graph); 153 154 } // namespace tensorflow 155 156 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ 157