from __future__ import annotations import asyncio import ctypes import errno import functools import logging import socket import ssl import struct import sys import weakref from ssl import SSLCertVerificationError, SSLError from typing import Any, ClassVar from tlz import sliding_window from tornado import gen, netutil from tornado.iostream import IOStream, StreamClosedError from tornado.tcpclient import TCPClient from tornado.tcpserver import TCPServer import dask from dask.utils import parse_timedelta from distributed.comm.addressing import parse_host_port, unparse_host_port from distributed.comm.core import ( BaseListener, Comm, CommClosedError, Connector, FatalCommClosedError, ) from distributed.comm.registry import Backend from distributed.comm.utils import ( ensure_concrete_host, from_frames, get_tcp_server_address, to_frames, ) from distributed.protocol.utils import host_array, pack_frames_prelude, unpack_frames from distributed.system import MEMORY_LIMIT from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes logger = logging.getLogger(__name__) # We must not load more than this into a buffer at a time # It's currently unclear why that is # see # - https://github.com/dask/distributed/pull/5854 # - https://bugs.python.org/issue42853 # - https://github.com/dask/distributed/pull/8507 C_INT_MAX = 256 ** ctypes.sizeof(ctypes.c_int) // 2 - 1 MAX_BUFFER_SIZE = MEMORY_LIMIT / 2 def set_tcp_timeout(comm): """ Set kernel-level TCP timeout on the stream. """ if comm.closed(): return timeout = dask.config.get("distributed.comm.timeouts.tcp") timeout = int(parse_timedelta(timeout, default="seconds")) sock = comm.socket # Default (unsettable) value on Windows # https://msdn.microsoft.com/en-us/library/windows/desktop/dd877220(v=vs.85).aspx nprobes = 10 assert timeout >= nprobes + 1, "Timeout too low" idle = max(2, timeout // 4) interval = max(1, (timeout - idle) // nprobes) idle = timeout - interval * nprobes assert idle > 0 try: if sys.platform.startswith("win"): logger.debug("Setting TCP keepalive: idle=%d, interval=%d", idle, interval) sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle * 1000, interval * 1000)) else: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) try: TCP_KEEPIDLE = socket.TCP_KEEPIDLE TCP_KEEPINTVL = socket.TCP_KEEPINTVL TCP_KEEPCNT = socket.TCP_KEEPCNT except AttributeError: if sys.platform == "darwin": TCP_KEEPIDLE = 0x10 # (named "TCP_KEEPALIVE" in C) TCP_KEEPINTVL = 0x101 TCP_KEEPCNT = 0x102 else: TCP_KEEPIDLE = None if TCP_KEEPIDLE is not None: logger.debug( "Setting TCP keepalive: nprobes=%d, idle=%d, interval=%d", nprobes, idle, interval, ) sock.setsockopt(socket.SOL_TCP, TCP_KEEPCNT, nprobes) sock.setsockopt(socket.SOL_TCP, TCP_KEEPIDLE, idle) sock.setsockopt(socket.SOL_TCP, TCP_KEEPINTVL, interval) if sys.platform.startswith("linux"): logger.debug("Setting TCP user timeout: %d ms", timeout * 1000) TCP_USER_TIMEOUT = 18 # since Linux 2.6.37 sock.setsockopt(socket.SOL_TCP, TCP_USER_TIMEOUT, timeout * 1000) except OSError: logger.exception("Could not set timeout on TCP stream.") def get_stream_address(comm): """ Get a stream's local address. """ # raise OSError in case the comm is closed, s.t. # retry code can handle it appropriately; see also # https://github.com/dask/distributed/issues/7953 if comm.closed(): raise CommClosedError() return unparse_host_port(*comm.socket.getsockname()[:2]) def convert_stream_closed_error(obj, exc): """ Re-raise StreamClosedError as CommClosedError. """ if exc.real_error is not None: # The stream was closed because of an underlying OS error exc = exc.real_error if isinstance(exc, ssl.SSLError): if exc.reason and "UNKNOWN_CA" in exc.reason: raise FatalCommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") raise CommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") from exc else: raise CommClosedError(f"in {obj}: {exc}") from exc def _close_comm(ref): """Callback to close Dask Comm when Tornado Stream closes Parameters ---------- ref: weak reference to a Dask comm """ comm = ref() if comm: comm._closed = True class TCP(Comm): """ An established communication based on an underlying Tornado IOStream. """ max_shard_size: ClassVar[int] = dask.utils.parse_bytes( dask.config.get("distributed.comm.shard") ) stream: IOStream | None def __init__( self, stream: IOStream, local_addr: str, peer_addr: str, deserialize: bool = True, ): self._closed = False super().__init__(deserialize=deserialize) self._local_addr = local_addr self._peer_addr = peer_addr self.stream = stream self._finalizer = weakref.finalize(self, self._get_finalizer()) self._finalizer.atexit = False self._extra: dict = {} ref = weakref.ref(self) stream.set_close_callback(functools.partial(_close_comm, ref)) stream.set_nodelay(True) set_tcp_timeout(stream) self._read_extra() def _read_extra(self): pass def _get_finalizer(self): r = repr(self) def finalize(stream=self.stream, r=r): # stream is None if a StreamClosedError is raised during interpreter # shutdown if stream is not None and not stream.closed(): logger.warning(f"Closing dangling stream in {r}") stream.close() return finalize @property def local_address(self) -> str: return self._local_addr @property def peer_address(self) -> str: return self._peer_addr async def read(self, deserializers=None): stream = self.stream if stream is None: raise CommClosedError() fmt = "Q" fmt_size = struct.calcsize(fmt) try: # Don't store multiple numpy or parquet buffers into the same buffer, or # none will be released until all are released. frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size) (frames_nosplit_nbytes,) = struct.unpack(fmt, frames_nosplit_nbytes_bin) frames_nosplit = await read_bytes_rw(stream, frames_nosplit_nbytes) frames, buffers_nbytes = unpack_frames(frames_nosplit, partial=True) for buffer_nbytes in buffers_nbytes: buffer = await read_bytes_rw(stream, buffer_nbytes) frames.append(buffer) except StreamClosedError as e: self.stream = None self._closed = True convert_stream_closed_error(self, e) except BaseException: # Some OSError, CancelledError or another "low-level" exception. # We do not really know what was already read from the underlying # socket, so it is not even safe to retry here using the same stream. # The only safe thing to do is to abort. # (See also GitHub #4133, #6548). self.abort() raise else: try: msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, ) except EOFError: # Frames possibly garbled or truncated by communication error self.abort() raise CommClosedError("aborted stream on truncated data") return msg async def write(self, msg, serializers=None, on_error="message"): stream = self.stream if stream is None: raise CommClosedError() frames = await to_frames( msg, allow_offload=self.allow_offload, serializers=serializers, on_error=on_error, context={ "sender": self.local_info, "recipient": self.remote_info, **self.handshake_options, }, frame_split_size=self.max_shard_size, ) frames, frames_nbytes, frames_nbytes_total = _add_frames_header(frames) try: # trick to enqueue all frames for writing beforehand for each_frame_nbytes, each_frame in zip(frames_nbytes, frames): if each_frame_nbytes: # Make sure that `len(data) == data.nbytes` # See each_frame = ensure_memoryview(each_frame) for i, j in sliding_window( 2, range( 0, each_frame_nbytes + C_INT_MAX, C_INT_MAX, ), ): chunk = each_frame[i:j] chunk_nbytes = chunk.nbytes if stream._write_buffer is None: raise StreamClosedError() stream._write_buffer.append(chunk) stream._total_write_index += chunk_nbytes # start writing frames stream.write(b"") except StreamClosedError as e: self.stream = None self._closed = True convert_stream_closed_error(self, e) except BaseException: # Some OSError or a another "low-level" exception. We do not really know # what was already written to the underlying socket, so it is not even safe # to retry here using the same stream. The only safe thing to do is to # abort. (See also GitHub #4133). # In case of, for instance, KeyboardInterrupts or other # BaseExceptions that could be handled further upstream, we equally # want to discard this comm self.abort() raise return frames_nbytes_total @gen.coroutine def close(self): # We use gen.coroutine here rather than async def to avoid errors like # Task was destroyed but it is pending! # Triggered by distributed.deploy.tests.test_local::test_silent_startup stream, self.stream = self.stream, None self._closed = True if stream is not None and not stream.closed(): try: # Flush the stream's write buffer by waiting for a last write. if stream.writing(): yield stream.write(b"") stream.socket.shutdown(socket.SHUT_RDWR) except OSError: pass finally: self._finalizer.detach() stream.close() def abort(self) -> None: stream, self.stream = self.stream, None self._closed = True if stream is not None and not stream.closed(): self._finalizer.detach() stream.close() def closed(self) -> bool: return self._closed @property def extra_info(self): return self._extra async def read_bytes_rw(stream: IOStream, n: int) -> memoryview: """Read n bytes from stream. Unlike stream.read_bytes, allow for very large messages and return a writeable buffer. """ buf = host_array(n) for i, j in sliding_window( 2, range(0, n + C_INT_MAX, C_INT_MAX), ): chunk = buf[i:j] actual = await stream.read_into(chunk) # type: ignore[arg-type] assert actual == chunk.nbytes return buf def _add_frames_header( frames: list[bytes | memoryview], ) -> tuple[list[bytes | memoryview], list[int], int]: """ """ frames_nbytes = [nbytes(f) for f in frames] frames_nbytes_total = sum(frames_nbytes) # Calculate the number of bytes that are inclusive of: # - prelude # - msgpack header # - simple pickle bytes # - compressed buffers # - first uncompressed buffer (possibly sharded), IFF the pickle bytes are # negligible in size # # All these can be fetched by read() into a single buffer with a single call to # Tornado, because they will be dereferenced soon after they are deserialized. # Read uncompressed numpy/parquet buffers, which will survive indefinitely past # the end of read(), into their own host arrays so that their memory can be # released independently. frames_nbytes_nosplit = 0 first_uncompressed_buffer: object = None for frame, nb in zip(frames, frames_nbytes): buffer = frame.obj if isinstance(frame, memoryview) else frame if not isinstance(buffer, bytes): # Uncompressed buffer; it will be referenced by the unpickled object if first_uncompressed_buffer is None: if frames_nbytes_nosplit > max(2048, nb * 0.05): # Don't extend the lifespan of non-trivial amounts of pickled bytes # to that of the buffers break first_uncompressed_buffer = buffer elif first_uncompressed_buffer is not buffer: # don't split sharded frame # Always store 2+ separate numpy/parquet objects onto separate # buffers break frames_nbytes_nosplit += nb header = pack_frames_prelude(frames) header = struct.pack("Q", nbytes(header) + frames_nbytes_nosplit) + header header_nbytes = nbytes(header) frames = [header, *frames] frames_nbytes = [header_nbytes, *frames_nbytes] frames_nbytes_total += header_nbytes if frames_nbytes_total < 2**17 or ( # 128 kiB total frames_nbytes_total < 2**25 # 32 MiB total and frames_nbytes_total // len(frames) < 2**15 # 32 kiB mean ): # very small or very fragmented; send in one go frames = [b"".join(frames)] frames_nbytes = [frames_nbytes_total] return frames, frames_nbytes, frames_nbytes_total class TLS(TCP): """ A TLS-specific version of TCP. """ # Workaround for OpenSSL 1.0.2 (can drop with OpenSSL 1.1.1) max_shard_size = min(C_INT_MAX, TCP.max_shard_size) def _read_extra(self): TCP._read_extra(self) sock = self.stream.socket if sock is not None: self._extra.update(peercert=sock.getpeercert(), cipher=sock.cipher()) cipher, proto, bits = self._extra["cipher"] logger.debug( "TLS connection with %r: protocol=%s, cipher=%s, bits=%d", self._peer_addr, proto, cipher, bits, ) def _expect_tls_context(connection_args): ctx = connection_args.get("ssl_context") if not isinstance(ctx, ssl.SSLContext): raise TypeError( "TLS expects a `ssl_context` argument of type " "ssl.SSLContext (perhaps check your TLS configuration?)" f" Instead got {ctx!r}" ) return ctx class RequireEncryptionMixin: def _check_encryption(self, address, connection_args): if not self.encrypted and connection_args.get("require_encryption"): # XXX Should we have a dedicated SecurityError class? raise RuntimeError( "encryption required by Dask configuration, " "refusing communication from/to %r" % (self.prefix + address,) ) _NUMERIC_ONLY = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV async def _getaddrinfo(host, port, *, family, type=socket.SOCK_STREAM): # If host and port are numeric, then getaddrinfo doesn't block and we # can skip get_running_loop().getaddrinfo which is implemented by # running in a ThreadPoolExecutor. So we try first with the # _NUMERIC_ONLY flags set, and then only use the threadpool if that # fails with EAI_NONAME: try: return socket.getaddrinfo( host, port, family=family, type=type, flags=_NUMERIC_ONLY, ) except socket.gaierror as e: if e.errno != socket.EAI_NONAME: raise # That failed; it's a real hostname. We better use a thread. return await asyncio.get_running_loop().getaddrinfo( host, port, family=family, type=socket.SOCK_STREAM ) class _DefaultLoopResolver(netutil.Resolver): """ Resolver implementation using `asyncio.loop.getaddrinfo`. backport from Tornado 6.2+ https://github.com/tornadoweb/tornado/blob/3de78b7a15ba7134917a18b0755ea24d7f8fde94/tornado/netutil.py#L416-L432 With an additional optimization based on https://github.com/python-trio/trio/blob/4edfd41bd5519a2e626e87f6c6ca9fb32b90a6f4/trio/_socket.py#L125-L192 (Copyright Contributors to the Trio project.) And proposed to cpython in https://github.com/python/cpython/pull/31497/ """ async def resolve( self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC ) -> list[tuple[int, Any]]: # On Solaris, getaddrinfo fails if the given port is not found # in /etc/services and no socket type is given, so we must pass # one here. The socket type used here doesn't seem to actually # matter (we discard the one we get back in the results), # so the addresses we return should still be usable with SOCK_DGRAM. return [ (fam, address) for fam, _, _, _, address in await _getaddrinfo( host, port, family=family, type=socket.SOCK_STREAM ) ] class BaseTCPConnector(Connector, RequireEncryptionMixin): client: ClassVar[TCPClient] = TCPClient(resolver=_DefaultLoopResolver()) async def connect(self, address, deserialize=True, **connection_args): self._check_encryption(address, connection_args) ip, port = parse_host_port(address) kwargs = self._get_connect_args(**connection_args) try: # server_hostname option (for SNI) only works with tornado.iostream.IOStream if "server_hostname" in kwargs: stream = await self.client.connect( ip, port, max_buffer_size=MAX_BUFFER_SIZE ) stream = await stream.start_tls(False, **kwargs) else: stream = await self.client.connect( ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs ) # Under certain circumstances tornado will have a closed connection with an # error and not raise a StreamClosedError. # # This occurs with tornado 5.x and openssl 1.1+ if stream.closed() and stream.error: raise StreamClosedError(stream.error) except StreamClosedError as e: # The socket connect() call failed convert_stream_closed_error(self, e) except SSLCertVerificationError as err: raise FatalCommClosedError( "TLS certificate does not match. Check your security settings. " "More info at https://distributed.dask.org/en/latest/tls.html" ) from err except SSLError as err: raise FatalCommClosedError() from err local_address = self.prefix + get_stream_address(stream) comm = self.comm_class( stream, local_address, self.prefix + address, deserialize ) return comm class TCPConnector(BaseTCPConnector): prefix = "tcp://" comm_class = TCP encrypted = False def _get_connect_args(self, **connection_args): return {} class TLSConnector(BaseTCPConnector): prefix = "tls://" comm_class = TLS encrypted = True def _get_connect_args(self, **connection_args): tls_args = {"ssl_options": _expect_tls_context(connection_args)} if connection_args.get("server_hostname"): tls_args["server_hostname"] = connection_args["server_hostname"] return tls_args class BaseTCPListener(BaseListener, RequireEncryptionMixin): def __init__( self, address, comm_handler, deserialize=True, allow_offload=True, default_host=None, default_port=0, **connection_args, ): super().__init__() self._check_encryption(address, connection_args) self.ip, self.port = parse_host_port(address, default_port) self.default_host = default_host self.comm_handler = comm_handler self.deserialize = deserialize self.allow_offload = allow_offload self.server_args = self._get_server_args(**connection_args) self.tcp_server = None self.bound_address = None async def start(self): self.tcp_server = TCPServer(max_buffer_size=MAX_BUFFER_SIZE, **self.server_args) self.tcp_server.handle_stream = self._handle_stream backlog = int(dask.config.get("distributed.comm.socket-backlog")) for _ in range(5): try: # When shuffling data between workers, there can # really be O(cluster size) connection requests # on a single worker socket, make sure the backlog # is large enough not to lose any. sockets = netutil.bind_sockets( self.port, address=self.ip, backlog=backlog ) except OSError as e: # EADDRINUSE can happen sporadically when trying to bind # to an ephemeral port if self.port != 0 or e.errno != errno.EADDRINUSE: raise exc = e else: self.tcp_server.add_sockets(sockets) break else: raise exc self.get_host_port() # trigger assignment to self.bound_address def stop(self): tcp_server, self.tcp_server = self.tcp_server, None if tcp_server is not None: tcp_server.stop() def _check_started(self): if self.tcp_server is None: raise ValueError("invalid operation on non-started TCPListener") async def _handle_stream(self, stream, address): address = self.prefix + unparse_host_port(*address[:2]) stream = await self._prepare_stream(stream, address) if stream is None: # Preparation failed return logger.debug("Incoming connection from %r to %r", address, self.contact_address) local_address = self.prefix + get_stream_address(stream) comm = self.comm_class(stream, local_address, address, self.deserialize) comm.allow_offload = self.allow_offload try: await self.on_connection(comm) except CommClosedError: logger.info("Connection from %s closed before handshake completed", address) return await self.comm_handler(comm) def get_host_port(self): """ The listening address as a (host, port) tuple. """ self._check_started() if self.bound_address is None: self.bound_address = get_tcp_server_address(self.tcp_server) # IPv6 getsockname() can return more a 4-len tuple return self.bound_address[:2] @property def listen_address(self): """ The listening address as a string. """ return self.prefix + unparse_host_port(*self.get_host_port()) @property def contact_address(self): """ The contact address as a string. """ host, port = self.get_host_port() host = ensure_concrete_host(host, default_host=self.default_host) return self.prefix + unparse_host_port(host, port) class TCPListener(BaseTCPListener): prefix = "tcp://" comm_class = TCP encrypted = False def _get_server_args(self, **connection_args): return {} async def _prepare_stream(self, stream, address): return stream class TLSListener(BaseTCPListener): prefix = "tls://" comm_class = TLS encrypted = True def _get_server_args(self, **connection_args): ctx = _expect_tls_context(connection_args) return {"ssl_options": ctx} async def _prepare_stream(self, stream, address): try: await stream.wait_for_handshake() except OSError as e: # The handshake went wrong, log and ignore logger.warning( "Listener on %r: TLS handshake failed with remote %r: %s", self.listen_address, address, getattr(e, "real_error", None) or e, ) else: return stream class BaseTCPBackend(Backend): # I/O def get_connector(self): return self._connector_class() def get_listener(self, loc, handle_comm, deserialize, **connection_args): return self._listener_class(loc, handle_comm, deserialize, **connection_args) # Address handling def get_address_host(self, loc): return parse_host_port(loc)[0] def get_address_host_port(self, loc): return parse_host_port(loc) def resolve_address(self, loc): host, port = parse_host_port(loc) return unparse_host_port(ensure_ip(host), port) def get_local_address_for(self, loc): host, port = parse_host_port(loc) host = ensure_ip(host) if ":" in host: local_host = get_ipv6(host) else: local_host = get_ip(host) return unparse_host_port(local_host, None) class TCPBackend(BaseTCPBackend): _connector_class = TCPConnector _listener_class = TCPListener class TLSBackend(BaseTCPBackend): _connector_class = TLSConnector _listener_class = TLSListener