xref: /aosp_15_r20/external/tink/testing/cross_language/util/testing_servers.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS-IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""testing_server starts up testing gRPC servers in different languages."""
15
16import os
17import subprocess
18import time
19
20from typing import List, Optional, Type, TypeVar
21from absl import logging
22import grpc
23import portpicker
24import tink
25
26from tink.proto import tink_pb2
27from util import _primitives
28from protos import testing_api_pb2
29from protos import testing_api_pb2_grpc
30
31P = TypeVar('P')
32
33# Server paths are relative to a root folder where all the server are located.
34# It can be set manually as follows:
35#   bazel test util:testing_servers_test \
36#     --test_env TINK_CROSS_LANG_ROOT_PATH=<path to the root folder>
37_SERVER_PATHS = {
38    'cc': ['cc/bazel-bin/testing_server', 'cc/testing_server'],
39    'go': ['go/bazel-bin/testing_server_/testing_server', 'go/testing_server'],
40    'java': [
41        'java_src/bazel-bin/testing_server_deploy.jar',
42        'java_src/testing_server'
43    ],
44    'python': ['python/bazel-bin/testing_server', 'python/testing_server']
45}
46
47# All languages that have a testing server
48LANGUAGES = list(_SERVER_PATHS.keys())
49
50KEYSET_READER_WRITER_TYPES = [('KEYSET_READER_BINARY', 'KEYSET_WRITER_BINARY'),
51                              ('KEYSET_READER_JSON', 'KEYSET_WRITER_JSON')]
52
53# location of the testing_server java binary, relative to the root folder where
54# all the server are located.
55_JAVA_PATH = ('java_src/bazel-bin/testing_server.runfiles/local_jdk/bin/java')
56
57_PRIMITIVE_STUBS = {
58    'aead': testing_api_pb2_grpc.AeadStub,
59    'daead': testing_api_pb2_grpc.DeterministicAeadStub,
60    'streaming_aead': testing_api_pb2_grpc.StreamingAeadStub,
61    'hybrid': testing_api_pb2_grpc.HybridStub,
62    'mac': testing_api_pb2_grpc.MacStub,
63    'signature': testing_api_pb2_grpc.SignatureStub,
64    'prf': testing_api_pb2_grpc.PrfSetStub,
65    'jwt': testing_api_pb2_grpc.JwtStub,
66}
67
68# All primitives.
69_PRIMITIVES = list(_PRIMITIVE_STUBS.keys())
70
71SUPPORTED_LANGUAGES_BY_PRIMITIVE = {
72    'aead': ['cc', 'go', 'java', 'python'],
73    'daead': ['cc', 'go', 'java', 'python'],
74    'streaming_aead': ['cc', 'go', 'java', 'python'],
75    'hybrid': ['cc', 'go', 'java', 'python'],
76    'mac': ['cc', 'go', 'java', 'python'],
77    'signature': ['cc', 'go', 'java', 'python'],
78    'prf': ['cc', 'java', 'go', 'python'],
79    'jwt': ['cc', 'java', 'go', 'python'],
80}
81
82# Needed in golang, because there key URIs are not optional.
83GCP_KEY_URI_PREFIX = (
84    'gcp-kms://projects/tink-test-infrastructure/locations/global/'
85    'keyRings/unit-and-integration-testing/cryptoKeys/')
86AWS_KEY_URI_PREFIX = 'aws-kms://arn:aws:kms:us-east-2:235739564943:'
87
88GCP_CREDENTIALS_PATH = os.path.join(
89    os.environ['TEST_SRCDIR'] if 'TEST_SRCDIR' in os.environ else '',
90    'cross_language_test/testdata/gcp/credential.json')
91AWS_CREDENTIALS_INI_PATH = os.path.join(
92    os.environ['TEST_SRCDIR'] if 'TEST_SRCDIR' in os.environ else '',
93    'cross_language_test/testdata/aws/credentials.ini')
94AWS_CREDENTIALS_CRED_PATH = os.path.join(
95    os.environ['TEST_SRCDIR'] if 'TEST_SRCDIR' in os.environ else '',
96    'cross_language_test/testdata/aws/credentials.cred')
97
98_RELATIVE_ROOT_PATH = 'tink_base/testing'
99
100
101def _root_path() -> str:
102  """Return the root path where server binaries are located.
103
104  This root path can be set in the TINK_CROSS_LANG_ROOT_PATH enviroment
105  variable. If TINK_CROSS_LANG_ROOT_PATH is not set, the root path is calculated
106  from the TEST_SRCDIR enviroment variable.
107
108  Returns:
109    The root path of the cross language tests servers.
110  Raises:
111    ValueError if no variables are set.
112    FileNotFoundError if a variable is set but the path is invalid.
113  """
114
115  def _check_path_exists_or_fail(path, env_variable):
116    """Returns the path if it eixts, otherwise raises a FileNotFoundError."""
117    if os.path.exists(path):
118      return path
119    raise FileNotFoundError(f'Variable {env_variable} is set but has an ' +
120                            f'invalid path {path}')
121
122  if 'TINK_CROSS_LANG_ROOT_PATH' in os.environ:
123    return _check_path_exists_or_fail(os.environ['TINK_CROSS_LANG_ROOT_PATH'],
124                                      'TINK_CROSS_LANG_ROOT_PATH')
125  if 'TEST_SRCDIR' in os.environ:
126    return _check_path_exists_or_fail(
127        os.path.join(os.environ['TEST_SRCDIR'], _RELATIVE_ROOT_PATH),
128        'TEST_SRCDIR')
129
130  raise ValueError('No root path environment variable set')
131
132
133def _server_path(lang: str) -> str:
134  """Returns the path where the server binary is located."""
135  root_dir = _root_path()
136  for relative_server_path in _SERVER_PATHS[lang]:
137    server_path = os.path.join(root_dir, relative_server_path)
138    logging.info('try path: %s', server_path)
139    if os.path.exists(server_path):
140      return server_path
141  raise RuntimeError('Executable for lang %s not found' % lang)
142
143
144def _server_cmd(lang: str, port: int) -> List[str]:
145  """Returns the server command."""
146  if lang == 'java':
147    # Java expects a .cred file. Others a .ini file.
148    aws_credentials_path = AWS_CREDENTIALS_CRED_PATH
149  else:
150    aws_credentials_path = AWS_CREDENTIALS_INI_PATH
151
152  server_path = _server_path(lang)
153  # TODO(b/249015767): Refactor KMS integration to pass credentials via gRPC.
154  server_args = [
155      '--port',
156      '%d' % port, '--gcp_credentials_path', GCP_CREDENTIALS_PATH,
157      '--aws_credentials_path', aws_credentials_path
158  ]
159  if lang == 'go':
160    # in all languages except go, the key URI parameters are optional.
161    # in go, they are required, but can be a prefix.
162    server_args.extend([
163        '--gcp_key_uri', GCP_KEY_URI_PREFIX,
164        '--aws_key_uri', AWS_KEY_URI_PREFIX])
165
166  if lang == 'java' and server_path.endswith('.jar'):
167    java_path = os.path.join(_root_path(), _JAVA_PATH)
168    return [java_path, '-jar', server_path] + server_args
169  else:
170    return [server_path] + server_args
171
172
173def _get_file_content(filename: str) -> str:
174  with open(filename, 'r') as f:
175    return f.read()
176
177
178class _TestingServers():
179  """TestingServers starts up testing gRPC servers and returns service stubs."""
180
181  def __init__(self, test_name: str):
182    self._server = {}
183    self._output_file = {}
184    self._channel = {}
185    self._metadata_stub = {}
186    self._keyset_stub = {}
187    self._aead_stub = {}
188    self._daead_stub = {}
189    self._streaming_aead_stub = {}
190    self._hybrid_stub = {}
191    self._mac_stub = {}
192    self._signature_stub = {}
193    self._prf_stub = {}
194    self._jwt_stub = {}
195    self._test_name = test_name
196
197    for lang in LANGUAGES:
198      port = portpicker.pick_unused_port()
199      cmd = _server_cmd(lang, port)
200      logging.info('cmd = %s', cmd)
201      output_path = self._get_output_path(lang)
202      logging.info('writing server output to %s', output_path)
203      try:
204        self._output_file[lang] = open(output_path, 'w+')
205      except IOError as e:
206        logging.info('unable to open server output file %s', output_path)
207        raise RuntimeError('Could not start %s server' % lang) from e
208      self._server[lang] = subprocess.Popen(
209          cmd, stdout=self._output_file[lang], stderr=subprocess.STDOUT)
210      logging.info('%s server started on port %d with pid: %d. Log output: %s',
211                   lang, port, self._server[lang].pid,
212                   self._output_file[lang].name)
213      self._channel[lang] = grpc.secure_channel(
214          '[::]:%d' % port, grpc.local_channel_credentials())
215    for lang in LANGUAGES:
216      try:
217        grpc.channel_ready_future(self._channel[lang]).result(timeout=30)
218      except Exception as e:
219        logging.info('Timeout while connecting to server %s', lang)
220        self._server[lang].kill()
221        _, _ = self._server[lang].communicate()
222        raise RuntimeError(
223            'Could not start %s server, output=%s' %
224            (lang, _get_file_content(self._output_file[lang].name))) from e
225      self._metadata_stub[lang] = testing_api_pb2_grpc.MetadataStub(
226          self._channel[lang])
227      self._keyset_stub[lang] = testing_api_pb2_grpc.KeysetStub(
228          self._channel[lang])
229    for primitive in _PRIMITIVES:
230      for lang in SUPPORTED_LANGUAGES_BY_PRIMITIVE[primitive]:
231        stub_name = '_%s_stub' % primitive
232        getattr(self, stub_name)[lang] = _PRIMITIVE_STUBS[primitive](
233            self._channel[lang])
234
235  def _get_output_path(self, lang) -> str:
236    try:
237      output_dir = os.environ['TEST_UNDECLARED_OUTPUTS_DIR']
238    except KeyError as e:
239      raise RuntimeError(
240          'Could not start %s server, TEST_UNDECLARED_OUTPUTS_DIR environment'
241          'variable must be set') from e
242    output_file = '%s-%s-%s' % (self._test_name, lang, 'server.log')
243    return os.path.join(output_dir, output_file)
244
245  def keyset_stub(self, lang) -> testing_api_pb2_grpc.KeysetStub:
246    return self._keyset_stub[lang]
247
248  def aead_stub(self, lang) -> testing_api_pb2_grpc.AeadStub:
249    return self._aead_stub[lang]
250
251  def daead_stub(self, lang) -> testing_api_pb2_grpc.DeterministicAeadStub:
252    return self._daead_stub[lang]
253
254  def streaming_aead_stub(self, lang) -> testing_api_pb2_grpc.StreamingAeadStub:
255    return self._streaming_aead_stub[lang]
256
257  def hybrid_stub(self, lang) -> testing_api_pb2_grpc.HybridStub:
258    return self._hybrid_stub[lang]
259
260  def mac_stub(self, lang) -> testing_api_pb2_grpc.MacStub:
261    return self._mac_stub[lang]
262
263  def signature_stub(self, lang) -> testing_api_pb2_grpc.SignatureStub:
264    return self._signature_stub[lang]
265
266  def prf_stub(self, lang) -> testing_api_pb2_grpc.PrfSetStub:
267    return self._prf_stub[lang]
268
269  def jwt_stub(self, lang) -> testing_api_pb2_grpc.JwtStub:
270    return self._jwt_stub[lang]
271
272  def metadata_stub(self, lang) -> testing_api_pb2_grpc.MetadataStub:
273    return self._metadata_stub[lang]
274
275  def stop(self):
276    """Stops all servers."""
277    logging.info('Stopping servers...')
278    for lang in LANGUAGES:
279      self._channel[lang].close()
280    for lang in LANGUAGES:
281      self._server[lang].terminate()
282    time.sleep(2)
283    for lang in LANGUAGES:
284      if self._server[lang].poll() is None:
285        logging.info('Killing server %s.', lang)
286        self._server[lang].kill()
287    for lang in LANGUAGES:
288      self._output_file[lang].close()
289    logging.info('All servers stopped.')
290
291    print()
292    print()
293    for lang in LANGUAGES:
294      total_reps = 1 + 100 // len(lang + ' ')
295      length = total_reps * len(lang + ' ') - 1
296      print('=' * length)
297      print((lang + ' ') * total_reps)
298      print('v' * length)
299      with open(self._get_output_path(lang)) as f:
300        print(f.read())
301      print('^' * length)
302      print((lang + ' ') * total_reps)
303      print('=' * length)
304      print()
305
306_ts = None
307
308
309def start(output_files_prefix: str) -> None:
310  """Starts all servers."""
311  global _ts
312  _ts = _TestingServers(output_files_prefix)
313
314  versions = {}
315  for lang in LANGUAGES:
316    response = _ts.metadata_stub(lang).GetServerInfo(
317        testing_api_pb2.ServerInfoRequest())
318    if lang != response.language:
319      raise ValueError(
320          'lang = %s != response.language = %s' % (lang, response.language))
321    if response.tink_version:
322      versions[lang] = response.tink_version
323    else:
324      logging.warning('server in lang %s has no tink version.', lang)
325  unique_versions = list(set(versions.values()))
326  logging.info('Tink version: %s', unique_versions[0])
327
328
329def stop() -> None:
330  """Stops all servers."""
331  _ts.stop()
332
333
334def key_template(lang: str, template_name: str) -> tink_pb2.KeyTemplate:
335  """Returns the key template of template_name, implemented in lang."""
336  return _primitives.key_template(_ts.keyset_stub(lang), template_name)
337
338
339def new_keyset(lang: str, template: tink_pb2.KeyTemplate) -> bytes:
340  """Returns a new KeysetHandle, implemented in lang."""
341  return _primitives.new_keyset(_ts.keyset_stub(lang), template)
342
343
344def public_keyset(lang: str, private_keyset: bytes) -> bytes:
345  """Returns a public keyset handle, implemented in lang."""
346  return _primitives.public_keyset(_ts.keyset_stub(lang), private_keyset)
347
348
349def keyset_to_json(lang: str, keyset: bytes) -> str:
350  return _primitives.keyset_to_json(_ts.keyset_stub(lang), keyset)
351
352
353def keyset_from_json(lang: str, json_keyset: str) -> bytes:
354  return _primitives.keyset_from_json(_ts.keyset_stub(lang), json_keyset)
355
356
357def keyset_read_encrypted(lang: str, encrypted_keyset: bytes,
358                          master_keyset: bytes,
359                          associated_data: Optional[bytes],
360                          keyset_reader_type: str) -> bytes:
361  return _primitives.keyset_read_encrypted(
362      _ts.keyset_stub(lang), encrypted_keyset, master_keyset, associated_data,
363      keyset_reader_type)
364
365
366def keyset_write_encrypted(lang: str, keyset: bytes, master_keyset: bytes,
367                           associated_data: Optional[bytes],
368                           keyset_writer_type: str) -> bytes:
369  return _primitives.keyset_write_encrypted(
370      _ts.keyset_stub(lang), keyset, master_keyset, associated_data,
371      keyset_writer_type)
372
373
374def jwk_set_to_keyset(lang: str, jwk_set: str) -> bytes:
375  return _primitives.jwk_set_to_keyset(_ts.jwt_stub(lang), jwk_set)
376
377
378def jwk_set_from_keyset(lang: str, keyset: bytes) -> str:
379  return _primitives.jwk_set_from_keyset(_ts.jwt_stub(lang), keyset)
380
381
382def remote_primitive(lang: str, keyset: bytes, primitive_class: Type[P]) -> P:
383  """Creates a primitive from a keyset backed by the given language.
384
385  Internally, this does an RPC to the server specified by 'lang' in order to
386  try to 'Create' the primitive. If the RPC returns with an error, a TinkError
387  is returned. Otherwise, an instance of the primitive is returned which
388  forwards calls to the service implemented in the language.
389
390  Args:
391    lang: specification of the language to use
392    keyset: the serialized keyset
393    primitive_class: the type of the primitive
394
395  Returns:
396    A primitive to be used.
397
398  Raises:
399    TinkError if creation fails.
400  """
401
402  if primitive_class == tink.aead.Aead:
403    return _primitives.Aead(lang, _ts.aead_stub(lang), keyset, None)
404  if primitive_class == tink.daead.DeterministicAead:
405    return _primitives.DeterministicAead(lang, _ts.daead_stub(lang), keyset,
406                                         None)
407  if primitive_class == tink.streaming_aead.StreamingAead:
408    return _primitives.StreamingAead(lang, _ts.streaming_aead_stub(lang),
409                                     keyset)
410  if primitive_class == tink.hybrid.HybridDecrypt:
411    return _primitives.HybridDecrypt(lang, _ts.hybrid_stub(lang), keyset, None)
412  if primitive_class == tink.hybrid.HybridEncrypt:
413    return _primitives.HybridEncrypt(lang, _ts.hybrid_stub(lang), keyset, None)
414  if primitive_class == tink.mac.Mac:
415    return _primitives.Mac(lang, _ts.mac_stub(lang), keyset, None)
416  if primitive_class == tink.signature.PublicKeySign:
417    return _primitives.PublicKeySign(lang, _ts.signature_stub(lang), keyset,
418                                     None)
419  if primitive_class == tink.signature.PublicKeyVerify:
420    return _primitives.PublicKeyVerify(lang, _ts.signature_stub(lang), keyset,
421                                       None)
422  if primitive_class == tink.prf.PrfSet:
423    return _primitives.PrfSet(lang, _ts.prf_stub(lang), keyset, None)
424  if primitive_class == tink.jwt.JwtMac:
425    return _primitives.JwtMac(lang, _ts.jwt_stub(lang), keyset)
426  if primitive_class == tink.jwt.JwtPublicKeySign:
427    return _primitives.JwtPublicKeySign(lang, _ts.jwt_stub(lang), keyset)
428  if primitive_class == tink.jwt.JwtPublicKeyVerify:
429    return _primitives.JwtPublicKeyVerify(lang, _ts.jwt_stub(lang), keyset)
430  raise ValueError('Unsupported P in remote_primitive: ' + str(primitive_class))
431