diff --git a/NDP-HNN/main.py b/NDP-HNN/main.py index cddb949..4458e73 100644 --- a/NDP-HNN/main.py +++ b/NDP-HNN/main.py @@ -3,6 +3,7 @@ Contributer: Lalith Bharadwaj Baru """ import argparse +import json import os from config import Config from utils import set_seed, get_device, ensure_dir @@ -52,7 +53,12 @@ def main(): ).to(device) #--- 4. train - model = train_model(model, snaps, dataset, epochs=args.epochs, lr=args.lr, device=device) + model, history = train_model(model, snaps, dataset, epochs=args.epochs, lr=args.lr, device=device) + + history_path = os.path.join(args.save_dir, "loss_history.json") + with open(history_path, "w") as f: + json.dump(history, f, indent=2) + print(f"Loss history saved to: {history_path}") #--- 5. embeddings (T, N, D) embeds = extract_embeddings(model, snaps, device=device) diff --git a/NDP-HNN/train.py b/NDP-HNN/train.py index d8f12db..00703ed 100644 --- a/NDP-HNN/train.py +++ b/NDP-HNN/train.py @@ -12,7 +12,7 @@ def train_model(model, dataset: Dict[str, Any], epochs: int = 30, lr: float = 1e-3, - device: str = "cuda"): + device: str = "cuda") -> tuple: birth_feat = dataset['birth_feat'] birth_times = dataset['birth_times'] @@ -21,9 +21,13 @@ def train_model(model, model.to(device) opt = torch.optim.Adam(model.parameters(), lr=lr) + history: Dict[str, List[float]] = {"loss": [], "loss_xyz": [], "loss_rec": []} + for epoch in range(1, epochs+1): state = None total_loss = 0.0 + total_xyz = 0.0 + total_rec = 0.0 for data in snapshots: data = data.to(device) @@ -55,8 +59,20 @@ def train_model(model, state = (state[0].detach(), state[1].detach()) else: state = state.detach() + total_loss += float(loss.item()) + total_xyz += float(loss_xyz.item()) + total_rec += float(loss_rec.item()) + + n = len(snapshots) + avg_loss = total_loss / n + avg_xyz = total_xyz / n + avg_rec = total_rec / n + + history["loss"].append(avg_loss) + history["loss_xyz"].append(avg_xyz) + history["loss_rec"].append(avg_rec) - print(f"Epoch {epoch:03d} — avg loss: {total_loss/len(snapshots):.4f}") + print(f"Epoch {epoch:03d} — loss: {avg_loss:.4f} xyz: {avg_xyz:.4f} rec: {avg_rec:.4f}") - return model + return model, history