1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
17
18 #include "absl/strings/str_join.h"
19 #include "absl/strings/substitute.h"
20
21 namespace tensorflow {
22
InitializeFusedComputation(OpKernelConstruction * context,const string & kernel_name,const std::vector<FusedComputationPattern> & patterns,FusedComputationType * fused_computation,FusedComputationArgs * fused_computation_args)23 Status InitializeFusedComputation(
24 OpKernelConstruction* context, const string& kernel_name,
25 const std::vector<FusedComputationPattern>& patterns,
26 FusedComputationType* fused_computation,
27 FusedComputationArgs* fused_computation_args) {
28 // 'fused_ops' and 'num_args' attributes are specified by the Grappler
29 // Remapper optimizer (see grappler/optimizers/remapper.cc).
30
31 std::vector<string> fused_ops;
32 TF_RETURN_IF_ERROR(context->GetAttr("fused_ops", &fused_ops));
33 if (fused_ops.empty()) {
34 return errors::InvalidArgument("Fused ", kernel_name,
35 " must have at least one fused op.");
36 }
37
38 int num_args;
39 TF_RETURN_IF_ERROR(context->GetAttr("num_args", &num_args));
40
41 // TODO(ezhulenev): Add support for fusion element-wise op chains defined
42 // at runtime, e.g. Relu+Sqrt+Tanh+etc.
43
44 // Reset fused computation type.
45 *fused_computation = FusedComputationType::kUndefined;
46
47 // Match op fusion to one of the supported patterns.
48 for (const auto& pattern : patterns) {
49 if (fused_ops == pattern.fused_ops) {
50 *fused_computation = pattern.fused_computation;
51 break;
52 }
53 }
54 if (*fused_computation == FusedComputationType::kUndefined) {
55 return errors::Unimplemented("Fusion is not implemented: [",
56 absl::StrJoin(fused_ops, ","), "]");
57 }
58
59 // Depending on a picked fusion type validate fusion-specific arguments.
60 if (*fused_computation == FusedComputationType::kBiasAdd ||
61 *fused_computation == FusedComputationType::kBiasAddWithRelu ||
62 *fused_computation == FusedComputationType::kBiasAddWithRelu6 ||
63 *fused_computation == FusedComputationType::kBiasAddWithElu ||
64 *fused_computation == FusedComputationType::kBiasAddWithLeakyRelu ||
65 *fused_computation == FusedComputationType::kBiasAddWithGeluApproximate) {
66 if (num_args != 1) {
67 return errors::InvalidArgument(
68 "Fused ", kernel_name,
69 " with BiasAdd must have one extra argument: bias.");
70 }
71 if (*fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) {
72 TF_RETURN_IF_ERROR(context->GetAttr(
73 "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
74 }
75 }
76
77 if (*fused_computation == FusedComputationType::kFusedBatchNorm ||
78 *fused_computation == FusedComputationType::kFusedBatchNormWithRelu ||
79 *fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 ||
80 *fused_computation == FusedComputationType::kFusedBatchNormWithElu ||
81 *fused_computation ==
82 FusedComputationType::kFusedBatchNormWithLeakyRelu) {
83 if (num_args != 4) {
84 return errors::InvalidArgument(
85 "Fused ", kernel_name,
86 " with FusedBatchNorm must have four extra arguments: scale, offset, "
87 "mean, variance.");
88 }
89 TF_RETURN_IF_ERROR(
90 context->GetAttr("epsilon", &fused_computation_args->epsilon));
91 if (*fused_computation ==
92 FusedComputationType::kFusedBatchNormWithLeakyRelu) {
93 TF_RETURN_IF_ERROR(context->GetAttr(
94 "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
95 }
96 }
97
98 return OkStatus();
99 }
100
101 } // namespace tensorflow
102