Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions mlpstorage_py/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,25 @@ def _restore_signal_handlers(self):
self._original_handlers = {}


def _mpi_params_contain_flag(params: Optional[List[str]], flag: str) -> bool:
"""Return True if ``params`` already specifies the MPI ``flag``.

Matches any of the surface forms a user may pass through ``--mpi-params``:
``--flag``, ``-flag``, ``--flag=value``, ``-flag=value``. ``flag`` is the
bare name without leading dashes (e.g. ``"bind-to"``).
"""
if not params:
return False
candidates = {f"--{flag}", f"-{flag}"}
for tok in params:
if not isinstance(tok, str):
continue
head = tok.split("=", 1)[0]
if head in candidates:
return True
return False


def generate_mpi_prefix_cmd(
mpi_cmd: str,
hosts: List[str],
Expand Down Expand Up @@ -567,13 +586,22 @@ def generate_mpi_prefix_cmd(
host_part = host.split(':')[0] if ':' in host else host
unique_hosts.add(host_part)

if len(unique_hosts) > 1:
# Multi-host: prioritize even distribution across nodes
prefix += " --bind-to none --map-by node"
is_multi_host = len(unique_hosts) > 1

# OpenMPI rejects duplicate --bind-to / --map-by occurrences, so suppress
# the default for whichever of those flags the user supplied via --mpi-params.
user_set_bind_to = _mpi_params_contain_flag(params, "bind-to")
user_set_map_by = _mpi_params_contain_flag(params, "map-by")

if not user_set_bind_to:
prefix += " --bind-to none"
if not user_set_map_by:
prefix += " --map-by node" if is_multi_host else " --map-by socket"

if is_multi_host:
logger.info("MPI BTL transport: auto (multi-host run; transport managed by network fabric)")
else:
# Single-host: optimize for NUMA domains
prefix += " --bind-to none --map-by socket"
if mpi_btl == "vader":
prefix += " --mca btl vader,self"
logger.info("MPI BTL transport: vader (POSIX shared-memory)")
Expand Down
105 changes: 105 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,111 @@ def test_processes_per_node_position_before_bind_to(self, mock_logger):
bindto_idx = result.index('--bind-to')
assert npernode_idx < bindto_idx

def test_user_bind_to_in_mpi_params_suppresses_default(self, mock_logger):
"""User-supplied --bind-to in --mpi-params suppresses the default --bind-to none.

OpenMPI rejects duplicate occurrences of --bind-to, so emitting both the
default and the user value would break the command.
"""
result = generate_mpi_prefix_cmd(
mpi_cmd=MPIRUN,
hosts=['host1'],
num_processes=4,
oversubscribe=False,
allow_run_as_root=False,
params=['--bind-to', 'core'],
logger=mock_logger,
)
assert '--bind-to none' not in result
assert '--bind-to core' in result
# --map-by default still applied since user didn't override it
assert '--map-by socket' in result

def test_user_map_by_in_mpi_params_suppresses_default(self, mock_logger):
"""User-supplied --map-by in --mpi-params suppresses the default --map-by."""
result = generate_mpi_prefix_cmd(
mpi_cmd=MPIRUN,
hosts=['host1'],
num_processes=4,
oversubscribe=False,
allow_run_as_root=False,
params=['--map-by', 'core'],
logger=mock_logger,
)
assert '--map-by socket' not in result
assert '--map-by node' not in result
assert '--map-by core' in result
# --bind-to default still applied since user didn't override it
assert '--bind-to none' in result

def test_user_override_both_bind_and_map(self, mock_logger):
"""User overriding both --bind-to and --map-by suppresses both defaults."""
result = generate_mpi_prefix_cmd(
mpi_cmd=MPIRUN,
hosts=['host1', 'host2'],
num_processes=8,
oversubscribe=False,
allow_run_as_root=False,
params=['--bind-to', 'socket', '--map-by', 'core'],
logger=mock_logger,
)
# No defaults emitted
assert '--bind-to none' not in result
assert '--map-by node' not in result
# Only the user values present (each --bind-to / --map-by appears once)
assert result.count('--bind-to') == 1
assert result.count('--map-by') == 1
assert '--bind-to socket' in result
assert '--map-by core' in result

def test_user_override_with_equals_form(self, mock_logger):
"""The --flag=value token form also suppresses defaults."""
result = generate_mpi_prefix_cmd(
mpi_cmd=MPIRUN,
hosts=['host1'],
num_processes=4,
oversubscribe=False,
allow_run_as_root=False,
params=['--bind-to=core', '--map-by=numa'],
logger=mock_logger,
)
assert '--bind-to none' not in result
assert '--map-by socket' not in result
assert result.count('--bind-to') == 1
assert result.count('--map-by') == 1

def test_user_override_with_single_dash_form(self, mock_logger):
"""Single-dash ``-bind-to`` / ``-map-by`` forms also suppress defaults."""
result = generate_mpi_prefix_cmd(
mpi_cmd=MPIRUN,
hosts=['host1'],
num_processes=4,
oversubscribe=False,
allow_run_as_root=False,
params=['-bind-to', 'core', '-map-by', 'socket'],
logger=mock_logger,
)
assert '--bind-to none' not in result
assert '--map-by socket' not in result.replace('-map-by socket', '')
# Two map-by tokens would be a bug; we only emit the user's one
# (which uses single-dash form here)
assert result.count('--map-by') == 0
assert result.count('--bind-to') == 0

def test_unrelated_mpi_params_do_not_suppress_defaults(self, mock_logger):
"""Unrelated --mpi-params flags leave defaults intact."""
result = generate_mpi_prefix_cmd(
mpi_cmd=MPIRUN,
hosts=['host1'],
num_processes=4,
oversubscribe=False,
allow_run_as_root=False,
params=['--mca', 'btl', 'tcp,self', '-x', 'FOO=bar'],
logger=mock_logger,
)
assert '--bind-to none' in result
assert '--map-by socket' in result


class TestCommandExecutor:
"""Tests for CommandExecutor class."""
Expand Down
Loading