Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ test = [
"Pympler==1.1",
"scipy==1.17.1",
# HTTP server and client for mock server fixture
"aiohttp==3.14.0",
"aiohttp==3.14.1",
# Plotting for benchmark sweep mode
"matplotlib==3.10.8",
# Property-based testing (CLI fuzz)
Expand Down
13 changes: 8 additions & 5 deletions src/inference_endpoint/commands/benchmark/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ def _load_datasets(
acc_cfg.accuracy_config.extras or {},
)
)
ds.load(
api_type=config.endpoint_config.api_type, model_params=config.model_params
)
# Per-dataset max_new_tokens override (falls back to global model_params).
acc_model_params = acc_cfg.get_model_params(config.model_params)
ds.load(api_type=config.endpoint_config.api_type, model_params=acc_model_params)
logger.info(f"Loaded {ds} - {ds.num_samples()} samples")

if not accuracy_cfgs:
Expand All @@ -321,9 +321,12 @@ def _load_datasets(
raise InputValidationError("Multiple performance datasets not supported")

try:
dataloader = DataLoaderFactory.create_loader(performance_cfgs[0])
perf_cfg = performance_cfgs[0]
# Per-dataset max_new_tokens override (falls back to global model_params).
perf_model_params = perf_cfg.get_model_params(config.model_params)
dataloader = DataLoaderFactory.create_loader(perf_cfg)
dataloader.load(
api_type=config.endpoint_config.api_type, model_params=config.model_params
api_type=config.endpoint_config.api_type, model_params=perf_model_params
)
logger.info(f"Loaded {dataloader.num_samples()} samples")
except FileNotFoundError as e:
Expand Down
22 changes: 22 additions & 0 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ class Dataset(BaseModel):
] = None
format: str | None = Field(None, description="Dataset format (auto-detected)")
samples: int | None = Field(None, gt=0, description="Number of samples to use")
max_new_tokens: int | None = Field(
None,
gt=0,
description=(
"Per-dataset override of model_params.max_new_tokens (sent as the "
"per-request max_tokens). Lets a performance dataset use a small cap "
"(to avoid server-side KV over-reservation/overload at high concurrency) "
"while accuracy datasets use a larger cap (to avoid truncating long "
"reasoning output). Falls back to model_params.max_new_tokens when unset."
),
)
Comment thread
roborluo marked this conversation as resolved.
eval_method: EvalMethod | None = Field(
None, description="Accuracy evaluation method"
)
Expand All @@ -322,6 +333,17 @@ def _auto_derive_name(self) -> Self:
object.__setattr__(self, "name", Path(self.path).stem)
return self

def get_model_params(self, model_params: ModelParams) -> ModelParams:
"""Apply this dataset's per-dataset max_new_tokens override.

Returns ``model_params`` unchanged when the dataset does not set a
max_new_tokens; otherwise returns a copy with max_new_tokens replaced
by the per-dataset value.
"""
if self.max_new_tokens is None:
return model_params
return model_params.model_copy(update={"max_new_tokens": self.max_new_tokens})


class AccuracyConfig(BaseModel):
"""Accuracy configuration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model_params:
presence_penalty: null # Presence penalty
frequency_penalty: null # Frequency penalty
chat_template_kwargs: null # Per-request chat-template kwargs forwarded to compatible servers.
max_new_tokens: 1024 # Max output tokens
max_new_tokens: 1024
osl_distribution: null # Output sequence length distribution
streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off
tokenizer_name: null # HF repo ID or local path for the tokenizer. Overrides model name for client-side token metrics (ISL/OSL/TPOT).
Expand All @@ -22,6 +22,7 @@ datasets: # Dataset configs
path: '<DATASET_PATH eg: tests/assets/datasets/dummy_1k.jsonl>' # Dataset file path
format: null # Dataset format (auto-detected)
samples: null # Number of samples to use
max_new_tokens: null
eval_method: null
parser: # Column remapping: {prompt: <col>, system: <col>}
prompt: text_input
Expand All @@ -32,6 +33,7 @@ datasets: # Dataset configs
path: '<DATASET_PATH eg: tests/assets/datasets/ds_samples.jsonl>' # Dataset file path
format: null # Dataset format (auto-detected)
samples: null # Number of samples to use
max_new_tokens: null
eval_method: exact_match # Accuracy evaluation method | options: exact_match, contains, judge
parser: # Column remapping: {prompt: <col>, system: <col>}
prompt: question
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model_params:
presence_penalty: null # Presence penalty
frequency_penalty: null # Frequency penalty
chat_template_kwargs: null # Per-request chat-template kwargs forwarded to compatible servers.
max_new_tokens: 1024 # Max output tokens
max_new_tokens: 1024
osl_distribution: null # Output sequence length distribution
streaming: 'off' # Streaming mode: auto/on/off | options: auto, on, off
tokenizer_name: null # HF repo ID or local path for the tokenizer. Overrides model name for client-side token metrics (ISL/OSL/TPOT).
Expand All @@ -22,6 +22,7 @@ datasets: # Dataset configs
path: '<DATASET_PATH eg: tests/assets/datasets/dummy_1k.jsonl>' # Dataset file path
format: null # Dataset format (auto-detected)
samples: null # Number of samples to use
max_new_tokens: null
eval_method: null
parser: # Column remapping: {prompt: <col>, system: <col>}
prompt: text_input
Expand All @@ -32,6 +33,7 @@ datasets: # Dataset configs
path: '<DATASET_PATH eg: tests/assets/datasets/ds_samples.jsonl>' # Dataset file path
format: null # Dataset format (auto-detected)
samples: null # Number of samples to use
max_new_tokens: null
eval_method: exact_match # Accuracy evaluation method | options: exact_match, contains, judge
parser: # Column remapping: {prompt: <col>, system: <col>}
prompt: question
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model_params:
presence_penalty: null # Presence penalty
frequency_penalty: null # Frequency penalty
chat_template_kwargs: null # Per-request chat-template kwargs forwarded to compatible servers.
max_new_tokens: 1024 # Max output tokens
max_new_tokens: 1024
osl_distribution: null # Output sequence length distribution
streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off
tokenizer_name: null # HF repo ID or local path for the tokenizer. Overrides model name for client-side token metrics (ISL/OSL/TPOT).
Expand All @@ -22,6 +22,7 @@ datasets: # Dataset configs
path: '<DATASET_PATH eg: tests/assets/datasets/dummy_1k.jsonl>' # Dataset file path
format: null # Dataset format (auto-detected)
samples: null # Number of samples to use
max_new_tokens: null
eval_method: null
parser: # Column remapping: {prompt: <col>, system: <col>}
prompt: text_input
Expand All @@ -32,6 +33,7 @@ datasets: # Dataset configs
path: '<DATASET_PATH eg: tests/assets/datasets/ds_samples.jsonl>' # Dataset file path
format: null # Dataset format (auto-detected)
samples: null # Number of samples to use
max_new_tokens: null
eval_method: exact_match # Accuracy evaluation method | options: exact_match, contains, judge
parser: # Column remapping: {prompt: <col>, system: <col>}
prompt: question
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/config/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,59 @@ def test_auto_derive_name(self):
ds = Dataset(path="datasets/my_data.jsonl")
assert ds.name == "my_data"

@pytest.mark.unit
def test_max_new_tokens_defaults_none(self):
ds = Dataset(name="perf", type=DatasetType.PERFORMANCE, path="data.jsonl")
assert ds.max_new_tokens is None

@pytest.mark.unit
def test_per_dataset_max_new_tokens_override(self):
ds = Dataset(
name="aime25",
type=DatasetType.ACCURACY,
path="aime25.jsonl",
eval_method=EvalMethod.EXACT_MATCH,
max_new_tokens=32768,
)
assert ds.max_new_tokens == 32768

@pytest.mark.unit
def test_max_new_tokens_rejects_non_positive(self):
with pytest.raises(ValueError, match="greater than 0"):
Dataset(
name="perf",
type=DatasetType.PERFORMANCE,
path="data.jsonl",
max_new_tokens=0,
)
Comment thread
roborluo marked this conversation as resolved.

@pytest.mark.unit
def test_get_model_params_falls_back_when_unset(self):
ds = Dataset(name="perf", type=DatasetType.PERFORMANCE, path="data.jsonl")
base = ModelParams(name="m", max_new_tokens=1024)
result = ds.get_model_params(base)
# No per-dataset override -> returns the global params unchanged.
assert result is base
assert result.max_new_tokens == 1024

@pytest.mark.unit
def test_get_model_params_applies_override(self):
ds = Dataset(
name="aime25",
type=DatasetType.ACCURACY,
path="aime25.jsonl",
eval_method=EvalMethod.EXACT_MATCH,
max_new_tokens=32768,
)
base = ModelParams(name="m", max_new_tokens=1024, temperature=0.7)
result = ds.get_model_params(base)
# Override replaces max_new_tokens but leaves other params and the
# original (frozen) instance untouched.
assert result is not base
assert result.max_new_tokens == 32768
assert result.temperature == 0.7
assert base.max_new_tokens == 1024


class TestBenchmarkConfig:
@pytest.mark.unit
Expand Down
Loading
Loading