xref: /aosp_15_r20/external/tensorflow/tensorflow/core/platform/cloud/google_auth_provider.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
17 #ifndef _WIN32
18 #include <pwd.h>
19 #include <unistd.h>
20 #else
21 #include <sys/types.h>
22 #endif
23 #include <fstream>
24 #include <utility>
25 
26 #include "absl/strings/match.h"
27 #include "json/json.h"
28 #include "tensorflow/core/platform/base64.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/path.h"
32 #include "tensorflow/core/platform/retrying_utils.h"
33 
34 namespace tensorflow {
35 
36 namespace {
37 
38 // The environment variable pointing to the file with local
39 // Application Default Credentials.
40 constexpr char kGoogleApplicationCredentials[] =
41     "GOOGLE_APPLICATION_CREDENTIALS";
42 
43 // The environment variable to override token generation for testing.
44 constexpr char kGoogleAuthTokenForTesting[] = "GOOGLE_AUTH_TOKEN_FOR_TESTING";
45 
46 // The environment variable which can override '~/.config/gcloud' if set.
47 constexpr char kCloudSdkConfig[] = "CLOUDSDK_CONFIG";
48 
49 // The environment variable used to skip attempting to fetch GCE credentials:
50 // setting this to 'true' (case insensitive) will skip attempting to contact
51 // the GCE metadata service.
52 constexpr char kNoGceCheck[] = "NO_GCE_CHECK";
53 
54 // The default path to the gcloud config folder, relative to the home folder.
55 constexpr char kGCloudConfigFolder[] = ".config/gcloud/";
56 
57 // The name of the well-known credentials JSON file in the gcloud config folder.
58 constexpr char kWellKnownCredentialsFile[] =
59     "application_default_credentials.json";
60 
61 // The minimum time delta between now and the token expiration time
62 // for the token to be re-used.
63 constexpr int kExpirationTimeMarginSec = 60;
64 
65 // The URL to retrieve the auth bearer token via OAuth with a refresh token.
66 constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token";
67 
68 // The URL to retrieve the auth bearer token via OAuth with a private key.
69 constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";
70 
71 // The URL to retrieve the auth bearer token when running in Google Compute
72 // Engine.
73 constexpr char kGceTokenPath[] = "instance/service-accounts/default/token";
74 
75 // The authentication token scope to request.
76 constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";
77 
78 /// Returns whether the given path points to a readable file.
IsFile(const string & filename)79 bool IsFile(const string& filename) {
80   std::ifstream fstream(filename.c_str());
81   return fstream.good();
82 }
83 
84 /// Returns the credentials file name from the env variable.
GetEnvironmentVariableFileName(string * filename)85 Status GetEnvironmentVariableFileName(string* filename) {
86   if (!filename) {
87     return errors::FailedPrecondition("'filename' cannot be nullptr.");
88   }
89   const char* result = std::getenv(kGoogleApplicationCredentials);
90   if (!result || !IsFile(result)) {
91     return errors::NotFound(strings::StrCat("$", kGoogleApplicationCredentials,
92                                             " is not set or corrupt."));
93   }
94   *filename = result;
95   return OkStatus();
96 }
97 
98 /// Returns the well known file produced by command 'gcloud auth login'.
GetWellKnownFileName(string * filename)99 Status GetWellKnownFileName(string* filename) {
100   if (!filename) {
101     return errors::FailedPrecondition("'filename' cannot be nullptr.");
102   }
103   string config_dir;
104   const char* config_dir_override = std::getenv(kCloudSdkConfig);
105   if (config_dir_override) {
106     config_dir = config_dir_override;
107   } else {
108     // Determine the home dir path.
109     const char* home_dir = std::getenv("HOME");
110     if (!home_dir) {
111       return errors::FailedPrecondition("Could not read $HOME.");
112     }
113     config_dir = io::JoinPath(home_dir, kGCloudConfigFolder);
114   }
115   auto result = io::JoinPath(config_dir, kWellKnownCredentialsFile);
116   if (!IsFile(result)) {
117     return errors::NotFound(
118         "Could not find the credentials file in the standard gcloud location.");
119   }
120   *filename = result;
121   return OkStatus();
122 }
123 
124 }  // namespace
125 
GoogleAuthProvider(std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client)126 GoogleAuthProvider::GoogleAuthProvider(
127     std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client)
128     : GoogleAuthProvider(std::unique_ptr<OAuthClient>(new OAuthClient()),
129                          std::move(compute_engine_metadata_client),
130                          Env::Default()) {}
131 
GoogleAuthProvider(std::unique_ptr<OAuthClient> oauth_client,std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client,Env * env)132 GoogleAuthProvider::GoogleAuthProvider(
133     std::unique_ptr<OAuthClient> oauth_client,
134     std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client,
135     Env* env)
136     : oauth_client_(std::move(oauth_client)),
137       compute_engine_metadata_client_(
138           std::move(compute_engine_metadata_client)),
139       env_(env) {}
140 
GetToken(string * t)141 Status GoogleAuthProvider::GetToken(string* t) {
142   mutex_lock lock(mu_);
143   const uint64 now_sec = env_->NowSeconds();
144 
145   if (now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
146     *t = current_token_;
147     return OkStatus();
148   }
149 
150   if (GetTokenForTesting().ok()) {
151     *t = current_token_;
152     return OkStatus();
153   }
154 
155   auto token_from_files_status = GetTokenFromFiles();
156   if (token_from_files_status.ok()) {
157     *t = current_token_;
158     return OkStatus();
159   }
160 
161   char* no_gce_check_var = std::getenv(kNoGceCheck);
162   bool skip_gce_check = no_gce_check_var != nullptr &&
163                         absl::EqualsIgnoreCase(no_gce_check_var, "true");
164   Status token_from_gce_status;
165   if (skip_gce_check) {
166     token_from_gce_status =
167         Status(error::CANCELLED,
168                strings::StrCat("GCE check skipped due to presence of $",
169                                kNoGceCheck, " environment variable."));
170   } else {
171     token_from_gce_status = GetTokenFromGce();
172   }
173 
174   if (token_from_gce_status.ok()) {
175     *t = current_token_;
176     return OkStatus();
177   }
178 
179   if (skip_gce_check) {
180     LOG(INFO)
181         << "Attempting an empty bearer token since no token was retrieved "
182         << "from files, and GCE metadata check was skipped.";
183   } else {
184     LOG(WARNING)
185         << "All attempts to get a Google authentication bearer token failed, "
186         << "returning an empty token. Retrieving token from files failed with "
187            "\""
188         << token_from_files_status.ToString() << "\"."
189         << " Retrieving token from GCE failed with \""
190         << token_from_gce_status.ToString() << "\".";
191   }
192 
193   // Public objects can still be accessed with an empty bearer token,
194   // so return an empty token instead of failing.
195   *t = "";
196 
197   // We only want to keep returning our empty token if we've tried and failed
198   // the (potentially slow) task of detecting GCE.
199   if (skip_gce_check) {
200     expiration_timestamp_sec_ = 0;
201   } else {
202     expiration_timestamp_sec_ = UINT64_MAX;
203   }
204   current_token_ = "";
205 
206   return OkStatus();
207 }
208 
GetTokenFromFiles()209 Status GoogleAuthProvider::GetTokenFromFiles() {
210   string credentials_filename;
211   if (!GetEnvironmentVariableFileName(&credentials_filename).ok() &&
212       !GetWellKnownFileName(&credentials_filename).ok()) {
213     return errors::NotFound("Could not locate the credentials file.");
214   }
215 
216   Json::Value json;
217   Json::Reader reader;
218   std::ifstream credentials_fstream(credentials_filename);
219   if (!reader.parse(credentials_fstream, json)) {
220     return errors::FailedPrecondition(
221         "Couldn't parse the JSON credentials file.");
222   }
223   if (json.isMember("refresh_token")) {
224     TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson(
225         json, kOAuthV3Url, &current_token_, &expiration_timestamp_sec_));
226   } else if (json.isMember("private_key")) {
227     TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson(
228         json, kOAuthV4Url, kOAuthScope, &current_token_,
229         &expiration_timestamp_sec_));
230   } else {
231     return errors::FailedPrecondition(
232         "Unexpected content of the JSON credentials file.");
233   }
234   return OkStatus();
235 }
236 
GetTokenFromGce()237 Status GoogleAuthProvider::GetTokenFromGce() {
238   std::vector<char> response_buffer;
239   const uint64 request_timestamp_sec = env_->NowSeconds();
240 
241   TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata(
242       kGceTokenPath, &response_buffer));
243   StringPiece response =
244       StringPiece(&response_buffer[0], response_buffer.size());
245 
246   TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse(
247       response, request_timestamp_sec, &current_token_,
248       &expiration_timestamp_sec_));
249 
250   return OkStatus();
251 }
252 
GetTokenForTesting()253 Status GoogleAuthProvider::GetTokenForTesting() {
254   const char* token = std::getenv(kGoogleAuthTokenForTesting);
255   if (!token) {
256     return errors::NotFound("The env variable for testing was not set.");
257   }
258   expiration_timestamp_sec_ = UINT64_MAX;
259   current_token_ = token;
260   return OkStatus();
261 }
262 
263 }  // namespace tensorflow
264