xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/dockerfiles/assembler.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Multipurpose TensorFlow Docker Helper.
16
17- Assembles Dockerfiles
18- Builds images (and optionally runs image tests)
19- Pushes images to Docker Hub (provided with credentials)
20
21Logs are written to stderr; the list of successfully built images is
22written to stdout.
23
24Read README.md (in this directory) for instructions!
25"""
26
27import collections
28import copy
29import errno
30import itertools
31import json
32import multiprocessing
33import os
34import platform
35import re
36import shutil
37import sys
38
39from absl import app
40from absl import flags
41import cerberus
42import docker
43import yaml
44
45FLAGS = flags.FLAGS
46
47flags.DEFINE_string('hub_username', None,
48                    'Dockerhub username, only used with --upload_to_hub')
49
50flags.DEFINE_string(
51    'hub_password', None,
52    ('Dockerhub password, only used with --upload_to_hub. Use from an env param'
53     ' so your password isn\'t in your history.'))
54
55flags.DEFINE_integer('hub_timeout', 3600,
56                     'Abort Hub upload if it takes longer than this.')
57
58flags.DEFINE_string(
59    'repository', 'tensorflow',
60    'Tag local images as {repository}:tag (in addition to the '
61    'hub_repository, if uploading to hub)')
62
63flags.DEFINE_string(
64    'hub_repository', None,
65    'Push tags to this Docker Hub repository, e.g. tensorflow/tensorflow')
66
67flags.DEFINE_boolean(
68    'upload_to_hub',
69    False,
70    ('Push built images to Docker Hub (you must also provide --hub_username, '
71     '--hub_password, and --hub_repository)'),
72    short_name='u',
73)
74
75flags.DEFINE_boolean(
76    'construct_dockerfiles', False, 'Do not build images', short_name='d')
77
78flags.DEFINE_boolean(
79    'keep_temp_dockerfiles',
80    False,
81    'Retain .temp.Dockerfiles created while building images.',
82    short_name='k')
83
84flags.DEFINE_boolean(
85    'build_images', False, 'Do not build images', short_name='b')
86
87flags.DEFINE_string(
88    'run_tests_path', None,
89    ('Execute test scripts on generated Dockerfiles before pushing them. '
90     'Flag value must be a full path to the "tests" directory, which is usually'
91     ' $(realpath ./tests). A failed tests counts the same as a failed build.'))
92
93flags.DEFINE_boolean(
94    'stop_on_failure', False,
95    ('Stop processing tags if any one build fails. If False or not specified, '
96     'failures are reported but do not affect the other images.'))
97
98flags.DEFINE_boolean(
99    'dry_run',
100    False,
101    'Do not build or deploy anything at all.',
102    short_name='n',
103)
104
105flags.DEFINE_string(
106    'exclude_tags_matching',
107    None,
108    ('Regular expression that skips processing on any tag it matches. Must '
109     'match entire string, e.g. ".*gpu.*" ignores all GPU tags.'),
110    short_name='x')
111
112flags.DEFINE_string(
113    'only_tags_matching',
114    None,
115    ('Regular expression that skips processing on any tag it does not match. '
116     'Must match entire string, e.g. ".*gpu.*" includes only GPU tags.'),
117    short_name='i')
118
119flags.DEFINE_string(
120    'dockerfile_dir',
121    './dockerfiles', 'Path to an output directory for Dockerfiles.'
122    ' Will be created if it doesn\'t exist.'
123    ' Existing files in this directory will be deleted when new Dockerfiles'
124    ' are made.',
125    short_name='o')
126
127flags.DEFINE_string(
128    'partial_dir',
129    './partials',
130    'Path to a directory containing foo.partial.Dockerfile partial files.'
131    ' can have subdirectories, e.g. "bar/baz.partial.Dockerfile".',
132    short_name='p')
133
134flags.DEFINE_multi_string(
135    'release', [],
136    'Set of releases to build and tag. Defaults to every release type.',
137    short_name='r')
138
139flags.DEFINE_multi_string(
140    'arg', [],
141    ('Extra build arguments. These are used for expanding tag names if needed '
142     '(e.g. --arg _TAG_PREFIX=foo) and for using as build arguments (unused '
143     'args will print a warning).'),
144    short_name='a')
145
146flags.DEFINE_boolean(
147    'nocache', False,
148    'Disable the Docker build cache; identical to "docker build --no-cache"')
149
150flags.DEFINE_string(
151    'spec_file',
152    './spec.yml',
153    'Path to the YAML specification file',
154    short_name='s')
155
156# Schema to verify the contents of tag-spec.yml with Cerberus.
157# Must be converted to a dict from yaml to work.
158# Note: can add python references with e.g.
159# !!python/name:builtins.str
160# !!python/name:__main__.funcname
161# (but this may not be considered safe?)
162SCHEMA_TEXT = """
163header:
164  type: string
165
166slice_sets:
167  type: dict
168  keyschema:
169    type: string
170  valueschema:
171     type: list
172     schema:
173        type: dict
174        schema:
175           add_to_name:
176             type: string
177           dockerfile_exclusive_name:
178             type: string
179           dockerfile_subdirectory:
180             type: string
181           partials:
182             type: list
183             schema:
184               type: string
185               ispartial: true
186           test_runtime:
187             type: string
188             required: false
189           tests:
190             type: list
191             default: []
192             schema:
193               type: string
194           args:
195             type: list
196             default: []
197             schema:
198               type: string
199               isfullarg: true
200
201releases:
202  type: dict
203  keyschema:
204    type: string
205  valueschema:
206    type: dict
207    schema:
208      is_dockerfiles:
209        type: boolean
210        required: false
211        default: false
212      upload_images:
213        type: boolean
214        required: false
215        default: true
216      tag_specs:
217        type: list
218        required: true
219        schema:
220          type: string
221"""
222
223
224class TfDockerTagValidator(cerberus.Validator):
225  """Custom Cerberus validator for TF tag spec.
226
227  Note: Each _validate_foo function's docstring must end with a segment
228  describing its own validation schema, e.g. "The rule's arguments are...". If
229  you add a new validator, you can copy/paste that section.
230  """
231
232  def __init__(self, *args, **kwargs):
233    # See http://docs.python-cerberus.org/en/stable/customize.html
234    if 'partials' in kwargs:
235      self.partials = kwargs['partials']
236    super(cerberus.Validator, self).__init__(*args, **kwargs)
237
238  def _validate_ispartial(self, ispartial, field, value):
239    """Validate that a partial references an existing partial spec.
240
241    Args:
242      ispartial: Value of the rule, a bool
243      field: The field being validated
244      value: The field's value
245    The rule's arguments are validated against this schema:
246    {'type': 'boolean'}
247    """
248    if ispartial and value not in self.partials:
249      self._error(field,
250                  '{} is not present in the partials directory.'.format(value))
251
252  def _validate_isfullarg(self, isfullarg, field, value):
253    """Validate that a string is either a FULL=arg or NOT.
254
255    Args:
256      isfullarg: Value of the rule, a bool
257      field: The field being validated
258      value: The field's value
259    The rule's arguments are validated against this schema:
260    {'type': 'boolean'}
261    """
262    if isfullarg and '=' not in value:
263      self._error(field, '{} should be of the form ARG=VALUE.'.format(value))
264    if not isfullarg and '=' in value:
265      self._error(field, '{} should be of the form ARG (no =).'.format(value))
266
267
268def eprint(*args, **kwargs):
269  print(*args, file=sys.stderr, flush=True, **kwargs)
270
271
272def aggregate_all_slice_combinations(spec, slice_set_names):
273  """Figure out all of the possible slice groupings for a tag spec."""
274  slice_sets = copy.deepcopy(spec['slice_sets'])
275
276  for name in slice_set_names:
277    for slice_set in slice_sets[name]:
278      slice_set['set_name'] = name
279
280  slices_grouped_but_not_keyed = [slice_sets[name] for name in slice_set_names]
281  all_slice_combos = list(itertools.product(*slices_grouped_but_not_keyed))
282  return all_slice_combos
283
284
285def build_name_from_slices(format_string, slices, args, is_dockerfile=False):
286  """Build the tag name (cpu-devel...) from a list of slices."""
287  name_formatter = copy.deepcopy(args)
288  name_formatter.update({s['set_name']: s['add_to_name'] for s in slices})
289  name_formatter.update({
290      s['set_name']: s['dockerfile_exclusive_name']
291      for s in slices
292      if is_dockerfile and 'dockerfile_exclusive_name' in s
293  })
294  name = format_string.format(**name_formatter)
295  return name
296
297
298def update_args_dict(args_dict, updater):
299  """Update a dict of arg values with more values from a list or dict."""
300  if isinstance(updater, list):
301    for arg in updater:
302      key, sep, value = arg.partition('=')
303      if sep == '=':
304        args_dict[key] = value
305  if isinstance(updater, dict):
306    for key, value in updater.items():
307      args_dict[key] = value
308  return args_dict
309
310
311def get_slice_sets_and_required_args(slice_sets, tag_spec):
312  """Extract used-slice-sets and required CLI arguments from a spec string.
313
314  For example, {FOO}{bar}{bat} finds FOO, bar, and bat. Assuming bar and bat
315  are both named slice sets, FOO must be specified on the command line.
316
317  Args:
318     slice_sets: Dict of named slice sets
319     tag_spec: The tag spec string, e.g. {_FOO}{blep}
320
321  Returns:
322     (used_slice_sets, required_args), a tuple of lists
323  """
324  required_args = []
325  used_slice_sets = []
326
327  extract_bracketed_words = re.compile(r'\{([^}]+)\}')
328  possible_args_or_slice_set_names = extract_bracketed_words.findall(tag_spec)
329  for name in possible_args_or_slice_set_names:
330    if name in slice_sets:
331      used_slice_sets.append(name)
332    else:
333      required_args.append(name)
334
335  return (used_slice_sets, required_args)
336
337
338def gather_tag_args(slices, cli_input_args, required_args):
339  """Build a dictionary of all the CLI and slice-specified args for a tag."""
340  args = {}
341
342  for s in slices:
343    args = update_args_dict(args, s['args'])
344
345  args = update_args_dict(args, cli_input_args)
346  for arg in required_args:
347    if arg not in args:
348      eprint(('> Error: {} is not a valid slice_set, and also isn\'t an arg '
349              'provided on the command line. If it is an arg, please specify '
350              'it with --arg. If not, check the slice_sets list.'.format(arg)))
351      exit(1)
352
353  return args
354
355
356def gather_slice_list_items(slices, key):
357  """For a list of slices, get the flattened list of all of a certain key."""
358  return list(itertools.chain(*[s[key] for s in slices if key in s]))
359
360
361def find_first_slice_value(slices, key):
362  """For a list of slices, get the first value for a certain key."""
363  for s in slices:
364    if key in s and s[key] is not None:
365      return s[key]
366  return None
367
368
369def assemble_tags(spec, cli_args, enabled_releases, all_partials):
370  """Gather all the tags based on our spec.
371
372  Args:
373    spec: Nested dict containing full Tag spec
374    cli_args: List of ARG=foo arguments to pass along to Docker build
375    enabled_releases: List of releases to parse. Empty list = all
376    all_partials: Dict of every partial, for reference
377
378  Returns:
379    Dict of tags and how to build them
380  """
381  tag_data = collections.defaultdict(list)
382
383  for name, release in spec['releases'].items():
384    for tag_spec in release['tag_specs']:
385      if enabled_releases and name not in enabled_releases:
386        eprint('> Skipping release {}'.format(name))
387        continue
388
389      used_slice_sets, required_cli_args = get_slice_sets_and_required_args(
390          spec['slice_sets'], tag_spec)
391
392      slice_combos = aggregate_all_slice_combinations(spec, used_slice_sets)
393      for slices in slice_combos:
394
395        tag_args = gather_tag_args(slices, cli_args, required_cli_args)
396        tag_name = build_name_from_slices(tag_spec, slices, tag_args,
397                                          release['is_dockerfiles'])
398        used_partials = gather_slice_list_items(slices, 'partials')
399        used_tests = gather_slice_list_items(slices, 'tests')
400        test_runtime = find_first_slice_value(slices, 'test_runtime')
401        dockerfile_subdirectory = find_first_slice_value(
402            slices, 'dockerfile_subdirectory')
403        dockerfile_contents = merge_partials(spec['header'], used_partials,
404                                             all_partials)
405
406        tag_data[tag_name].append({
407            'release': name,
408            'tag_spec': tag_spec,
409            'is_dockerfiles': release['is_dockerfiles'],
410            'upload_images': release['upload_images'],
411            'cli_args': tag_args,
412            'dockerfile_subdirectory': dockerfile_subdirectory or '',
413            'partials': used_partials,
414            'tests': used_tests,
415            'test_runtime': test_runtime,
416            'dockerfile_contents': dockerfile_contents,
417        })
418
419  return tag_data
420
421
422def merge_partials(header, used_partials, all_partials):
423  """Merge all partial contents with their header."""
424  used_partials = list(used_partials)
425  return '\n'.join([header] + [all_partials[u] for u in used_partials])
426
427
428def upload_in_background(hub_repository, dock, image, tag):
429  """Upload a docker image (to be used by multiprocessing)."""
430  image.tag(hub_repository, tag=tag)
431  print(dock.images.push(hub_repository, tag=tag))
432
433
434def mkdir_p(path):
435  """Create a directory and its parents, even if it already exists."""
436  try:
437    os.makedirs(path)
438  except OSError as e:
439    if e.errno != errno.EEXIST:
440      raise
441
442
443def gather_existing_partials(partial_path):
444  """Find and read all available partials.
445
446  Args:
447    partial_path (string): read partials from this directory.
448
449  Returns:
450    Dict[string, string] of partial short names (like "ubuntu/python" or
451      "bazel") to the full contents of that partial.
452  """
453  partials = {}
454  for path, _, files in os.walk(partial_path):
455    for name in files:
456      fullpath = os.path.join(path, name)
457      if '.partial.Dockerfile' not in fullpath:
458        eprint(('> Probably not a problem: skipping {}, which is not a '
459                'partial.').format(fullpath))
460        continue
461      # partial_dir/foo/bar.partial.Dockerfile -> foo/bar
462      simple_name = fullpath[len(partial_path) + 1:-len('.partial.dockerfile')]
463      with open(fullpath, 'r') as f:
464        partial_contents = f.read()
465      partials[simple_name] = partial_contents
466  return partials
467
468
469def main(argv):
470  if len(argv) > 1:
471    raise app.UsageError('Too many command-line arguments.')
472
473  # Read the full spec file, used for everything
474  with open(FLAGS.spec_file, 'r') as spec_file:
475    tag_spec = yaml.safe_load(spec_file)
476
477  # Get existing partial contents
478  partials = gather_existing_partials(FLAGS.partial_dir)
479
480  # Abort if spec.yaml is invalid
481  schema = yaml.safe_load(SCHEMA_TEXT)
482  v = TfDockerTagValidator(schema, partials=partials)
483  if not v.validate(tag_spec):
484    eprint('> Error: {} is an invalid spec! The errors are:'.format(
485        FLAGS.spec_file))
486    eprint(yaml.dump(v.errors, indent=2))
487    exit(1)
488  tag_spec = v.normalized(tag_spec)
489
490  # Assemble tags and images used to build them
491  all_tags = assemble_tags(tag_spec, FLAGS.arg, FLAGS.release, partials)
492
493  # Empty Dockerfile directory if building new Dockerfiles
494  if FLAGS.construct_dockerfiles:
495    eprint('> Emptying Dockerfile dir "{}"'.format(FLAGS.dockerfile_dir))
496    shutil.rmtree(FLAGS.dockerfile_dir, ignore_errors=True)
497    mkdir_p(FLAGS.dockerfile_dir)
498
499  # Set up Docker helper
500  dock = docker.from_env()
501
502  # Login to Docker if uploading images
503  if FLAGS.upload_to_hub:
504    if not FLAGS.hub_username:
505      eprint('> Error: please set --hub_username when uploading to Dockerhub.')
506      exit(1)
507    if not FLAGS.hub_repository:
508      eprint(
509          '> Error: please set --hub_repository when uploading to Dockerhub.')
510      exit(1)
511    if not FLAGS.hub_password:
512      eprint('> Error: please set --hub_password when uploading to Dockerhub.')
513      exit(1)
514    dock.login(
515        username=FLAGS.hub_username,
516        password=FLAGS.hub_password,
517    )
518
519  # Each tag has a name ('tag') and a definition consisting of the contents
520  # of its Dockerfile, its build arg list, etc.
521  failed_tags = []
522  succeeded_tags = []
523  for tag, tag_defs in all_tags.items():
524    for tag_def in tag_defs:
525      eprint('> Working on {}'.format(tag))
526
527      if FLAGS.exclude_tags_matching and re.match(FLAGS.exclude_tags_matching,
528                                                  tag):
529        eprint('>> Excluded due to match against "{}".'.format(
530            FLAGS.exclude_tags_matching))
531        continue
532
533      if FLAGS.only_tags_matching and not re.match(FLAGS.only_tags_matching,
534                                                   tag):
535        eprint('>> Excluded due to failure to match against "{}".'.format(
536            FLAGS.only_tags_matching))
537        continue
538
539      # Write releases marked "is_dockerfiles" into the Dockerfile directory
540      if FLAGS.construct_dockerfiles and tag_def['is_dockerfiles']:
541        path = os.path.join(FLAGS.dockerfile_dir,
542                            tag_def['dockerfile_subdirectory'],
543                            tag + '.Dockerfile')
544        eprint('>> Writing {}...'.format(path))
545        if not FLAGS.dry_run:
546          mkdir_p(os.path.dirname(path))
547          with open(path, 'w') as f:
548            f.write(tag_def['dockerfile_contents'])
549
550      # Don't build any images for dockerfile-only releases
551      if not FLAGS.build_images:
552        continue
553
554      # Only build images for host architecture
555      proc_arch = platform.processor()
556      is_x86 = proc_arch.startswith('x86')
557      if (is_x86 and any(arch in tag for arch in ['ppc64le']) or
558          not is_x86 and proc_arch not in tag):
559        continue
560
561      # Generate a temporary Dockerfile to use to build, since docker-py
562      # needs a filepath relative to the build context (i.e. the current
563      # directory)
564      dockerfile = os.path.join(FLAGS.dockerfile_dir, tag + '.temp.Dockerfile')
565      if not FLAGS.dry_run:
566        with open(dockerfile, 'w') as f:
567          f.write(tag_def['dockerfile_contents'])
568      eprint('>> (Temporary) writing {}...'.format(dockerfile))
569
570      repo_tag = '{}:{}'.format(FLAGS.repository, tag)
571      eprint('>> Building {} using build args:'.format(repo_tag))
572      for arg, value in tag_def['cli_args'].items():
573        eprint('>>> {}={}'.format(arg, value))
574
575      # Note that we are NOT using cache_from, which appears to limit
576      # available cache layers to those from explicitly specified layers. Many
577      # of our layers are similar between local builds, so we want to use the
578      # implied local build cache.
579      tag_failed = False
580      image, logs = None, []
581      if not FLAGS.dry_run:
582        try:
583          # Use low level APIClient in order to stream log output
584          resp = dock.api.build(
585              timeout=FLAGS.hub_timeout,
586              path='.',
587              nocache=FLAGS.nocache,
588              dockerfile=dockerfile,
589              buildargs=tag_def['cli_args'],
590              tag=repo_tag)
591          last_event = None
592          image_id = None
593          # Manually process log output extracting build success and image id
594          # in order to get built image
595          while True:
596            try:
597              output = next(resp).decode('utf-8')
598              json_output = json.loads(output.strip('\r\n'))
599              if 'stream' in json_output:
600                eprint(json_output['stream'], end='')
601                match = re.search(r'(^Successfully built |sha256:)([0-9a-f]+)$',
602                                  json_output['stream'])
603                if match:
604                  image_id = match.group(2)
605                last_event = json_output['stream']
606                # collect all log lines into the logs object
607                logs.append(json_output)
608            except StopIteration:
609              eprint('Docker image build complete.')
610              break
611            except ValueError:
612              eprint('Error parsing from docker image build: {}'.format(output))
613          # If Image ID is not set, the image failed to built properly. Raise
614          # an error in this case with the last log line and all logs
615          if image_id:
616            image = dock.images.get(image_id)
617          else:
618            raise docker.errors.BuildError(last_event or 'Unknown', logs)
619
620          # Run tests if requested, and dump output
621          # Could be improved by backgrounding, but would need better
622          # multiprocessing support to track failures properly.
623          if FLAGS.run_tests_path:
624            if not tag_def['tests']:
625              eprint('>>> No tests to run.')
626            for test in tag_def['tests']:
627              eprint('>> Testing {}...'.format(test))
628              container, = dock.containers.run(
629                  image,
630                  '/tests/' + test,
631                  working_dir='/',
632                  log_config={'type': 'journald'},
633                  detach=True,
634                  stderr=True,
635                  stdout=True,
636                  volumes={
637                      FLAGS.run_tests_path: {
638                          'bind': '/tests',
639                          'mode': 'ro'
640                      }
641                  },
642                  runtime=tag_def['test_runtime']),
643              ret = container.wait()
644              code = ret['StatusCode']
645              out = container.logs(stdout=True, stderr=False)
646              err = container.logs(stdout=False, stderr=True)
647              container.remove()
648              if out:
649                eprint('>>> Output stdout:')
650                eprint(out.decode('utf-8'))
651              else:
652                eprint('>>> No test standard out.')
653              if err:
654                eprint('>>> Output stderr:')
655                eprint(err.decode('utf-8'))
656              else:
657                eprint('>>> No test standard err.')
658              if code != 0:
659                eprint('>> {} failed tests with status: "{}"'.format(
660                    repo_tag, code))
661                failed_tags.append(tag)
662                tag_failed = True
663                if FLAGS.stop_on_failure:
664                  eprint('>> ABORTING due to --stop_on_failure!')
665                  exit(1)
666              else:
667                eprint('>> Tests look good!')
668
669        except docker.errors.BuildError as e:
670          eprint('>> {} failed to build with message: "{}"'.format(
671              repo_tag, e.msg))
672          eprint('>> Build logs follow:')
673          log_lines = [l.get('stream', '') for l in e.build_log]
674          eprint(''.join(log_lines))
675          failed_tags.append(tag)
676          tag_failed = True
677          if FLAGS.stop_on_failure:
678            eprint('>> ABORTING due to --stop_on_failure!')
679            exit(1)
680
681        # Clean temporary dockerfiles if they were created earlier
682        if not FLAGS.keep_temp_dockerfiles:
683          os.remove(dockerfile)
684
685      # Upload new images to DockerHub as long as they built + passed tests
686      if FLAGS.upload_to_hub:
687        if not tag_def['upload_images']:
688          continue
689        if tag_failed:
690          continue
691
692        eprint('>> Uploading to {}:{}'.format(FLAGS.hub_repository, tag))
693        if not FLAGS.dry_run:
694          p = multiprocessing.Process(
695              target=upload_in_background,
696              args=(FLAGS.hub_repository, dock, image, tag))
697          p.start()
698
699      if not tag_failed:
700        succeeded_tags.append(tag)
701
702  if failed_tags:
703    eprint(
704        '> Some tags failed to build or failed testing, check scrollback for '
705        'errors: {}'.format(','.join(failed_tags)))
706    exit(1)
707
708  eprint('> Writing built{} tags to standard out.'.format(
709      ' and tested' if FLAGS.run_tests_path else ''))
710  for tag in succeeded_tags:
711    print('{}:{}'.format(FLAGS.repository, tag))
712
713
714if __name__ == '__main__':
715  app.run(main)
716