diff --git a/README.md b/README.md index 8465c81..86c937b 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ The repository profile is the source of truth for reproducing experiments end-to ## End-to-End Pipeline ```text -Stage 1 preprocess metadata + zscores +Stage 1 normalize dataset inputs (PV1/CWP/BKP) to a shared training contract Stage 2 generate ESM-2 per-residue embeddings Stage 3 build residue-level label shards Stage 4 train FFNN (seeded or ensemble-kfold, DDP-aware) @@ -60,7 +60,91 @@ Stage 7 evaluate residue metrics (+ optional Cocci peptide compare) ## Stage Reference -### Stage 1: Preprocess +### Stage 1: Multi-Dataset Prepare (PV1/CWP/BKP) + +**CLI:** `pepseqpred-prepare-dataset` (`src/pepseqpred/apps/prepare_dataset_cli.py`) + +This stage is the recommended entrypoint when training on one or more of: + +- PV1 (human virome) +- CWP/Cocci (fungal) +- BKP (bacterial) + +It normalizes source-specific metadata and FASTA headers into a shared PV1-compatible contract so downstream embedding, label generation, and training CLIs can be reused unchanged. + +**Core module** + +- `src/pepseqpred/core/preprocess/preparedataset.py` + +**Required output contract per dataset** + +- `prepared_targets.fasta` +- `prepared_labels_metadata.tsv` +- `prepared_embedding_metadata.tsv` +- `prepare_summary.json` + +**PV1 inputs and command** + +- metadata TSV +- z-score TSV +- protein FASTA + +```bash +pepseqpred-prepare-dataset \ + localdata/PV1/PV1_meta_2020-11-23_cleaned.tsv \ + localdata/PV1/prepared \ + --dataset-kind pv1 \ + --protein-fasta localdata/PV1/PV1_targets.fasta \ + --z-file localdata/PV1/PV1_zscores.tsv +``` + +**CWP/Cocci inputs and command** + +- metadata TSV +- protein FASTA +- reactive code list TSV +- non-reactive code list TSV + +```bash +pepseqpred-prepare-dataset \ + localdata/Cocci/CWP_metadata.tsv \ + localdata/Cocci/prepared \ + --dataset-kind cwp \ + --protein-fasta localdata/Cocci/CWP_targets.faa \ + --reactive-codes localdata/Cocci/CWP_reactive_Z20N4.tsv \ + --nonreactive-codes localdata/Cocci/CWP_nonreactive_Z20N4.tsv +``` + +**BKP inputs and command** + +- metadata TSV +- protein FASTA +- reactive code list TSV +- non-reactive code list TSV + +```bash +pepseqpred-prepare-dataset \ + localdata/BKP/BKP_metadata.tsv \ + localdata/BKP/prepared \ + --dataset-kind bkp \ + --protein-fasta localdata/BKP/BKP.faa \ + --reactive-codes localdata/BKP/BKP_reactive_Z20N4.tsv \ + --nonreactive-codes localdata/BKP/BKP_nonreactive_Z20N4.tsv +``` + +**Dataset-specific grouping used for leakage-aware splitting (`--split-type id-family`)** + +- PV1: family from PV1 `OXX` +- CWP/Cocci: `Cluster50ID` mapped to deterministic numeric IDs +- BKP: `reClusterID_70` mapped to deterministic numeric IDs + +**Next stages after prepare** + +- run `pepseqpred-esm` with `--embedding-key-mode id-family` and each dataset's `prepared_embedding_metadata.tsv` +- run `pepseqpred-labels` with `--embedding-key-delim -` +- train with `--split-type id-family` + +### Stage 1 (Legacy): PV1 Z-Score Preprocess **CLI:** `pepseqpred-preprocess` (`src/pepseqpred/apps/preprocess_cli.py`) @@ -177,6 +261,54 @@ pepseqpred-train-ffnn \ --results-csv localdata/models/ffnn_smoke/runs.csv ``` +**Submit one SLURM training job with multiple datasets (PV1 + CWP + BKP)** + +`scripts/hpc/trainffnn.sh` accepts multiple embedding directories and multiple label shards in one call: + +- all embedding dirs first +- separator `--` +- all label shard `.pt` files after `--` + +```bash +# Example: use per-dataset shard outputs together in one training run +EMB_DIRS=( + /scratch/$USER/esm2/pv1/artifacts/pts/shard_000 + /scratch/$USER/esm2/pv1/artifacts/pts/shard_001 + /scratch/$USER/esm2/pv1/artifacts/pts/shard_002 + /scratch/$USER/esm2/pv1/artifacts/pts/shard_003 + /scratch/$USER/esm2/cwp/artifacts/pts/shard_000 + /scratch/$USER/esm2/cwp/artifacts/pts/shard_001 + /scratch/$USER/esm2/cwp/artifacts/pts/shard_002 + /scratch/$USER/esm2/cwp/artifacts/pts/shard_003 + /scratch/$USER/esm2/bkp/artifacts/pts/shard_000 + /scratch/$USER/esm2/bkp/artifacts/pts/shard_001 + /scratch/$USER/esm2/bkp/artifacts/pts/shard_002 + /scratch/$USER/esm2/bkp/artifacts/pts/shard_003 +) + +LABEL_SHARDS=( + /scratch/$USER/labels/pv1/labels_shard_000.pt + /scratch/$USER/labels/pv1/labels_shard_001.pt + /scratch/$USER/labels/pv1/labels_shard_002.pt + /scratch/$USER/labels/pv1/labels_shard_003.pt + /scratch/$USER/labels/cwp/labels_shard_000.pt + /scratch/$USER/labels/cwp/labels_shard_001.pt + /scratch/$USER/labels/cwp/labels_shard_002.pt + /scratch/$USER/labels/cwp/labels_shard_003.pt + /scratch/$USER/labels/bkp/labels_shard_000.pt + /scratch/$USER/labels/bkp/labels_shard_001.pt + /scratch/$USER/labels/bkp/labels_shard_002.pt + /scratch/$USER/labels/bkp/labels_shard_003.pt +) + +sbatch trainffnn.sh "${EMB_DIRS[@]}" -- "${LABEL_SHARDS[@]}" +``` + +Notes: + +- Keep `SPLIT_TYPE=id-family` for family-aware leakage control across PV1/CWP/BKP. +- Protein IDs should be globally unique across all provided label shards/embedding dirs. + **Outputs** - run checkpoint(s), usually `fully_connected.pt` @@ -337,6 +469,7 @@ Bundled pretrained registry currently includes: | CLI | File | Purpose | | --- | --- | --- | +| `pepseqpred-prepare-dataset` | `apps/prepare_dataset_cli.py` | normalize PV1/CWP/BKP into shared training contract | | `pepseqpred-preprocess` | `apps/preprocess_cli.py` | metadata + z-score preprocessing | | `pepseqpred-esm` | `apps/esm_cli.py` | ESM-2 embedding generation | | `pepseqpred-labels` | `apps/labels_cli.py` | residue label shard generation | diff --git a/docs/extra/pv1_cwp_bkp_merge_split_and_pos_weight.md b/docs/extra/pv1_cwp_bkp_merge_split_and_pos_weight.md new file mode 100644 index 0000000..66cbfd8 --- /dev/null +++ b/docs/extra/pv1_cwp_bkp_merge_split_and_pos_weight.md @@ -0,0 +1,128 @@ +# PV1 + CWP + BKP Multi-Source Training/Validation and Positive Weight + +This document explains how PV1, CWP, and BKP datasets are normalized separately and then used together for model training/validation across one or more DDP ranks. It also explains how the positive class weight is calculated at label build and train time. + +## 1) Per-dataset normalization before multi-source training + +Each source dataset is first normalized independently with `pepseqpred-prepare-dataset` into the same 4-file contract: + +- `prepared_targets.fasta` +- `prepared_labels_metadata.tsv` +- `prepared_embedding_metadata.tsv` +- `prepare_summary.json` + +### PV1 normalization + +- Uses existing PV1 preprocessing (`preprocess_pv1`) to derive `Def epitope`, `Uncertain`, `Not epitope`. +- Parses family from PV1 fullname `OXX` (last comma-delimited token). +- Uses PV1 FASTA to validate protein IDs and alignment bounds. +- Group/family is the parsed PV1 family (numeric). + +### CWP normalization + +- Keeps only `CodeName` rows listed in `--reactive-codes` or `--nonreactive-codes`. +- Label mapping: + - reactive code -> `Def epitope=1`, `Not epitope=0`, `Uncertain=0` + - nonreactive code -> `Def epitope=0`, `Not epitope=1`, `Uncertain=0` +- Uses `SequenceAccession` as normalized `ProteinID`. +- Uses `Cluster50ID` as group token. +- Resolves align columns with fallbacks (`StartIndex/AlignStart/...`, `StopIndex/AlignStop/...`). +- Builds deterministic numeric group IDs from sorted unique `Cluster50ID` values with an offset. + - default offset in CLI: `100000000` + +### BKP normalization + +- Same reactive/nonreactive mapping logic as CWP. +- Uses `SequenceAccession` as normalized `ProteinID`. +- Uses `reClusterID_70` as group token. +- Resolves align columns with BKP-priority fallbacks (`alignStart`, `alignStop`, then others). +- Deterministic numeric group IDs from sorted unique `reClusterID_70` values with an offset. + - default offset in CLI: `200000000` + +### Shared normalized fullname format + +After normalization, all datasets are rewritten to PV1-style fullnames: + +`ID= AC= OXX=0,0,0,` + +So downstream tools can parse a single `ID + family` pattern uniformly. + +## 2) How PV1/CWP/BKP are used together for training/validation + +There is currently no dedicated "multi-source merge" CLI, and the datasets remain separate source datasets. The integration-covered way to use them together is: + +1. Concatenate FASTA records from each dataset's `prepared_targets.fasta` into one combined FASTA. +2. Row-concatenate all `prepared_labels_metadata.tsv` files. +3. Row-concatenate all `prepared_embedding_metadata.tsv` files, then de-duplicate on `["Name", "Family"]`. + +That is the behavior exercised in `tests/integration/test_prepare_dataset_multisource_pipeline.py` and is used to form shared inputs for embedding, label generation, and training. However, generating both embeddings and labels can be done separately, and training can be done by passing in all dataset files as outlined [here](../README.md#stage-4-train-ffnn). + +## 3) How training/validation partitioning is done + +### 3.1 Base protein universe + +`train_ffnn_cli` builds a `ProteinDataset` from provided embedding dirs + label shards, then uses: + +- `protein_ids = intersection(embedding_index IDs, label_index IDs)` + +So split candidates are only proteins that exist in both embeddings and labels. + +### 3.2 `split_type=id-family` (default) + +- Family is parsed from embedding filename stem (`-.pt`). +- If family is missing for an ID, it is treated as a singleton group: + - `__missing_family__:` +- Global split then uses grouped split functions so a family/group cannot appear in both train and validation. + +### Seeded mode (`--train-mode seeded`) + +- Uses `split_ids_grouped(ids, val_frac, split_seed, family_groups)`. +- Target val size is `floor(len(ids) * val_frac)`, but exact fraction can differ because whole groups move together. +- A leakage check is run after splitting and raises if any family overlaps. + +### Ensemble-kfold mode (`--train-mode ensemble-kfold`) + +- `--val-frac` is ignored. +- Uses grouped k-fold (`build_grouped_kfold_splits`) when `split_type=id-family`. +- Groups/families are assigned to folds intact, then each fold is one validation set. +- Leakage check is run per fold. + +### 3.3 `split_type=id` + +- No family grouping; IDs are split directly. +- Seeded mode uses `split_ids`: + - shuffle IDs with `split_seed` + - validation = first `floor(N * val_frac)` IDs + - training = remaining IDs +- Ensemble-kfold mode uses plain `build_kfold_splits`. + +### 3.4 DDP rank partitioning after global train/val split + +For multi-rank training, each run's already-determined global train/val ID lists are partitioned across ranks using `partition_ids_weighted`: + +- greedy load balance by estimated embedding file size +- optional grouping by label-shard path for locality +- train partition enforces non-empty per rank + +This is a rank-level sharding step, not a new train/val split. + +## 4) Positive class weight calculation + +There are two connected pieces: + +1. Label build time (`pepseqpred-labels --calc-pos-weight`) + - Each label shard writes: + - `class_stats.pos_count` + - `class_stats.neg_count` + - `class_stats.pos_weight` + - Formula: + - `pos_weight = neg_count / max(1, pos_count)` + - For 3-column labels `[Def epitope, Uncertain, Not epitope]`, counts only include residues where `Uncertain == 0`. + +2. Train time (`pepseqpred-train-ffnn`) + - If `--pos-weight` is provided, that value is used directly. + - Otherwise it reads `class_stats` from all provided label shards and recomputes: + - `total_neg / max(1, total_pos)` + - That scalar is passed to `BCEWithLogitsLoss(pos_weight=...)`. + +Note: automatic train-time `pos_weight` uses shard-level totals, not run-specific train-only IDs. diff --git a/pyproject.toml b/pyproject.toml index d35c9f5..e685f2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pepseqpred" -version = "1.0.0" +version = "1.1.0" description = "Residue-level epitope prediction pipeline for peptide/protein workflows." readme = "README.pypi.md" requires-python = ">=3.12" @@ -72,6 +72,7 @@ pepseqpred-esm = "pepseqpred.apps.esm_cli:main" pepseqpred-labels = "pepseqpred.apps.labels_cli:main" pepseqpred-predict = "pepseqpred.apps.prediction_cli:main" pepseqpred-preprocess = "pepseqpred.apps.preprocess_cli:main" +pepseqpred-prepare-dataset = "pepseqpred.apps.prepare_dataset_cli:main" pepseqpred-eval-ffnn = "pepseqpred.apps.evaluate_ffnn_cli:main" pepseqpred-train-ffnn = "pepseqpred.apps.train_ffnn_cli:main" pepseqpred-train-ffnn-optuna = "pepseqpred.apps.train_ffnn_optuna_cli:main" diff --git a/src/pepseqpred/apps/prepare_dataset_cli.py b/src/pepseqpred/apps/prepare_dataset_cli.py new file mode 100644 index 0000000..bbba5d4 --- /dev/null +++ b/src/pepseqpred/apps/prepare_dataset_cli.py @@ -0,0 +1,210 @@ +"""prepare_dataset_cli.py + +Normalize PV1/CWP/BKP sources into a shared PV1-compatible training contract. +""" +import argparse +import time +from pathlib import Path +from pepseqpred.core.io.logger import setup_logger +from pepseqpred.core.preprocess.preparedataset import prepare_dataset + + +def main() -> None: + """Parse arguments and run dataset normalization.""" + t0 = time.perf_counter() + parser = argparse.ArgumentParser( + description=( + "Prepare dataset-specific metadata/labels/targets into a PV1-compatible " + "contract for embedding, label generation, and training." + ) + ) + parser.add_argument( + "meta_file", + type=Path, + help="Path to metadata TSV source file." + ) + parser.add_argument( + "output_dir", + type=Path, + help="Directory to write prepared outputs." + ) + parser.add_argument( + "--dataset-kind", + action="store", + dest="dataset_kind", + type=str, + choices=["pv1", "cwp", "bkp"], + required=True, + help="Dataset source kind." + ) + parser.add_argument( + "--protein-fasta", + action="store", + dest="protein_fasta", + type=Path, + required=True, + help="Protein FASTA used to resolve full protein sequences." + ) + parser.add_argument( + "--z-file", + action="store", + dest="z_file", + type=Path, + default=None, + help="PV1 z-score TSV (required when --dataset-kind pv1)." + ) + parser.add_argument( + "--reactive-codes", + action="store", + dest="reactive_codes", + type=Path, + default=None, + help="Reactive code-list TSV (required for cwp/bkp)." + ) + parser.add_argument( + "--nonreactive-codes", + action="store", + dest="nonreactive_codes", + type=Path, + default=None, + help="Non-reactive code-list TSV (required for cwp/bkp)." + ) + parser.add_argument( + "--group-id-offset", + action="store", + dest="group_id_offset", + type=int, + default=None, + help=( + "Optional numeric offset applied to group IDs for cwp/bkp mapping. " + "Defaults: cwp=100000000, bkp=200000000, pv1=0." + ) + ) + + # PV1 threshold configuration (mirrors preprocess CLI defaults) + parser.add_argument( + "--is-epi-z-thresh", + action="store", + dest="is_epi_z_min", + type=float, + default=20.0, + help="Minimum z-score required for peptide to contain epitopes (pv1 only)." + ) + parser.add_argument( + "--is-epi-min-subs", + action="store", + dest="is_epi_min_subs", + type=int, + default=4, + help="Minimum # of subjects at/above z threshold for epitope calls (pv1 only)." + ) + parser.add_argument( + "--not-epi-z-thresh", + action="store", + dest="not_epi_z_max", + type=float, + default=10.0, + help="Maximum z-score for non-epitope calls (pv1 only)." + ) + parser.add_argument( + "--not-epi-max-subs", + action="store", + dest="not_epi_max_subs", + type=int, + default=0, + help=( + "Maximum # subjects below non-epitope z threshold. " + "Use 0 for all subjects (pv1 only)." + ) + ) + parser.add_argument( + "--subject-prefix", + action="store", + dest="subject_prefix", + type=str, + default="VW_", + help="Prefix for subject z-score columns (pv1 only)." + ) + + parser.add_argument( + "--log-level", + action="store", + dest="log_level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level." + ) + parser.add_argument( + "--log-json", + action="store_true", + dest="log_json", + default=False, + help="Emit JSON logs." + ) + args = parser.parse_args() + + logger = setup_logger( + log_dir=None, + log_level=args.log_level, + json_lines=args.log_json, + json_indent=2 if args.log_json else None, + name="prepare_dataset_cli" + ) + + if args.group_id_offset is None: + if args.dataset_kind == "cwp": + group_id_offset = 100_000_000 + elif args.dataset_kind == "bkp": + group_id_offset = 200_000_000 + else: + group_id_offset = 0 + else: + group_id_offset = int(args.group_id_offset) + + not_epitope_max_subjects = args.not_epi_max_subs if int( + args.not_epi_max_subs) != 0 else None + + logger.info("run_start", extra={"extra": { + "dataset_kind": args.dataset_kind, + "meta_file": str(args.meta_file), + "output_dir": str(args.output_dir), + "protein_fasta": str(args.protein_fasta), + "z_file": str(args.z_file) if args.z_file is not None else None, + "reactive_codes": str(args.reactive_codes) if args.reactive_codes is not None else None, + "nonreactive_codes": str(args.nonreactive_codes) if args.nonreactive_codes is not None else None, + "group_id_offset": group_id_offset + }}) + + summary = prepare_dataset( + dataset_kind=args.dataset_kind, + meta_path=args.meta_file, + output_dir=args.output_dir, + protein_fasta=args.protein_fasta, + reactive_codes=args.reactive_codes, + nonreactive_codes=args.nonreactive_codes, + z_path=args.z_file, + is_epitope_z_min=args.is_epi_z_min, + is_epitope_min_subjects=args.is_epi_min_subs, + not_epitope_z_max=args.not_epi_z_max, + not_epitope_max_subjects=not_epitope_max_subjects, + subject_prefix=args.subject_prefix, + group_id_offset=group_id_offset, + logger=logger + ) + + logger.info("run_done", extra={"extra": { + "dataset_kind": args.dataset_kind, + "prepared_targets_fasta": summary.get("prepared_targets_fasta"), + "prepared_labels_metadata_tsv": summary.get("prepared_labels_metadata_tsv"), + "prepared_embedding_metadata_tsv": summary.get("prepared_embedding_metadata_tsv"), + "prepare_summary_json": summary.get("prepare_summary_json"), + "n_targets": summary.get("n_targets"), + "n_label_rows": summary.get("n_label_rows"), + "n_label_proteins": summary.get("n_label_proteins"), + "duration_s": round(time.perf_counter() - t0, 3) + }}) + + +if __name__ == "__main__": + main() diff --git a/src/pepseqpred/core/preprocess/preparedataset.py b/src/pepseqpred/core/preprocess/preparedataset.py new file mode 100644 index 0000000..a87d9ec --- /dev/null +++ b/src/pepseqpred/core/preprocess/preparedataset.py @@ -0,0 +1,749 @@ +"""prepare_dataset.py + +Dataset normalization adapter for multi-source training preparation. + +This module converts PV1/CWP/BKP source inputs into a common PV1-compatible +contract used by existing embedding, label, and training CLIs. +""" +import csv +import json +import logging +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Set +import pandas as pd +from pepseqpred.core.io.keys import parse_fullname +from pepseqpred.core.preprocess.pv1 import preprocess as preprocess_pv1 + + +def _build_fullname(protein_id: str, group_numeric: int) -> str: + """Builds a PV1-style fullname for normalized outputs.""" + return f"ID={protein_id} AC={protein_id} OXX=0,0,0,{int(group_numeric)}" + + +def _clean_str(value: Any) -> str: + """Returns stripped string value with null-like tokens collapsed to empty.""" + if value is None: + return "" + if pd.isna(value): + return "" + text = str(value).strip() + if text == "" or text.lower() == "nan": + return "" + return text + + +def _record_drop( + drop_counts: Dict[str, int], + drop_examples: Dict[str, List[str]], + reason: str, + value: str, + max_examples: int = 20 +) -> None: + """Accumulates drop counts and examples by reason.""" + drop_counts[reason] = int(drop_counts.get(reason, 0)) + 1 + if reason not in drop_examples: + drop_examples[reason] = [] + examples = drop_examples[reason] + if len(examples) < max_examples and value not in examples: + examples.append(value) + + +def _read_code_set(tsv_path: Path | str) -> Set[str]: + """Loads code values from first TSV column.""" + tsv_path = Path(tsv_path) + out: set[str] = set() + with tsv_path.open("r", encoding="utf-8", newline="") as in_f: + reader = csv.reader(in_f, delimiter="\t") + for row in reader: + if len(row) == 0: + continue + value = str(row[0]).strip() + if value == "": + continue + low = value.lower() + if low in {"sequence name", "codename", "code_name"}: + continue + out.add(value) + return out + + +def _header_token_to_accession(token: str) -> str: + """Returns accession key from FASTA first token.""" + token = str(token).strip() + if token == "": + return token + if token.startswith(("tr|", "sp|")): + parts = token.split("|") + if len(parts) >= 2 and parts[1].strip() != "": + return parts[1].strip() + return token + + +def _read_fasta_records(fasta_path: Path | str) -> Iterator[Tuple[str, str]]: + """Yields (header_without_>, sequence).""" + header = None + seq_lines: List[str] = [] + with Path(fasta_path).open("r", encoding="utf-8") as fasta_f: + for raw in fasta_f: + line = raw.strip() + if line == "": + continue + if line.startswith(">"): + if header is not None: + yield header, "".join(seq_lines) + header = line[1:].strip() + seq_lines = [] + else: + seq_lines.append(line) + if header is not None: + yield header, "".join(seq_lines) + + +def _write_fasta_records(out_path: Path | str, records: Iterable[Tuple[str, str]]) -> None: + """Writes FASTA records.""" + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8") as out_f: + for header, seq in records: + out_f.write(f">{header}\n{seq}\n") + + +def _build_nonpv1_fasta_index(fasta_path: Path | str) -> Tuple[Dict[str, str], Dict[str, List[str]]]: + """ + Builds accession -> sequence mapping from generic FASTA headers. + + Returns + ------- + Tuple[Dict[str, str], Dict[str, List[str]]] + - unique accession->sequence map + - ambiguous accession->list_of_distinct_sequences + """ + seqs_by_accession: Dict[str, List[str]] = {} + for header, seq in _read_fasta_records(fasta_path): + token = header.split()[0] if len(header.split()) > 0 else "" + accession = _header_token_to_accession(token) + if accession == "": + continue + lst = seqs_by_accession.setdefault(accession, []) + if seq not in lst: + lst.append(seq) + + unique: Dict[str, str] = {} + ambiguous: Dict[str, List[str]] = {} + for accession, seqs in seqs_by_accession.items(): + if len(seqs) == 1: + unique[accession] = seqs[0] + elif len(seqs) > 1: + ambiguous[accession] = seqs + return unique, ambiguous + + +def _build_pv1_fasta_index(fasta_path: Path | str) -> Tuple[Dict[str, str], Dict[str, List[str]]]: + """ + Builds PV1 protein_id -> sequence mapping from PV1-style FASTA headers. + """ + seqs_by_id: Dict[str, List[str]] = {} + for header, seq in _read_fasta_records(fasta_path): + parsed = parse_fullname(header) + protein_id = str(parsed[0]) + lst = seqs_by_id.setdefault(protein_id, []) + if seq not in lst: + lst.append(seq) + + unique: Dict[str, str] = {} + ambiguous: Dict[str, List[str]] = {} + for protein_id, seqs in seqs_by_id.items(): + if len(seqs) == 1: + unique[protein_id] = seqs[0] + elif len(seqs) > 1: + ambiguous[protein_id] = seqs + return unique, ambiguous + + +def _resolve_nonpv1_align_columns( + df: pd.DataFrame, + dataset_kind: Literal["cwp", "bkp"] +) -> Tuple[List[str], List[str]]: + """Picks start/stop fallback columns for non-PV1 datasets.""" + if dataset_kind == "cwp": + start_candidates = ["StartIndex", "AlignStart", "Start", "alignStart"] + stop_candidates = ["StopIndex", "AlignStop", "Stop", "alignStop"] + else: + start_candidates = ["alignStart", "StartIndex", "AlignStart", "Start"] + stop_candidates = ["alignStop", "StopIndex", "AlignStop", "Stop"] + + start_cols = [col for col in start_candidates if col in df.columns] + stop_cols = [col for col in stop_candidates if col in df.columns] + if len(start_cols) == 0 or len(stop_cols) == 0: + raise ValueError( + f"Could not resolve alignment columns for dataset_kind='{dataset_kind}'. " + f"Found start_cols={start_cols}, stop_cols={stop_cols}, available={list(df.columns)}" + ) + return start_cols, stop_cols + + +def _coalesce_numeric_columns(df: pd.DataFrame, cols: List[str]) -> pd.Series: + """Coalesces numeric values across fallback columns left-to-right.""" + series = pd.to_numeric(df[cols[0]], errors="coerce") + for col in cols[1:]: + series = series.combine_first(pd.to_numeric(df[col], errors="coerce")) + return series + + +def _build_group_numeric_map(tokens: Iterable[str], offset: int) -> Dict[str, int]: + """Builds deterministic token->numeric map using lexical token order.""" + sorted_tokens = sorted({str(tok).strip() + for tok in tokens if str(tok).strip() != ""}) + mapping: Dict[str, int] = {} + for idx, token in enumerate(sorted_tokens, start=1): + mapping[token] = int(offset) + idx + return mapping + + +def _prepare_nonpv1_rows( + dataset_kind: Literal["cwp", "bkp"], + meta_df: pd.DataFrame, + reactive_codes: set[str], + nonreactive_codes: set[str], + protein_seqs_by_accession: Dict[str, str], + ambiguous_accessions: Dict[str, List[str]], + group_col: str, + group_id_offset: int +) -> Tuple[pd.DataFrame, Dict[str, Any]]: + """Normalizes CWP/BKP rows into a shared label metadata contract.""" + if "CodeName" not in meta_df.columns: + raise ValueError("Metadata must include column 'CodeName'") + if "SequenceAccession" not in meta_df.columns: + raise ValueError("Metadata must include column 'SequenceAccession'") + if group_col not in meta_df.columns: + raise ValueError( + f"Metadata must include grouping column '{group_col}'") + + overlap = sorted(list(reactive_codes & nonreactive_codes)) + if len(overlap) > 0: + preview = overlap[:10] + raise ValueError( + f"Reactive/non-reactive code overlap detected ({len(overlap)}), examples={preview}" + ) + + selected_codes = reactive_codes | nonreactive_codes + selected_df = meta_df[meta_df["CodeName"].astype( + str).isin(selected_codes)].copy() + + start_cols, stop_cols = _resolve_nonpv1_align_columns( + selected_df, dataset_kind) + if "PeptideSequence" in selected_df.columns: + peptide_col = "PeptideSequence" + elif "Peptide" in selected_df.columns: + peptide_col = "Peptide" + else: + peptide_col = "" + + selected_df["CodeName"] = selected_df["CodeName"].map(_clean_str) + selected_df["ProteinID"] = selected_df["SequenceAccession"].map(_clean_str) + selected_df["GroupToken"] = selected_df[group_col].map(_clean_str) + selected_df["AlignStart"] = _coalesce_numeric_columns( + selected_df, start_cols).astype("Int64") + selected_df["AlignStop"] = _coalesce_numeric_columns( + selected_df, stop_cols).astype("Int64") + selected_df["LabelSource"] = selected_df["CodeName"].map( + lambda code: "reactive" if code in reactive_codes else ( + "nonreactive" if code in nonreactive_codes else "unknown") + ) + selected_df["Def epitope"] = ( + selected_df["LabelSource"] == "reactive").astype("int8") + selected_df["Uncertain"] = 0 + selected_df["Not epitope"] = ( + selected_df["LabelSource"] == "nonreactive").astype("int8") + + drop_counts: Dict[str, int] = {} + drop_examples: Dict[str, List[str]] = {} + kept_rows: List[Dict[str, Any]] = [] + + for row in selected_df.to_dict(orient="records"): + code_name = _clean_str(row.get("CodeName", "")) + protein_id = _clean_str(row.get("ProteinID", "")) + group_token = _clean_str(row.get("GroupToken", "")) + + if protein_id == "": + _record_drop(drop_counts, drop_examples, + "missing_protein_id", code_name) + continue + if protein_id in ambiguous_accessions: + _record_drop(drop_counts, drop_examples, + "ambiguous_protein_sequence", protein_id) + continue + seq = protein_seqs_by_accession.get(protein_id) + if seq is None: + _record_drop(drop_counts, drop_examples, + "missing_protein_sequence", protein_id) + continue + + if group_token == "": + _record_drop(drop_counts, drop_examples, + "missing_group_token", code_name) + continue + + start_raw = row.get("AlignStart") + stop_raw = row.get("AlignStop") + if pd.isna(start_raw) or pd.isna(stop_raw): + _record_drop(drop_counts, drop_examples, + "missing_align_bounds", code_name) + continue + + try: + start = int(start_raw) + stop = int(stop_raw) + except (TypeError, ValueError): + _record_drop(drop_counts, drop_examples, + "invalid_align_bounds", code_name) + continue + + if start < 0 or stop <= start or stop > len(seq): + _record_drop( + drop_counts, + drop_examples, + "out_of_bounds_align", + f"{code_name}:{start}:{stop}:{len(seq)}", + ) + continue + + if peptide_col != "": + peptide = _clean_str(row.get(peptide_col, "")) + else: + peptide = "" + if peptide == "": + peptide = str(seq[start:stop]) + + if len(peptide) == 0: + _record_drop(drop_counts, drop_examples, + "missing_peptide_sequence", code_name) + continue + + kept_rows.append( + { + "CodeName": code_name, + "ProteinID": protein_id, + "GroupToken": group_token, + "AlignStart": start, + "AlignStop": stop, + "Peptide": peptide, + "Def epitope": int(row["Def epitope"]), + "Uncertain": 0, + "Not epitope": int(row["Not epitope"]) + } + ) + + normalized_df = pd.DataFrame(kept_rows) + if normalized_df.empty: + raise ValueError( + f"No rows left after normalization for dataset_kind='{dataset_kind}'. " + f"drop_counts={drop_counts}" + ) + + group_map = _build_group_numeric_map( + normalized_df["GroupToken"].tolist(), group_id_offset) + normalized_df["GroupID"] = normalized_df["GroupToken"].map(group_map) + missing_group_numeric = normalized_df["GroupID"].isna() + if bool(missing_group_numeric.any()): + bad_codes = normalized_df.loc[missing_group_numeric, "CodeName"].head( + 10).tolist() + raise RuntimeError( + f"Missing GroupID after mapping, examples={bad_codes}") + normalized_df["GroupID"] = normalized_df["GroupID"].astype(int) + + summary = { + "selected_codes": int(len(selected_codes)), + "selected_rows": int(selected_df.shape[0]), + "normalized_rows": int(normalized_df.shape[0]), + "drop_counts": drop_counts, + "drop_examples": drop_examples, + "start_columns_used": start_cols, + "stop_columns_used": stop_cols, + "group_token_column": group_col, + "group_id_offset": int(group_id_offset), + "group_mapping": {k: int(v) for k, v in group_map.items()} + } + return normalized_df, summary + + +def _prepare_pv1_rows( + meta_path: Path | str, + z_path: Path | str, + protein_fasta: Path | str, + is_epitope_z_min: float, + is_epitope_min_subjects: int, + not_epitope_z_max: float, + not_epitope_max_subjects: Optional[int], + subject_prefix: str, + logger: Optional[logging.Logger] +) -> Tuple[pd.DataFrame, Dict[str, str], Dict[str, Any]]: + """Prepares PV1 rows using existing preprocess logic.""" + effective_logger = logger if logger is not None else logging.getLogger( + "prepare_dataset") + pre_df = preprocess_pv1( + meta_path=meta_path, + z_path=z_path, + fname_col="FullName", + code_col="CodeName", + is_epitope_z_min=is_epitope_z_min, + is_epitope_min_subjects=is_epitope_min_subjects, + not_epitope_z_max=not_epitope_z_max, + not_epitope_max_subjects=not_epitope_max_subjects, + prefix=subject_prefix, + save_path=None, + logger=effective_logger + ) + required = [ + "CodeName", + "FullName", + "Peptide", + "AlignStart", + "AlignStop", + "Def epitope", + "Uncertain", + "Not epitope" + ] + missing = [col for col in required if col not in pre_df.columns] + if len(missing) > 0: + raise ValueError( + f"PV1 preprocess output missing required columns: {missing}") + + pv1_seqs, ambiguous_pv1 = _build_pv1_fasta_index(protein_fasta) + if len(pv1_seqs) == 0: + raise ValueError("PV1 FASTA index is empty") + + drop_counts: Dict[str, int] = {} + drop_examples: Dict[str, List[str]] = {} + kept_rows: List[Dict[str, Any]] = [] + groups_by_id: Dict[str, str] = {} + + for row in pre_df.to_dict(orient="records"): + code_name = _clean_str(row.get("CodeName", "")) + fullname = _clean_str(row.get("FullName", "")) + try: + protein_id, _ac, _oxx, family = parse_fullname(fullname) + except ValueError: + _record_drop(drop_counts, drop_examples, + "invalid_fullname", code_name) + continue + + protein_id = _clean_str(protein_id) + family = _clean_str(family) + if family == "" or not family.isdigit(): + _record_drop(drop_counts, drop_examples, + "invalid_family", f"{protein_id}:{family}") + continue + + if protein_id in ambiguous_pv1: + _record_drop(drop_counts, drop_examples, + "ambiguous_protein_sequence", protein_id) + continue + seq = pv1_seqs.get(protein_id) + if seq is None: + _record_drop(drop_counts, drop_examples, + "missing_protein_sequence", protein_id) + continue + + start_raw = row.get("AlignStart") + stop_raw = row.get("AlignStop") + if pd.isna(start_raw) or pd.isna(stop_raw): + _record_drop(drop_counts, drop_examples, + "missing_align_bounds", code_name) + continue + try: + start = int(start_raw) + stop = int(stop_raw) + except (TypeError, ValueError): + _record_drop(drop_counts, drop_examples, + "invalid_align_bounds", code_name) + continue + if start < 0 or stop <= start or stop > len(seq): + _record_drop( + drop_counts, + drop_examples, + "out_of_bounds_align", + f"{code_name}:{start}:{stop}:{len(seq)}" + ) + continue + + prev_family = groups_by_id.get(protein_id) + if prev_family is None: + groups_by_id[protein_id] = family + elif prev_family != family: + _record_drop( + drop_counts, + drop_examples, + "conflicting_family_for_protein", + f"{protein_id}:{prev_family}:{family}" + ) + continue + + kept_rows.append( + { + "CodeName": code_name, + "ProteinID": protein_id, + "GroupToken": family, + "AlignStart": start, + "AlignStop": stop, + "Peptide": _clean_str(row.get("Peptide", "")), + "Def epitope": int(row.get("Def epitope", 0)), + "Uncertain": int(row.get("Uncertain", 0)), + "Not epitope": int(row.get("Not epitope", 0)) + } + ) + + normalized_df = pd.DataFrame(kept_rows) + if normalized_df.empty: + raise ValueError( + f"No PV1 rows left after normalization. drop_counts={drop_counts}" + ) + + normalized_df["GroupID"] = normalized_df["GroupToken"].astype(int) + summary = { + "selected_rows": int(pre_df.shape[0]), + "normalized_rows": int(normalized_df.shape[0]), + "drop_counts": drop_counts, + "drop_examples": drop_examples + } + return normalized_df, pv1_seqs, summary + + +def _finalize_and_write_outputs( + normalized_df: pd.DataFrame, + protein_seqs: Dict[str, str], + out_dir: Path | str +) -> Dict[str, Any]: + """Writes prepared targets, label metadata, and embedding metadata outputs.""" + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # filter rows to proteins that still have sequence records + normalized_df = normalized_df[normalized_df["ProteinID"].astype( + str).isin(set(protein_seqs.keys()))].copy() + if normalized_df.empty: + raise ValueError( + "No rows left after intersecting normalized rows with protein FASTA IDs") + + # ensure one group id per protein + id_group_counts = ( + normalized_df.groupby("ProteinID")["GroupID"] + .nunique(dropna=False) + .reset_index(name="n") + ) + conflicting = id_group_counts[id_group_counts["n"] > 1] + if not conflicting.empty: + examples = conflicting["ProteinID"].head(10).tolist() + raise ValueError( + "Found proteins assigned to multiple groups after normalization, " + f"examples={examples}" + ) + + group_by_id = ( + normalized_df.groupby("ProteinID")["GroupID"] + .first() + .to_dict() + ) + fullname_by_id = { + str(protein_id): _build_fullname(str(protein_id), int(group_id)) + for protein_id, group_id in group_by_id.items() + } + + normalized_df["FullName"] = normalized_df["ProteinID"].map(fullname_by_id) + label_cols = [ + "CodeName", + "FullName", + "Peptide", + "AlignStart", + "AlignStop", + "Def epitope", + "Uncertain", + "Not epitope", + "ProteinID", + "GroupToken", + "GroupID" + ] + label_df = normalized_df[label_cols].copy() + label_df = label_df.sort_values( + ["ProteinID", "AlignStart", "CodeName"]).reset_index(drop=True) + + emb_meta_df = ( + label_df[["ProteinID", "FullName", "GroupID"]] + .drop_duplicates() + .rename(columns={"FullName": "Name", "GroupID": "Family"}) + ) + emb_meta_df["Family"] = emb_meta_df["Family"].astype(int) + emb_meta_df = emb_meta_df.sort_values(["ProteinID"]).reset_index(drop=True) + + target_records = [] + for protein_id in sorted(group_by_id.keys()): + seq = protein_seqs.get(protein_id) + if seq is None: + continue + target_records.append((fullname_by_id[protein_id], seq)) + if len(target_records) == 0: + raise ValueError( + "No target FASTA records could be written after normalization") + + targets_path = out_dir / "prepared_targets.fasta" + labels_meta_path = out_dir / "prepared_labels_metadata.tsv" + emb_meta_path = out_dir / "prepared_embedding_metadata.tsv" + _write_fasta_records(targets_path, target_records) + label_df.to_csv(labels_meta_path, sep="\t", index=False) + emb_meta_df[["Name", "Family"]].to_csv( + emb_meta_path, sep="\t", index=False) + + return { + "prepared_targets_fasta": str(targets_path), + "prepared_labels_metadata_tsv": str(labels_meta_path), + "prepared_embedding_metadata_tsv": str(emb_meta_path), + "n_targets": int(len(target_records)), + "n_label_rows": int(label_df.shape[0]), + "n_label_proteins": int(label_df["ProteinID"].nunique()) + } + + +def prepare_dataset( + dataset_kind: Literal["pv1", "cwp", "bkp"], + meta_path: Path | str, + output_dir: Path | str, + protein_fasta: Path | str, + reactive_codes: Optional[Path | str] = None, + nonreactive_codes: Optional[Path | str] = None, + z_path: Optional[Path | str] = None, + is_epitope_z_min: float = 20.0, + is_epitope_min_subjects: int = 4, + not_epitope_z_max: float = 10.0, + not_epitope_max_subjects: Optional[int] = None, + subject_prefix: str = "VW_", + group_id_offset: int = 0, + logger: Optional[logging.Logger] = None +) -> Dict[str, Any]: + """ + Converts dataset-specific sources into a shared PV1-compatible contract. + + Outputs under `output_dir`: + - prepared_targets.fasta + - prepared_labels_metadata.tsv + - prepared_embedding_metadata.tsv + - prepare_summary.json + """ + dataset_kind = str(dataset_kind).strip().lower() + if dataset_kind not in {"pv1", "cwp", "bkp"}: + raise ValueError( + f"Unsupported dataset_kind='{dataset_kind}'. Expected one of: pv1,cwp,bkp" + ) + + meta_path = Path(meta_path) + output_dir = Path(output_dir) + protein_fasta = Path(protein_fasta) + if not meta_path.exists(): + raise FileNotFoundError(f"Metadata file not found: {meta_path}") + if not protein_fasta.exists(): + raise FileNotFoundError(f"Protein FASTA not found: {protein_fasta}") + + summary: Dict[str, Any] = { + "dataset_kind": dataset_kind, + "meta_path": str(meta_path), + "protein_fasta": str(protein_fasta) + } + + if dataset_kind == "pv1": + if z_path is None: + raise ValueError("--z-file is required when --dataset-kind pv1") + z_path = Path(z_path) + if not z_path.exists(): + raise FileNotFoundError(f"Z-score file not found: {z_path}") + summary["z_path"] = str(z_path) + + normalized_df, pv1_seqs, prep_summary = _prepare_pv1_rows( + meta_path=meta_path, + z_path=z_path, + protein_fasta=protein_fasta, + is_epitope_z_min=is_epitope_z_min, + is_epitope_min_subjects=is_epitope_min_subjects, + not_epitope_z_max=not_epitope_z_max, + not_epitope_max_subjects=not_epitope_max_subjects, + subject_prefix=subject_prefix, + logger=logger + ) + outputs = _finalize_and_write_outputs( + normalized_df=normalized_df, + protein_seqs=pv1_seqs, + out_dir=output_dir + ) + summary["normalization"] = prep_summary + else: + if reactive_codes is None or nonreactive_codes is None: + raise ValueError( + "--reactive-codes and --nonreactive-codes are required " + "when --dataset-kind is cwp or bkp" + ) + reactive_codes = Path(reactive_codes) + nonreactive_codes = Path(nonreactive_codes) + if not reactive_codes.exists(): + raise FileNotFoundError( + f"Reactive code file not found: {reactive_codes}") + if not nonreactive_codes.exists(): + raise FileNotFoundError( + f"Non-reactive code file not found: {nonreactive_codes}") + summary["reactive_codes"] = str(reactive_codes) + summary["nonreactive_codes"] = str(nonreactive_codes) + + meta_df = pd.read_csv(meta_path, sep="\t", dtype=str) + reactive_set = _read_code_set(reactive_codes) + nonreactive_set = _read_code_set(nonreactive_codes) + + unique_seqs, ambiguous = _build_nonpv1_fasta_index(protein_fasta) + if len(unique_seqs) == 0: + raise ValueError( + f"No protein sequences resolved from FASTA: {protein_fasta}") + + if dataset_kind == "cwp": + group_col = "Cluster50ID" + else: + group_col = "reClusterID_70" + + normalized_df, prep_summary = _prepare_nonpv1_rows( + dataset_kind=dataset_kind, + meta_df=meta_df, + reactive_codes=reactive_set, + nonreactive_codes=nonreactive_set, + protein_seqs_by_accession=unique_seqs, + ambiguous_accessions=ambiguous, + group_col=group_col, + group_id_offset=int(group_id_offset) + ) + outputs = _finalize_and_write_outputs( + normalized_df=normalized_df, + protein_seqs=unique_seqs, + out_dir=output_dir + ) + summary["normalization"] = prep_summary + summary["n_ambiguous_accessions"] = int(len(ambiguous)) + if len(ambiguous) > 0: + summary["ambiguous_accession_examples"] = sorted( + list(ambiguous.keys()))[:20] + + summary.update(outputs) + summary_path = output_dir / "prepare_summary.json" + summary_path.write_text( + json.dumps(summary, indent=2, ensure_ascii=False), + encoding="utf-8" + ) + summary["prepare_summary_json"] = str(summary_path) + + if logger is not None: + logger.info( + "prepare_dataset_done", + extra={"extra": { + "dataset_kind": dataset_kind, + "output_dir": str(output_dir), + "n_targets": int(summary["n_targets"]), + "n_label_rows": int(summary["n_label_rows"]), + "n_label_proteins": int(summary["n_label_proteins"]), + "summary_path": str(summary_path) + }} + ) + + return summary diff --git a/src/pepseqpred/core/train/ddp.py b/src/pepseqpred/core/train/ddp.py index 0c0664f..491c6bc 100644 --- a/src/pepseqpred/core/train/ddp.py +++ b/src/pepseqpred/core/train/ddp.py @@ -104,6 +104,10 @@ def ddp_gather_all_1d(t: torch.Tensor, device: torch.device) -> Tuple[List[torch """ if not _ddp_enabled(): return [t], [int(t.numel())] + if t.dim() != 1: + raise ValueError( + f"ddp_gather_all_1d expects a 1D tensor, got shape={tuple(t.shape)}" + ) sizes = torch.tensor([t.numel()], device=device, dtype=torch.long) size_list = [torch.zeros_like(sizes) for _ in range(_ddp_world())] @@ -111,6 +115,25 @@ def ddp_gather_all_1d(t: torch.Tensor, device: torch.device) -> Tuple[List[torch sizes_int = [int(s.item()) for s in size_list] max_size = max(sizes_int) if sizes_int else int(t.numel()) + max_allowed_raw = os.environ.get( + "PEPSEQPRED_DDP_MAX_GATHER_ELEMS", "100000000") + try: + max_allowed = max(1, int(max_allowed_raw)) + except ValueError: + max_allowed = 100_000_000 + + bad_sizes = [ + {"rank_index": idx, "size": int(size)} + for idx, size in enumerate(sizes_int) + if (int(size) < 0) or (int(size) > max_allowed) + ] + if bad_sizes: + raise RuntimeError( + "Invalid gathered tensor sizes detected in ddp_gather_all_1d; " + "possible DDP desync or failed collective. " + f"sizes={bad_sizes} max_allowed={max_allowed}" + ) + padded = torch.zeros(max_size, device=device, dtype=t.dtype) if t.numel() > 0: padded[:t.numel()] = t diff --git a/src/pepseqpred/core/train/trainer.py b/src/pepseqpred/core/train/trainer.py index fabf369..4f317f9 100644 --- a/src/pepseqpred/core/train/trainer.py +++ b/src/pepseqpred/core/train/trainer.py @@ -13,6 +13,7 @@ from typing import Optional, List, Dict, Any, Tuple import torch import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as TorchDDP from torch.utils.data import DataLoader import numpy as np import optuna @@ -124,8 +125,12 @@ def _batch_step(self, batch: torch.Tensor, train: bool = True) -> Dict[str, Any] raise ValueError( f"Expected y_onehot shape (B, L), got {tuple(y.shape)}") - # get logits to calculate loss and validate shape - logits = self.model(X) + # In eval with DDP, call the wrapped module directly to avoid + # DDP forward-time collectives on uneven per-rank validation inputs. + if (not train) and isinstance(self.model, TorchDDP): + logits = self.model.module(X) + else: + logits = self.model(X) if logits.shape != y.shape: raise ValueError( f"Expected logits shape {tuple(y.shape)}, got {tuple(logits.shape)}") diff --git a/tests/integration/test_prepare_dataset_multisource_pipeline.py b/tests/integration/test_prepare_dataset_multisource_pipeline.py new file mode 100644 index 0000000..4aaf848 --- /dev/null +++ b/tests/integration/test_prepare_dataset_multisource_pipeline.py @@ -0,0 +1,365 @@ +import sys +import types +from pathlib import Path + +import pandas as pd +import pytest +import torch + +import pepseqpred.apps.esm_cli as esm_cli +import pepseqpred.apps.labels_cli as labels_cli +import pepseqpred.apps.train_ffnn_cli as train_cli +from pepseqpred.core.io.keys import parse_fullname +from pepseqpred.core.preprocess.preparedataset import prepare_dataset +from pepseqpred.core.train.split import split_ids_grouped + +pytestmark = [pytest.mark.integration, pytest.mark.slow] + + +class FakeAlphabet: + def get_batch_converter(self): + def _batch_converter(pairs): + labels = [name for name, _seq in pairs] + seqs = [seq for _name, seq in pairs] + max_len = max((len(seq) for seq in seqs), default=0) + tokens = torch.zeros((len(seqs), max_len + 2), dtype=torch.long) + for i, seq in enumerate(seqs): + seq_len = len(seq) + tokens[i, 1:1 + seq_len] = 1 + tokens[i, 1 + seq_len] = 2 + return labels, seqs, tokens + + return _batch_converter + + +class FakeESMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, batch_tokens, repr_layers, return_contacts=False): + _ = return_contacts + batch_size, token_len = batch_tokens.shape + rep_dim = 3 # append_seq_len => final emb dim=4 + reps = torch.ones((batch_size, token_len, rep_dim), + dtype=torch.float32) + return {"representations": {repr_layers[0]: reps}} + + +def _write_code_list(path: Path, codes: list[str]) -> None: + path.write_text("Sequence name\n" + "\n".join(codes) + + "\n", encoding="utf-8") + + +def _append_fasta_records(path: Path, records: list[tuple[str, str]]) -> None: + with path.open("a", encoding="utf-8") as out_f: + for header, seq in records: + out_f.write(f">{header}\n{seq}\n") + + +def _build_pv1_inputs(root: Path) -> tuple[Path, Path, Path]: + root.mkdir(parents=True, exist_ok=True) + meta = root / "pv1_meta.tsv" + z = root / "pv1_z.tsv" + fasta = root / "pv1_targets.fasta" + + pd.DataFrame( + [ + { + "CodeName": "pv1_pep_1", + "Category": "SetCover", + "SpeciesID": "1", + "Species": "PV1", + "Protein": "Prot", + "FullName": "ID=PV1P001 AC=A1 OXX=11,22,301_0_4", + "Peptide": "MNPQ", + "Encoding": "enc", + }, + { + "CodeName": "pv1_pep_2", + "Category": "SetCover", + "SpeciesID": "1", + "Species": "PV1", + "Protein": "Prot", + "FullName": "ID=PV1P001 AC=A1 OXX=11,22,301_2_6", + "Peptide": "PQRS", + "Encoding": "enc", + }, + ] + ).to_csv(meta, sep="\t", index=False) + pd.DataFrame( + [ + {"Sequence name": "pv1_pep_1", "VW_001": 30.0, "VW_002": 0.0}, + {"Sequence name": "pv1_pep_2", "VW_001": 1.0, "VW_002": 2.0}, + ] + ).to_csv(z, sep="\t", index=False) + fasta.write_text( + ">ID=PV1P001 AC=A1 OXX=11,22,301\nMNPQRS\n", + encoding="utf-8", + ) + return meta, z, fasta + + +def _build_cwp_inputs(root: Path) -> tuple[Path, Path, Path, Path]: + root.mkdir(parents=True, exist_ok=True) + meta = root / "cwp_meta.tsv" + reactive = root / "cwp_reactive.tsv" + nonreactive = root / "cwp_nonreactive.tsv" + fasta = root / "cwp_targets.faa" + + pd.DataFrame( + [ + { + "CodeName": "CWP_000001", + "SequenceAccession": "A0CWP1", + "Cluster50ID": "Cocci_id50_010", + "StartIndex": 0, + "StopIndex": 4, + "PeptideSequence": "ACDE", + }, + { + "CodeName": "CWP_000002", + "SequenceAccession": "A0CWP1", + "Cluster50ID": "Cocci_id50_010", + "StartIndex": 1, + "StopIndex": 5, + "PeptideSequence": "CDEF", + }, + ] + ).to_csv(meta, sep="\t", index=False) + _write_code_list(reactive, ["CWP_000001"]) + _write_code_list(nonreactive, ["CWP_000002"]) + fasta.write_text(">tr|A0CWP1|A0CWP1_FAKE\nACDEFG\n", encoding="utf-8") + return meta, reactive, nonreactive, fasta + + +def _build_bkp_inputs(root: Path) -> tuple[Path, Path, Path, Path]: + root.mkdir(parents=True, exist_ok=True) + meta = root / "bkp_meta.tsv" + reactive = root / "bkp_reactive.tsv" + nonreactive = root / "bkp_nonreactive.tsv" + fasta = root / "bkp_targets.faa" + + pd.DataFrame( + [ + { + "CodeName": "BKP_000001", + "SequenceAccession": "A0BKP1", + "reClusterID_70": "BKP1_id70_200", + "alignStart": "0.0", + "alignStop": "4.0", + "PeptideSequence": "WXYZ", + }, + { + "CodeName": "BKP_000002", + "SequenceAccession": "A0BKP1", + "reClusterID_70": "BKP1_id70_200", + "alignStart": "1.0", + "alignStop": "5.0", + "PeptideSequence": "XYZA", + }, + ] + ).to_csv(meta, sep="\t", index=False) + _write_code_list(reactive, ["BKP_000001"]) + _write_code_list(nonreactive, ["BKP_000002"]) + fasta.write_text(">tr|A0BKP1|A0BKP1_FAKE\nWXYZAB\n", encoding="utf-8") + return meta, reactive, nonreactive, fasta + + +def test_prepare_dataset_multisource_pipeline_smoke(monkeypatch, tmp_path: Path): + # Build three mini datasets. + pv1_meta, pv1_z, pv1_fasta = _build_pv1_inputs(tmp_path / "pv1") + cwp_meta, cwp_reactive, cwp_nonreactive, cwp_fasta = _build_cwp_inputs( + tmp_path / "cwp") + bkp_meta, bkp_reactive, bkp_nonreactive, bkp_fasta = _build_bkp_inputs( + tmp_path / "bkp") + + out_pv1 = tmp_path / "out_pv1" + out_cwp = tmp_path / "out_cwp" + out_bkp = tmp_path / "out_bkp" + + prepare_dataset( + dataset_kind="pv1", + meta_path=pv1_meta, + z_path=pv1_z, + output_dir=out_pv1, + protein_fasta=pv1_fasta, + is_epitope_min_subjects=1, + ) + prepare_dataset( + dataset_kind="cwp", + meta_path=cwp_meta, + output_dir=out_cwp, + protein_fasta=cwp_fasta, + reactive_codes=cwp_reactive, + nonreactive_codes=cwp_nonreactive, + group_id_offset=100_000_000, + ) + prepare_dataset( + dataset_kind="bkp", + meta_path=bkp_meta, + output_dir=out_bkp, + protein_fasta=bkp_fasta, + reactive_codes=bkp_reactive, + nonreactive_codes=bkp_nonreactive, + group_id_offset=200_000_000, + ) + + # Combine prepared artifacts. + combined_dir = tmp_path / "combined" + combined_dir.mkdir(parents=True, exist_ok=True) + combined_fasta = combined_dir / "prepared_targets.fasta" + combined_meta = combined_dir / "prepared_labels_metadata.tsv" + combined_emb_meta = combined_dir / "prepared_embedding_metadata.tsv" + combined_fasta.write_text("", encoding="utf-8") + + for source in [out_pv1, out_cwp, out_bkp]: + recs = [] + header = None + seq_lines = [] + for raw in (source / "prepared_targets.fasta").read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if line == "": + continue + if line.startswith(">"): + if header is not None: + recs.append((header, "".join(seq_lines))) + header = line[1:].strip() + seq_lines = [] + else: + seq_lines.append(line) + if header is not None: + recs.append((header, "".join(seq_lines))) + _append_fasta_records(combined_fasta, recs) + + labels_df = pd.concat( + [ + pd.read_csv(out_pv1 / "prepared_labels_metadata.tsv", sep="\t"), + pd.read_csv(out_cwp / "prepared_labels_metadata.tsv", sep="\t"), + pd.read_csv(out_bkp / "prepared_labels_metadata.tsv", sep="\t"), + ], + ignore_index=True, + ) + labels_df.to_csv(combined_meta, sep="\t", index=False) + + emb_meta_df = pd.concat( + [ + pd.read_csv(out_pv1 / "prepared_embedding_metadata.tsv", sep="\t"), + pd.read_csv(out_cwp / "prepared_embedding_metadata.tsv", sep="\t"), + pd.read_csv(out_bkp / "prepared_embedding_metadata.tsv", sep="\t"), + ], + ignore_index=True, + ).drop_duplicates(subset=["Name", "Family"]) + emb_meta_df.to_csv(combined_emb_meta, sep="\t", index=False) + + # Assert grouped split behavior: no family overlap between train/val IDs. + id_to_family = { + parse_fullname(str(name))[0]: str(int(family)) + for name, family in emb_meta_df[["Name", "Family"]].itertuples(index=False, name=None) + } + all_ids = sorted(id_to_family.keys()) + train_ids, val_ids = split_ids_grouped( + all_ids, + val_frac=0.34, + seed=11, + groups=id_to_family, + ) + train_fams = {id_to_family[pid] for pid in train_ids} + val_fams = {id_to_family[pid] for pid in val_ids} + assert train_fams.isdisjoint(val_fams) + + # Run ESM CLI with fake model. + fake_pretrained = types.SimpleNamespace( + fake_model=lambda: (FakeESMModel(), FakeAlphabet()) + ) + monkeypatch.setattr(esm_cli.esm, "pretrained", fake_pretrained) + monkeypatch.setattr(esm_cli.torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(esm_cli.torch.cuda, "device_count", lambda: 0) + + embs_out = tmp_path / "esm_out" + monkeypatch.setattr( + sys, + "argv", + [ + "esm_cli.py", + "--fasta-file", + str(combined_fasta), + "--metadata-file", + str(combined_emb_meta), + "--out-dir", + str(embs_out), + "--embedding-key-mode", + "id-family", + "--model-name", + "fake_model", + "--max-tokens", + "16", + "--batch-size", + "4", + ], + ) + esm_cli.main() + + emb_dir = embs_out / "artifacts" / "pts" + assert emb_dir.exists() + + # Run labels CLI against prepared metadata and generated embeddings. + labels_pt = tmp_path / "labels.pt" + monkeypatch.setattr( + sys, + "argv", + [ + "labels_cli.py", + str(combined_meta), + str(labels_pt), + "--emb-dir", + str(emb_dir), + "--embedding-key-delim", + "-", + "--calc-pos-weight", + ], + ) + labels_cli.main() + assert labels_pt.exists() + + # Train smoke run using grouped split (id-family) over all three datasets. + save_dir = tmp_path / "train_out" + monkeypatch.setattr( + sys, + "argv", + [ + "train_ffnn_cli.py", + "--embedding-dirs", + str(emb_dir), + "--label-shards", + str(labels_pt), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--val-frac", + "0.34", + "--split-type", + "id-family", + "--split-seeds", + "11", + "--train-seeds", + "101", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv"), + ], + ) + train_cli.main() + + assert (save_dir / "runs.csv").exists() + run_dirs = list(save_dir.glob("run_*")) + assert run_dirs + assert (run_dirs[0] / "fully_connected.pt").exists() diff --git a/tests/unit/apps/test_cli_wrappers.py b/tests/unit/apps/test_cli_wrappers.py index ebef2d2..9dbf627 100644 --- a/tests/unit/apps/test_cli_wrappers.py +++ b/tests/unit/apps/test_cli_wrappers.py @@ -5,6 +5,7 @@ import pytest import pepseqpred.apps.esm_cli as esm_cli import pepseqpred.apps.labels_cli as labels_cli +import pepseqpred.apps.prepare_dataset_cli as prepare_dataset_cli import pepseqpred.apps.preprocess_cli as preprocess_cli pytestmark = pytest.mark.unit @@ -167,3 +168,56 @@ def test_esm_cli_id_family_requires_metadata(monkeypatch, tmp_path: Path): with pytest.raises(ValueError, match="Metadata file is required"): esm_cli.main() + + +def test_prepare_dataset_cli_invokes_adapter(monkeypatch, tmp_path: Path): + captured = {} + + def fake_prepare_dataset(**kwargs): + captured["kwargs"] = kwargs + return { + "prepared_targets_fasta": str(tmp_path / "prepared_targets.fasta"), + "prepared_labels_metadata_tsv": str(tmp_path / "prepared_labels_metadata.tsv"), + "prepared_embedding_metadata_tsv": str(tmp_path / "prepared_embedding_metadata.tsv"), + "prepare_summary_json": str(tmp_path / "prepare_summary.json"), + "n_targets": 2, + "n_label_rows": 4, + "n_label_proteins": 2, + } + + ns = argparse.Namespace( + meta_file=tmp_path / "meta.tsv", + output_dir=tmp_path / "out", + dataset_kind="cwp", + protein_fasta=tmp_path / "proteins.faa", + z_file=None, + reactive_codes=tmp_path / "reactive.tsv", + nonreactive_codes=tmp_path / "nonreactive.tsv", + group_id_offset=None, + is_epi_z_min=20.0, + is_epi_min_subs=4, + not_epi_z_max=10.0, + not_epi_max_subs=0, + subject_prefix="VW_", + log_level="INFO", + log_json=False, + ) + + monkeypatch.setattr( + prepare_dataset_cli.argparse.ArgumentParser, "parse_args", lambda self: ns + ) + monkeypatch.setattr( + prepare_dataset_cli, + "setup_logger", + lambda **kwargs: logging.getLogger("prepare_dataset_cli_test") + ) + monkeypatch.setattr( + prepare_dataset_cli, + "prepare_dataset", + fake_prepare_dataset + ) + + prepare_dataset_cli.main() + + assert captured["kwargs"]["dataset_kind"] == "cwp" + assert int(captured["kwargs"]["group_id_offset"]) == 100_000_000 diff --git a/tests/unit/core/preprocess/test_prepare_dataset.py b/tests/unit/core/preprocess/test_prepare_dataset.py new file mode 100644 index 0000000..3ef8d4f --- /dev/null +++ b/tests/unit/core/preprocess/test_prepare_dataset.py @@ -0,0 +1,223 @@ +import json +from pathlib import Path + +import pandas as pd +import pytest + +from pepseqpred.core.preprocess.preparedataset import ( + _build_group_numeric_map, + prepare_dataset, +) + +pytestmark = pytest.mark.unit + + +def _write_tsv(path: Path, rows: list[dict]) -> None: + pd.DataFrame(rows).to_csv(path, sep="\t", index=False) + + +def _write_code_list(path: Path, codes: list[str]) -> None: + lines = ["Sequence name", *codes] + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def test_group_mapping_is_deterministic_and_disjoint_ranges(): + cwp_tokens = ["Cocci_id50_010", "Cocci_id50_001", "Cocci_id50_010"] + bkp_tokens = ["BKP1_id70_200", "BKP1_id70_050"] + + cwp_map = _build_group_numeric_map(cwp_tokens, offset=100_000_000) + bkp_map = _build_group_numeric_map(bkp_tokens, offset=200_000_000) + + assert cwp_map == { + "Cocci_id50_001": 100000001, + "Cocci_id50_010": 100000002, + } + assert bkp_map == { + "BKP1_id70_050": 200000001, + "BKP1_id70_200": 200000002, + } + assert max(cwp_map.values()) < min(bkp_map.values()) + + +def test_prepare_cwp_outputs_and_drop_report(tmp_path: Path): + meta_path = tmp_path / "cwp_meta.tsv" + protein_fasta = tmp_path / "proteins.faa" + reactive = tmp_path / "reactive.tsv" + nonreactive = tmp_path / "nonreactive.tsv" + out_dir = tmp_path / "prepared" + + _write_tsv( + meta_path, + [ + { + "CodeName": "CWP_000001", + "SequenceAccession": "A0A111", + "Cluster50ID": "Cocci_id50_010", + "StartIndex": 0, + "StopIndex": 4, + "PeptideSequence": "MNPQ", + }, + { + "CodeName": "CWP_000002", + "SequenceAccession": "WP_200.1", + "Cluster50ID": "Cocci_id50_001", + "StartIndex": 1, + "StopIndex": 5, + "PeptideSequence": "QWER", + }, + { + "CodeName": "CWP_000003", + "SequenceAccession": "MISSING_ACC", + "Cluster50ID": "Cocci_id50_002", + "StartIndex": 0, + "StopIndex": 3, + "PeptideSequence": "AAA", + }, + ], + ) + protein_fasta.write_text( + ">tr|A0A111|A0A111_FAKE desc\nMNPQRS\n" + ">WP_200.1 hypothetical protein\nAQWERT\n", + encoding="utf-8", + ) + _write_code_list(reactive, ["CWP_000001"]) + _write_code_list(nonreactive, ["CWP_000002", "CWP_000003"]) + + summary = prepare_dataset( + dataset_kind="cwp", + meta_path=meta_path, + output_dir=out_dir, + protein_fasta=protein_fasta, + reactive_codes=reactive, + nonreactive_codes=nonreactive, + group_id_offset=1000, + ) + + assert (out_dir / "prepared_targets.fasta").exists() + assert (out_dir / "prepared_labels_metadata.tsv").exists() + assert (out_dir / "prepared_embedding_metadata.tsv").exists() + assert (out_dir / "prepare_summary.json").exists() + assert int(summary["n_label_rows"]) == 2 + assert int(summary["n_label_proteins"]) == 2 + + labels_df = pd.read_csv(out_dir / "prepared_labels_metadata.tsv", sep="\t") + emb_df = pd.read_csv(out_dir / "prepared_embedding_metadata.tsv", sep="\t") + assert labels_df["CodeName"].tolist() == ["CWP_000001", "CWP_000002"] + + fam_by_name = dict(zip(emb_df["Name"], emb_df["Family"])) + name_c1 = labels_df.loc[labels_df["CodeName"] + == "CWP_000001", "FullName"].iloc[0] + name_c2 = labels_df.loc[labels_df["CodeName"] + == "CWP_000002", "FullName"].iloc[0] + assert int(fam_by_name[name_c2]) == 1001 + assert int(fam_by_name[name_c1]) == 1002 + + payload = json.loads( + (out_dir / "prepare_summary.json").read_text(encoding="utf-8")) + assert int(payload["normalization"]["drop_counts"] + ["missing_protein_sequence"]) == 1 + + +def test_prepare_bkp_derives_peptide_and_align_fallback(tmp_path: Path): + meta_path = tmp_path / "bkp_meta.tsv" + protein_fasta = tmp_path / "proteins.faa" + reactive = tmp_path / "reactive.tsv" + nonreactive = tmp_path / "nonreactive.tsv" + out_dir = tmp_path / "prepared" + + _write_tsv( + meta_path, + [ + { + "CodeName": "BKP_000001", + "SequenceAccession": "A0B123", + "reClusterID_70": "BKP1_id70_200", + "alignStart": "0.0", + "alignStop": "4.0", + "PeptideSequence": "", + } + ], + ) + protein_fasta.write_text( + ">tr|A0B123|A0B123_FAKE desc\nMNPQRS\n", + encoding="utf-8", + ) + _write_code_list(reactive, ["BKP_000001"]) + _write_code_list(nonreactive, []) + + summary = prepare_dataset( + dataset_kind="bkp", + meta_path=meta_path, + output_dir=out_dir, + protein_fasta=protein_fasta, + reactive_codes=reactive, + nonreactive_codes=nonreactive, + group_id_offset=2000, + ) + + assert int(summary["n_label_rows"]) == 1 + labels_df = pd.read_csv(out_dir / "prepared_labels_metadata.tsv", sep="\t") + emb_df = pd.read_csv(out_dir / "prepared_embedding_metadata.tsv", sep="\t") + + assert labels_df["Peptide"].iloc[0] == "MNPQ" + assert int(labels_df["AlignStart"].iloc[0]) == 0 + assert int(labels_df["AlignStop"].iloc[0]) == 4 + assert int(emb_df["Family"].iloc[0]) == 2001 + + +def test_prepare_pv1_reuses_preprocess_and_family_from_fullname(tmp_path: Path): + pv1_meta = tmp_path / "pv1_meta.tsv" + pv1_z = tmp_path / "pv1_z.tsv" + pv1_fasta = tmp_path / "pv1_targets.fasta" + out_dir = tmp_path / "prepared" + + _write_tsv( + pv1_meta, + [ + { + "CodeName": "pep1", + "Category": "SetCover", + "SpeciesID": "1", + "Species": "X", + "Protein": "Y", + "FullName": "ID=P001 AC=A1 OXX=11,22,33_0_4", + "Peptide": "MNPQ", + "Encoding": "enc", + }, + { + "CodeName": "pep2", + "Category": "SetCover", + "SpeciesID": "1", + "Species": "X", + "Protein": "Y", + "FullName": "ID=P001 AC=A1 OXX=11,22,33_2_6", + "Peptide": "PQRS", + "Encoding": "enc", + }, + ], + ) + _write_tsv( + pv1_z, + [ + {"Sequence name": "pep1", "VW_001": 30.0, "VW_002": 0.0}, + {"Sequence name": "pep2", "VW_001": 1.0, "VW_002": 2.0}, + ], + ) + pv1_fasta.write_text( + ">ID=P001 AC=A1 OXX=11,22,33\nMNPQRS\n", + encoding="utf-8", + ) + + summary = prepare_dataset( + dataset_kind="pv1", + meta_path=pv1_meta, + output_dir=out_dir, + protein_fasta=pv1_fasta, + z_path=pv1_z, + is_epitope_min_subjects=1, + ) + + assert int(summary["n_label_rows"]) == 2 + assert int(summary["n_label_proteins"]) == 1 + emb_df = pd.read_csv(out_dir / "prepared_embedding_metadata.tsv", sep="\t") + assert int(emb_df["Family"].iloc[0]) == 33 diff --git a/tests/unit/core/train/test_support_modules.py b/tests/unit/core/train/test_support_modules.py index 9d3cf4c..81433d9 100644 --- a/tests/unit/core/train/test_support_modules.py +++ b/tests/unit/core/train/test_support_modules.py @@ -184,3 +184,24 @@ def _all_gather(out_list, in_tensor): assert gathered[0].tolist() == [9, 8, 0, 0] assert gathered[1].tolist() == [7, 6, 5, 4] assert gathered[2].tolist() == [3, 0, 0, 0] + + +def test_ddp_gather_all_1d_rejects_invalid_gathered_sizes(monkeypatch): + monkeypatch.setattr(ddp_mod.dist, "is_available", lambda: True) + monkeypatch.setattr(ddp_mod.dist, "is_initialized", lambda: True) + monkeypatch.setattr(ddp_mod.dist, "get_world_size", lambda: 2) + + def _all_gather(out_list, in_tensor): + if in_tensor.dtype == torch.long and in_tensor.numel() == 1: + out_list[0].fill_(2) + out_list[1].fill_(10**12) + return + raise AssertionError("Payload gather should not execute after size validation failure") + + monkeypatch.setattr(ddp_mod.dist, "all_gather", _all_gather) + + with pytest.raises(RuntimeError, match="Invalid gathered tensor sizes"): + ddp_mod.ddp_gather_all_1d( + torch.tensor([1, 2], dtype=torch.int64), + torch.device("cpu") + ) diff --git a/tests/unit/core/train/test_trainer_fit.py b/tests/unit/core/train/test_trainer_fit.py index 3b290c7..6fe6e1b 100644 --- a/tests/unit/core/train/test_trainer_fit.py +++ b/tests/unit/core/train/test_trainer_fit.py @@ -3,7 +3,9 @@ import optuna import pytest import torch +import torch.nn as nn from pepseqpred.core.models.ffnn import PepSeqFFNN +import pepseqpred.core.train.trainer as trainer_mod from pepseqpred.core.train.trainer import ( Trainer, TrainerConfig, @@ -117,3 +119,35 @@ def test_fit_optuna_prune_path(tmp_path: Path): with pytest.raises(optuna.TrialPruned): trainer.fit_optuna(save_dir=tmp_path, trial=_AlwaysPruneTrial(), score_key="f1") + + +def test_run_epoch_eval_uses_wrapped_module_for_ddp_like_models(monkeypatch): + class _FakeDDP(nn.Module): + def __init__(self, module: nn.Module): + super().__init__() + self.module = module + + def forward(self, _x): + raise RuntimeError("DDP wrapper forward should not be used during eval") + + monkeypatch.setattr(trainer_mod, "TorchDDP", _FakeDDP) + + base_model = PepSeqFFNN( + emb_dim=4, + hidden_sizes=(8,), + dropouts=(0.0,), + num_classes=1, + use_layer_norm=False, + use_residual=False + ) + wrapped_model = _FakeDDP(base_model) + trainer = Trainer( + model=wrapped_model, + train_loader=_make_batches(n_batches=1), + val_loader=_make_batches(n_batches=1), + logger=logging.getLogger("trainer_eval_ddp_like_test"), + config=TrainerConfig(epochs=1, batch_size=2, learning_rate=1e-2, device="cpu") + ) + + out = trainer._run_epoch(0, train=False) + assert "eval_metrics" in out