1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
17
18 #include <stdlib.h>
19
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/cloud/http_request_fake.h"
22 #include "tensorflow/core/platform/path.h"
23 #include "tensorflow/core/platform/resource_loader.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27
28 namespace {
29
TestData()30 string TestData() {
31 return io::JoinPath("tensorflow", "core", "platform", "cloud", "testdata");
32 }
33
34 class FakeEnv : public EnvWrapper {
35 public:
FakeEnv()36 FakeEnv() : EnvWrapper(Env::Default()) {}
37
NowSeconds() const38 uint64 NowSeconds() const override { return now; }
39 uint64 now = 10000;
40 };
41
42 class FakeOAuthClient : public OAuthClient {
43 public:
GetTokenFromServiceAccountJson(Json::Value json,StringPiece oauth_server_uri,StringPiece scope,string * token,uint64 * expiration_timestamp_sec)44 Status GetTokenFromServiceAccountJson(
45 Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
46 string* token, uint64* expiration_timestamp_sec) override {
47 provided_credentials_json = json;
48 *token = return_token;
49 *expiration_timestamp_sec = return_expiration_timestamp;
50 return OkStatus();
51 }
52
53 /// Retrieves a bearer token using a refresh token.
GetTokenFromRefreshTokenJson(Json::Value json,StringPiece oauth_server_uri,string * token,uint64 * expiration_timestamp_sec)54 Status GetTokenFromRefreshTokenJson(
55 Json::Value json, StringPiece oauth_server_uri, string* token,
56 uint64* expiration_timestamp_sec) override {
57 provided_credentials_json = json;
58 *token = return_token;
59 *expiration_timestamp_sec = return_expiration_timestamp;
60 return OkStatus();
61 }
62
63 string return_token;
64 uint64 return_expiration_timestamp;
65 Json::Value provided_credentials_json;
66 };
67
68 } // namespace
69
70 class GoogleAuthProviderTest : public ::testing::Test {
71 protected:
SetUp()72 void SetUp() override { ClearEnvVars(); }
73
TearDown()74 void TearDown() override { ClearEnvVars(); }
75
ClearEnvVars()76 void ClearEnvVars() {
77 unsetenv("CLOUDSDK_CONFIG");
78 unsetenv("GOOGLE_APPLICATION_CREDENTIALS");
79 unsetenv("GOOGLE_AUTH_TOKEN_FOR_TESTING");
80 unsetenv("NO_GCE_CHECK");
81 }
82 };
83
TEST_F(GoogleAuthProviderTest,EnvironmentVariable_Caching)84 TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) {
85 setenv("GOOGLE_APPLICATION_CREDENTIALS",
86 GetDataDependencyFilepath(
87 io::JoinPath(TestData(), "service_account_credentials.json"))
88 .c_str(),
89 1);
90 setenv("CLOUDSDK_CONFIG", GetDataDependencyFilepath(TestData()).c_str(),
91 1); // Will not be used.
92
93 auto oauth_client = new FakeOAuthClient;
94 std::vector<HttpRequest*> requests;
95
96 FakeEnv env;
97
98 std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
99 std::make_shared<FakeHttpRequestFactory>(&requests);
100 auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
101 fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
102 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
103 metadataClient, &env);
104 oauth_client->return_token = "fake-token";
105 oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
106
107 string token;
108 TF_EXPECT_OK(provider.GetToken(&token));
109 EXPECT_EQ("fake-token", token);
110 EXPECT_EQ("fake_key_id",
111 oauth_client->provided_credentials_json.get("private_key_id", "")
112 .asString());
113
114 // Check that the token is re-used if not expired.
115 oauth_client->return_token = "new-fake-token";
116 env.now += 3000;
117 TF_EXPECT_OK(provider.GetToken(&token));
118 EXPECT_EQ("fake-token", token);
119
120 // Check that the token is re-generated when almost expired.
121 env.now += 598; // 2 seconds before expiration
122 TF_EXPECT_OK(provider.GetToken(&token));
123 EXPECT_EQ("new-fake-token", token);
124 }
125
TEST_F(GoogleAuthProviderTest,GCloudRefreshToken)126 TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) {
127 setenv("CLOUDSDK_CONFIG", GetDataDependencyFilepath(TestData()).c_str(), 1);
128
129 auto oauth_client = new FakeOAuthClient;
130 std::vector<HttpRequest*> requests;
131
132 FakeEnv env;
133 std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
134 std::make_shared<FakeHttpRequestFactory>(&requests);
135 auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
136 fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
137
138 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
139 metadataClient, &env);
140 oauth_client->return_token = "fake-token";
141 oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
142
143 string token;
144 TF_EXPECT_OK(provider.GetToken(&token));
145 EXPECT_EQ("fake-token", token);
146 EXPECT_EQ("fake-refresh-token",
147 oauth_client->provided_credentials_json.get("refresh_token", "")
148 .asString());
149 }
150
TEST_F(GoogleAuthProviderTest,RunningOnGCE)151 TEST_F(GoogleAuthProviderTest, RunningOnGCE) {
152 auto oauth_client = new FakeOAuthClient;
153 std::vector<HttpRequest*> requests(
154 {new FakeHttpRequest(
155 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
156 "/default/token\n"
157 "Header Metadata-Flavor: Google\n",
158 R"(
159 {
160 "access_token":"fake-gce-token",
161 "expires_in": 3920,
162 "token_type":"Bearer"
163 })"),
164 // The first token refresh request fails and will be retried.
165 new FakeHttpRequest(
166 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
167 "/default/token\n"
168 "Header Metadata-Flavor: Google\n",
169 "", errors::Unavailable("503"), 503),
170 new FakeHttpRequest(
171 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
172 "/default/token\n"
173 "Header Metadata-Flavor: Google\n",
174 R"(
175 {
176 "access_token":"new-fake-gce-token",
177 "expires_in": 3920,
178 "token_type":"Bearer"
179 })")});
180
181 FakeEnv env;
182 std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
183 std::make_shared<FakeHttpRequestFactory>(&requests);
184 auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
185 fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
186 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
187 metadataClient, &env);
188
189 string token;
190 TF_EXPECT_OK(provider.GetToken(&token));
191 EXPECT_EQ("fake-gce-token", token);
192
193 // Check that the token is re-used if not expired.
194 env.now += 3700;
195 TF_EXPECT_OK(provider.GetToken(&token));
196 EXPECT_EQ("fake-gce-token", token);
197
198 // Check that the token is re-generated when almost expired.
199 env.now += 598; // 2 seconds before expiration
200 TF_EXPECT_OK(provider.GetToken(&token));
201 EXPECT_EQ("new-fake-gce-token", token);
202 }
203
TEST_F(GoogleAuthProviderTest,OverrideForTesting)204 TEST_F(GoogleAuthProviderTest, OverrideForTesting) {
205 setenv("GOOGLE_AUTH_TOKEN_FOR_TESTING", "tokenForTesting", 1);
206
207 auto oauth_client = new FakeOAuthClient;
208 std::vector<HttpRequest*> empty_requests;
209 FakeEnv env;
210 std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
211 std::make_shared<FakeHttpRequestFactory>(&empty_requests);
212 auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
213 fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
214 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
215 metadataClient, &env);
216
217 string token;
218 TF_EXPECT_OK(provider.GetToken(&token));
219 EXPECT_EQ("tokenForTesting", token);
220 }
221
TEST_F(GoogleAuthProviderTest,NothingAvailable)222 TEST_F(GoogleAuthProviderTest, NothingAvailable) {
223 auto oauth_client = new FakeOAuthClient;
224
225 std::vector<HttpRequest*> requests({new FakeHttpRequest(
226 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
227 "/default/token\n"
228 "Header Metadata-Flavor: Google\n",
229 "", errors::NotFound("404"), 404)});
230
231 FakeEnv env;
232 std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
233 std::make_shared<FakeHttpRequestFactory>(&requests);
234 auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
235 fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
236 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
237 metadataClient, &env);
238
239 string token;
240 TF_EXPECT_OK(provider.GetToken(&token));
241 EXPECT_EQ("", token);
242 }
243
TEST_F(GoogleAuthProviderTest,NoGceCheckEnvironmentVariable)244 TEST_F(GoogleAuthProviderTest, NoGceCheckEnvironmentVariable) {
245 setenv("NO_GCE_CHECK", "True", 1);
246 auto oauth_client = new FakeOAuthClient;
247
248 FakeEnv env;
249 // If the env var above isn't respected, attempting to fetch a token
250 // from GCE will segfault (as the metadata client is null).
251 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
252 nullptr, &env);
253
254 string token;
255 TF_EXPECT_OK(provider.GetToken(&token));
256 EXPECT_EQ("", token);
257
258 // We confirm that our env var is case insensitive.
259 setenv("NO_GCE_CHECK", "true", 1);
260 TF_EXPECT_OK(provider.GetToken(&token));
261 EXPECT_EQ("", token);
262
263 // We also want to confirm that our empty token has a short expiration set: we
264 // now set a testing token, and confirm that it's returned instead of our
265 // empty token.
266 setenv("GOOGLE_AUTH_TOKEN_FOR_TESTING", "newToken", 1);
267 TF_EXPECT_OK(provider.GetToken(&token));
268 EXPECT_EQ("newToken", token);
269 }
270
271 } // namespace tensorflow
272