xref: /aosp_15_r20/external/tensorflow/tensorflow/core/platform/cloud/google_auth_provider_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
17 
18 #include <stdlib.h>
19 
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/cloud/http_request_fake.h"
22 #include "tensorflow/core/platform/path.h"
23 #include "tensorflow/core/platform/resource_loader.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 
TestData()30 string TestData() {
31   return io::JoinPath("tensorflow", "core", "platform", "cloud", "testdata");
32 }
33 
34 class FakeEnv : public EnvWrapper {
35  public:
FakeEnv()36   FakeEnv() : EnvWrapper(Env::Default()) {}
37 
NowSeconds() const38   uint64 NowSeconds() const override { return now; }
39   uint64 now = 10000;
40 };
41 
42 class FakeOAuthClient : public OAuthClient {
43  public:
GetTokenFromServiceAccountJson(Json::Value json,StringPiece oauth_server_uri,StringPiece scope,string * token,uint64 * expiration_timestamp_sec)44   Status GetTokenFromServiceAccountJson(
45       Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
46       string* token, uint64* expiration_timestamp_sec) override {
47     provided_credentials_json = json;
48     *token = return_token;
49     *expiration_timestamp_sec = return_expiration_timestamp;
50     return OkStatus();
51   }
52 
53   /// Retrieves a bearer token using a refresh token.
GetTokenFromRefreshTokenJson(Json::Value json,StringPiece oauth_server_uri,string * token,uint64 * expiration_timestamp_sec)54   Status GetTokenFromRefreshTokenJson(
55       Json::Value json, StringPiece oauth_server_uri, string* token,
56       uint64* expiration_timestamp_sec) override {
57     provided_credentials_json = json;
58     *token = return_token;
59     *expiration_timestamp_sec = return_expiration_timestamp;
60     return OkStatus();
61   }
62 
63   string return_token;
64   uint64 return_expiration_timestamp;
65   Json::Value provided_credentials_json;
66 };
67 
68 }  // namespace
69 
70 class GoogleAuthProviderTest : public ::testing::Test {
71  protected:
SetUp()72   void SetUp() override { ClearEnvVars(); }
73 
TearDown()74   void TearDown() override { ClearEnvVars(); }
75 
ClearEnvVars()76   void ClearEnvVars() {
77     unsetenv("CLOUDSDK_CONFIG");
78     unsetenv("GOOGLE_APPLICATION_CREDENTIALS");
79     unsetenv("GOOGLE_AUTH_TOKEN_FOR_TESTING");
80     unsetenv("NO_GCE_CHECK");
81   }
82 };
83 
TEST_F(GoogleAuthProviderTest,EnvironmentVariable_Caching)84 TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) {
85   setenv("GOOGLE_APPLICATION_CREDENTIALS",
86          GetDataDependencyFilepath(
87              io::JoinPath(TestData(), "service_account_credentials.json"))
88              .c_str(),
89          1);
90   setenv("CLOUDSDK_CONFIG", GetDataDependencyFilepath(TestData()).c_str(),
91          1);  // Will not be used.
92 
93   auto oauth_client = new FakeOAuthClient;
94   std::vector<HttpRequest*> requests;
95 
96   FakeEnv env;
97 
98   std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
99       std::make_shared<FakeHttpRequestFactory>(&requests);
100   auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
101       fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
102   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
103                               metadataClient, &env);
104   oauth_client->return_token = "fake-token";
105   oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
106 
107   string token;
108   TF_EXPECT_OK(provider.GetToken(&token));
109   EXPECT_EQ("fake-token", token);
110   EXPECT_EQ("fake_key_id",
111             oauth_client->provided_credentials_json.get("private_key_id", "")
112                 .asString());
113 
114   // Check that the token is re-used if not expired.
115   oauth_client->return_token = "new-fake-token";
116   env.now += 3000;
117   TF_EXPECT_OK(provider.GetToken(&token));
118   EXPECT_EQ("fake-token", token);
119 
120   // Check that the token is re-generated when almost expired.
121   env.now += 598;  // 2 seconds before expiration
122   TF_EXPECT_OK(provider.GetToken(&token));
123   EXPECT_EQ("new-fake-token", token);
124 }
125 
TEST_F(GoogleAuthProviderTest,GCloudRefreshToken)126 TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) {
127   setenv("CLOUDSDK_CONFIG", GetDataDependencyFilepath(TestData()).c_str(), 1);
128 
129   auto oauth_client = new FakeOAuthClient;
130   std::vector<HttpRequest*> requests;
131 
132   FakeEnv env;
133   std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
134       std::make_shared<FakeHttpRequestFactory>(&requests);
135   auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
136       fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
137 
138   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
139                               metadataClient, &env);
140   oauth_client->return_token = "fake-token";
141   oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
142 
143   string token;
144   TF_EXPECT_OK(provider.GetToken(&token));
145   EXPECT_EQ("fake-token", token);
146   EXPECT_EQ("fake-refresh-token",
147             oauth_client->provided_credentials_json.get("refresh_token", "")
148                 .asString());
149 }
150 
TEST_F(GoogleAuthProviderTest,RunningOnGCE)151 TEST_F(GoogleAuthProviderTest, RunningOnGCE) {
152   auto oauth_client = new FakeOAuthClient;
153   std::vector<HttpRequest*> requests(
154       {new FakeHttpRequest(
155            "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
156            "/default/token\n"
157            "Header Metadata-Flavor: Google\n",
158            R"(
159           {
160             "access_token":"fake-gce-token",
161             "expires_in": 3920,
162             "token_type":"Bearer"
163           })"),
164        // The first token refresh request fails and will be retried.
165        new FakeHttpRequest(
166            "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
167            "/default/token\n"
168            "Header Metadata-Flavor: Google\n",
169            "", errors::Unavailable("503"), 503),
170        new FakeHttpRequest(
171            "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
172            "/default/token\n"
173            "Header Metadata-Flavor: Google\n",
174            R"(
175               {
176                 "access_token":"new-fake-gce-token",
177                 "expires_in": 3920,
178                 "token_type":"Bearer"
179               })")});
180 
181   FakeEnv env;
182   std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
183       std::make_shared<FakeHttpRequestFactory>(&requests);
184   auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
185       fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
186   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
187                               metadataClient, &env);
188 
189   string token;
190   TF_EXPECT_OK(provider.GetToken(&token));
191   EXPECT_EQ("fake-gce-token", token);
192 
193   // Check that the token is re-used if not expired.
194   env.now += 3700;
195   TF_EXPECT_OK(provider.GetToken(&token));
196   EXPECT_EQ("fake-gce-token", token);
197 
198   // Check that the token is re-generated when almost expired.
199   env.now += 598;  // 2 seconds before expiration
200   TF_EXPECT_OK(provider.GetToken(&token));
201   EXPECT_EQ("new-fake-gce-token", token);
202 }
203 
TEST_F(GoogleAuthProviderTest,OverrideForTesting)204 TEST_F(GoogleAuthProviderTest, OverrideForTesting) {
205   setenv("GOOGLE_AUTH_TOKEN_FOR_TESTING", "tokenForTesting", 1);
206 
207   auto oauth_client = new FakeOAuthClient;
208   std::vector<HttpRequest*> empty_requests;
209   FakeEnv env;
210   std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
211       std::make_shared<FakeHttpRequestFactory>(&empty_requests);
212   auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
213       fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
214   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
215                               metadataClient, &env);
216 
217   string token;
218   TF_EXPECT_OK(provider.GetToken(&token));
219   EXPECT_EQ("tokenForTesting", token);
220 }
221 
TEST_F(GoogleAuthProviderTest,NothingAvailable)222 TEST_F(GoogleAuthProviderTest, NothingAvailable) {
223   auto oauth_client = new FakeOAuthClient;
224 
225   std::vector<HttpRequest*> requests({new FakeHttpRequest(
226       "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
227       "/default/token\n"
228       "Header Metadata-Flavor: Google\n",
229       "", errors::NotFound("404"), 404)});
230 
231   FakeEnv env;
232   std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
233       std::make_shared<FakeHttpRequestFactory>(&requests);
234   auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
235       fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
236   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
237                               metadataClient, &env);
238 
239   string token;
240   TF_EXPECT_OK(provider.GetToken(&token));
241   EXPECT_EQ("", token);
242 }
243 
TEST_F(GoogleAuthProviderTest,NoGceCheckEnvironmentVariable)244 TEST_F(GoogleAuthProviderTest, NoGceCheckEnvironmentVariable) {
245   setenv("NO_GCE_CHECK", "True", 1);
246   auto oauth_client = new FakeOAuthClient;
247 
248   FakeEnv env;
249   // If the env var above isn't respected, attempting to fetch a token
250   // from GCE will segfault (as the metadata client is null).
251   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
252                               nullptr, &env);
253 
254   string token;
255   TF_EXPECT_OK(provider.GetToken(&token));
256   EXPECT_EQ("", token);
257 
258   // We confirm that our env var is case insensitive.
259   setenv("NO_GCE_CHECK", "true", 1);
260   TF_EXPECT_OK(provider.GetToken(&token));
261   EXPECT_EQ("", token);
262 
263   // We also want to confirm that our empty token has a short expiration set: we
264   // now set a testing token, and confirm that it's returned instead of our
265   // empty token.
266   setenv("GOOGLE_AUTH_TOKEN_FOR_TESTING", "newToken", 1);
267   TF_EXPECT_OK(provider.GetToken(&token));
268   EXPECT_EQ("newToken", token);
269 }
270 
271 }  // namespace tensorflow
272