xref: /aosp_15_r20/external/libtextclassifier/native/utils/regex-match.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/regex-match.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <memory>
20*993b0882SAndroid Build Coastguard Worker 
21*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
22*993b0882SAndroid Build Coastguard Worker 
23*993b0882SAndroid Build Coastguard Worker #ifndef TC3_DISABLE_LUA
24*993b0882SAndroid Build Coastguard Worker #include "utils/lua-utils.h"
25*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
26*993b0882SAndroid Build Coastguard Worker extern "C" {
27*993b0882SAndroid Build Coastguard Worker #endif
28*993b0882SAndroid Build Coastguard Worker #include "lauxlib.h"
29*993b0882SAndroid Build Coastguard Worker #include "lualib.h"
30*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
31*993b0882SAndroid Build Coastguard Worker }
32*993b0882SAndroid Build Coastguard Worker #endif
33*993b0882SAndroid Build Coastguard Worker #endif
34*993b0882SAndroid Build Coastguard Worker 
35*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
36*993b0882SAndroid Build Coastguard Worker namespace {
37*993b0882SAndroid Build Coastguard Worker 
38*993b0882SAndroid Build Coastguard Worker #ifndef TC3_DISABLE_LUA
39*993b0882SAndroid Build Coastguard Worker // Provide a lua environment for running regex match post verification.
40*993b0882SAndroid Build Coastguard Worker // It sets up and exposes the match data as well as the context.
41*993b0882SAndroid Build Coastguard Worker class LuaVerifier : public LuaEnvironment {
42*993b0882SAndroid Build Coastguard Worker  public:
43*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<LuaVerifier> Create(
44*993b0882SAndroid Build Coastguard Worker       const std::string& context, const std::string& verifier_code,
45*993b0882SAndroid Build Coastguard Worker       const UniLib::RegexMatcher* matcher);
46*993b0882SAndroid Build Coastguard Worker 
47*993b0882SAndroid Build Coastguard Worker   bool Verify(bool* result);
48*993b0882SAndroid Build Coastguard Worker 
49*993b0882SAndroid Build Coastguard Worker  private:
LuaVerifier(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)50*993b0882SAndroid Build Coastguard Worker   explicit LuaVerifier(const std::string& context,
51*993b0882SAndroid Build Coastguard Worker                        const std::string& verifier_code,
52*993b0882SAndroid Build Coastguard Worker                        const UniLib::RegexMatcher* matcher)
53*993b0882SAndroid Build Coastguard Worker       : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
54*993b0882SAndroid Build Coastguard Worker   bool Initialize();
55*993b0882SAndroid Build Coastguard Worker 
56*993b0882SAndroid Build Coastguard Worker   // Provides details of a capturing group to lua.
57*993b0882SAndroid Build Coastguard Worker   int GetCapturingGroup();
58*993b0882SAndroid Build Coastguard Worker 
59*993b0882SAndroid Build Coastguard Worker   const std::string& context_;
60*993b0882SAndroid Build Coastguard Worker   const std::string& verifier_code_;
61*993b0882SAndroid Build Coastguard Worker   const UniLib::RegexMatcher* matcher_;
62*993b0882SAndroid Build Coastguard Worker };
63*993b0882SAndroid Build Coastguard Worker 
Initialize()64*993b0882SAndroid Build Coastguard Worker bool LuaVerifier::Initialize() {
65*993b0882SAndroid Build Coastguard Worker   // Run protected to not lua panic in case of setup failure.
66*993b0882SAndroid Build Coastguard Worker   return RunProtected([this] {
67*993b0882SAndroid Build Coastguard Worker            LoadDefaultLibraries();
68*993b0882SAndroid Build Coastguard Worker 
69*993b0882SAndroid Build Coastguard Worker            // Expose context of the match as `context` global variable.
70*993b0882SAndroid Build Coastguard Worker            PushString(context_);
71*993b0882SAndroid Build Coastguard Worker            lua_setglobal(state_, "context");
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker            // Expose match array as `match` global variable.
74*993b0882SAndroid Build Coastguard Worker            // Each entry `match[i]` exposes the ith capturing group as:
75*993b0882SAndroid Build Coastguard Worker            //   * `begin`: span start
76*993b0882SAndroid Build Coastguard Worker            //   * `end`: span end
77*993b0882SAndroid Build Coastguard Worker            //   * `text`: the text
78*993b0882SAndroid Build Coastguard Worker            PushLazyObject(&LuaVerifier::GetCapturingGroup);
79*993b0882SAndroid Build Coastguard Worker            lua_setglobal(state_, "match");
80*993b0882SAndroid Build Coastguard Worker            return LUA_OK;
81*993b0882SAndroid Build Coastguard Worker          }) == LUA_OK;
82*993b0882SAndroid Build Coastguard Worker }
83*993b0882SAndroid Build Coastguard Worker 
Create(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)84*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LuaVerifier> LuaVerifier::Create(
85*993b0882SAndroid Build Coastguard Worker     const std::string& context, const std::string& verifier_code,
86*993b0882SAndroid Build Coastguard Worker     const UniLib::RegexMatcher* matcher) {
87*993b0882SAndroid Build Coastguard Worker   auto verifier = std::unique_ptr<LuaVerifier>(
88*993b0882SAndroid Build Coastguard Worker       new LuaVerifier(context, verifier_code, matcher));
89*993b0882SAndroid Build Coastguard Worker   if (!verifier->Initialize()) {
90*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not initialize lua environment.";
91*993b0882SAndroid Build Coastguard Worker     return nullptr;
92*993b0882SAndroid Build Coastguard Worker   }
93*993b0882SAndroid Build Coastguard Worker   return verifier;
94*993b0882SAndroid Build Coastguard Worker }
95*993b0882SAndroid Build Coastguard Worker 
GetCapturingGroup()96*993b0882SAndroid Build Coastguard Worker int LuaVerifier::GetCapturingGroup() {
97*993b0882SAndroid Build Coastguard Worker   if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
98*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
99*993b0882SAndroid Build Coastguard Worker                    << lua_type(state_, /*idx=*/-1);
100*993b0882SAndroid Build Coastguard Worker     lua_error(state_);
101*993b0882SAndroid Build Coastguard Worker     return 0;
102*993b0882SAndroid Build Coastguard Worker   }
103*993b0882SAndroid Build Coastguard Worker   const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
104*993b0882SAndroid Build Coastguard Worker   int status = UniLib::RegexMatcher::kNoError;
105*993b0882SAndroid Build Coastguard Worker   const CodepointSpan span = {matcher_->Start(group_id, &status),
106*993b0882SAndroid Build Coastguard Worker                               matcher_->End(group_id, &status)};
107*993b0882SAndroid Build Coastguard Worker   std::string text = matcher_->Group(group_id, &status).ToUTF8String();
108*993b0882SAndroid Build Coastguard Worker   if (status != UniLib::RegexMatcher::kNoError) {
109*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not extract span from capturing group.";
110*993b0882SAndroid Build Coastguard Worker     lua_error(state_);
111*993b0882SAndroid Build Coastguard Worker     return 0;
112*993b0882SAndroid Build Coastguard Worker   }
113*993b0882SAndroid Build Coastguard Worker   lua_newtable(state_);
114*993b0882SAndroid Build Coastguard Worker   lua_pushinteger(state_, span.first);
115*993b0882SAndroid Build Coastguard Worker   lua_setfield(state_, /*idx=*/-2, "begin");
116*993b0882SAndroid Build Coastguard Worker   lua_pushinteger(state_, span.second);
117*993b0882SAndroid Build Coastguard Worker   lua_setfield(state_, /*idx=*/-2, "end");
118*993b0882SAndroid Build Coastguard Worker   PushString(text);
119*993b0882SAndroid Build Coastguard Worker   lua_setfield(state_, /*idx=*/-2, "text");
120*993b0882SAndroid Build Coastguard Worker   return 1;
121*993b0882SAndroid Build Coastguard Worker }
122*993b0882SAndroid Build Coastguard Worker 
Verify(bool * result)123*993b0882SAndroid Build Coastguard Worker bool LuaVerifier::Verify(bool* result) {
124*993b0882SAndroid Build Coastguard Worker   if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
125*993b0882SAndroid Build Coastguard Worker                       /*name=*/nullptr) != LUA_OK) {
126*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not load verifier snippet.";
127*993b0882SAndroid Build Coastguard Worker     return false;
128*993b0882SAndroid Build Coastguard Worker   }
129*993b0882SAndroid Build Coastguard Worker 
130*993b0882SAndroid Build Coastguard Worker   if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
131*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not run verifier snippet.";
132*993b0882SAndroid Build Coastguard Worker     return false;
133*993b0882SAndroid Build Coastguard Worker   }
134*993b0882SAndroid Build Coastguard Worker 
135*993b0882SAndroid Build Coastguard Worker   if (RunProtected(
136*993b0882SAndroid Build Coastguard Worker           [this, result] {
137*993b0882SAndroid Build Coastguard Worker             if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
138*993b0882SAndroid Build Coastguard Worker               TC3_LOG(ERROR) << "Unexpected verification result type: "
139*993b0882SAndroid Build Coastguard Worker                              << lua_type(state_, /*idx=*/-1);
140*993b0882SAndroid Build Coastguard Worker               lua_error(state_);
141*993b0882SAndroid Build Coastguard Worker               return LUA_ERRRUN;
142*993b0882SAndroid Build Coastguard Worker             }
143*993b0882SAndroid Build Coastguard Worker             *result = lua_toboolean(state_, /*idx=*/-1);
144*993b0882SAndroid Build Coastguard Worker             return LUA_OK;
145*993b0882SAndroid Build Coastguard Worker           },
146*993b0882SAndroid Build Coastguard Worker           /*num_args=*/1) != LUA_OK) {
147*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not read lua result.";
148*993b0882SAndroid Build Coastguard Worker     return false;
149*993b0882SAndroid Build Coastguard Worker   }
150*993b0882SAndroid Build Coastguard Worker   return true;
151*993b0882SAndroid Build Coastguard Worker }
152*993b0882SAndroid Build Coastguard Worker #endif  // TC3_DISABLE_LUA
153*993b0882SAndroid Build Coastguard Worker 
154*993b0882SAndroid Build Coastguard Worker }  // namespace
155*993b0882SAndroid Build Coastguard Worker 
GetCapturingGroupText(const UniLib::RegexMatcher * matcher,const int group_id)156*993b0882SAndroid Build Coastguard Worker Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
157*993b0882SAndroid Build Coastguard Worker                                             const int group_id) {
158*993b0882SAndroid Build Coastguard Worker   int status = UniLib::RegexMatcher::kNoError;
159*993b0882SAndroid Build Coastguard Worker   std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
160*993b0882SAndroid Build Coastguard Worker   if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
161*993b0882SAndroid Build Coastguard Worker     return Optional<std::string>();
162*993b0882SAndroid Build Coastguard Worker   }
163*993b0882SAndroid Build Coastguard Worker   return Optional<std::string>(group_text);
164*993b0882SAndroid Build Coastguard Worker }
165*993b0882SAndroid Build Coastguard Worker 
VerifyMatch(const std::string & context,const UniLib::RegexMatcher * matcher,const std::string & lua_verifier_code)166*993b0882SAndroid Build Coastguard Worker bool VerifyMatch(const std::string& context,
167*993b0882SAndroid Build Coastguard Worker                  const UniLib::RegexMatcher* matcher,
168*993b0882SAndroid Build Coastguard Worker                  const std::string& lua_verifier_code) {
169*993b0882SAndroid Build Coastguard Worker   bool status = false;
170*993b0882SAndroid Build Coastguard Worker #ifndef TC3_DISABLE_LUA
171*993b0882SAndroid Build Coastguard Worker   auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
172*993b0882SAndroid Build Coastguard Worker   if (verifier == nullptr) {
173*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not create verifier.";
174*993b0882SAndroid Build Coastguard Worker     return false;
175*993b0882SAndroid Build Coastguard Worker   }
176*993b0882SAndroid Build Coastguard Worker   if (!verifier->Verify(&status)) {
177*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not create verifier.";
178*993b0882SAndroid Build Coastguard Worker     return false;
179*993b0882SAndroid Build Coastguard Worker   }
180*993b0882SAndroid Build Coastguard Worker #endif  // TC3_DISABLE_LUA
181*993b0882SAndroid Build Coastguard Worker   return status;
182*993b0882SAndroid Build Coastguard Worker }
183*993b0882SAndroid Build Coastguard Worker 
184*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
185