xref: /aosp_15_r20/external/autotest/server/hosts/ssh_multiplex.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1# Lint as: python2, python3
2# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6from __future__ import absolute_import
7from __future__ import division
8from __future__ import print_function
9
10import logging
11import multiprocessing
12import os
13import threading
14
15from autotest_lib.client.common_lib import autotemp
16from autotest_lib.server import utils
17import six
18
19# TODO b:169251326 terms below are set outside of this codebase
20# and should be updated when possible. ("master" -> "main")
21_MAIN_SSH_COMMAND_TEMPLATE = (
22        '/usr/bin/ssh -a -x -N '
23        '-o ControlMaster=yes '  # Create multiplex socket. # nocheck
24        '-o ControlPath=%(socket)s '
25        '-o StrictHostKeyChecking=no '
26        '-o UserKnownHostsFile=/dev/null '
27        '-o BatchMode=yes '
28        '-o ConnectTimeout=30 '
29        '-o ServerAliveInterval=30 '
30        '-o ServerAliveCountMax=1 '
31        '-o ConnectionAttempts=1 '
32        '-o Protocol=2 '
33        '-l %(user)s %(port)s %(hostname)s')
34
35
36class MainSsh(object):
37    """Manages multiplex ssh connection."""
38
39    def __init__(self, hostname, user, port):
40        self._hostname = hostname
41        self._user = user
42        self._port = port
43
44        self._main_job = None
45        self._main_tempdir = None
46
47        self._lock = multiprocessing.Lock()
48
49    def __del__(self):
50        self.close()
51
52    @property
53    def _socket_path(self):
54        return os.path.join(self._main_tempdir.name, 'socket')
55
56    @property
57    def ssh_option(self):
58        """Returns the ssh option to use this multiplexed ssh.
59
60        If background process is not running, returns an empty string.
61        """
62        if not self._main_tempdir:
63            return ''
64        return '-o ControlPath=%s' % (self._socket_path,)
65
66    def maybe_start(self, timeout=5):
67        """Starts the background process to run multiplex ssh connection.
68
69        If there already is a background process running, this does nothing.
70        If there is a stale process or a stale socket, first clean them up,
71        then create a background process.
72
73        @param timeout: timeout in seconds (default 5) to wait for main ssh
74                        connection to be established. If timeout is reached, a
75                        warning message is logged, but no other action is
76                        taken.
77        """
78        # Multiple processes might try in parallel to clean up the old main
79        # ssh connection and create a new one, therefore use a lock to protect
80        # against race conditions.
81        with self._lock:
82            # If a previously started main SSH connection is not running
83            # anymore, it needs to be cleaned up and then restarted.
84            if (self._main_job and (not os.path.exists(self._socket_path) or
85                                      self._main_job.sp.poll() is not None)):
86                logging.info(
87                        'Main-ssh connection to %s is down.', self._hostname)
88                self._close_internal()
89
90            # Start a new main SSH connection.
91            if not self._main_job:
92                # Create a shared socket in a temp location.
93                self._main_tempdir = autotemp.tempdir(dir=_short_tmpdir())
94
95                # Start the main SSH connection in the background.
96                main_cmd = _MAIN_SSH_COMMAND_TEMPLATE % {
97                        'hostname': self._hostname,
98                        'user': self._user,
99                        'port': "-p %s" % self._port if self._port else "",
100                        'socket': self._socket_path,
101                }
102                logging.info(
103                    'Starting main-ssh connection \'%s\'', main_cmd)
104                self._main_job = utils.BgJob(
105                    main_cmd, nickname='main-ssh',
106                    stdout_tee=utils.DEVNULL, stderr_tee=utils.DEVNULL,
107                    unjoinable=True)
108
109                # To prevent a race between the main ssh connection
110                # startup and its first attempted use, wait for socket file to
111                # exist before returning.
112                try:
113                    utils.poll_for_condition(
114                            condition=lambda: os.path.exists(self._socket_path),
115                            timeout=timeout,
116                            sleep_interval=0.2,
117                            desc='main-ssh connection up')
118                except utils.TimeoutError:
119                    # poll_for_conditional already logs an error upon timeout
120                    pass
121
122
123    def close(self):
124        """Releases all resources used by multiplexed ssh connection."""
125        with self._lock:
126            self._close_internal()
127
128    def _close_internal(self):
129        # Assume that when this is called, _lock should be acquired, already.
130        if self._main_job:
131            logging.debug('Nuking ssh main_job')
132            utils.nuke_subprocess(self._main_job.sp)
133            self._main_job = None
134
135        if self._main_tempdir:
136            logging.debug('Cleaning ssh main_tempdir')
137            self._main_tempdir.clean()
138            self._main_tempdir = None
139
140
141class ConnectionPool(object):
142    """Holds SSH multiplex connection instance."""
143
144    def __init__(self):
145        self._pool = {}
146        self._lock = threading.Lock()
147
148    def get(self, hostname, user, port):
149        """Returns MainSsh instance for the given endpoint.
150
151        If the pool holds the instance already, returns it. If not, create the
152        instance, and returns it.
153
154        Caller has the responsibility to call maybe_start() before using it.
155
156        @param hostname: Host name of the endpoint.
157        @param user: User name to log in.
158        @param port: Port number sshd is listening.
159        """
160        key = (hostname, user, port)
161        logging.debug('Get main ssh connection for %s@%s%s', user, hostname,
162                      ":%s" % port if port else "")
163
164        with self._lock:
165            conn = self._pool.get(key)
166            if not conn:
167                conn = MainSsh(hostname, user, port)
168                self._pool[key] = conn
169            return conn
170
171    def shutdown(self):
172        """Closes all ssh multiplex connections."""
173        for ssh in six.itervalues(self._pool):
174            ssh.close()
175
176
177def _short_tmpdir():
178    # crbug/865171 Unix domain socket paths are limited to 108 characters.
179    # crbug/945523 Swarming does not like too many top-level directories in
180    # /tmp.
181    # So use a shared parent directory in /tmp
182    user = os.environ.get("USER", "no_USER")[:8]
183    d = '/tmp/ssh-main_%s' % user
184    if not os.path.exists(d):
185        os.mkdir(d)
186    return d
187