Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 135 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The repository profile is the source of truth for reproducing experiments end-to
## End-to-End Pipeline

```text
Stage 1 preprocess metadata + zscores
Stage 1 normalize dataset inputs (PV1/CWP/BKP) to a shared training contract
Stage 2 generate ESM-2 per-residue embeddings
Stage 3 build residue-level label shards
Stage 4 train FFNN (seeded or ensemble-kfold, DDP-aware)
Expand All @@ -60,7 +60,91 @@ Stage 7 evaluate residue metrics (+ optional Cocci peptide compare)

## Stage Reference

### Stage 1: Preprocess
### Stage 1: Multi-Dataset Prepare (PV1/CWP/BKP)

**CLI:** `pepseqpred-prepare-dataset` (`src/pepseqpred/apps/prepare_dataset_cli.py`)

This stage is the recommended entrypoint when training on one or more of:

- PV1 (human virome)
- CWP/Cocci (fungal)
- BKP (bacterial)

It normalizes source-specific metadata and FASTA headers into a shared PV1-compatible contract so downstream embedding, label generation, and training CLIs can be reused unchanged.

**Core module**

- `src/pepseqpred/core/preprocess/preparedataset.py`

**Required output contract per dataset**

- `prepared_targets.fasta`
- `prepared_labels_metadata.tsv`
- `prepared_embedding_metadata.tsv`
- `prepare_summary.json`

**PV1 inputs and command**

- metadata TSV
- z-score TSV
- protein FASTA

```bash
pepseqpred-prepare-dataset \
localdata/PV1/PV1_meta_2020-11-23_cleaned.tsv \
localdata/PV1/prepared \
--dataset-kind pv1 \
--protein-fasta localdata/PV1/PV1_targets.fasta \
--z-file localdata/PV1/PV1_zscores.tsv
```

**CWP/Cocci inputs and command**

- metadata TSV
- protein FASTA
- reactive code list TSV
- non-reactive code list TSV

```bash
pepseqpred-prepare-dataset \
localdata/Cocci/CWP_metadata.tsv \
localdata/Cocci/prepared \
--dataset-kind cwp \
--protein-fasta localdata/Cocci/CWP_targets.faa \
--reactive-codes localdata/Cocci/CWP_reactive_Z20N4.tsv \
--nonreactive-codes localdata/Cocci/CWP_nonreactive_Z20N4.tsv
```

**BKP inputs and command**

- metadata TSV
- protein FASTA
- reactive code list TSV
- non-reactive code list TSV

```bash
pepseqpred-prepare-dataset \
localdata/BKP/BKP_metadata.tsv \
localdata/BKP/prepared \
--dataset-kind bkp \
--protein-fasta localdata/BKP/BKP.faa \
--reactive-codes localdata/BKP/BKP_reactive_Z20N4.tsv \
--nonreactive-codes localdata/BKP/BKP_nonreactive_Z20N4.tsv
```

**Dataset-specific grouping used for leakage-aware splitting (`--split-type id-family`)**

- PV1: family from PV1 `OXX`
- CWP/Cocci: `Cluster50ID` mapped to deterministic numeric IDs
- BKP: `reClusterID_70` mapped to deterministic numeric IDs

**Next stages after prepare**

- run `pepseqpred-esm` with `--embedding-key-mode id-family` and each dataset's `prepared_embedding_metadata.tsv`
- run `pepseqpred-labels` with `--embedding-key-delim -`
- train with `--split-type id-family`

### Stage 1 (Legacy): PV1 Z-Score Preprocess

**CLI:** `pepseqpred-preprocess` (`src/pepseqpred/apps/preprocess_cli.py`)

Expand Down Expand Up @@ -177,6 +261,54 @@ pepseqpred-train-ffnn \
--results-csv localdata/models/ffnn_smoke/runs.csv
```

**Submit one SLURM training job with multiple datasets (PV1 + CWP + BKP)**

`scripts/hpc/trainffnn.sh` accepts multiple embedding directories and multiple label shards in one call:

- all embedding dirs first
- separator `--`
- all label shard `.pt` files after `--`

```bash
# Example: use per-dataset shard outputs together in one training run
EMB_DIRS=(
/scratch/$USER/esm2/pv1/artifacts/pts/shard_000
/scratch/$USER/esm2/pv1/artifacts/pts/shard_001
/scratch/$USER/esm2/pv1/artifacts/pts/shard_002
/scratch/$USER/esm2/pv1/artifacts/pts/shard_003
/scratch/$USER/esm2/cwp/artifacts/pts/shard_000
/scratch/$USER/esm2/cwp/artifacts/pts/shard_001
/scratch/$USER/esm2/cwp/artifacts/pts/shard_002
/scratch/$USER/esm2/cwp/artifacts/pts/shard_003
/scratch/$USER/esm2/bkp/artifacts/pts/shard_000
/scratch/$USER/esm2/bkp/artifacts/pts/shard_001
/scratch/$USER/esm2/bkp/artifacts/pts/shard_002
/scratch/$USER/esm2/bkp/artifacts/pts/shard_003
)

LABEL_SHARDS=(
/scratch/$USER/labels/pv1/labels_shard_000.pt
/scratch/$USER/labels/pv1/labels_shard_001.pt
/scratch/$USER/labels/pv1/labels_shard_002.pt
/scratch/$USER/labels/pv1/labels_shard_003.pt
/scratch/$USER/labels/cwp/labels_shard_000.pt
/scratch/$USER/labels/cwp/labels_shard_001.pt
/scratch/$USER/labels/cwp/labels_shard_002.pt
/scratch/$USER/labels/cwp/labels_shard_003.pt
/scratch/$USER/labels/bkp/labels_shard_000.pt
/scratch/$USER/labels/bkp/labels_shard_001.pt
/scratch/$USER/labels/bkp/labels_shard_002.pt
/scratch/$USER/labels/bkp/labels_shard_003.pt
)

sbatch trainffnn.sh "${EMB_DIRS[@]}" -- "${LABEL_SHARDS[@]}"
```

Notes:

- Keep `SPLIT_TYPE=id-family` for family-aware leakage control across PV1/CWP/BKP.
- Protein IDs should be globally unique across all provided label shards/embedding dirs.

**Outputs**

- run checkpoint(s), usually `fully_connected.pt`
Expand Down Expand Up @@ -337,6 +469,7 @@ Bundled pretrained registry currently includes:

| CLI | File | Purpose |
| --- | --- | --- |
| `pepseqpred-prepare-dataset` | `apps/prepare_dataset_cli.py` | normalize PV1/CWP/BKP into shared training contract |
| `pepseqpred-preprocess` | `apps/preprocess_cli.py` | metadata + z-score preprocessing |
| `pepseqpred-esm` | `apps/esm_cli.py` | ESM-2 embedding generation |
| `pepseqpred-labels` | `apps/labels_cli.py` | residue label shard generation |
Expand Down
128 changes: 128 additions & 0 deletions docs/extra/pv1_cwp_bkp_merge_split_and_pos_weight.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# PV1 + CWP + BKP Multi-Source Training/Validation and Positive Weight

This document explains how PV1, CWP, and BKP datasets are normalized separately and then used together for model training/validation across one or more DDP ranks. It also explains how the positive class weight is calculated at label build and train time.

## 1) Per-dataset normalization before multi-source training

Each source dataset is first normalized independently with `pepseqpred-prepare-dataset` into the same 4-file contract:

- `prepared_targets.fasta`
- `prepared_labels_metadata.tsv`
- `prepared_embedding_metadata.tsv`
- `prepare_summary.json`

### PV1 normalization

- Uses existing PV1 preprocessing (`preprocess_pv1`) to derive `Def epitope`, `Uncertain`, `Not epitope`.
- Parses family from PV1 fullname `OXX` (last comma-delimited token).
- Uses PV1 FASTA to validate protein IDs and alignment bounds.
- Group/family is the parsed PV1 family (numeric).

### CWP normalization

- Keeps only `CodeName` rows listed in `--reactive-codes` or `--nonreactive-codes`.
- Label mapping:
- reactive code -> `Def epitope=1`, `Not epitope=0`, `Uncertain=0`
- nonreactive code -> `Def epitope=0`, `Not epitope=1`, `Uncertain=0`
- Uses `SequenceAccession` as normalized `ProteinID`.
- Uses `Cluster50ID` as group token.
- Resolves align columns with fallbacks (`StartIndex/AlignStart/...`, `StopIndex/AlignStop/...`).
- Builds deterministic numeric group IDs from sorted unique `Cluster50ID` values with an offset.
- default offset in CLI: `100000000`

### BKP normalization

- Same reactive/nonreactive mapping logic as CWP.
- Uses `SequenceAccession` as normalized `ProteinID`.
- Uses `reClusterID_70` as group token.
- Resolves align columns with BKP-priority fallbacks (`alignStart`, `alignStop`, then others).
- Deterministic numeric group IDs from sorted unique `reClusterID_70` values with an offset.
- default offset in CLI: `200000000`

### Shared normalized fullname format

After normalization, all datasets are rewritten to PV1-style fullnames:

`ID=<ProteinID> AC=<ProteinID> OXX=0,0,0,<GroupID>`

So downstream tools can parse a single `ID + family` pattern uniformly.

## 2) How PV1/CWP/BKP are used together for training/validation

There is currently no dedicated "multi-source merge" CLI, and the datasets remain separate source datasets. The integration-covered way to use them together is:

1. Concatenate FASTA records from each dataset's `prepared_targets.fasta` into one combined FASTA.
2. Row-concatenate all `prepared_labels_metadata.tsv` files.
3. Row-concatenate all `prepared_embedding_metadata.tsv` files, then de-duplicate on `["Name", "Family"]`.

That is the behavior exercised in `tests/integration/test_prepare_dataset_multisource_pipeline.py` and is used to form shared inputs for embedding, label generation, and training. However, generating both embeddings and labels can be done separately, and training can be done by passing in all dataset files as outlined [here](../README.md#stage-4-train-ffnn).

## 3) How training/validation partitioning is done

### 3.1 Base protein universe

`train_ffnn_cli` builds a `ProteinDataset` from provided embedding dirs + label shards, then uses:

- `protein_ids = intersection(embedding_index IDs, label_index IDs)`

So split candidates are only proteins that exist in both embeddings and labels.

### 3.2 `split_type=id-family` (default)

- Family is parsed from embedding filename stem (`<protein_id>-<family>.pt`).
- If family is missing for an ID, it is treated as a singleton group:
- `__missing_family__:<protein_id>`
- Global split then uses grouped split functions so a family/group cannot appear in both train and validation.

### Seeded mode (`--train-mode seeded`)

- Uses `split_ids_grouped(ids, val_frac, split_seed, family_groups)`.
- Target val size is `floor(len(ids) * val_frac)`, but exact fraction can differ because whole groups move together.
- A leakage check is run after splitting and raises if any family overlaps.

### Ensemble-kfold mode (`--train-mode ensemble-kfold`)

- `--val-frac` is ignored.
- Uses grouped k-fold (`build_grouped_kfold_splits`) when `split_type=id-family`.
- Groups/families are assigned to folds intact, then each fold is one validation set.
- Leakage check is run per fold.

### 3.3 `split_type=id`

- No family grouping; IDs are split directly.
- Seeded mode uses `split_ids`:
- shuffle IDs with `split_seed`
- validation = first `floor(N * val_frac)` IDs
- training = remaining IDs
- Ensemble-kfold mode uses plain `build_kfold_splits`.

### 3.4 DDP rank partitioning after global train/val split

For multi-rank training, each run's already-determined global train/val ID lists are partitioned across ranks using `partition_ids_weighted`:

- greedy load balance by estimated embedding file size
- optional grouping by label-shard path for locality
- train partition enforces non-empty per rank

This is a rank-level sharding step, not a new train/val split.

## 4) Positive class weight calculation

There are two connected pieces:

1. Label build time (`pepseqpred-labels --calc-pos-weight`)
- Each label shard writes:
- `class_stats.pos_count`
- `class_stats.neg_count`
- `class_stats.pos_weight`
- Formula:
- `pos_weight = neg_count / max(1, pos_count)`
- For 3-column labels `[Def epitope, Uncertain, Not epitope]`, counts only include residues where `Uncertain == 0`.

2. Train time (`pepseqpred-train-ffnn`)
- If `--pos-weight` is provided, that value is used directly.
- Otherwise it reads `class_stats` from all provided label shards and recomputes:
- `total_neg / max(1, total_pos)`
- That scalar is passed to `BCEWithLogitsLoss(pos_weight=...)`.

Note: automatic train-time `pos_weight` uses shard-level totals, not run-specific train-only IDs.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pepseqpred"
version = "1.0.0"
version = "1.1.0"
description = "Residue-level epitope prediction pipeline for peptide/protein workflows."
readme = "README.pypi.md"
requires-python = ">=3.12"
Expand Down Expand Up @@ -72,6 +72,7 @@ pepseqpred-esm = "pepseqpred.apps.esm_cli:main"
pepseqpred-labels = "pepseqpred.apps.labels_cli:main"
pepseqpred-predict = "pepseqpred.apps.prediction_cli:main"
pepseqpred-preprocess = "pepseqpred.apps.preprocess_cli:main"
pepseqpred-prepare-dataset = "pepseqpred.apps.prepare_dataset_cli:main"
pepseqpred-eval-ffnn = "pepseqpred.apps.evaluate_ffnn_cli:main"
pepseqpred-train-ffnn = "pepseqpred.apps.train_ffnn_cli:main"
pepseqpred-train-ffnn-optuna = "pepseqpred.apps.train_ffnn_optuna_cli:main"
Expand Down
Loading
Loading