1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The full pipline of converting jax random include 2 steps.
17 // 1. Rename the jax random functions to tflite wrapped functions with the aid
18 // of "jax.named_call". For example, in the dumped hlo, the
19 // jax.random.uniform will have name "tfl_wrapped_jax_random_uniform".
20 // 2. Replace the body of "tfl_wrapped_jax_random_uniform" and
21 // "tfl_wrapped_jax_random_normal" with tfl.CustomOp("RandomUniform") and
22 // tfl.CustomOp("RandomStandardNormal"), respectively.
23
24 #include <string>
25
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Debug.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
32 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
33 #include "mlir/IR/Attributes.h" // from @llvm-project
34 #include "mlir/IR/Block.h" // from @llvm-project
35 #include "mlir/IR/Builders.h" // from @llvm-project
36 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
37 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
40 #include "mlir/IR/MLIRContext.h" // from @llvm-project
41 #include "mlir/IR/OperationSupport.h" // from @llvm-project
42 #include "mlir/IR/PatternMatch.h" // from @llvm-project
43 #include "mlir/IR/Region.h" // from @llvm-project
44 #include "mlir/IR/TypeRange.h" // from @llvm-project
45 #include "mlir/IR/Visitors.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Support/LLVM.h" // from @llvm-project
48 #include "mlir/Support/LogicalResult.h" // from @llvm-project
49 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
53 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
54
55 namespace mlir {
56 namespace TFL {
57 namespace {
58 #define GEN_PASS_CLASSES
59 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
60
61 struct LegalizeJaxRandomPass
62 : public LegalizeJaxRandomPassBase<LegalizeJaxRandomPass> {
63 public:
64 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeJaxRandomPass)
65
66 void runOnOperation() override;
67 };
68
CustomOption(ImplicitLocOpBuilder * builder,const std::string & content)69 inline ConstBytesAttr CustomOption(ImplicitLocOpBuilder *builder,
70 const std::string &content) {
71 return ConstBytesAttr::get(builder->getContext(),
72 StringRef(content.data(), content.size()));
73 }
74
IsJaxRandomUniform(mlir::func::FuncOp func)75 inline bool IsJaxRandomUniform(mlir::func::FuncOp func) {
76 return func.getName().contains("tfl_wrapped_jax_random_uniform");
77 }
78
IsJaxRandomNormal(mlir::func::FuncOp func)79 inline bool IsJaxRandomNormal(mlir::func::FuncOp func) {
80 return func.getName().contains("tfl_wrapped_jax_random_normal");
81 }
82
runOnOperation()83 void LegalizeJaxRandomPass::runOnOperation() {
84 auto func = getOperation();
85 if (!IsJaxRandomUniform(func) && !IsJaxRandomNormal(func)) return;
86 auto result_tuple_ty =
87 func.getFunctionType().getResult(0).dyn_cast_or_null<TupleType>();
88 if (!result_tuple_ty) return;
89 if (result_tuple_ty.size() != 1) return;
90 auto result_ty = result_tuple_ty.getType(0).dyn_cast<ShapedType>();
91
92 func.eraseBody();
93 func.addEntryBlock();
94 ImplicitLocOpBuilder builder(func.getLoc(), func.getBody());
95 llvm::SmallVector<int32_t> result_shape_i32;
96 auto result_shape = result_ty.getShape();
97 for (auto element : result_shape) {
98 result_shape_i32.push_back(static_cast<int32_t>(element));
99 }
100 auto result_shape_attr = builder.getI32TensorAttr(result_shape_i32);
101 Value result_shape_tensor =
102 builder.create<mhlo::ConstantOp>(result_shape_attr);
103 auto custom_code =
104 IsJaxRandomUniform(func) ? "RandomUniform" : "RandomStandardNormal";
105
106 llvm::SmallVector<Type> result_ty_vec({result_ty});
107 llvm::SmallVector<Value> result_shape_tensor_vec({result_shape_tensor});
108 auto attr = CustomOption(&builder, "");
109 Value random_result =
110 builder
111 .create<TFL::CustomOp>(TypeRange(result_ty_vec),
112 ValueRange(result_shape_tensor_vec),
113 custom_code, attr)
114 .getResult(0);
115 Value tulple_result = builder.create<mhlo::TupleOp>(random_result);
116 builder.create<mlir::func::ReturnOp>(tulple_result);
117 }
118 } // namespace
119
CreateLegalizeJaxRandomPass()120 std::unique_ptr<OperationPass<func::FuncOp>> CreateLegalizeJaxRandomPass() {
121 return std::make_unique<LegalizeJaxRandomPass>();
122 }
123
124 } // namespace TFL
125 } // namespace mlir
126