Skip to content

proxy

Flight SSH proxy functionality.

FlightProxy

FlightProxy(url, backend_url, netloc_map, **kwargs)

Bases: FlightServerBase

Transparent Flight proxy that rewrites endpoint URIs based on mapping.

Rewrites FlightInfo endpoint location URIs on GetFlightInfo responses; all other RPC methods are forwarded to the backend unchanged. This is useful for proxying a remote Flight server through SSH port forwards.

Parameters:

Name Type Description Default
url str

URL to listen for queries.

required
backend_url str

URL of remote info server to proxy.

required
netloc_map dict[str, str]

Dictionary mapping of endpoint locations to be replaced. It should be keyed on OLD_HOST:PORT with value as NEW_HOST:PORT. The URL scheme will be preserved.

required
Source code in arrakis/proxy.py
75
76
77
78
79
80
81
82
83
84
85
86
def __init__(
    self,
    url: str,
    backend_url: str,
    netloc_map: dict[str, str],
    **kwargs,
):
    super().__init__(url, **kwargs)
    self.netloc_map = netloc_map
    self.client = flight.connect(backend_url)
    logger.debug("flight proxy server initialized: %s -> %s", url, backend_url)
    logger.debug("netloc_map: %s", self.netloc_map)

SSHConnection

SSHConnection(host)

Manage a background SSH connection

Parameters:

Name Type Description Default
host str

Remote SSH host.

required
Source code in arrakis/proxy.py
133
134
135
136
137
138
def __init__(self, host: str):
    self.host = host
    self.ctrl_dir = pathlib.Path(tempfile.mkdtemp())
    self.ctrl_path = self.ctrl_dir / self.host
    self.ctrl = str(self.ctrl_path)
    self._netloc_map: dict[str, str] = {}

netloc_map property

netloc_map

Dictionary of port forwards

REMOTE_NETLOC: LOCAL_NETLOC

close

close()

Close the connection

Source code in arrakis/proxy.py
226
227
228
229
def close(self):
    """Close the connection"""
    self.exec(["-O", "exit"], check=False, capture_output=True)
    self.ctrl_dir.rmdir()

connect

connect()

connect to the ssh host

Source code in arrakis/proxy.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def connect(self):
    """connect to the ssh host"""
    cmd = [
        "ssh",
        "-S",
        self.ctrl,
        "-M",
        "-o",
        "ControlPersist=yes",
        "-f",
        # "-N",
        self.host,
        "sleep",
        "60",
    ]
    logger.debug(" ".join(cmd))
    subprocess.run(cmd, check=True)  # noqa S603

exec

exec(ssh_cmd=None, shell_cmd=None, **kwargs)

exec ssh command on control master

Source code in arrakis/proxy.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def exec(
    self,
    ssh_cmd: list[str] | None = None,
    shell_cmd: list[str] | None = None,
    **kwargs,
):
    """exec ssh command on control master"""
    cmd = [
        "ssh",
        "-S",
        self.ctrl,
        "-o",
        "ControlMaster=no",
    ]
    if ssh_cmd:
        cmd += ssh_cmd
    cmd += [self.host]
    if shell_cmd:
        cmd += shell_cmd
    logger.debug(" ".join(cmd))
    return subprocess.run(  # noqa S603
        cmd,
        **kwargs,
    )

forward_port

forward_port(remote_netloc, local_port=None, *, wait=False)

initiate a local port forward to host

Source code in arrakis/proxy.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def forward_port(
    self, remote_netloc: str, local_port: int | None = None, *, wait: bool = False
):
    """initiate a local port forward to host"""
    if not local_port:
        # hacky way to find an unused port. the acquired port
        # could be stolen between when the socket releases it and
        # ssh tries to take it.
        s = socket.socket()
        s.bind(("localhost", 0))
        _, local_port = s.getsockname()
        s.close()

    local_netloc = f"localhost:{local_port}"

    forward = f"{local_netloc}:{remote_netloc}"
    logger.debug("forwarding port: %s", forward)
    self.exec(["-O", "forward", "-L", forward], check=True)

    if wait:
        # wait for forward to be established
        s = socket.socket()
        while True:
            try:
                s.connect(("localhost", local_port))
                s.close()
                break
            except ConnectionRefusedError:
                time.sleep(0.01)
                continue

    self.netloc_map[remote_netloc] = local_netloc
    return local_netloc

ssh_proxy

ssh_proxy(ssh_host, arrakis_server=None, server_netloc='localhost:31206')

Create Flight proxy server over SSH.

This is done by:

  1. Make SSH connection to the remote host.
  2. Determine initial Flight info server URL on the remote side.
  3. Retrieve all known endpoints from the info server.
  4. Set up local ssh port forwards to all endpoints.
  5. Launch Flight proxy server that rewrites endpoints to point to the local port forwards.

This should always be used as a context manager so that all connections are closed when done.

Parameters:

Name Type Description Default
ssh_host str

Remote SSH host.

required
arrakis_server str | None

Remote Arrakis info server. If not specified it will be determined from the ARRAKIS_SERVER env var on the remote side.

None
server_netloc str

Flight proxy server HOST:PORT. Defaults to "localhost:31206".

'localhost:31206'
Source code in arrakis/proxy.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@contextlib.contextmanager
def ssh_proxy(
    ssh_host: str,
    arrakis_server: str | None = None,
    server_netloc: str = "localhost:31206",
):
    """Create Flight proxy server over SSH.

    This is done by:

    0. Make SSH connection to the remote host.
    1. Determine initial Flight info server URL on the remote side.
    2. Retrieve all known endpoints from the info server.
    3. Set up local ssh port forwards to all endpoints.
    4. Launch Flight proxy server that rewrites endpoints to point to
       the local port forwards.

    This should always be used as a context manager so that all
    connections are closed when done.

    Parameters
    ----------
    ssh_host : str
        Remote SSH host.
    arrakis_server : str | None
        Remote Arrakis info server.  If not specified it will be
        determined from the ARRAKIS_SERVER env var on the remote side.
    server_netloc : str
        Flight proxy server HOST:PORT.  Defaults to "localhost:31206".

    """
    logger.info("creating ssh proxy via %s", ssh_host)

    # check/hold the requested server port
    server_socket = socket.socket()
    server_host, server_port = server_netloc.split(":")
    try:
        server_socket.bind((server_host, int(server_port)))
    except OSError:
        msg = f"server address already in use: {server_netloc}"
        raise OSError(msg) from None

    # initiate the ssh connection to the host as a context manager, so
    # that the connection is properly shut down if there are any
    # errors during setup.
    with SSHConnection(ssh_host) as ssh:
        # if not specified, try to determine the remote server location
        # from the ARRAKIS_SERVER env var on the remote host
        if arrakis_server is None:
            logger.debug("resolving remote ARRAKIS_SERVER...")
            arrakis_server = (
                ssh.exec(
                    shell_cmd=["printenv", "ARRAKIS_SERVER"],
                    capture_output=True,
                    check=True,
                )
                .stdout.decode()
                .strip()
            )
            if not arrakis_server:
                msg = "Could not determine remote ARRAKIS_SERVER."
                raise ValueError(msg)

        parsed = urlparse(arrakis_server, scheme="grpc")

        # create an initial forward to the info server so that we can
        # query for the endpoint information
        backend_netloc = ssh.forward_port(parsed.netloc)

        logger.debug("retrieving endpoints...")
        backend_url = f"grpc://{backend_netloc}"
        endpoints = Client(backend_url).endpoints()

        # setup forwards for all endpoints
        logger.debug("creating endpoint port forwarding...")
        for endpoint in endpoints:
            remote_netloc = urlparse(endpoint).netloc
            ssh.forward_port(remote_netloc)

        # start the proxy server
        logger.debug("starting Flight proxy server...")
        server_url = f"grpc://{server_netloc}"
        server_socket.close()
        server = FlightProxy(server_url, backend_url, ssh.netloc_map)

        def _run_until_shutdown(server):
            server.serve()

        executor = ThreadPoolExecutor(max_workers=1)
        executor.submit(_run_until_shutdown, server)

        # yield context manager
        try:
            yield server_url
        finally:
            logger.info("shutting down ssh proxy...")
            server.shutdown()
            executor.shutdown()
            logger.info("done.")