Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
2964e91
update numpy concat
sahilsethi0105 Mar 13, 2026
7856555
add all local changes
sahilsethi0105 Mar 13, 2026
74d1216
integrate audioset
sahilsethi0105 Mar 17, 2026
814b0a2
Merge remote-tracking branch 'origin/main' into sahil_dev_branch
sahilsethi0105 Mar 17, 2026
ac026de
update sampling rate and audioset metrics
sahilsethi0105 Mar 18, 2026
fc54671
add new backbones, spectrograms, and prototype mechanism
sahilsethi0105 Mar 19, 2026
f0f4c21
update scripts for audio
sahilsethi0105 Mar 19, 2026
0941cfd
fix multigpu training
sahilsethi0105 Mar 20, 2026
b4851f7
fix typo
sahilsethi0105 Mar 20, 2026
3855ee8
add local changes
sahilsethi0105 Mar 21, 2026
8267bf2
merge in main changes
sahilsethi0105 Mar 21, 2026
41f76eb
delay early stopping
sahilsethi0105 Mar 21, 2026
877b51d
remove comments
sahilsethi0105 Mar 21, 2026
da28bac
add constrastive pretraining in
sahilsethi0105 Mar 25, 2026
b6bcfc0
add resume from checkpoint, original loss coeffs, and audio into LR
sahilsethi0105 Mar 26, 2026
cb65db7
fix supervised loss
sahilsethi0105 Mar 29, 2026
f53fa34
add ilp
sahilsethi0105 Mar 29, 2026
5dc1e4b
update weight resumed
sahilsethi0105 Mar 29, 2026
d7d4103
add audio edits
sahilsethi0105 Mar 29, 2026
d20ae28
Merge remote-tracking branch 'origin/main' into sahil_dev_branch
sahilsethi0105 Mar 29, 2026
4c95a07
add ilp check code
sahilsethi0105 Mar 29, 2026
20cfaa6
remove softmax
sahilsethi0105 Mar 29, 2026
b7c54f0
fix protopool bug
sahilsethi0105 Mar 30, 2026
9993839
debug assignment ilp pipeline
sahilsethi0105 Mar 30, 2026
ccb514a
add audio pipeline scripts
sahilsethi0105 Mar 30, 2026
32ec5b5
allow extra keys for fine tuning
sahilsethi0105 Apr 1, 2026
bd11f50
Add files via upload
sahilsethi0105 Apr 14, 2026
6985d16
tweaks
StevenSong Apr 20, 2026
0b5fe2e
Merge remote-tracking branch 'origin/ss-downstream-audio' into ss-audio
StevenSong Apr 21, 2026
4a85b4e
prepare audio scripts
StevenSong Apr 21, 2026
ea35aff
ensure ecg models still work
StevenSong Apr 21, 2026
79bcff6
rename base_dataset in package
StevenSong Apr 21, 2026
bc152b1
resolve torchcodec dependencies for datasets (bump torch to 2.7, test…
StevenSong Apr 21, 2026
294b49b
piping
StevenSong Apr 21, 2026
3bd8a46
graft implementation of full audioset training from hf dataset, inclu…
StevenSong Apr 22, 2026
6452abc
boilerplate for esc50
StevenSong Apr 22, 2026
06aa03f
add esc50 dataset
StevenSong Apr 22, 2026
dfe59d7
enable multiclass (instead of multilabel) training, wip debugging dow…
StevenSong Apr 22, 2026
e4d23c3
align pretraining config to what was done in other branch
StevenSong Apr 22, 2026
51b5b8a
put kwargs on specific stages
StevenSong Apr 22, 2026
b2edf2a
use dataset source seconds
StevenSong Apr 22, 2026
3167ace
add urban sound, cleanup scripts
StevenSong Apr 23, 2026
8fc83ea
cleanup script
StevenSong Apr 24, 2026
21e6b9e
fix some ecg bugs introduced during renames, run ecg ablations over p…
StevenSong Apr 26, 2026
e514fa2
add audio voxceleb1 and iemocap
StevenSong Apr 26, 2026
25dee68
run st-mem, prepare results
StevenSong Apr 28, 2026
221a135
bootstrap audio results
StevenSong Apr 28, 2026
6f0b7e7
cleanup audio queue script
StevenSong Apr 28, 2026
80aa9bd
bootstrap audio
StevenSong Apr 29, 2026
aa83f91
derive prototypes from fms
StevenSong Apr 30, 2026
72b3987
fix seeding bug
StevenSong Apr 30, 2026
b1ef307
fix audio seeding just in case we do multiseed later
StevenSong May 1, 2026
3c18aa4
add fine tuning to protopool
StevenSong May 1, 2026
d5eec3c
last ablations
StevenSong May 5, 2026
e8b3e32
remove unused things for paper
StevenSong May 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
.gitmodules @StevenSong
.pre-commit-config.yaml @StevenSong
.secrets.baseline @StevenSong
configs/audio @StevenSong
configs/multi-arm-heedb @StevenSong
configs/pretrain-supervised.yaml @StevenSong
configs/pretrain-unsupervised.yaml @StevenSong
Expand All @@ -13,21 +14,28 @@ configs/target-unguided.yaml @StevenSong
data-preprocessing/ @StevenSong
env.yaml @StevenSong
external/PierreElias-IntroECG @StevenSong
external/vuno-ST-MEM @StevenSong
hf-token-plugin.py @StevenSong
LICENSE @StevenSong
results/ @StevenSong
protossl/__init__.py @StevenSong
protossl/datasets/__init__.py @StevenSong
protossl/datasets/streaming_loaders @StevenSong
protossl/datasets/_audioset_contrastive_wrapper_dataset.py @sahilsethi0105
protossl/datasets/_audioset_dataset.py @sahilsethi0105
protossl/datasets/_base_ecg_dataset.py @StevenSong
protossl/datasets/_base_dataset.py @StevenSong
protossl/datasets/_cinc_dataset.py @StevenSong
protossl/datasets/_code15_dataset.py @StevenSong
protossl/datasets/_echonext_dataset.py @StevenSong
protossl/datasets/_heedb_dataset.py @StevenSong
protossl/datasets/_iemocap_dataset.py @StevenSong
protossl/datasets/_mimic_dataset.py @StevenSong
protossl/datasets/_pclr_wrapper_dataset.py @StevenSong
protossl/datasets/_ptbxl_dataset.py @StevenSong
protossl/datasets/_urbansound8k_dataset.py @StevenSong
protossl/datasets/_utils.py @StevenSong
protossl/datasets/_voxceleb1id_dataset.py @StevenSong
protossl/datasets/_zzu_dataset.py @StevenSong
protossl/datasets/streaming_loaders @StevenSong
protossl/defines.py @StevenSong
protossl/lightning_utils.py @StevenSong
protossl/models/__init__.py @StevenSong
Expand All @@ -42,6 +50,7 @@ protossl/models/_prototype_supervisor.py @StevenSong
protossl/models/encoders/__init__.py @StevenSong
protossl/models/encoders/_base_encoder.py @StevenSong
protossl/models/encoders/_net1d.py @StevenSong
protossl/models/encoders/_panns_encoder.py @sahilsethi0105
protossl/models/encoders/_prototype_encoder_with_assignment.py @StevenSong
protossl/models/encoders/_prototype_encoder.py @StevenSong
protossl/models/encoders/_resnet1d.py @StevenSong
Expand All @@ -50,27 +59,38 @@ protossl/models/helpers/__init__.py @StevenSong
protossl/models/helpers/_prototype_ilp_assigner.py @sahilsethi0105
protossl/models/layers/__init__.py @StevenSong
protossl/models/layers/_multi_input_linear.py @StevenSong
protossl/models/layers/_panns_backbones.py @StevenSong
protossl/trainer.py @StevenSong
plot/ecg-results.ipynb @StevenSong
pyproject.toml @StevenSong
README.md @StevenSong @sahilsethi0105
requirements.txt @StevenSong
scripts/_cache_data.py @StevenSong
scripts/_eval_probs.py @StevenSong
scripts/_eval_probs_bootstrapped.py @StevenSong
scripts/_eval_probs.py @StevenSong
scripts/_linear_probe.py @StevenSong
scripts/_slurm_wrapper.sh @StevenSong
scripts/_submit_job.sh @StevenSong
scripts/0-run-cache-data.sh @StevenSong
scripts/1-run-blackbox-direct.sh @StevenSong
scripts/2-run-labsup-proto-direct.sh @StevenSong
scripts/2-z-* @StevenSong
scripts/3-run-protossl-heedb-pila.sh @StevenSong
scripts/3-z-* @StevenSong
scripts/4-run-labsup-proto-heedb-rila.sh @StevenSong
scripts/5-run-ecgfounder-logreg.sh @StevenSong
scripts/4-y-* @StevenSong
scripts/4-z-* @StevenSong
scripts/5-1-run-ecgfounder-logreg.sh @StevenSong
scripts/5-2-run-stmem-logreg.sh @StevenSong
scripts/6-run-protossl-heedb-pia.sh @StevenSong
scripts/ablations @StevenSong
scripts/audio-1-1-run-proto-from-scratch.sh
scripts/ecgfounder/_compute_ecgfounder_embeddings.py @StevenSong
scripts/7-run-protossl-heedb-pit.sh @StevenSong
scripts/8-run-protossl-heedb-pip.sh @StevenSong
scripts/9-0-run-ecgfounder-patches.sh @StevenSong
scripts/9-1-run-ecgfounder-lap.sh @StevenSong
scripts/9-2-run-ecgfounder-clustering.sh @StevenSong
scripts/9-3-run-ecgfounder-random.sh @StevenSong
scripts/audio @StevenSong
scripts/ecg-fms/_compute_ecgfounder_embeddings.py @StevenSong
scripts/ecg-fms/_compute_stmem_embeddings.py @StevenSong
scripts/echonext/_tabular_logreg.py @StevenSong
scripts/echonext/run-columbia-minimodel.sh @StevenSong
scripts/echonext/run-tabular-logreg.sh @StevenSong
Expand All @@ -80,10 +100,11 @@ scripts/pretrain/run-heedb-normalizations.sh @StevenSong
scripts/pretrain/run-pass-heedb-pretrain-no-attn.sh @StevenSong
scripts/pretrain/run-pass-heedb-pretrain.sh @StevenSong
scripts/pretrain/run-prosup-heedb-pretrain.sh @StevenSong
scripts/prototypes-from-fms @StevenSong
scripts/queue-experiments.sh @StevenSong
scripts/README.md @StevenSong @sahilsethi0105
user-study/prepare_samples.ipynb @StevenSong
user-study/decode_samples.ipynb @StevenSong
user-study/analyze_results.ipynb @StevenSong
user-study/images @StevenSong
user-study/metadata.csv @StevenSong
user-study/prepare_samples.ipynb @StevenSong
user-study/results.csv @StevenSong
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,5 @@ temp/
outputs*/
OLD/
slurm-logs/
plot/figs/
results/figs/
ecgfounder-checkpoint/
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "user-study/images"]
path = user-study/images
url = git@github.com:StevenSong/protossl-user-study-images.git
[submodule "external/vuno-ST-MEM"]
path = external/vuno-ST-MEM
url = git@github.com:vuno/ST-MEM.git
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ repos:
hooks:
- id: detect-secrets
args: ['--baseline', '.secrets.baseline']
exclude: data-preprocessing/hf_audioset_ids
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
Expand Down
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ git clone git@github.com:StevenSong/ProtoSSL.git
cd ProtoSSL

# 2) create and activate environment
# NOTE: you don't have to use conda, just make sure you're using the same python version and install from `requirements.txt` instead
# NOTE: if you don't use conda, make sure you're using the same python version, install from `requirements.txt`, and MAKE SURE YOU HAVE FFMPEG 5.* FOR TORCHCODEC (see below)
conda env create -f env.yaml
conda activate protossl

Expand All @@ -30,3 +30,20 @@ pip install -e .

# 5) dev away
```

**torchcodec:** torchcodec is a bit fragile with dependencies. We've pinned `torch==2.7.0` which is compatible with `torchcodec==0.4.0`, both compiled against CUDA 12.8 (which we use on our machines). This torchcodec version is only compatible with `datasets==4.0.0`. If you see errors relating to torchcodec (you can diagnose this by just importing torchcodec), make sure the dependencies are compatible not just relative to versioning, but also relating to the CUDA versions. We also use `ffmpeg=5.*` installed via conda. If you see an error relating to not being able to find `libnppicc.so.12`, it might be that the linker can't find the binaries (which we ensure are available by installing `nvidia-npp-cu12`). To fix this, you can try setting the `LD_LIBRARY_PATH` environment variable:
```bash
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/npp/lib:$LD_LIBRARY_PATH
# test by importing torchcodec in a python runtime
```
If this works, you can consider making the fix automatic via the following conda activate scripts:
```bash
mkdir -p $CONDA_PREFIX/etc/conda/activate.d
mkdir -p $CONDA_PREFIX/etc/conda/deactivate.d

# Set on activate
echo 'export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/npp/lib:$LD_LIBRARY_PATH' > $CONDA_PREFIX/etc/conda/activate.d/npp_lib.sh

# Unset on deactivate
echo 'export LD_LIBRARY_PATH=$(echo $LD_LIBRARY_PATH | sed "s|$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/npp/lib:||g")' > $CONDA_PREFIX/etc/conda/deactivate.d/npp_lib.sh
```
60 changes: 60 additions & 0 deletions configs/audio/pretrain-supervised.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# lightning.pytorch==2.6.0
seed_everything: 42
trainer:
devices: 1
precision: 32
logger:
class_path: protossl.lightning_utils.StrictWandbLogger
init_args:
project: ProtoSSL-Audio
# save_dir: /path/to/runs/
# name: exp-name
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_loss
mode: min
patience: 10 # check reduce lr on plateau
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
mode: min
save_last: True
- class_path: PredictionWriter
max_epochs: 5
log_every_n_steps: 1
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: val_loss
patience: 5 # check early stopping
model:
class_path: LitModel
init_args:
backbone_type: Cnn14
conv_type: PANNS
input_channels: 1
prototype_type: partial
partial_len: 32000 # 1-second @ 32 kHz
partial_overlap: 0.5
prototype_h: 1
prototype_w: 1
n_prototypes_per_label: 5
# pretrained_weights: /path/to/weights
data:
class_path: LitData
init_args:
# dataset_path: /path/to/dataset
sampling_rate: 32000
batch_size: 128
num_workers: 12
prefetch_factor: 2
# pipeline_stage: learn-prototypes-supervised|project-prototypes-supervised|compute-embeddings|train-classifier
63 changes: 63 additions & 0 deletions configs/audio/pretrain-unsupervised.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# lightning.pytorch==2.6.0
seed_everything: 42
trainer:
devices: 1
precision: 32
logger:
class_path: protossl.lightning_utils.StrictWandbLogger
init_args:
project: ProtoSSL-Audio
# save_dir: /path/to/runs/
# name: exp-name
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_loss
mode: min
patience: 10 # check reduce lr on plateau
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
mode: min
save_last: True
- class_path: PredictionWriter
max_epochs: 5
log_every_n_steps: 1
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: val_loss
patience: 5 # check early stopping
model:
class_path: LitModel
init_args:
backbone_type: Cnn14
conv_type: PANNS
input_channels: 1
prototype_type: partial
partial_len: 32000 # 1-second @ 32 kHz
partial_overlap: 0.5
prototype_h: 1
prototype_w: 1
n_prototypes: 2635 # 527 audioset labels, comparing to 5 prototypes per label for supervised pretraining
model_kwargs: '{"cola_loss_weight": 2, "clar_loss_weight": 1, "koleo_loss_weight": 1}'
# pretrained_weights: /path/to/weights
data:
class_path: LitData
init_args:
# dataset_path: /path/to/dataset
sampling_rate: 32000
batch_size: 128
num_workers: 12
prefetch_factor: 2
data_kwargs: '{"cola_view_seconds": 2}'
contrastive_pair_mode: cola+clar
# pipeline_stage: learn-prototypes|project-prototypes|compute-embeddings|train-classifier
53 changes: 53 additions & 0 deletions configs/audio/target-blackbox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# lightning.pytorch==2.6.0
seed_everything: 42
trainer:
devices: 1
precision: 32
logger:
class_path: protossl.lightning_utils.StrictWandbLogger
init_args:
project: ProtoSSL-Audio
# save_dir: /path/to/runs/
# name: exp-name
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_loss
mode: min
patience: 20 # check reduce lr on plateau
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
mode: min
save_last: True
- class_path: PredictionWriter
max_epochs: 1000
log_every_n_steps: 1
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: val_loss
patience: 10 # check early stopping
model:
class_path: LitModel
init_args:
backbone_type: Cnn14
conv_type: PANNS
input_channels: 1
model_kwargs: '{"label_type": "multiclass"}'
data:
class_path: LitData
init_args:
# dataset_path: /path/to/dataset
sampling_rate: 32000
batch_size: 128
num_workers: 4
pipeline_stage: train-classifier
60 changes: 60 additions & 0 deletions configs/audio/target-guided-2ppl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# lightning.pytorch==2.6.0
seed_everything: 42
trainer:
devices: 1
precision: 32
logger:
class_path: protossl.lightning_utils.StrictWandbLogger
init_args:
project: ProtoSSL-Audio
# save_dir: /path/to/runs/
# name: exp-name
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_loss
mode: min
patience: 20 # check reduce lr on plateau
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
mode: min
save_last: True
- class_path: PredictionWriter
max_epochs: 1000
log_every_n_steps: 1
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: val_loss
patience: 10 # check early stopping
model:
class_path: LitModel
init_args:
backbone_type: Cnn14
conv_type: PANNS
input_channels: 1
prototype_type: partial
partial_len: 32000 # 1-second @ 32 kHz
partial_overlap: 0.5
prototype_h: 1
prototype_w: 1
n_prototypes_per_label: 2
# model_kwargs: '{"label_type": "multiclass", "use_default_weights": True}' # granularly set these on each stage
# pretrained_weights: /path/to/weights
data:
class_path: LitData
init_args:
# dataset_path: /path/to/dataset
sampling_rate: 32000
batch_size: 128
num_workers: 4
# pipeline_stage: learn-prototypes-supervised|learn-prototype-assignments|project-prototypes-supervised|compute-embeddings|train-classifier
Loading
Loading