""" :ref:`UCX`_ based communications for distributed. See :ref:`communications` for more. .. _UCX: https://github.com/openucx/ucx """ from __future__ import annotations import functools import logging import os import struct import weakref from collections.abc import Awaitable, Callable, Collection from typing import TYPE_CHECKING, Any from unittest.mock import patch import dask from dask.utils import parse_bytes from distributed.comm.addressing import parse_host_port, unparse_host_port from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend, backends from distributed.comm.utils import ensure_concrete_host, from_frames, to_frames from distributed.diagnostics.nvml import ( CudaDeviceInfo, get_device_index_and_uuid, has_cuda_context, ) from distributed.protocol.utils import host_array from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes logger = logging.getLogger(__name__) # In order to avoid double init when forking/spawning new processes (multiprocess), # we make sure only to import and initialize UCX once at first use. This is also # required to ensure Dask configuration gets propagated to UCX, which needs # variables to be set before being imported. if TYPE_CHECKING: try: import ucp except ImportError: pass else: ucp = None device_array = None pre_existing_cuda_context = False cuda_context_created = False _warning_suffix = ( "This is often the result of a CUDA-enabled library calling a CUDA runtime function before " "Dask-CUDA can spawn worker processes. Please make sure any such function calls don't happen " "at import time or in the global scope of a program." ) def _get_device_and_uuid_str(device_info: CudaDeviceInfo) -> str: return f"{device_info.device_index} ({str(device_info.uuid)})" def _warn_existing_cuda_context(device_info: CudaDeviceInfo, pid: int) -> None: device_uuid_str = _get_device_and_uuid_str(device_info) logger.warning( f"A CUDA context for device {device_uuid_str} already exists " f"on process ID {pid}. {_warning_suffix}" ) def _warn_cuda_context_wrong_device( device_info_expected: CudaDeviceInfo, device_info_actual: CudaDeviceInfo, pid: int ) -> None: expected_device_uuid_str = _get_device_and_uuid_str(device_info_expected) actual_device_uuid_str = _get_device_and_uuid_str(device_info_actual) logger.warning( f"Worker with process ID {pid} should have a CUDA context assigned to device " f"{expected_device_uuid_str}, but instead the CUDA context is on device " f"{actual_device_uuid_str}. {_warning_suffix}" ) def synchronize_stream(stream=0): import numba.cuda ctx = numba.cuda.current_context() cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) stream = numba.cuda.driver.Stream(ctx, cu_stream, None) stream.synchronize() def init_once(): global ucp, device_array global ucx_create_endpoint, ucx_create_listener global pre_existing_cuda_context, cuda_context_created if ucp is not None: return # remove/process dask.ucx flags for valid ucx options ucx_config, ucx_environment = _prepare_ucx_config() # We ensure the CUDA context is created before initializing UCX. This can't # be safely handled externally because communications in Dask start before # preload scripts run. # Precedence: # 1. external environment # 2. ucx_config (high level settings passed to ucp.init) # 3. ucx_environment (low level settings equivalent to environment variables) ucx_tls = os.environ.get( "UCX_TLS", ucx_config.get("TLS", ucx_environment.get("UCX_TLS", "")), ) if ( dask.config.get("distributed.comm.ucx.create-cuda-context") is True # This is not foolproof, if UCX_TLS=all we might require CUDA # depending on configuration of UCX, but this is better than # nothing or ("cuda" in ucx_tls and "^cuda" not in ucx_tls) ): try: import numba.cuda except ImportError: raise ImportError( "CUDA support with UCX requires Numba for context management" ) cuda_visible_device = get_device_index_and_uuid( os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] ) pre_existing_cuda_context = has_cuda_context() if pre_existing_cuda_context.has_context: _warn_existing_cuda_context( pre_existing_cuda_context.device_info, os.getpid() ) numba.cuda.current_context() cuda_context_created = has_cuda_context() if ( cuda_context_created.has_context and cuda_context_created.device_info.uuid != cuda_visible_device.uuid ): _warn_cuda_context_wrong_device( cuda_visible_device, cuda_context_created.device_info, os.getpid() ) import ucp as _ucp ucp = _ucp with patch.dict(os.environ, ucx_environment): # We carefully ensure that ucx_environment only contains things # that don't override ucx_config or existing slots in the # environment, so the user's external environment can safely # override things here. ucp.init(options=ucx_config, env_takes_precedence=True) pool_size_str = dask.config.get("distributed.rmm.pool-size") # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: import rmm def device_array(n): return rmm.DeviceBuffer(size=n) if pool_size_str is not None: pool_size = parse_bytes(pool_size_str) rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=pool_size ) except ImportError: try: import numba.cuda def numba_device_array(n): a = numba.cuda.device_array((n,), dtype="u1") weakref.finalize(a, numba.cuda.current_context) return a device_array = numba_device_array except ImportError: def device_array(n): raise RuntimeError( "In order to send/recv CUDA arrays, Numba or RMM is required" ) if pool_size_str is not None: logger.warning( "Initial RMM pool size defined, but RMM is not available. " "Please consider installing RMM or removing the pool size option." ) def _close_comm(ref): """Callback to close Dask Comm when UCX Endpoint closes or errors Parameters ---------- ref: weak reference to a Dask UCX comm """ comm = ref() if comm is not None: comm._closed = True class UCX(Comm): """Comm object using UCP. Parameters ---------- ep : ucp.Endpoint The UCP endpoint. address : str The address, prefixed with `ucx://` to use. deserialize : bool, default True Whether to deserialize data in :meth:`distributed.protocol.loads` Notes ----- The read-write cycle uses the following pattern: Each msg is serialized into a number of "data" frames. We prepend these real frames with two additional frames 1. is_gpus: Boolean indicator for whether the frame should be received into GPU memory. Packed in '?' format. Unpack with ``?`` format. 2. frame_size : Unsigned int describing the size of frame (in bytes) to receive. Packed in 'Q' format, so a length-0 frame is equivalent to an unsized frame. Unpacked with ``Q``. The expected read cycle is 1. Read the frame describing if connection is closing and number of frames 2. Read the frame describing whether each data frame is gpu-bound 3. Read the frame describing whether each data frame is sized 4. Read all the data frames. """ def __init__( # type: ignore[no-untyped-def] self, ep, local_addr: str, peer_addr: str, deserialize: bool = True ): super().__init__(deserialize=deserialize) self._ep = ep if local_addr: assert local_addr.startswith("ucx") assert peer_addr.startswith("ucx") self._local_addr = local_addr self._peer_addr = peer_addr self.comm_flag = None # When the UCX endpoint closes or errors the registered callback # is called. if hasattr(self._ep, "set_close_callback"): ref = weakref.ref(self) self._ep.set_close_callback(functools.partial(_close_comm, ref)) self._closed = False self._has_close_callback = True else: self._has_close_callback = False logger.debug("UCX.__init__ %s", self) @property def local_address(self) -> str: return self._local_addr @property def peer_address(self) -> str: return self._peer_addr @property def same_host(self) -> bool: """Unlike in TCP, local_address can be blank""" return super().same_host if self._local_addr else False @log_errors async def write( self, msg: dict, serializers: Collection[str] | None = None, on_error: str = "message", ) -> int: if self.closed(): raise CommClosedError("Endpoint is closed -- unable to send message") if serializers is None: serializers = ("cuda", "dask", "pickle", "error") # msg can also be a list of dicts when sending batched messages frames = await to_frames( msg, serializers=serializers, on_error=on_error, allow_offload=self.allow_offload, ) nframes = len(frames) cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames) sizes = tuple(nbytes(f) for f in frames) cuda_send_frames, send_frames = zip( *( (is_cuda, each_frame) for is_cuda, each_frame in zip(cuda_frames, frames) if nbytes(each_frame) > 0 ) ) try: # Send meta data # Send close flag and number of frames (_Bool, int64) await self.ep.send(struct.pack("?Q", False, nframes)) # Send which frames are CUDA (bool) and # how large each frame is (uint64) await self.ep.send( struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) ) # Send frames # It is necessary to first synchronize the default stream before start # sending We synchronize the default stream because UCX is not # stream-ordered and syncing the default stream will wait for other # non-blocking CUDA streams. Note this is only sufficient if the memory # being sent is not currently in use on non-blocking CUDA streams. if any(cuda_send_frames): synchronize_stream(0) for each_frame in send_frames: await self.ep.send(each_frame) return sum(sizes) except ucp.exceptions.UCXBaseException: self.abort() raise CommClosedError("While writing, the connection was closed") @log_errors async def read(self, deserializers=("cuda", "dask", "pickle", "error")): if deserializers is None: deserializers = ("cuda", "dask", "pickle", "error") try: # Recv meta data # Recv close flag and number of frames (_Bool, int64) msg = host_array(struct.calcsize("?Q")) await self.ep.recv(msg) (shutdown, nframes) = struct.unpack("?Q", msg) if shutdown: # The writer is closing the connection raise CommClosedError("Connection closed by writer") # Recv which frames are CUDA (bool) and # how large each frame is (uint64) header_fmt = nframes * "?" + nframes * "Q" header = host_array(struct.calcsize(header_fmt)) await self.ep.recv(header) header = struct.unpack(header_fmt, header) cuda_frames, sizes = header[:nframes], header[nframes:] except BaseException as e: # In addition to UCX exceptions, may be CancelledError or another # "low-level" exception. The only safe thing to do is to abort. # (See also https://github.com/dask/distributed/pull/6574). self.abort() raise CommClosedError( f"Connection closed by writer.\nInner exception: {e!r}" ) else: # Recv frames frames = [ device_array(each_size) if is_cuda else host_array(each_size) for is_cuda, each_size in zip(cuda_frames, sizes) ] cuda_recv_frames, recv_frames = zip( *( (is_cuda, each_frame) for is_cuda, each_frame in zip(cuda_frames, frames) if nbytes(each_frame) > 0 ) ) # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated if any(cuda_recv_frames): synchronize_stream(0) try: for each_frame in recv_frames: await self.ep.recv(each_frame) except BaseException as e: # In addition to UCX exceptions, may be CancelledError or another # "low-level" exception. The only safe thing to do is to abort. # (See also https://github.com/dask/distributed/pull/6574). self.abort() raise CommClosedError( f"Connection closed by writer.\nInner exception: {e!r}" ) try: msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, ) except EOFError: # Frames possibly garbled or truncated by communication error self.abort() raise CommClosedError("Aborted stream on truncated data") return msg async def close(self): self._closed = True if self._ep is not None: try: await self.ep.send(struct.pack("?Q", True, 0)) except ( # noqa: B030 ucp.exceptions.UCXError, ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled, ) + (getattr(ucp.exceptions, "UCXConnectionReset", ()),): # If the other end is in the process of closing, # UCX will sometimes raise a `Input/output` error, # which we can ignore. pass self.abort() self._ep = None def abort(self): self._closed = True if self._ep is not None: self._ep.abort() self._ep = None @property def ep(self): if self._ep is not None: return self._ep else: raise CommClosedError("UCX Endpoint is closed") def closed(self): if self._has_close_callback is True: # The self._closed flag is separate from the endpoint's lifetime, even when # the endpoint has closed or errored, there may be messages on its buffer # still to be received, even though sending is not possible anymore. return self._closed else: return self._ep is None class UCXConnector(Connector): prefix = "ucx://" comm_class = UCX encrypted = False async def connect( self, address: str, deserialize: bool = True, **connection_args: Any ) -> UCX: logger.debug("UCXConnector.connect: %s", address) ip, port = parse_host_port(address) init_once() try: ep = await ucp.create_endpoint(ip, port) except ucp.exceptions.UCXBaseException: raise CommClosedError("Connection closed before handshake completed") return self.comm_class( ep, local_addr="", peer_addr=self.prefix + address, deserialize=deserialize, ) class UCXListener(BaseListener): prefix = UCXConnector.prefix comm_class = UCXConnector.comm_class encrypted = UCXConnector.encrypted def __init__( self, address: str, comm_handler: Callable[[UCX], Awaitable[None]] | None = None, deserialize: bool = False, allow_offload: bool = True, **connection_args: Any, ): super().__init__() if not address.startswith("ucx"): address = "ucx://" + address self.ip, self._input_port = parse_host_port(address, default_port=0) self.comm_handler = comm_handler self.deserialize = deserialize self.allow_offload = allow_offload self._ep = None # type: ucp.Endpoint self.ucp_server = None self.connection_args = connection_args @property def port(self): return self.ucp_server.port @property def address(self): return "ucx://" + self.ip + ":" + str(self.port) async def start(self): async def serve_forever(client_ep): ucx = UCX( client_ep, local_addr=self.address, peer_addr=self.address, deserialize=self.deserialize, ) ucx.allow_offload = self.allow_offload try: await self.on_connection(ucx) except CommClosedError: logger.debug("Connection closed before handshake completed") return if self.comm_handler: await self.comm_handler(ucx) init_once() self.ucp_server = ucp.create_listener(serve_forever, port=self._input_port) def stop(self): self.ucp_server = None def get_host_port(self): # TODO: TCP raises if this hasn't started yet. return self.ip, self.port @property def listen_address(self): return self.prefix + unparse_host_port(*self.get_host_port()) @property def contact_address(self): host, port = self.get_host_port() host = ensure_concrete_host(host) # TODO: ensure_concrete_host return self.prefix + unparse_host_port(host, port) @property def bound_address(self): # TODO: Does this become part of the base API? Kinda hazy, since # we exclude in for inproc. return self.get_host_port() class UCXBackend(Backend): # I / O def get_connector(self): return UCXConnector() def get_listener(self, loc, handle_comm, deserialize, **connection_args): return UCXListener(loc, handle_comm, deserialize, **connection_args) # Address handling # This duplicates BaseTCPBackend def get_address_host(self, loc): return parse_host_port(loc)[0] def get_address_host_port(self, loc): return parse_host_port(loc) def resolve_address(self, loc): host, port = parse_host_port(loc) return unparse_host_port(ensure_ip(host), port) def get_local_address_for(self, loc): host, port = parse_host_port(loc) host = ensure_ip(host) if ":" in host: local_host = get_ipv6(host) else: local_host = get_ip(host) return unparse_host_port(local_host, None) backends["ucx"] = UCXBackend() def _prepare_ucx_config(): """Translate dask config options to appropriate UCX config options Returns ------- tuple Options suitable for passing to ``ucp.init`` and additional UCX options that will be inserted directly into the environment while calling ``ucp.init``. """ # configuration of UCX can happen in two ways: # 1) high level on/off flags which correspond to UCX configuration # 2) explicitly defined UCX configuration flags in distributed.comm.ucx.environment # High-level settings in (1) are preferred to settings in (2) # Settings in the external environment override both high_level_options = {} # if any of the high level flags are set, as long as they are not Null/None, # we assume we should configure basic TLS settings for UCX, otherwise we # leave UCX to its default configuration if any( [ dask.config.get("distributed.comm.ucx.tcp"), dask.config.get("distributed.comm.ucx.nvlink"), dask.config.get("distributed.comm.ucx.infiniband"), ] ): if dask.config.get("distributed.comm.ucx.rdmacm"): tls = "tcp" tls_priority = "rdmacm" else: tls = "tcp" tls_priority = "tcp" # CUDA COPY can optionally be used with ucx -- we rely on the user # to define when messages will include CUDA objects. Note: # defining only the Infiniband flag will not enable cuda_copy if any( [ dask.config.get("distributed.comm.ucx.nvlink"), dask.config.get("distributed.comm.ucx.cuda-copy"), ] ): tls = tls + ",cuda_copy" if dask.config.get("distributed.comm.ucx.infiniband"): tls = "rc," + tls if dask.config.get("distributed.comm.ucx.nvlink"): tls = tls + ",cuda_ipc" high_level_options = {"TLS": tls, "SOCKADDR_TLS_PRIORITY": tls_priority} # Pick up any other ucx environment settings environment_options = {} for k, v in dask.config.get("distributed.comm.ucx.environment", {}).items(): # {"some-name": value} is translated to {"UCX_SOME_NAME": value} key = "_".join(map(str.upper, ("UCX", *k.split("-")))) if (hl_key := key[4:]) in high_level_options: logger.warning( f"Ignoring {k}={v} ({key=}) in ucx.environment, " f"preferring {hl_key}={high_level_options[hl_key]} " "from high level options" ) elif key in os.environ: # This is only info because setting UCX configuration via # environment variables is a reasonably common approach logger.info( f"Ignoring {k}={v} ({key=}) in ucx.environment, " f"preferring {key}={os.environ[key]} from external environment" ) else: environment_options[key] = v return high_level_options, environment_options