diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index cd7470e..fdf0780 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -82,6 +82,7 @@ scripts/4-z-* @StevenSong scripts/5-1-run-ecgfounder-logreg.sh @StevenSong scripts/5-2-run-stmem-logreg.sh @StevenSong scripts/6-run-protossl-heedb-pia.sh @StevenSong +scripts/6-z-* @StevenSong scripts/7-run-protossl-heedb-pit.sh @StevenSong scripts/8-run-protossl-heedb-pip.sh @StevenSong scripts/9-0-run-ecgfounder-patches.sh @StevenSong diff --git a/protossl/datasets/_cinc_dataset.py b/protossl/datasets/_cinc_dataset.py index 7b1452c..e461169 100644 --- a/protossl/datasets/_cinc_dataset.py +++ b/protossl/datasets/_cinc_dataset.py @@ -40,6 +40,7 @@ def __init__( _path = Path(dataset_path) df = pd.read_csv(_path / "georgia.csv") df = df[df["split"] == split] + self._df = df.reset_index(drop=True) self.source_ids = torch.as_tensor(df["patient_id"].to_numpy()) self.sample_ids = torch.as_tensor(df["ecg_id"].to_numpy()) diff --git a/protossl/datasets/_code15_dataset.py b/protossl/datasets/_code15_dataset.py index 128fa7b..150a8ca 100644 --- a/protossl/datasets/_code15_dataset.py +++ b/protossl/datasets/_code15_dataset.py @@ -40,6 +40,7 @@ def __init__( _path = Path(dataset_path) df = pd.read_csv(_path / "labels.csv") df = df[df["split"] == split] + self._df = df.reset_index(drop=True) self.source_ids = torch.as_tensor(df["patient_id"].to_numpy()) self.sample_ids = torch.as_tensor(df["exam_id"].to_numpy()) diff --git a/protossl/datasets/_echonext_dataset.py b/protossl/datasets/_echonext_dataset.py index 03b4bd8..249b6db 100644 --- a/protossl/datasets/_echonext_dataset.py +++ b/protossl/datasets/_echonext_dataset.py @@ -41,6 +41,7 @@ def __init__( df = pd.read_csv(_path / "EchoNext_metadata_100k.csv") df = df.rename(columns=mapping) split_mask = df["split"] == split + self._df = df.loc[split_mask].reset_index(drop=True) id_df = df.loc[split_mask, ["patient_key", "ecg_key"]].reset_index(drop=True) label_df = df.loc[split_mask, target_cols].reset_index(drop=True) diff --git a/protossl/datasets/_mimic_dataset.py b/protossl/datasets/_mimic_dataset.py index 906dcc8..6f608a9 100644 --- a/protossl/datasets/_mimic_dataset.py +++ b/protossl/datasets/_mimic_dataset.py @@ -43,6 +43,7 @@ def __init__( _path = Path(dataset_path) df = pd.read_csv(_path / "ed-ecgs.csv") df = df[df["split"] == split] + self._df = df.reset_index(drop=True) self.source_ids = torch.as_tensor(df["subject_id"].to_numpy()) self.sample_ids = torch.as_tensor(df["study_id"].to_numpy()) diff --git a/protossl/datasets/_ptbxl_dataset.py b/protossl/datasets/_ptbxl_dataset.py index 93750b2..a85b0ea 100644 --- a/protossl/datasets/_ptbxl_dataset.py +++ b/protossl/datasets/_ptbxl_dataset.py @@ -66,6 +66,7 @@ def __init__( else: raise ValueError(f"Unknown split: {split}") df = df[mask] + self._df = df.reset_index(drop=True) self.source_ids = torch.as_tensor(df["patient_id"].astype(int).to_numpy()) self.sample_ids = torch.as_tensor(df.index.to_numpy()) diff --git a/protossl/datasets/_zzu_dataset.py b/protossl/datasets/_zzu_dataset.py index cdf72e5..57e8624 100644 --- a/protossl/datasets/_zzu_dataset.py +++ b/protossl/datasets/_zzu_dataset.py @@ -57,6 +57,7 @@ def __init__( ) df = get_zzu_dataframe(dataset_path) df = df[df["split"] == split] + self._df = df.reset_index(drop=True) self.source_ids = torch.as_tensor(df["Patient_ID"].to_numpy()) self.sample_ids = torch.as_tensor(df["ECG_ID"].to_numpy()) diff --git a/results/audio-results.ipynb b/results/audio-results.ipynb index be6ef80..01b6815 100644 --- a/results/audio-results.ipynb +++ b/results/audio-results.ipynb @@ -52,7 +52,7 @@ " his = []\n", " for i in range(n_folds):\n", " fold_path = Path(fpath.format(i))\n", - " metrics = pd.read_csv(fold_path / model_dir / \"metrics-bootstrapped.csv\", index_col=\"Label\")\n", + " metrics = pd.read_csv(fold_path / model_dir / \"metrics-bootstrapped-v2.csv\", index_col=\"Label\")\n", " accs.append(metrics.loc[\"Multiclass\", \"Accuracy\"])\n", " los.append(metrics.loc[\"Multiclass\", \"Accuracy 95% CI (lo)\"])\n", " his.append(metrics.loc[\"Multiclass\", \"Accuracy 95% CI (hi)\"])\n", @@ -60,7 +60,7 @@ " lo = np.mean(los)\n", " hi = np.mean(his)\n", " else:\n", - " metrics = pd.read_csv(Path(fpath) / model_dir / \"metrics-bootstrapped.csv\", index_col=\"Label\")\n", + " metrics = pd.read_csv(Path(fpath) / model_dir / \"metrics-bootstrapped-v2.csv\", index_col=\"Label\")\n", " acc = metrics.loc[\"Multiclass\", \"Accuracy\"]\n", " lo = metrics.loc[\"Multiclass\", \"Accuracy 95% CI (lo)\"]\n", " hi = metrics.loc[\"Multiclass\", \"Accuracy 95% CI (hi)\"]\n", diff --git a/results/data-tables.ipynb b/results/data-tables.ipynb new file mode 100644 index 0000000..aaded21 --- /dev/null +++ b/results/data-tables.ipynb @@ -0,0 +1,841 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "from typing import Optional\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from protossl.datasets import (\n", + " HeedbECGDataset,\n", + " EchoNextECGDataset,\n", + " MimicECGDataset,\n", + " ZzuECGDataset,\n", + " PtbxlECGDataset,\n", + " get_ptbxl_labels,\n", + " CincECGDataset,\n", + " Code15ECGDataset,\n", + ")\n", + "from protossl.defines import (\n", + " HEEDB_TARGETS, \n", + " ECHONEXT_TARGETS,\n", + " MIMIC_TARGETS,\n", + " ZZU_TARGETS,\n", + " PTBXL_TARGETS,\n", + " CINC_TARGETS,\n", + " CODE15_TARGETS,\n", + ")\n", + "\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "from typing import Optional\n", + "\n", + "\n", + "class TableOne:\n", + " \"\"\"\n", + " A simplified, fast version of TableOne for large datasets.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : pd.DataFrame\n", + " Input data.\n", + " columns : list\n", + " Columns to include in the table.\n", + " categorical : list\n", + " Categorical columns.\n", + " continuous : list\n", + " Continuous columns.\n", + " nonnormal : list, optional\n", + " Continuous columns to display as median [Q1, Q3] instead of mean (SD).\n", + " nunique : list, optional\n", + " Columns for which to report the number of unique values per group\n", + " (e.g. patient ID, study ID). Rendered as a single row per column.\n", + " groupby : str, optional\n", + " Column to group by (becomes column headers).\n", + " order : dict, optional\n", + " Mapping of column name -> ordered list of values. Used for groupby column\n", + " ordering and for ordering categorical row values.\n", + " binary_show : dict, optional\n", + " Mapping of binary categorical column name -> the single class to display.\n", + " Only columns explicitly listed here will be collapsed to a single row;\n", + " binary columns not listed are rendered with all levels like any other\n", + " categorical.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " data: pd.DataFrame,\n", + " columns: list,\n", + " categorical: Optional[list] = None,\n", + " continuous: Optional[list] = None,\n", + " nonnormal: Optional[list] = None,\n", + " nunique: Optional[list] = None,\n", + " groupby: Optional[str] = None,\n", + " order: Optional[dict] = None,\n", + " binary_show: Optional[dict] = None,\n", + " ):\n", + " self.data = data\n", + " self.columns = list(columns)\n", + " self.categorical = list(categorical) if categorical else []\n", + " self.continuous = list(continuous) if continuous else []\n", + " self.nonnormal = set(nonnormal) if nonnormal else set()\n", + " self.nunique = list(nunique) if nunique else []\n", + " self.groupby = groupby\n", + " self.order = order or {}\n", + " self.binary_show = binary_show or {}\n", + "\n", + " # Don't render the groupby column as a row — it's a diagonal submatrix.\n", + " if self.groupby is not None:\n", + " self.categorical = [c for c in self.categorical if c != self.groupby]\n", + " self.continuous = [c for c in self.continuous if c != self.groupby]\n", + " self.nunique = [c for c in self.nunique if c != self.groupby]\n", + " self.columns = [c for c in self.columns if c != self.groupby]\n", + "\n", + " # Make sure nunique columns appear in self.columns even if the user\n", + " # didn't list them explicitly.\n", + " for c in self.nunique:\n", + " if c not in self.columns:\n", + " self.columns.append(c)\n", + "\n", + " self.table = self._build()\n", + "\n", + " # ------------------------------------------------------------------ #\n", + " # Group setup\n", + " # ------------------------------------------------------------------ #\n", + " def _group_indexers(self):\n", + " \"\"\"Return list of (group_label, boolean_mask, n) tuples.\"\"\"\n", + " if self.groupby is None:\n", + " mask = np.ones(len(self.data), dtype=bool)\n", + " return [(\"Overall\", mask, int(mask.sum()))]\n", + "\n", + " col = self.data[self.groupby]\n", + " if self.groupby in self.order:\n", + " levels = list(self.order[self.groupby])\n", + " else:\n", + " levels = list(pd.unique(col.dropna()))\n", + "\n", + " out = []\n", + " for lvl in levels:\n", + " mask = (col == lvl).to_numpy()\n", + " out.append((lvl, mask, int(mask.sum())))\n", + " return out\n", + "\n", + " # ------------------------------------------------------------------ #\n", + " # Continuous formatting\n", + " # ------------------------------------------------------------------ #\n", + " @staticmethod\n", + " def _fmt_mean_sd(x: np.ndarray) -> str:\n", + " x = x[~np.isnan(x)]\n", + " if x.size == 0:\n", + " return \"\"\n", + " return f\"{x.mean():.1f} ({x.std(ddof=1):.1f})\"\n", + "\n", + " @staticmethod\n", + " def _fmt_median_iqr(x: np.ndarray) -> str:\n", + " x = x[~np.isnan(x)]\n", + " if x.size == 0:\n", + " return \"\"\n", + " q1, med, q3 = np.quantile(x, [0.25, 0.5, 0.75])\n", + " return f\"{med:.1f} [{q1:.1f}, {q3:.1f}]\"\n", + "\n", + " def _continuous_row(self, col: str, groups):\n", + " nonnormal = col in self.nonnormal\n", + " label = f\"{col}, {'median [Q1, Q3]' if nonnormal else 'mean (SD)'}\"\n", + " values = self.data[col].to_numpy(dtype=float, na_value=np.nan)\n", + "\n", + " row = {}\n", + " for gname, mask, _ in groups:\n", + " sub = values[mask]\n", + " row[gname] = self._fmt_median_iqr(sub) if nonnormal else self._fmt_mean_sd(sub)\n", + " return label, row\n", + "\n", + " # ------------------------------------------------------------------ #\n", + " # Unique-count formatting\n", + " # ------------------------------------------------------------------ #\n", + " def _nunique_row(self, col: str, groups):\n", + " label = f\"{col}, unique n\"\n", + " series = self.data[col].to_numpy()\n", + " row = {}\n", + " for gname, mask, _ in groups:\n", + " sub = series[mask]\n", + " # pd.unique handles NaN, mixed types, and is faster than np.unique\n", + " # on object dtype.\n", + " row[gname] = f\"{pd.unique(sub).size}\"\n", + " return label, row\n", + "\n", + " # ------------------------------------------------------------------ #\n", + " # Categorical formatting\n", + " # ------------------------------------------------------------------ #\n", + " def _categorical_rows(self, col: str, groups):\n", + " \"\"\"\n", + " Yields (label, {group: 'n (pct%)'} ) tuples.\n", + " Only collapses to a single row if the user has explicitly listed the\n", + " column in binary_show.\n", + " \"\"\"\n", + " series = self.data[col]\n", + "\n", + " # Explicit single-class selection — only path that collapses the column.\n", + " if col in self.binary_show:\n", + " chosen = self.binary_show[col]\n", + " label = f\"{col} = {chosen}, n (%)\"\n", + " row = {}\n", + " chosen_mask = (series == chosen).to_numpy()\n", + " for gname, mask, n in groups:\n", + " if n == 0:\n", + " row[gname] = \"\"\n", + " continue\n", + " count = int((chosen_mask & mask).sum())\n", + " pct = 100.0 * count / n\n", + " row[gname] = f\"{count} ({pct:.1f})\"\n", + " yield label, row\n", + " return\n", + "\n", + " # Determine level order\n", + " if col in self.order:\n", + " levels = list(self.order[col])\n", + " else:\n", + " levels = sorted(series.dropna().unique().tolist(), key=lambda v: (str(type(v)), v))\n", + "\n", + " # Header row + one row per level\n", + " yield f\"{col}, n (%)\", {gname: \"\" for gname, _, _ in groups}\n", + " for lvl in levels:\n", + " label = f\" {lvl}\"\n", + " row = {}\n", + " lvl_mask = (series == lvl).to_numpy()\n", + " for gname, mask, n in groups:\n", + " if n == 0:\n", + " row[gname] = \"\"\n", + " continue\n", + " count = int((lvl_mask & mask).sum())\n", + " pct = 100.0 * count / n\n", + " row[gname] = f\"{count} ({pct:.1f})\"\n", + " yield label, row\n", + "\n", + " # ------------------------------------------------------------------ #\n", + " # Build\n", + " # ------------------------------------------------------------------ #\n", + " def _build(self) -> pd.DataFrame:\n", + " groups = self._group_indexers()\n", + " group_names = [g[0] for g in groups]\n", + "\n", + " rows = []\n", + " index = []\n", + "\n", + " # n row\n", + " index.append(\"n\")\n", + " rows.append({gname: f\"{n}\" for gname, _, n in groups})\n", + "\n", + " nunique_set = set(self.nunique)\n", + "\n", + " # Preserve user-specified column order\n", + " for col in self.columns:\n", + " if col in nunique_set:\n", + " label, row = self._nunique_row(col, groups)\n", + " index.append(label)\n", + " rows.append(row)\n", + " elif col in self.continuous:\n", + " label, row = self._continuous_row(col, groups)\n", + " index.append(label)\n", + " rows.append(row)\n", + " elif col in self.categorical:\n", + " for label, row in self._categorical_rows(col, groups):\n", + " index.append(label)\n", + " rows.append(row)\n", + "\n", + " df = pd.DataFrame(rows, index=index, columns=group_names)\n", + " df.columns = pd.MultiIndex.from_tuples(\n", + " [(self.groupby or \"\", c) for c in df.columns]\n", + " )\n", + " return df\n", + "\n", + " # ------------------------------------------------------------------ #\n", + " # Display\n", + " # ------------------------------------------------------------------ #\n", + " def __repr__(self) -> str:\n", + " return self.table.to_string()\n", + "\n", + " def _repr_html_(self) -> str:\n", + " return self.table.to_html()\n", + "\n", + " def to_csv(self, path, **kwargs):\n", + " return self.table.to_csv(path, **kwargs)\n", + "\n", + " def to_latex(self, **kwargs):\n", + " return self.table.to_latex(**kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "ecg_dir = Path(\"/opt/gpudata/ecg\")\n", + "\n", + "datasets = {\n", + " \"EchoNext\": {\n", + " 72475: \"echonext\",\n", + " 32768: \"echonext-32k\",\n", + " 16384: \"echonext-16k\",\n", + " 8192: \"echonext-8k\",\n", + " 4096: \"echonext-4k\",\n", + " 2048: \"echonext-2k\",\n", + " 1024: \"echonext-1k\",\n", + " 512: \"echonext-512\",\n", + " 256: \"echonext-256\",\n", + " },\n", + " \"MIMIC-IV-ECG\": {\n", + " 78470: \"mimic-iv-ecg\",\n", + " 32768: \"mimic-iv-ecg-32k\",\n", + " 16384: \"mimic-iv-ecg-16k\",\n", + " 8192: \"mimic-iv-ecg-8k\",\n", + " 4096: \"mimic-iv-ecg-4k\",\n", + " 2048: \"mimic-iv-ecg-2k\",\n", + " 1024: \"mimic-iv-ecg-1k\",\n", + " 512: \"mimic-iv-ecg-512\",\n", + " 256: \"mimic-iv-ecg-256\",\n", + " },\n", + " \"CODE-15%\": {\n", + " 74112: \"code15\",\n", + " 32768: \"code15-32k\",\n", + " 16384: \"code15-16k\",\n", + " 8192: \"code15-8k\",\n", + " 4096: \"code15-4k\",\n", + " 2048: \"code15-2k\",\n", + " 1024: \"code15-1k\",\n", + " 512: \"code15-512\",\n", + " 256: \"code15-256\",\n", + " },\n", + " \"PTB-XL\": {\n", + " 17418: \"ptb-xl\",\n", + " 8722: \"ptb-xl-8k\",\n", + " 4356: \"ptb-xl-4k\",\n", + " 2175: \"ptb-xl-2k\",\n", + " 1091: \"ptb-xl-1k\",\n", + " 547: \"ptb-xl-512\",\n", + " 273: \"ptb-xl-256\",\n", + " },\n", + " \"CinC Georgia\": {\n", + " 8192: \"cinc-2020\",\n", + " 4096: \"cinc-2020-4k\",\n", + " 2048: \"cinc-2020-2k\",\n", + " 1024: \"cinc-2020-1k\",\n", + " 512: \"cinc-2020-512\",\n", + " 256: \"cinc-2020-256\",\n", + " },\n", + " \"ZZU pECG\": {\n", + " 8658: \"zzu-pecg\",\n", + " 4096: \"zzu-pecg-4k\",\n", + " 2048: \"zzu-pecg-2k\",\n", + " 1024: \"zzu-pecg-1k\",\n", + " 512: \"zzu-pecg-512\",\n", + " 256: \"zzu-pecg-256\",\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "### HEEDB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "if not os.path.exists(\"temp_heedb_train.csv\"):\n", + " ds_train = HeedbECGDataset(\n", + " dataset_path=str(ecg_dir / \"heedb\"),\n", + " split=\"train\",\n", + " sampling_rate=100,\n", + " )\n", + " ds_val = HeedbECGDataset(\n", + " dataset_path=str(ecg_dir / \"heedb\"),\n", + " split=\"val\",\n", + " sampling_rate=100,\n", + " )\n", + " ds_test = HeedbECGDataset(\n", + " dataset_path=str(ecg_dir / \"heedb\"),\n", + " split=\"test\",\n", + " sampling_rate=100,\n", + " )\n", + "\n", + " train_df = ds_train._df.copy()\n", + " val_df = ds_val._df.copy()\n", + " test_df = ds_test._df.copy()\n", + "\n", + " train_df[list(HEEDB_TARGETS)] = ds_train.labels\n", + " val_df[list(HEEDB_TARGETS)] = ds_val.labels\n", + " test_df[list(HEEDB_TARGETS)] = ds_test.labels\n", + "\n", + " train_df[[\"patient_id\", \"age\", \"sex\", \"source\", \"split\"] + list(HEEDB_TARGETS)].to_csv(\"temp_heedb_train.csv\", index=False)\n", + " val_df[[\"patient_id\", \"age\", \"sex\", \"source\", \"split\"] + list(HEEDB_TARGETS)].to_csv(\"temp_heedb_val.csv\", index=False)\n", + " test_df[[\"patient_id\", \"age\", \"sex\", \"source\", \"split\"] + list(HEEDB_TARGETS)].to_csv(\"temp_heedb_test.csv\", index=False)\n", + "\n", + "train_df = pd.read_csv(\"temp_heedb_train.csv\")\n", + "val_df = pd.read_csv(\"temp_heedb_val.csv\")\n", + "test_df = pd.read_csv(\"temp_heedb_test.csv\")\n", + "\n", + "df = pd.concat([train_df, val_df, test_df], ignore_index=True)\n", + "df[\"sex\"] = df[\"sex\"].str.lower()\n", + "df.columns = [\"Patient\"] + list(df.columns.str.title())[1:]\n", + "\n", + "tab1 = TableOne(\n", + " data=df,\n", + " columns=list(df.columns),\n", + " nunique=[\"Patient\"],\n", + " categorical=[c for c in df.columns if c != \"Age\" and c != \"Patient\"],\n", + " continuous=[\"Age\"],\n", + " nonnormal=[\"Age\"],\n", + " groupby=\"Split\",\n", + " order={\n", + " \"Split\": list(df[\"Split\"].drop_duplicates()),\n", + " },\n", + " binary_show={\n", + " k.title(): 1 for k in HEEDB_TARGETS\n", + " },\n", + ")\n", + "\n", + "print(tab1.to_latex())" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "### EchoNext" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "ds_key = \"EchoNext\"\n", + "make_dataset = lambda ds_dir, split: EchoNextECGDataset(dataset_path=str(ecg_dir/ds_dir), split=split, sampling_rate=100)\n", + "def get_df(ds_dir, split):\n", + " ds = make_dataset(ds_dir, split)\n", + " df = ds._df[[\"patient_key\", \"age_at_ecg\", \"sex\", \"split\"] + list(ECHONEXT_TARGETS)].copy()\n", + " df.columns = [\"Patient\", \"Age\", \"Sex\", \"Split\"] + list(ECHONEXT_TARGETS)\n", + " if split == \"train\":\n", + " suffix = ds_dir.split(\"echonext\")[-1]\n", + " if suffix == \"\":\n", + " suffix = \"Full\"\n", + " else:\n", + " suffix = suffix[1:] # remove leading dash\n", + " df[\"Split\"] = f\"Train ({suffix})\"\n", + " else:\n", + " df[\"Split\"] = split.title()\n", + " return df\n", + "\n", + "\n", + "dfs_train = []\n", + "df_val = None\n", + "df_test = None\n", + "for i, (train_size, ds_dir) in enumerate(datasets[ds_key].items()):\n", + " df_train = get_df(ds_dir, \"train\")\n", + " assert len(df_train) == train_size\n", + " dfs_train.append(df_train)\n", + " if i == 0:\n", + " df_val = get_df(ds_dir, \"val\")\n", + " df_test = get_df(ds_dir, \"test\")\n", + "\n", + "df = pd.concat(dfs_train + [df_val, df_test], ignore_index=True)\n", + "\n", + "tab1 = TableOne(\n", + " data=df,\n", + " columns=list(df.columns),\n", + " nunique=[\"Patient\"],\n", + " categorical=[c for c in df.columns if c != \"Age\" and c != \"Patient\"],\n", + " continuous=[\"Age\"],\n", + " nonnormal=[\"Age\"],\n", + " groupby=\"Split\",\n", + " order={\n", + " \"Split\": list(df[\"Split\"].drop_duplicates()),\n", + " },\n", + " binary_show={\n", + " k: 1 for k in ECHONEXT_TARGETS\n", + " },\n", + ")\n", + "\n", + "print(tab1.to_latex())" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "### MIMIC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "ds_key = \"MIMIC-IV-ECG\"\n", + "make_dataset = lambda ds_dir, split: MimicECGDataset(dataset_path=str(ecg_dir/ds_dir), split=split, sampling_rate=100)\n", + "def get_df(ds_dir, split):\n", + " ds = make_dataset(ds_dir, split)\n", + " df = ds._df[[\"subject_id\", \"age\", \"gender\", \"split\"] + MIMIC_TARGETS].copy()\n", + " df.columns = [\"Patient\", \"Age\", \"Sex\", \"Split\"] + MIMIC_TARGETS\n", + " if split == \"train\":\n", + " suffix = ds_dir.split(\"mimic-iv-ecg\")[-1]\n", + " if suffix == \"\":\n", + " suffix = \"Full\"\n", + " else:\n", + " suffix = suffix[1:] # remove leading dash\n", + " df[\"Split\"] = f\"Train ({suffix})\"\n", + " else:\n", + " df[\"Split\"] = split.title()\n", + " return df\n", + "\n", + "\n", + "dfs_train = []\n", + "df_val = None\n", + "df_test = None\n", + "for i, (train_size, ds_dir) in enumerate(datasets[ds_key].items()):\n", + " df_train = get_df(ds_dir, \"train\")\n", + " assert len(df_train) == train_size\n", + " dfs_train.append(df_train)\n", + " if i == 0:\n", + " df_val = get_df(ds_dir, \"val\")\n", + " df_test = get_df(ds_dir, \"test\")\n", + "\n", + "df = pd.concat(dfs_train + [df_val, df_test], ignore_index=True)\n", + "\n", + "tab1 = TableOne(\n", + " data=df,\n", + " columns=list(df.columns),\n", + " nunique=[\"Patient\"],\n", + " categorical=[c for c in df.columns if c != \"Age\" and c != \"Patient\"],\n", + " continuous=[\"Age\"],\n", + " nonnormal=[\"Age\"],\n", + " groupby=\"Split\",\n", + " order={\n", + " \"Split\": list(df[\"Split\"].drop_duplicates()),\n", + " },\n", + " binary_show={\n", + " k: 1 for k in MIMIC_TARGETS\n", + " },\n", + ")\n", + "\n", + "print(tab1.to_latex())" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "### ZZU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "ds_key = \"ZZU pECG\"\n", + "make_dataset = lambda ds_dir, split: ZzuECGDataset(dataset_path=str(ecg_dir/ds_dir), split=split, sampling_rate=100)\n", + "def get_df(ds_dir, split):\n", + " ds = make_dataset(ds_dir, split)\n", + " df = ds._df[[\"Patient_ID\", \"Age\", \"Gender\", \"split\"] + list(ZZU_TARGETS)].copy()\n", + " df.columns = [\"Patient\", \"Age\", \"Sex\", \"Split\"] + list(ZZU_TARGETS)\n", + " df[\"Age\"] = df[\"Age\"].str.strip(\"d\").astype(int) # age in days for ZZU\n", + " if split == \"train\":\n", + " suffix = ds_dir.split(\"zzu-pecg\")[-1]\n", + " if suffix == \"\":\n", + " suffix = \"Full\"\n", + " else:\n", + " suffix = suffix[1:] # remove leading dash\n", + " df[\"Split\"] = f\"Train ({suffix})\"\n", + " else:\n", + " df[\"Split\"] = split.title()\n", + " return df\n", + "\n", + "\n", + "dfs_train = []\n", + "df_val = None\n", + "df_test = None\n", + "for i, (train_size, ds_dir) in enumerate(datasets[ds_key].items()):\n", + " df_train = get_df(ds_dir, \"train\")\n", + " assert len(df_train) == train_size\n", + " dfs_train.append(df_train)\n", + " if i == 0:\n", + " df_val = get_df(ds_dir, \"val\")\n", + " df_test = get_df(ds_dir, \"test\")\n", + "\n", + "df = pd.concat(dfs_train + [df_val, df_test], ignore_index=True)\n", + "\n", + "tab1 = TableOne(\n", + " data=df,\n", + " columns=list(df.columns),\n", + " nunique=[\"Patient\"],\n", + " categorical=[c for c in df.columns if c != \"Age\" and c != \"Patient\"],\n", + " continuous=[\"Age\"],\n", + " nonnormal=[\"Age\"],\n", + " groupby=\"Split\",\n", + " order={\n", + " \"Split\": list(df[\"Split\"].drop_duplicates()),\n", + " },\n", + " binary_show={\n", + " k: 1 for k in ZZU_TARGETS\n", + " },\n", + ")\n", + "\n", + "print(tab1.to_latex())" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "### PTB-XL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "ds_key = \"PTB-XL\"\n", + "make_dataset = lambda ds_dir, split: PtbxlECGDataset(dataset_path=str(ecg_dir/ds_dir), split=split, sampling_rate=100)\n", + "def get_df(ds_dir, split):\n", + " ds = make_dataset(ds_dir, split)\n", + " ds._df[PTBXL_TARGETS] = get_ptbxl_labels(ds._df)\n", + " ds._df[\"split\"] = split\n", + " df = ds._df[[\"patient_id\", \"age\", \"sex\", \"split\"] + PTBXL_TARGETS].copy()\n", + " df.columns = [\"Patient\", \"Age\", \"Sex\", \"Split\"] + PTBXL_TARGETS\n", + " if split == \"train\":\n", + " suffix = ds_dir.split(\"ptb-xl\")[-1]\n", + " if suffix == \"\":\n", + " suffix = \"Full\"\n", + " else:\n", + " suffix = suffix[1:] # remove leading dash\n", + " df[\"Split\"] = f\"Train ({suffix})\"\n", + " else:\n", + " df[\"Split\"] = split.title()\n", + " return df\n", + "\n", + "\n", + "dfs_train = []\n", + "df_val = None\n", + "df_test = None\n", + "for i, (train_size, ds_dir) in enumerate(datasets[ds_key].items()):\n", + " df_train = get_df(ds_dir, \"train\")\n", + " assert len(df_train) == train_size\n", + " dfs_train.append(df_train)\n", + " if i == 0:\n", + " df_val = get_df(ds_dir, \"val\")\n", + " df_test = get_df(ds_dir, \"test\")\n", + "\n", + "df = pd.concat(dfs_train + [df_val, df_test], ignore_index=True)\n", + "\n", + "tab1 = TableOne(\n", + " data=df,\n", + " columns=list(df.columns),\n", + " nunique=[\"Patient\"],\n", + " categorical=[c for c in df.columns if c != \"Age\" and c != \"Patient\"],\n", + " continuous=[\"Age\"],\n", + " nonnormal=[\"Age\"],\n", + " groupby=\"Split\",\n", + " order={\n", + " \"Split\": list(df[\"Split\"].drop_duplicates()),\n", + " },\n", + " binary_show={\n", + " k: 1 for k in PTBXL_TARGETS\n", + " },\n", + ")\n", + "\n", + "print(tab1.to_latex())" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "### CinC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "ds_key = \"CinC Georgia\"\n", + "make_dataset = lambda ds_dir, split: CincECGDataset(dataset_path=str(ecg_dir/ds_dir), split=split, sampling_rate=100)\n", + "def get_df(ds_dir, split):\n", + " ds = make_dataset(ds_dir, split)\n", + " df = ds._df[[\"patient_id\", \"age\", \"sex\", \"split\"] + CINC_TARGETS].copy()\n", + " df.columns = [\"Patient\", \"Age\", \"Sex\", \"Split\"] + CINC_TARGETS\n", + " if split == \"train\":\n", + " suffix = ds_dir.split(\"cinc-2020\")[-1]\n", + " if suffix == \"\":\n", + " suffix = \"Full\"\n", + " else:\n", + " suffix = suffix[1:] # remove leading dash\n", + " df[\"Split\"] = f\"Train ({suffix})\"\n", + " else:\n", + " df[\"Split\"] = split.title()\n", + " return df\n", + "\n", + "\n", + "dfs_train = []\n", + "df_val = None\n", + "df_test = None\n", + "for i, (train_size, ds_dir) in enumerate(datasets[ds_key].items()):\n", + " df_train = get_df(ds_dir, \"train\")\n", + " assert len(df_train) == train_size\n", + " dfs_train.append(df_train)\n", + " if i == 0:\n", + " df_val = get_df(ds_dir, \"val\")\n", + " df_test = get_df(ds_dir, \"test\")\n", + "\n", + "df = pd.concat(dfs_train + [df_val, df_test], ignore_index=True)\n", + "\n", + "tab1 = TableOne(\n", + " data=df,\n", + " columns=list(df.columns),\n", + " nunique=[\"Patient\"],\n", + " categorical=[c for c in df.columns if c != \"Age\" and c != \"Patient\"],\n", + " continuous=[\"Age\"],\n", + " nonnormal=[\"Age\"],\n", + " groupby=\"Split\",\n", + " order={\n", + " \"Split\": list(df[\"Split\"].drop_duplicates()),\n", + " },\n", + " binary_show={\n", + " k: 1 for k in CINC_TARGETS\n", + " },\n", + ")\n", + "\n", + "print(tab1.to_latex())" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "### CODE-15%" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "ds_key = \"CODE-15%\"\n", + "make_dataset = lambda ds_dir, split: Code15ECGDataset(dataset_path=str(ecg_dir/ds_dir), split=split, sampling_rate=100)\n", + "def get_df(ds_dir, split):\n", + " ds = make_dataset(ds_dir, split)\n", + " df = ds._df[[\"patient_id\", \"age\", \"is_male\", \"split\"] + CODE15_TARGETS].copy()\n", + " df.columns = [\"Patient\", \"Age\", \"Sex\", \"Split\"] + CODE15_TARGETS\n", + " if split == \"train\":\n", + " suffix = ds_dir.split(\"code15\")[-1]\n", + " if suffix == \"\":\n", + " suffix = \"Full\"\n", + " else:\n", + " suffix = suffix[1:] # remove leading dash\n", + " df[\"Split\"] = f\"Train ({suffix})\"\n", + " else:\n", + " df[\"Split\"] = split.title()\n", + " return df\n", + "\n", + "\n", + "dfs_train = []\n", + "df_val = None\n", + "df_test = None\n", + "for i, (train_size, ds_dir) in enumerate(datasets[ds_key].items()):\n", + " df_train = get_df(ds_dir, \"train\")\n", + " assert len(df_train) == train_size\n", + " dfs_train.append(df_train)\n", + " if i == 0:\n", + " df_val = get_df(ds_dir, \"val\")\n", + " df_test = get_df(ds_dir, \"test\")\n", + "\n", + "df = pd.concat(dfs_train + [df_val, df_test], ignore_index=True)\n", + "\n", + "tab1 = TableOne(\n", + " data=df,\n", + " columns=list(df.columns),\n", + " nunique=[\"Patient\"],\n", + " categorical=[c for c in df.columns if c != \"Age\" and c != \"Patient\"],\n", + " continuous=[\"Age\"],\n", + " nonnormal=[\"Age\"],\n", + " groupby=\"Split\",\n", + " order={\n", + " \"Split\": list(df[\"Split\"].drop_duplicates()),\n", + " },\n", + " binary_show={\n", + " k: 1 for k in CODE15_TARGETS\n", + " },\n", + ")\n", + "\n", + "print(tab1.to_latex())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "protossl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/results/ecg-results.ipynb b/results/ecg-results.ipynb index e89dfc8..2ec018b 100644 --- a/results/ecg-results.ipynb +++ b/results/ecg-results.ipynb @@ -125,12 +125,15 @@ " ###\n", " \"ProtoSSL HEEDB (PIA)\": (\"cornflowerblue\", \"protossl-heedb-pia\"),\n", " \"ProtoSSL HEEDB (PIA) (FT)\": (\"lightsteelblue\", \"protossl-heedb-pia-ft\"),\n", + " ###\n", + " \"ProtoSSL HEEDB (83ppl)\": (\"yellow\", \"protossl-heedb-pila-83ppl\"),\n", + " \"ProtoSSL HEEDB (PIA) (83ppl)\": (\"lime\", \"protossl-heedb-pia-83ppl\"),\n", " \"ProtoSSL HEEDB (PIT)\": (\"blue\", \"protossl-heedb-pit\"),\n", - " \"ProtoSSL HEEDB (PIP)\": (\"darkblue\", \"protossl-heedb-pip\"),\n", + " \"ProtoSSL HEEDB (PIP)\": (\"red\", \"protossl-heedb-pip\"),\n", " ###\n", - " \"ECGFounder (LAP)\": (\"indianred\", \"ecgfounder-lap\"),\n", + " \"ECGFounder (LAP)\": (\"indianred\", \"ecgfounder-lap-1000-init\"),\n", " \"ECGFounder (SK-OT)\": (\"firebrick\", \"ecgfounder-clustering\"),\n", - " \"ECGFounder (Rand)\": (\"maroon\", \"ecgfounder-random\"),\n", + " \"ECGFounder (Rand)\": (\"maroon\", \"ecgfounder-random-1000-init\"),\n", "}\n", "\n", "use_new = [\n", @@ -149,12 +152,15 @@ " \"ProtoSSL HEEDB\",\n", " \"ProtoSSL HEEDB (FT)\",\n", " \"ProtoSSL HEEDB (PIA)\",\n", + " \"ProtoSSL HEEDB (PIA) (FT)\", # new\n", " \"ProtoSSL HEEDB (PIT)\",\n", " \"ProtoSSL HEEDB (PIP)\",\n", " \"ProtoSSL HEEDB (7ppl)\",\n", " \"ProtoSSL HEEDB (7ppl) (FT)\",\n", " \"ProtoSSL HEEDB (28ppl)\",\n", " \"ProtoSSL HEEDB (28ppl) (FT)\",\n", + " \"ProtoSSL HEEDB (83ppl)\",\n", + " \"ProtoSSL HEEDB (PIA) (83ppl)\",\n", "]\n", "\n", "def get_palette(exp_names):\n", @@ -183,7 +189,7 @@ " _output_dir = Path(f\"/opt/gpu_working/steven/new/protossl-outputs-seed{seed}\")\n", " else:\n", " _output_dir = output_dir\n", - " metrics_csv = _output_dir / run_dir / exp_dir / \"metrics-bootstrapped.csv\"\n", + " metrics_csv = _output_dir / run_dir / exp_dir / \"metrics-bootstrapped-v2.csv\"\n", " if not os.path.exists(metrics_csv):\n", " continue\n", " metrics = pd.read_csv(metrics_csv, index_col=\"Label\")\n", @@ -493,6 +499,25 @@ " \"ProtoSSL HEEDB (FT)\",\n", " \"ProtoSSL HEEDB (PIA)\",\n", " \"ProtoSSL HEEDB (PIA) (FT)\",\n", + " ],\n", + " seed=[42, 67, 70, 73, 99],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "plot_lift(\n", + " df=results,\n", + " dataset=\"EchoNext\",\n", + " metric=\"Multilabel (AUROC)\",\n", + " models=[\n", + " \"ProtoSSL HEEDB (83ppl)\",\n", + " \"ProtoSSL HEEDB (PIA) (83ppl)\",\n", " \"ProtoSSL HEEDB (PIT)\",\n", " \"ProtoSSL HEEDB (PIP)\",\n", " ],\n", @@ -503,7 +528,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -527,7 +552,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -563,7 +588,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -584,7 +609,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "16", "metadata": {}, "source": [ "## Make Tables" @@ -593,7 +618,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -640,7 +665,7 @@ }, { "cell_type": "markdown", - "id": "17", + "id": "18", "metadata": {}, "source": [ "### Full Data Scale, Primary Models" @@ -649,7 +674,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -670,7 +695,7 @@ }, { "cell_type": "markdown", - "id": "19", + "id": "20", "metadata": {}, "source": [ "### All Data Scales, Primary Models" @@ -679,7 +704,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -693,7 +718,7 @@ }, { "cell_type": "markdown", - "id": "21", + "id": "22", "metadata": {}, "source": [ "### Blackbox Baselines" @@ -702,7 +727,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -716,7 +741,7 @@ }, { "cell_type": "markdown", - "id": "23", + "id": "24", "metadata": {}, "source": [ "### Ablations" @@ -724,7 +749,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "25", "metadata": {}, "source": [ "#### Num Prototypes" @@ -733,7 +758,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -772,7 +797,7 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "27", "metadata": {}, "source": [ "#### SupProto NoProj" @@ -781,7 +806,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -802,22 +827,22 @@ }, { "cell_type": "markdown", - "id": "28", + "id": "29", "metadata": {}, "source": [ - "#### LAP vs ProtoPool vs PIT vs PIP" + "#### LAP vs ProtoPool Assignment" ] }, { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "30", "metadata": {}, "outputs": [], "source": [ "print(\n", " pivoted.loc[\n", - " [\"ProtoSSL HEEDB (FT)\", \"ProtoSSL HEEDB (PIA) (FT)\", \"ProtoSSL HEEDB\", \"ProtoSSL HEEDB (PIA)\", \"ProtoSSL HEEDB (PIT)\", \"ProtoSSL HEEDB (PIP)\"],\n", + " [\"ProtoSSL HEEDB (FT)\", \"ProtoSSL HEEDB (PIA) (FT)\", \"ProtoSSL HEEDB\", \"ProtoSSL HEEDB (PIA)\"],\n", " [\"EchoNext\"]\n", " ].T.rename(columns={\n", " \"ProtoSSL HEEDB\": \"ProtoSSL HEEDB (LAP)\",\n", @@ -828,7 +853,32 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "31", + "metadata": {}, + "source": [ + "#### No Assignment (PIT & PIP)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " pivoted.loc[\n", + " [\"ProtoSSL HEEDB (83ppl)\", \"ProtoSSL HEEDB (PIT)\", \"ProtoSSL HEEDB (PIP)\"],\n", + " [\"EchoNext\"]\n", + " ].T.rename(columns={\n", + " \"ProtoSSL HEEDB (83ppl)\": \"ProtoSSL HEEDB (LAP) (83PPL)\",\n", + " }).to_latex()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "33", "metadata": {}, "source": [ "### Prototypes from FM" @@ -837,7 +887,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "34", "metadata": {}, "outputs": [], "source": [ @@ -852,7 +902,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "35", "metadata": {}, "outputs": [], "source": [] diff --git a/scripts/3-z-run-protossl-heedb-pila-83ppl.sh b/scripts/3-z-run-protossl-heedb-pila-83ppl.sh new file mode 100755 index 0000000..876ca9a --- /dev/null +++ b/scripts/3-z-run-protossl-heedb-pila-83ppl.sh @@ -0,0 +1,124 @@ +#!/bin/bash + +set -e + +# set these env vars prior to executing this script +: "${DATASET_PATH:?Env var DATASET_PATH must be set prior to script execution}" +: "${RUN_DIR:?Env var RUN_DIR must be set prior to script execution}" +: "${REPO_ROOT:?Env var REPO_ROOT must be set prior to script execution}" +: "${SEED:=42}" +echo "Using SEED=$SEED" +echo "Using DATASET_PATH=$DATASET_PATH" +echo "Using RUN_DIR=$RUN_DIR" +echo "Using REPO_ROOT=$REPO_ROOT" +cd $REPO_ROOT/scripts + +# experiment parameters +EXP_NAME="protossl-heedb-pila-83ppl" +PRETRAIN_RUN="$RUN_DIR/../pass-pretrain-heedb-no-attn" + +if [ ! -e "$RUN_DIR/$EXP_NAME/metrics-bootstrapped.csv" ]; then + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --model.n_prototypes_per_label 83 \ + --pipeline-stage learn-prototype-assignments \ + --assignment-strategy ilp_effect_size \ + --model.n_prototypes 1000 \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $PRETRAIN_RUN/learn-prototypes/latest/best.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --model.n_prototypes_per_label 83 \ + --pipeline-stage project-prototypes-supervised \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --model.n_prototypes_per_label 83 \ + --pipeline-stage compute-embeddings \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + + python _linear_probe.py \ + --random-seed $SEED \ + --dataset-path $DATASET_PATH \ + --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs_bootstrapped.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + +fi + +# now fine-tune +# PRETRAIN_RUN=$RUN_DIR/$EXP_NAME +# EXP_NAME="$EXP_NAME-ft" + +# if [ ! -e "$RUN_DIR/$EXP_NAME/metrics-bootstrapped.csv" ]; then + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ +# --seed_everything $SEED \ +# --model.n_prototypes_per_label 83 \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ +# --seed_everything $SEED \ +# --model.n_prototypes_per_label 83 \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ +# --seed_everything $SEED \ +# --model.n_prototypes_per_label 83 \ +# --pipeline-stage compute-embeddings \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# python _linear_probe.py \ +# --random-seed $SEED \ +# --dataset-path $DATASET_PATH \ +# --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs_bootstrapped.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME + +# fi diff --git a/scripts/4-run-labsup-proto-heedb-rila.sh b/scripts/4-run-labsup-proto-heedb-rila.sh index acd3b8e..3c49364 100755 --- a/scripts/4-run-labsup-proto-heedb-rila.sh +++ b/scripts/4-run-labsup-proto-heedb-rila.sh @@ -17,94 +17,102 @@ cd $REPO_ROOT/scripts EXP_NAME="labsup-proto-heedb-rila" PRETRAIN_RUN="$RUN_DIR/../prosup-pretrain-heedb" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototype-assignments \ - --assignment-strategy ilp_effect_size \ - --model.n_prototypes 1000 \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/project-prototypes-supervised/latest/proj.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ - --seed_everything $SEED \ - --pipeline-stage compute-embeddings \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -python _linear_probe.py \ ---random-seed $SEED \ ---dataset-path $DATASET_PATH \ ---prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs_bootstrapped.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +if [ ! -e "$RUN_DIR/$EXP_NAME/metrics-bootstrapped.csv" ]; then + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --pipeline-stage learn-prototype-assignments \ + --assignment-strategy ilp_effect_size \ + --model.n_prototypes 1000 \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $PRETRAIN_RUN/project-prototypes-supervised/latest/proj.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --pipeline-stage project-prototypes-supervised \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --pipeline-stage compute-embeddings \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + + python _linear_probe.py \ + --random-seed $SEED \ + --dataset-path $DATASET_PATH \ + --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs_bootstrapped.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + +fi # now fine-tune PRETRAIN_RUN=$RUN_DIR/$EXP_NAME EXP_NAME="$EXP_NAME-ft" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ - --seed_everything $SEED \ - --pipeline-stage compute-embeddings \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -python _linear_probe.py \ ---random-seed $SEED \ ---dataset-path $DATASET_PATH \ ---prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs_bootstrapped.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +if [ ! -e "$RUN_DIR/$EXP_NAME/metrics-bootstrapped.csv" ]; then + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --pipeline-stage learn-prototypes-supervised \ + --trainer.logger.save_dir $RUN_DIR/ \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --pipeline-stage project-prototypes-supervised \ + --trainer.logger.save_dir $RUN_DIR/ \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --pipeline-stage compute-embeddings \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + + python _linear_probe.py \ + --random-seed $SEED \ + --dataset-path $DATASET_PATH \ + --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs_bootstrapped.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + +fi diff --git a/scripts/6-z-run-protossl-heedb-pia-83ppl.sh b/scripts/6-z-run-protossl-heedb-pia-83ppl.sh new file mode 100755 index 0000000..a231c00 --- /dev/null +++ b/scripts/6-z-run-protossl-heedb-pia-83ppl.sh @@ -0,0 +1,124 @@ +#!/bin/bash + +set -e + +# set these env vars prior to executing this script +: "${DATASET_PATH:?Env var DATASET_PATH must be set prior to script execution}" +: "${RUN_DIR:?Env var RUN_DIR must be set prior to script execution}" +: "${REPO_ROOT:?Env var REPO_ROOT must be set prior to script execution}" +: "${SEED:=42}" +echo "Using SEED=$SEED" +echo "Using DATASET_PATH=$DATASET_PATH" +echo "Using RUN_DIR=$RUN_DIR" +echo "Using REPO_ROOT=$REPO_ROOT" +cd $REPO_ROOT/scripts + +# experiment parameters +EXP_NAME="protossl-heedb-pia-83ppl" +PRETRAIN_RUN="$RUN_DIR/../pass-pretrain-heedb-no-attn" + +if [ ! -e "$RUN_DIR/$EXP_NAME/metrics-bootstrapped.csv" ]; then + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --model.n_prototypes_per_label 83 \ + --pipeline-stage learn-prototype-assignments \ + --assignment-strategy protopool \ + --model.n_prototypes 1000 \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $PRETRAIN_RUN/learn-prototypes/latest/best.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --model.n_prototypes_per_label 83 \ + --pipeline-stage project-prototypes-supervised \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt + + python -m protossl.trainer \ + --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ + --seed_everything $SEED \ + --model.n_prototypes_per_label 83 \ + --pipeline-stage compute-embeddings \ + --trainer.logger.save_dir $RUN_DIR \ + --trainer.logger.name $EXP_NAME \ + --data.dataset_path $DATASET_PATH \ + --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + + python _linear_probe.py \ + --random-seed $SEED \ + --dataset-path $DATASET_PATH \ + --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + + python _eval_probs_bootstrapped.py \ + --dataset-path $DATASET_PATH \ + --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ + --output-path $RUN_DIR/$EXP_NAME + +fi + +# now fine-tune +# PRETRAIN_RUN=$RUN_DIR/$EXP_NAME +# EXP_NAME="$EXP_NAME-ft" + +# if [ ! -e "$RUN_DIR/$EXP_NAME/metrics-bootstrapped.csv" ]; then + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ +# --seed_everything $SEED \ +# --model.n_prototypes_per_label 83 \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ +# --seed_everything $SEED \ +# --model.n_prototypes_per_label 83 \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/target-guided-14ppl.yaml \ +# --seed_everything $SEED \ +# --model.n_prototypes_per_label 83 \ +# --pipeline-stage compute-embeddings \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# python _linear_probe.py \ +# --random-seed $SEED \ +# --dataset-path $DATASET_PATH \ +# --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs_bootstrapped.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME + +# fi diff --git a/scripts/9-1-run-ecgfounder-lap.sh b/scripts/9-1-run-ecgfounder-lap.sh index 1c990b9..270944e 100644 --- a/scripts/9-1-run-ecgfounder-lap.sh +++ b/scripts/9-1-run-ecgfounder-lap.sh @@ -17,13 +17,14 @@ echo "Using REPO_ROOT=$REPO_ROOT" cd $REPO_ROOT/scripts # experiment parameters -EXP_NAME="ecgfounder-lap" +EXP_NAME="ecgfounder-lap-1000-init" -python prototypes-from-fms/assign_random_prototypes.py \ +python prototypes-from-fms/LAP_assign_random_prototypes.py \ --random-seed $SEED \ --dataset-path $DATASET_PATH \ --patch-embeddings $RUN_DIR/ecgfounder-patches \ --prototypes-per-label $PPL \ +--n-init-protos 1000 \ --output-path $RUN_DIR/$EXP_NAME python _linear_probe.py \ diff --git a/scripts/9-2-run-ecgfounder-clustering.sh b/scripts/9-2-run-ecgfounder-clustering.sh index f3b1c83..6abe395 100644 --- a/scripts/9-2-run-ecgfounder-clustering.sh +++ b/scripts/9-2-run-ecgfounder-clustering.sh @@ -19,7 +19,7 @@ cd $REPO_ROOT/scripts # experiment parameters EXP_NAME="ecgfounder-clustering" -python prototypes-from-fms/prototype_clustering.py \ +python prototypes-from-fms/SK_OT_prototype_clustering.py \ --random-seed $SEED \ --dataset-path $DATASET_PATH \ --patch-embeddings $RUN_DIR/ecgfounder-patches \ diff --git a/scripts/9-3-run-ecgfounder-random.sh b/scripts/9-3-run-ecgfounder-random.sh index 74c4f25..078609f 100644 --- a/scripts/9-3-run-ecgfounder-random.sh +++ b/scripts/9-3-run-ecgfounder-random.sh @@ -17,13 +17,14 @@ echo "Using REPO_ROOT=$REPO_ROOT" cd $REPO_ROOT/scripts # experiment parameters -EXP_NAME="ecgfounder-random" +EXP_NAME="ecgfounder-random-1000-init" -python prototypes-from-fms/only_project_random_prototypes.py \ +python prototypes-from-fms/Rand_assign_random_prototypes.py \ --random-seed $SEED \ --dataset-path $DATASET_PATH \ --patch-embeddings $RUN_DIR/ecgfounder-patches \ --prototypes-per-label $PPL \ +--n-init-protos 1000 \ --output-path $RUN_DIR/$EXP_NAME python _linear_probe.py \ diff --git a/scripts/_eval_probs_bootstrapped.py b/scripts/_eval_probs_bootstrapped.py index 2f2d2e3..b049932 100644 --- a/scripts/_eval_probs_bootstrapped.py +++ b/scripts/_eval_probs_bootstrapped.py @@ -19,7 +19,6 @@ def parse_args(): parser.add_argument("--output-path", required=True) parser.add_argument("--label-subset", nargs="+") parser.add_argument("--n-bootstraps", type=int, default=1000) - parser.add_argument("--bootstrap-frac", type=float, default=0.5) parser.add_argument("--n-jobs", type=int, default=24) args = parser.parse_args() return args @@ -29,9 +28,8 @@ def worker( y_test: np.ndarray, y_prob: np.ndarray, n_bootstraps: int, - bootstrap_frac: float, ) -> dict: - bootstrap_n = int(len(y_test) * bootstrap_frac) + bootstrap_n = len(y_test) label_pos_frac = y_test.sum() / len(y_test) label_metrics = dict() pos_idxs = np.argwhere(y_test).squeeze(1) @@ -79,7 +77,6 @@ def main( output_path: str, label_subset: list[str] | None = None, n_bootstraps: int = 1000, - bootstrap_frac: float = 0.5, n_jobs: int = 24, ): ds_cls, src_label_names, is_audio = infer_dataset_class_from_path(dataset_path) @@ -109,7 +106,6 @@ def main( "y_test": test_targets[:, src_label_names.index(target_col)], "y_prob": target_probs[:, src_label_names.index(target_col)], "n_bootstraps": n_bootstraps, - "bootstrap_frac": bootstrap_frac, } for target_col in label_names ] @@ -123,7 +119,7 @@ def main( _labels = [x for x in label_names if x != composite_target] metrics.loc["Multilabel Averaged"] = metrics.loc[_labels].mean() - metrics_path = os.path.join(output_path, "metrics-bootstrapped.csv") + metrics_path = os.path.join(output_path, "metrics-bootstrapped-v2.csv") metrics.to_csv(metrics_path) print(f"Saved metrics to {metrics_path}") @@ -136,6 +132,5 @@ def main( output_path=args.output_path, label_subset=args.label_subset, n_bootstraps=args.n_bootstraps, - bootstrap_frac=args.bootstrap_frac, n_jobs=args.n_jobs, ) diff --git a/scripts/_slurm_wrapper.sh b/scripts/_slurm_wrapper.sh index feef3c5..22d9395 100755 --- a/scripts/_slurm_wrapper.sh +++ b/scripts/_slurm_wrapper.sh @@ -15,4 +15,8 @@ python3() { srun -u python3 "$@"; } export -f python export -f python3 -bash $1 +case "$1" in + *.sh) bash "$1" "${@:2}" ;; + *.py) python3 "$1" "${@:2}" ;; + *) echo "Unsupported extension" >&2; exit 1 ;; +esac diff --git a/scripts/audio/1-run-blackbox-direct.sh b/scripts/audio/1-run-blackbox-direct.sh index 0d1f04b..00f5d95 100755 --- a/scripts/audio/1-run-blackbox-direct.sh +++ b/scripts/audio/1-run-blackbox-direct.sh @@ -27,19 +27,19 @@ cd $REPO_ROOT/scripts/audio # experiment parameters EXP_NAME="blackbox-direct" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-blackbox.yaml \ - --seed_everything $SEED \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH - -cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-blackbox.yaml \ +# --seed_everything $SEED \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH + +# cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ diff --git a/scripts/audio/2-run-labsup-proto-direct.sh b/scripts/audio/2-run-labsup-proto-direct.sh index 3d22cb8..bc6420b 100755 --- a/scripts/audio/2-run-labsup-proto-direct.sh +++ b/scripts/audio/2-run-labsup-proto-direct.sh @@ -29,43 +29,43 @@ cd $REPO_ROOT/scripts/audio # experiment parameters EXP_NAME="labsup-proto-direct" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage compute-embeddings \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage compute-embeddings \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt -python _linear_probe.py \ ---random-seed $SEED \ ---dataset-path $DATASET_PATH \ ---prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ ---output-path $RUN_DIR/$EXP_NAME +# python _linear_probe.py \ +# --random-seed $SEED \ +# --dataset-path $DATASET_PATH \ +# --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ +# --output-path $RUN_DIR/$EXP_NAME -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ diff --git a/scripts/audio/3-run-protossl-audioset-pila.sh b/scripts/audio/3-run-protossl-audioset-pila.sh index 67a3765..eb0e8c8 100755 --- a/scripts/audio/3-run-protossl-audioset-pila.sh +++ b/scripts/audio/3-run-protossl-audioset-pila.sh @@ -30,46 +30,46 @@ cd $REPO_ROOT/scripts/audio EXP_NAME="protossl-audioset-pila" PRETRAIN_RUN="$RUN_DIR/../pass-audioset" -python -m protossl.trainer \ - --seed_everything $SEED \ - --pipeline-stage learn-prototype-assignments \ - --assignment-strategy ilp_effect_size \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --model.n_prototypes 2635 \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/learn-prototypes/latest/best.ckpt \ - --model.model_kwargs '{"label_type": "multiclass"}' - -python -m protossl.trainer \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt - -python -m protossl.trainer \ - --seed_everything $SEED \ - --pipeline-stage compute-embeddings \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -python _linear_probe.py \ ---random-seed $SEED \ ---dataset-path $DATASET_PATH \ ---prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototype-assignments \ +# --assignment-strategy ilp_effect_size \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --model.n_prototypes 2635 \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototypes/latest/best.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass"}' + +# python -m protossl.trainer \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt + +# python -m protossl.trainer \ +# --seed_everything $SEED \ +# --pipeline-stage compute-embeddings \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# python _linear_probe.py \ +# --random-seed $SEED \ +# --dataset-path $DATASET_PATH \ +# --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ @@ -80,44 +80,44 @@ python _eval_probs_bootstrapped.py \ PRETRAIN_RUN=$RUN_DIR/$EXP_NAME EXP_NAME="$EXP_NAME-ft" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ - --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage compute-embeddings \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -python _linear_probe.py \ ---random-seed $SEED \ ---dataset-path $DATASET_PATH \ ---prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage compute-embeddings \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# python _linear_probe.py \ +# --random-seed $SEED \ +# --dataset-path $DATASET_PATH \ +# --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ diff --git a/scripts/audio/4-run-labsup-proto-audioset-rila.sh b/scripts/audio/4-run-labsup-proto-audioset-rila.sh index 9336640..a696997 100755 --- a/scripts/audio/4-run-labsup-proto-audioset-rila.sh +++ b/scripts/audio/4-run-labsup-proto-audioset-rila.sh @@ -30,46 +30,46 @@ cd $REPO_ROOT/scripts/audio EXP_NAME="labsup-proto-audioset-rila" PRETRAIN_RUN="$RUN_DIR/../prosup-audioset" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototype-assignments \ - --assignment-strategy ilp_effect_size \ - --model.n_prototypes 2635 \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/project-prototypes-supervised/latest/proj.ckpt \ - --model.model_kwargs '{"label_type": "multiclass"}' - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage compute-embeddings \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -python _linear_probe.py \ ---random-seed $SEED \ ---dataset-path $DATASET_PATH \ ---prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototype-assignments \ +# --assignment-strategy ilp_effect_size \ +# --model.n_prototypes 2635 \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/project-prototypes-supervised/latest/proj.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass"}' + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage compute-embeddings \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# python _linear_probe.py \ +# --random-seed $SEED \ +# --dataset-path $DATASET_PATH \ +# --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ @@ -80,44 +80,44 @@ python _eval_probs_bootstrapped.py \ PRETRAIN_RUN=$RUN_DIR/$EXP_NAME EXP_NAME="$EXP_NAME-ft" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ - --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage compute-embeddings \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -python _linear_probe.py \ ---random-seed $SEED \ ---dataset-path $DATASET_PATH \ ---prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ ---output-path $RUN_DIR/$EXP_NAME - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage compute-embeddings \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# python _linear_probe.py \ +# --random-seed $SEED \ +# --dataset-path $DATASET_PATH \ +# --prototype-embeddings $RUN_DIR/$EXP_NAME/compute-embeddings/latest \ +# --output-path $RUN_DIR/$EXP_NAME + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ diff --git a/scripts/audio/_eval_probs_bootstrapped.py b/scripts/audio/_eval_probs_bootstrapped.py index 3331f52..2dce406 100644 --- a/scripts/audio/_eval_probs_bootstrapped.py +++ b/scripts/audio/_eval_probs_bootstrapped.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd import scipy.stats as st -from sklearn.model_selection import train_test_split +from sklearn.utils import resample from protossl.datasets import infer_dataset_class_from_path @@ -16,7 +16,6 @@ def parse_args(): parser.add_argument("--probs-npy", required=True) parser.add_argument("--output-path", required=True) parser.add_argument("--n-bootstraps", type=int, default=1000) - parser.add_argument("--bootstrap-frac", type=float, default=0.5) args = parser.parse_args() return args @@ -27,7 +26,6 @@ def main( probs_npy: str, output_path: str, n_bootstraps: int = 1000, - bootstrap_frac: float = 0.5, ): ds_cls, label_names, is_audio = infer_dataset_class_from_path(dataset_path) if not is_audio: @@ -48,21 +46,24 @@ def main( assert (y >= 0).all() and (y <= 1).all(), "Values must be 0 or 1" os.makedirs(output_path, exist_ok=True) - assert not os.path.exists(os.path.join(output_path, "metrics-bootstrapped.csv")) + assert not os.path.exists(os.path.join(output_path, "metrics-bootstrapped-v2.csv")) y_true = y.argmax(axis=-1) # convert to multiclass y_pred = y_prob.argmax(axis=-1) - bootstrap_n = int(len(y_true) * bootstrap_frac) + bootstrap_n = len(y_true) bootstrapped_metrics = defaultdict(lambda: defaultdict(list)) for b in range(n_bootstraps): - bootstrap_y_true, _, bootstrap_y_pred, _ = train_test_split( + temp = resample( y_true, y_pred, - train_size=bootstrap_n, + replace=True, + n_samples=bootstrap_n, random_state=b, stratify=y_true, ) + assert temp is not None + bootstrap_y_true, bootstrap_y_pred = temp acc = (bootstrap_y_true == bootstrap_y_pred).sum() / len(bootstrap_y_true) bootstrapped_metrics["Multiclass"]["Accuracy"].append(acc) @@ -79,7 +80,7 @@ def main( metrics = pd.DataFrame.from_dict(metrics, orient="index") metrics.index.name = "Label" - metrics_path = os.path.join(output_path, "metrics-bootstrapped.csv") + metrics_path = os.path.join(output_path, "metrics-bootstrapped-v2.csv") metrics.to_csv(metrics_path) print(f"Saved metrics to {metrics_path}") @@ -91,5 +92,4 @@ def main( probs_npy=args.probs_npy, output_path=args.output_path, n_bootstraps=args.n_bootstraps, - bootstrap_frac=args.bootstrap_frac, ) diff --git a/scripts/audio/voxceleb-2-run-labsup-proto-direct.sh b/scripts/audio/voxceleb-2-run-labsup-proto-direct.sh index aaa1397..fc716a0 100755 --- a/scripts/audio/voxceleb-2-run-labsup-proto-direct.sh +++ b/scripts/audio/voxceleb-2-run-labsup-proto-direct.sh @@ -33,40 +33,40 @@ cd $REPO_ROOT/scripts/audio # experiment parameters EXP_NAME="labsup-proto-direct" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage train-classifier \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.model_kwargs '{"label_type": "multiclass"}' \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage train-classifier \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.model_kwargs '{"label_type": "multiclass"}' \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt -cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy +# cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ diff --git a/scripts/audio/voxceleb-3-run-protossl-audioset-pila.sh b/scripts/audio/voxceleb-3-run-protossl-audioset-pila.sh index 866e750..7158ae9 100755 --- a/scripts/audio/voxceleb-3-run-protossl-audioset-pila.sh +++ b/scripts/audio/voxceleb-3-run-protossl-audioset-pila.sh @@ -34,43 +34,43 @@ cd $REPO_ROOT/scripts/audio EXP_NAME="protossl-audioset-pila" PRETRAIN_RUN="$RUN_DIR/../pass-audioset" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototype-assignments \ - --assignment-strategy ilp_effect_size \ - --model.n_prototypes 2635 \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/learn-prototypes/latest/best.ckpt \ - --model.model_kwargs '{"label_type": "multiclass"}' - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage train-classifier \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.model_kwargs '{"label_type": "multiclass"}' \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototype-assignments \ +# --assignment-strategy ilp_effect_size \ +# --model.n_prototypes 2635 \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototypes/latest/best.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass"}' + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage train-classifier \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.model_kwargs '{"label_type": "multiclass"}' \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ @@ -81,41 +81,41 @@ python _eval_probs_bootstrapped.py \ PRETRAIN_RUN=$RUN_DIR/$EXP_NAME EXP_NAME="$EXP_NAME-ft" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ - --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage train-classifier \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.model_kwargs '{"label_type": "multiclass"}' \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage train-classifier \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.model_kwargs '{"label_type": "multiclass"}' \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ diff --git a/scripts/audio/voxceleb-4-run-labsup-proto-audioset-rila.sh b/scripts/audio/voxceleb-4-run-labsup-proto-audioset-rila.sh index 0f71011..5fd8ab9 100755 --- a/scripts/audio/voxceleb-4-run-labsup-proto-audioset-rila.sh +++ b/scripts/audio/voxceleb-4-run-labsup-proto-audioset-rila.sh @@ -34,43 +34,43 @@ cd $REPO_ROOT/scripts/audio EXP_NAME="labsup-proto-audioset-rila" PRETRAIN_RUN="$RUN_DIR/../prosup-audioset" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototype-assignments \ - --assignment-strategy ilp_effect_size \ - --model.n_prototypes 2635 \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/project-prototypes-supervised/latest/proj.ckpt \ - --model.model_kwargs '{"label_type": "multiclass"}' - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage train-classifier \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.model_kwargs '{"label_type": "multiclass"}' \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototype-assignments \ +# --assignment-strategy ilp_effect_size \ +# --model.n_prototypes 2635 \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/project-prototypes-supervised/latest/proj.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass"}' + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototype-assignments/latest/assigned.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage train-classifier \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.model_kwargs '{"label_type": "multiclass"}' \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ @@ -81,41 +81,41 @@ python _eval_probs_bootstrapped.py \ PRETRAIN_RUN=$RUN_DIR/$EXP_NAME EXP_NAME="$EXP_NAME-ft" -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage learn-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ - --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage project-prototypes-supervised \ - --trainer.logger.save_dir $RUN_DIR/ \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt - -python -m protossl.trainer \ - --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ - --seed_everything $SEED \ - --pipeline-stage train-classifier \ - --trainer.logger.save_dir $RUN_DIR \ - --trainer.logger.name $EXP_NAME \ - --data.dataset_path $DATASET_PATH \ - --model.model_kwargs '{"label_type": "multiclass"}' \ - --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt - -cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy - -python _eval_probs.py \ ---dataset-path $DATASET_PATH \ ---probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ ---output-path $RUN_DIR/$EXP_NAME +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage learn-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $PRETRAIN_RUN/learn-prototype-assignments/latest/assigned.ckpt \ +# --model.model_kwargs '{"label_type": "multiclass", "use_default_weights": True}' + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage project-prototypes-supervised \ +# --trainer.logger.save_dir $RUN_DIR/ \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/learn-prototypes-supervised/latest/best.ckpt + +# python -m protossl.trainer \ +# --config $REPO_ROOT/configs/audio/target-guided-$PPL.yaml \ +# --seed_everything $SEED \ +# --pipeline-stage train-classifier \ +# --trainer.logger.save_dir $RUN_DIR \ +# --trainer.logger.name $EXP_NAME \ +# --data.dataset_path $DATASET_PATH \ +# --model.model_kwargs '{"label_type": "multiclass"}' \ +# --model.pretrained_weights $RUN_DIR/$EXP_NAME/project-prototypes-supervised/latest/proj.ckpt + +# cp $RUN_DIR/$EXP_NAME/train-classifier/latest/probs.npy $RUN_DIR/$EXP_NAME/probs.npy + +# python _eval_probs.py \ +# --dataset-path $DATASET_PATH \ +# --probs-npy $RUN_DIR/$EXP_NAME/probs.npy \ +# --output-path $RUN_DIR/$EXP_NAME python _eval_probs_bootstrapped.py \ --dataset-path $DATASET_PATH \ diff --git a/scripts/prototypes-from-fms/assign_random_prototypes.py b/scripts/prototypes-from-fms/LAP_assign_random_prototypes.py similarity index 94% rename from scripts/prototypes-from-fms/assign_random_prototypes.py rename to scripts/prototypes-from-fms/LAP_assign_random_prototypes.py index 61aa46f..c56ec3e 100644 --- a/scripts/prototypes-from-fms/assign_random_prototypes.py +++ b/scripts/prototypes-from-fms/LAP_assign_random_prototypes.py @@ -21,6 +21,7 @@ def parse_args(): parser.add_argument("--prototypes-per-label", type=int, required=True) parser.add_argument("--output-path", required=True) parser.add_argument("--random-seed", type=int, default=42) + parser.add_argument("--n-init-protos", type=int) return parser.parse_args() @@ -31,6 +32,7 @@ def main( prototypes_per_label: int, output_path: str, random_seed: int = 42, + n_init_protos: int | None = None, ): ds_cls, _, is_audio = infer_dataset_class_from_path(dataset_path) assert not is_audio, "ECG only script" @@ -57,10 +59,13 @@ def main( N, L, H = X_train.shape _N, C = y_train.shape assert N == _N + if n_init_protos is None: + n_init_protos = C * K + assert n_init_protos >= C * K rng = torch.manual_seed(random_seed) protos = nn.init.trunc_normal_( # (P, H) - torch.empty(C * K, H), std=0.02, generator=rng + torch.empty(n_init_protos, H), std=0.02, generator=rng ) protos_norm = F.normalize(protos, dim=-1) @@ -135,4 +140,5 @@ def main( prototypes_per_label=args.prototypes_per_label, output_path=args.output_path, random_seed=args.random_seed, + n_init_protos=args.n_init_protos, ) diff --git a/scripts/prototypes-from-fms/only_project_random_prototypes.py b/scripts/prototypes-from-fms/Rand_assign_random_prototypes.py similarity index 88% rename from scripts/prototypes-from-fms/only_project_random_prototypes.py rename to scripts/prototypes-from-fms/Rand_assign_random_prototypes.py index a031edd..735d49a 100644 --- a/scripts/prototypes-from-fms/only_project_random_prototypes.py +++ b/scripts/prototypes-from-fms/Rand_assign_random_prototypes.py @@ -20,6 +20,7 @@ def parse_args(): parser.add_argument("--prototypes-per-label", type=int, required=True) parser.add_argument("--output-path", required=True) parser.add_argument("--random-seed", type=int, default=42) + parser.add_argument("--n-init-protos", type=int) return parser.parse_args() @@ -30,6 +31,7 @@ def main( prototypes_per_label: int, output_path: str, random_seed: int = 42, + n_init_protos: int | None = None, ): ds_cls, _, is_audio = infer_dataset_class_from_path(dataset_path) assert not is_audio, "ECG only script" @@ -56,12 +58,19 @@ def main( N, L, H = X_train.shape _N, C = y_train.shape assert N == _N + if n_init_protos is None: + n_init_protos = C * K + assert n_init_protos >= C * K rng = torch.manual_seed(random_seed) - protos = nn.init.trunc_normal_( # (C, K, H) - torch.empty(C, K, H), std=0.02, generator=rng + _protos = nn.init.trunc_normal_( # (P, H) + torch.empty(n_init_protos, H), std=0.02, generator=rng ) + # randomly "assign" C*K prototypes + indices = torch.randperm(n_init_protos, generator=rng)[: C * K] + protos = _protos[indices].reshape(C, K, H) # (C, K, H) + # project prototypes to be real samples protos_norm = F.normalize(protos, dim=-1) # (C, K, H) X_norm = F.normalize(X_train, dim=-1) # (N, L, H) diff --git a/scripts/prototypes-from-fms/prototype_clustering.py b/scripts/prototypes-from-fms/SK_OT_prototype_clustering.py similarity index 100% rename from scripts/prototypes-from-fms/prototype_clustering.py rename to scripts/prototypes-from-fms/SK_OT_prototype_clustering.py diff --git a/scripts/queue-experiments.sh b/scripts/queue-experiments.sh index 4169637..a08c60f 100755 --- a/scripts/queue-experiments.sh +++ b/scripts/queue-experiments.sh @@ -2,8 +2,7 @@ set -e -# declare -a DATASETS=("echonext" "ptbxl" "cinc" "mimic" "zzu" "code15") -declare -a DATASETS=("echonext") +declare -a DATASETS=("echonext" "ptbxl" "cinc" "mimic" "zzu" "code15") declare -a SEEDS=(42 67 70 73 99) declare -A DATASET_DIRS=( ["echonext"]="echonext" @@ -46,37 +45,42 @@ for seed in "${SEEDS[@]}"; do cache_id=$(submit_job "$suffix" 0-run-cache-data.sh) echo $cache_id - # submit_job "$suffix" 1-run-blackbox-direct.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 2-run-labsup-proto-direct.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 3-run-protossl-heedb-pila.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 4-run-labsup-proto-heedb-rila.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 1-run-blackbox-direct.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 2-run-labsup-proto-direct.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 3-run-protossl-heedb-pila.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 4-run-labsup-proto-heedb-rila.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 5-1-run-ecgfounder-logreg.sh # does not depend on same 100 Hz cache (takes 500 Hz) - # submit_job "$suffix" 5-2-run-stmem-logreg.sh # does not depend on same 100 Hz cache (takes 250 Hz) + submit_job "$suffix" 5-1-run-ecgfounder-logreg.sh # does not depend on same 100 Hz cache (takes 500 Hz) + submit_job "$suffix" 5-2-run-stmem-logreg.sh # does not depend on same 100 Hz cache (takes 250 Hz) if [ "$dataset" == "echonext" ]; then echo "Ablations" # prototypes per label - # submit_job "$suffix" 2-z-run-labsup-proto-direct-7ppl.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 2-z-run-labsup-proto-direct-28ppl.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 3-z-run-protossl-heedb-pila-7ppl.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 3-z-run-protossl-heedb-pila-28ppl.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 4-z-run-labsup-proto-heedb-rila-7ppl.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 4-z-run-labsup-proto-heedb-rila-28ppl.sh "--dependency=afterok:$cache_id" - + submit_job "$suffix" 2-z-run-labsup-proto-direct-7ppl.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 2-z-run-labsup-proto-direct-28ppl.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 3-z-run-protossl-heedb-pila-7ppl.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 3-z-run-protossl-heedb-pila-28ppl.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 4-z-run-labsup-proto-heedb-rila-7ppl.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 4-z-run-labsup-proto-heedb-rila-28ppl.sh "--dependency=afterok:$cache_id" + + # supproto no-proj submit_job "$suffix" 4-y-run-labsup-proto-heedb-pila.sh "--dependency=afterok:$cache_id" - # assignment/projection method - # submit_job "$suffix" 6-run-protossl-heedb-pia.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 7-run-protossl-heedb-pit.sh "--dependency=afterok:$cache_id" - # submit_job "$suffix" 8-run-protossl-heedb-pip.sh "--dependency=afterok:$cache_id" + # assignment method + submit_job "$suffix" 6-run-protossl-heedb-pia.sh "--dependency=afterok:$cache_id" + + # no assignment + submit_job "$suffix" 3-z-run-protossl-heedb-pila-83ppl.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 6-z-run-protossl-heedb-pia-83ppl.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 7-run-protossl-heedb-pit.sh "--dependency=afterok:$cache_id" + submit_job "$suffix" 8-run-protossl-heedb-pip.sh "--dependency=afterok:$cache_id" # start with pretrained encoder - # patches_id=$(submit_job "$suffix" 9-0-run-ecgfounder-patches.sh) - # echo $patches_id - # submit_job "$suffix" 9-1-run-ecgfounder-lap.sh "--dependency=afterok:$patches_id" - # submit_job "$suffix" 9-2-run-ecgfounder-clustering.sh "--dependency=afterok:$patches_id" - # submit_job "$suffix" 9-3-run-ecgfounder-random.sh "--dependency=afterok:$patches_id" + patches_id=$(submit_job "$suffix" 9-0-run-ecgfounder-patches.sh) + echo $patches_id + submit_job "$suffix" 9-1-run-ecgfounder-lap.sh "--dependency=afterok:$patches_id" + submit_job "$suffix" 9-2-run-ecgfounder-clustering.sh "--dependency=afterok:$patches_id" + submit_job "$suffix" 9-3-run-ecgfounder-random.sh "--dependency=afterok:$patches_id" fi done done