1#!/usr/bin/env python 2# 3# Copyright 2015 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Service registry for apitools.""" 18 19import collections 20import logging 21import re 22import textwrap 23 24from apitools.base.py import base_api 25from apitools.gen import util 26 27# We're a code generator. I don't care. 28# pylint:disable=too-many-statements 29 30_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+') 31 32 33class ServiceRegistry(object): 34 35 """Registry for service types.""" 36 37 def __init__(self, client_info, message_registry, 38 names, root_package, base_files_package, 39 unelidable_request_methods): 40 self.__client_info = client_info 41 self.__package = client_info.package 42 self.__names = names 43 self.__service_method_info_map = collections.OrderedDict() 44 self.__message_registry = message_registry 45 self.__root_package = root_package 46 self.__base_files_package = base_files_package 47 self.__unelidable_request_methods = unelidable_request_methods 48 self.__all_scopes = set(self.__client_info.scopes) 49 50 def Validate(self): 51 self.__message_registry.Validate() 52 53 @property 54 def scopes(self): 55 return sorted(list(self.__all_scopes)) 56 57 def __GetServiceClassName(self, service_name): 58 return self.__names.ClassName( 59 '%sService' % self.__names.ClassName(service_name)) 60 61 def __PrintDocstring(self, printer, method_info, method_name, name): 62 """Print a docstring for a service method.""" 63 if method_info.description: 64 description = util.CleanDescription(method_info.description) 65 first_line, newline, remaining = method_info.description.partition( 66 '\n') 67 if not first_line.endswith('.'): 68 first_line = '%s.' % first_line 69 description = '%s%s%s' % (first_line, newline, remaining) 70 else: 71 description = '%s method for the %s service.' % (method_name, name) 72 with printer.CommentContext(): 73 printer('r"""%s' % description) 74 printer() 75 printer('Args:') 76 printer(' request: (%s) input message', method_info.request_type_name) 77 printer(' global_params: (StandardQueryParameters, default: None) ' 78 'global arguments') 79 if method_info.upload_config: 80 printer(' upload: (Upload, default: None) If present, upload') 81 printer(' this stream with the request.') 82 if method_info.supports_download: 83 printer( 84 ' download: (Download, default: None) If present, download') 85 printer(' data from the request via this stream.') 86 printer('Returns:') 87 printer(' (%s) The response message.', method_info.response_type_name) 88 printer('"""') 89 90 def __WriteSingleService( 91 self, printer, name, method_info_map, client_class_name): 92 printer() 93 class_name = self.__GetServiceClassName(name) 94 printer('class %s(base_api.BaseApiService):', class_name) 95 with printer.Indent(): 96 printer('"""Service class for the %s resource."""', name) 97 printer() 98 printer('_NAME = %s', repr(name)) 99 100 # Print the configs for the methods first. 101 printer() 102 printer('def __init__(self, client):') 103 with printer.Indent(): 104 printer('super(%s.%s, self).__init__(client)', 105 client_class_name, class_name) 106 printer('self._upload_configs = {') 107 with printer.Indent(indent=' '): 108 for method_name, method_info in method_info_map.items(): 109 upload_config = method_info.upload_config 110 if upload_config is not None: 111 printer( 112 "'%s': base_api.ApiUploadInfo(", method_name) 113 with printer.Indent(indent=' '): 114 attrs = sorted( 115 x.name for x in upload_config.all_fields()) 116 for attr in attrs: 117 printer('%s=%r,', 118 attr, getattr(upload_config, attr)) 119 printer('),') 120 printer('}') 121 122 # Now write each method in turn. 123 for method_name, method_info in method_info_map.items(): 124 printer() 125 params = ['self', 'request', 'global_params=None'] 126 if method_info.upload_config: 127 params.append('upload=None') 128 if method_info.supports_download: 129 params.append('download=None') 130 printer('def %s(%s):', method_name, ', '.join(params)) 131 with printer.Indent(): 132 self.__PrintDocstring( 133 printer, method_info, method_name, name) 134 printer("config = self.GetMethodConfig('%s')", method_name) 135 upload_config = method_info.upload_config 136 if upload_config is not None: 137 printer("upload_config = self.GetUploadConfig('%s')", 138 method_name) 139 arg_lines = [ 140 'config, request, global_params=global_params'] 141 if method_info.upload_config: 142 arg_lines.append( 143 'upload=upload, upload_config=upload_config') 144 if method_info.supports_download: 145 arg_lines.append('download=download') 146 printer('return self._RunMethod(') 147 with printer.Indent(indent=' '): 148 for line in arg_lines[:-1]: 149 printer('%s,', line) 150 printer('%s)', arg_lines[-1]) 151 printer() 152 printer('{0}.method_config = lambda: base_api.ApiMethodInfo(' 153 .format(method_name)) 154 with printer.Indent(indent=' '): 155 method_info = method_info_map[method_name] 156 attrs = sorted( 157 x.name for x in method_info.all_fields()) 158 for attr in attrs: 159 if attr in ('upload_config', 'description'): 160 continue 161 value = getattr(method_info, attr) 162 if value is not None: 163 printer('%s=%r,', attr, value) 164 printer(')') 165 166 def __WriteProtoServiceDeclaration(self, printer, name, method_info_map): 167 """Write a single service declaration to a proto file.""" 168 printer() 169 printer('service %s {', self.__GetServiceClassName(name)) 170 with printer.Indent(): 171 for method_name, method_info in method_info_map.items(): 172 for line in textwrap.wrap(method_info.description, 173 printer.CalculateWidth() - 3): 174 printer('// %s', line) 175 printer('rpc %s (%s) returns (%s);', 176 method_name, 177 method_info.request_type_name, 178 method_info.response_type_name) 179 printer('}') 180 181 def WriteProtoFile(self, printer): 182 """Write the services in this registry to out as proto.""" 183 self.Validate() 184 client_info = self.__client_info 185 printer('// Generated services for %s version %s.', 186 client_info.package, client_info.version) 187 printer() 188 printer('syntax = "proto2";') 189 printer('package %s;', self.__package) 190 printer('import "%s";', client_info.messages_proto_file_name) 191 printer() 192 for name, method_info_map in self.__service_method_info_map.items(): 193 self.__WriteProtoServiceDeclaration(printer, name, method_info_map) 194 195 def WriteFile(self, printer): 196 """Write the services in this registry to out.""" 197 self.Validate() 198 client_info = self.__client_info 199 printer('"""Generated client library for %s version %s."""', 200 client_info.package, client_info.version) 201 printer('# NOTE: This file is autogenerated and should not be edited ' 202 'by hand.') 203 printer('from %s import base_api', self.__base_files_package) 204 if self.__root_package: 205 import_prefix = 'from {0} '.format(self.__root_package) 206 else: 207 import_prefix = '' 208 printer('%simport %s as messages', import_prefix, 209 client_info.messages_rule_name) 210 printer() 211 printer() 212 printer('class %s(base_api.BaseApiClient):', 213 client_info.client_class_name) 214 with printer.Indent(): 215 printer( 216 '"""Generated client library for service %s version %s."""', 217 client_info.package, client_info.version) 218 printer() 219 printer('MESSAGES_MODULE = messages') 220 printer('BASE_URL = {0!r}'.format(client_info.base_url)) 221 printer('MTLS_BASE_URL = {0!r}'.format(client_info.mtls_base_url)) 222 printer() 223 printer('_PACKAGE = {0!r}'.format(client_info.package)) 224 printer('_SCOPES = {0!r}'.format( 225 client_info.scopes or 226 ['https://www.googleapis.com/auth/userinfo.email'])) 227 printer('_VERSION = {0!r}'.format(client_info.version)) 228 printer('_CLIENT_ID = {0!r}'.format(client_info.client_id)) 229 printer('_CLIENT_SECRET = {0!r}'.format(client_info.client_secret)) 230 printer('_USER_AGENT = {0!r}'.format(client_info.user_agent)) 231 printer('_CLIENT_CLASS_NAME = {0!r}'.format( 232 client_info.client_class_name)) 233 printer('_URL_VERSION = {0!r}'.format(client_info.url_version)) 234 printer('_API_KEY = {0!r}'.format(client_info.api_key)) 235 printer() 236 printer("def __init__(self, url='', credentials=None,") 237 with printer.Indent(indent=' '): 238 printer('get_credentials=True, http=None, model=None,') 239 printer('log_request=False, log_response=False,') 240 printer('credentials_args=None, default_global_params=None,') 241 printer('additional_http_headers=None, ' 242 'response_encoding=None):') 243 with printer.Indent(): 244 printer('"""Create a new %s handle."""', client_info.package) 245 printer('url = url or self.BASE_URL') 246 printer( 247 'super(%s, self).__init__(', client_info.client_class_name) 248 printer(' url, credentials=credentials,') 249 printer(' get_credentials=get_credentials, http=http, ' 250 'model=model,') 251 printer(' log_request=log_request, ' 252 'log_response=log_response,') 253 printer(' credentials_args=credentials_args,') 254 printer(' default_global_params=default_global_params,') 255 printer(' additional_http_headers=additional_http_headers,') 256 printer(' response_encoding=response_encoding)') 257 for name in self.__service_method_info_map.keys(): 258 printer('self.%s = self.%s(self)', 259 name, self.__GetServiceClassName(name)) 260 for name, method_info in self.__service_method_info_map.items(): 261 self.__WriteSingleService( 262 printer, name, method_info, client_info.client_class_name) 263 264 def __RegisterService(self, service_name, method_info_map): 265 if service_name in self.__service_method_info_map: 266 raise ValueError( 267 'Attempt to re-register descriptor %s' % service_name) 268 self.__service_method_info_map[service_name] = method_info_map 269 270 def __CreateRequestType(self, method_description, body_type=None): 271 """Create a request type for this method.""" 272 schema = {} 273 schema['id'] = self.__names.ClassName('%sRequest' % ( 274 self.__names.ClassName(method_description['id'], separator='.'),)) 275 schema['type'] = 'object' 276 schema['properties'] = collections.OrderedDict() 277 if 'parameterOrder' not in method_description: 278 ordered_parameters = list(method_description.get('parameters', [])) 279 else: 280 ordered_parameters = method_description['parameterOrder'][:] 281 for k in method_description['parameters']: 282 if k not in ordered_parameters: 283 ordered_parameters.append(k) 284 for parameter_name in ordered_parameters: 285 field_name = self.__names.CleanName(parameter_name) 286 field = dict(method_description['parameters'][parameter_name]) 287 if 'type' not in field: 288 raise ValueError('No type found in parameter %s' % field) 289 schema['properties'][field_name] = field 290 if body_type is not None: 291 body_field_name = self.__GetRequestField( 292 method_description, body_type) 293 if body_field_name in schema['properties']: 294 raise ValueError('Failed to normalize request resource name') 295 if 'description' not in body_type: 296 body_type['description'] = ( 297 'A %s resource to be passed as the request body.' % ( 298 self.__GetRequestType(body_type),)) 299 schema['properties'][body_field_name] = body_type 300 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 301 return schema['id'] 302 303 def __CreateVoidResponseType(self, method_description): 304 """Create an empty response type.""" 305 schema = {} 306 method_name = self.__names.ClassName( 307 method_description['id'], separator='.') 308 schema['id'] = self.__names.ClassName('%sResponse' % method_name) 309 schema['type'] = 'object' 310 schema['description'] = 'An empty %s response.' % method_name 311 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 312 return schema['id'] 313 314 def __NeedRequestType(self, method_description, request_type): 315 """Determine if this method needs a new request type created.""" 316 if not request_type: 317 return True 318 method_id = method_description.get('id', '') 319 if method_id in self.__unelidable_request_methods: 320 return True 321 message = self.__message_registry.LookupDescriptorOrDie(request_type) 322 if message is None: 323 return True 324 field_names = [x.name for x in message.fields] 325 parameters = method_description.get('parameters', {}) 326 for param_name, param_info in parameters.items(): 327 if (param_info.get('location') != 'path' or 328 self.__names.CleanName(param_name) not in field_names): 329 break 330 else: 331 return False 332 return True 333 334 def __MaxSizeToInt(self, max_size): 335 """Convert max_size to an int.""" 336 size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size) 337 if size_groups is None: 338 raise ValueError('Could not parse maxSize') 339 size, unit = size_groups.group('size', 'unit') 340 shift = 0 341 if unit is not None: 342 unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40} 343 shift = unit_dict.get(unit.upper()) 344 if shift is None: 345 raise ValueError('Unknown unit %s' % unit) 346 return int(size) * (1 << shift) 347 348 def __ComputeUploadConfig(self, media_upload_config, method_id): 349 """Fill out the upload config for this method.""" 350 config = base_api.ApiUploadInfo() 351 if 'maxSize' in media_upload_config: 352 config.max_size = self.__MaxSizeToInt( 353 media_upload_config['maxSize']) 354 if 'accept' not in media_upload_config: 355 logging.warn( 356 'No accept types found for upload configuration in ' 357 'method %s, using */*', method_id) 358 config.accept.extend([ 359 str(a) for a in media_upload_config.get('accept', '*/*')]) 360 361 for accept_pattern in config.accept: 362 if not _MIME_PATTERN_RE.match(accept_pattern): 363 logging.warn('Unexpected MIME type: %s', accept_pattern) 364 protocols = media_upload_config.get('protocols', {}) 365 for protocol in ('simple', 'resumable'): 366 media = protocols.get(protocol, {}) 367 for attr in ('multipart', 'path'): 368 if attr in media: 369 setattr(config, '%s_%s' % (protocol, attr), media[attr]) 370 return config 371 372 def __ComputeMethodInfo(self, method_description, request, response, 373 request_field): 374 """Compute the base_api.ApiMethodInfo for this method.""" 375 relative_path = self.__names.NormalizeRelativePath( 376 ''.join((self.__client_info.base_path, 377 method_description['path']))) 378 method_id = method_description['id'] 379 ordered_params = [] 380 for param_name in method_description.get('parameterOrder', []): 381 param_info = method_description['parameters'][param_name] 382 if param_info.get('required', False): 383 ordered_params.append(param_name) 384 method_info = base_api.ApiMethodInfo( 385 relative_path=relative_path, 386 method_id=method_id, 387 http_method=method_description['httpMethod'], 388 description=util.CleanDescription( 389 method_description.get('description', '')), 390 query_params=[], 391 path_params=[], 392 ordered_params=ordered_params, 393 request_type_name=self.__names.ClassName(request), 394 response_type_name=self.__names.ClassName(response), 395 request_field=request_field, 396 ) 397 flat_path = method_description.get('flatPath', None) 398 if flat_path is not None: 399 flat_path = self.__names.NormalizeRelativePath( 400 self.__client_info.base_path + flat_path) 401 if flat_path != relative_path: 402 method_info.flat_path = flat_path 403 if method_description.get('supportsMediaUpload', False): 404 method_info.upload_config = self.__ComputeUploadConfig( 405 method_description.get('mediaUpload'), method_id) 406 method_info.supports_download = method_description.get( 407 'supportsMediaDownload', False) 408 self.__all_scopes.update(method_description.get('scopes', ())) 409 for param, desc in method_description.get('parameters', {}).items(): 410 param = self.__names.CleanName(param) 411 location = desc['location'] 412 if location == 'query': 413 method_info.query_params.append(param) 414 elif location == 'path': 415 method_info.path_params.append(param) 416 else: 417 raise ValueError( 418 'Unknown parameter location %s for parameter %s' % ( 419 location, param)) 420 method_info.path_params.sort() 421 method_info.query_params.sort() 422 return method_info 423 424 def __BodyFieldName(self, body_type): 425 if body_type is None: 426 return '' 427 return self.__names.FieldName(body_type['$ref']) 428 429 def __GetRequestType(self, body_type): 430 return self.__names.ClassName(body_type.get('$ref')) 431 432 def __GetRequestField(self, method_description, body_type): 433 """Determine the request field for this method.""" 434 body_field_name = self.__BodyFieldName(body_type) 435 if body_field_name in method_description.get('parameters', {}): 436 body_field_name = self.__names.FieldName( 437 '%s_resource' % body_field_name) 438 # It's exceedingly unlikely that we'd get two name collisions, which 439 # means it's bound to happen at some point. 440 while body_field_name in method_description.get('parameters', {}): 441 body_field_name = self.__names.FieldName( 442 '%s_body' % body_field_name) 443 return body_field_name 444 445 def AddServiceFromResource(self, service_name, methods): 446 """Add a new service named service_name with the given methods.""" 447 service_name = self.__names.CleanName(service_name) 448 method_descriptions = methods.get('methods', {}) 449 method_info_map = collections.OrderedDict() 450 items = sorted(method_descriptions.items()) 451 for method_name, method_description in items: 452 method_name = self.__names.MethodName(method_name) 453 454 # NOTE: According to the discovery document, if the request or 455 # response is present, it will simply contain a `$ref`. 456 body_type = method_description.get('request') 457 if body_type is None: 458 request_type = None 459 else: 460 request_type = self.__GetRequestType(body_type) 461 if self.__NeedRequestType(method_description, request_type): 462 request = self.__CreateRequestType( 463 method_description, body_type=body_type) 464 request_field = self.__GetRequestField( 465 method_description, body_type) 466 else: 467 request = request_type 468 request_field = base_api.REQUEST_IS_BODY 469 470 if 'response' in method_description: 471 response = method_description['response']['$ref'] 472 else: 473 response = self.__CreateVoidResponseType(method_description) 474 475 method_info_map[method_name] = self.__ComputeMethodInfo( 476 method_description, request, response, request_field) 477 478 nested_services = methods.get('resources', {}) 479 services = sorted(nested_services.items()) 480 for subservice_name, submethods in services: 481 new_service_name = '%s_%s' % (service_name, subservice_name) 482 self.AddServiceFromResource(new_service_name, submethods) 483 484 self.__RegisterService(service_name, method_info_map) 485