Skip to content

Bug for fine-tuning Omat24 checkpoint model #2

@jinlhr542

Description

@jinlhr542

I am trying to fine-tuning the checkpoint_sevennet_mf_ompa.pth model:

import sevenn.util as util
model, config = util.model_from_checkpoint('checkpoint_sevennet_mf_ompa.pth')
cutoff = config['cutoff'] 
dataset = SevenNetGraphDataset(cutoff=cutoff, root=working_dir, files=dataset_files, processed_name='train.pt')

from sevenn.train.trainer import Trainer
import torch.optim.lr_scheduler as scheduler

trainer = Trainer.from_config(model, config)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 4
      1 from sevenn.train.trainer import Trainer
      2 import torch.optim.lr_scheduler as scheduler
----> 4 trainer = Trainer.from_config(model, config)
      6 # We have energy, force, stress loss function, which used to train 7net-0.
      7 # We will use it as it is, with loss weight: 1.0, 1.0, and 0.01 for energy, force, and stress, respectively.
      8 print(trainer.loss_functions)

File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py:88](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py#line=87), in Trainer.from_config(model, config)
     84 @staticmethod
     85 def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer':
     86     trainer = Trainer(
     87         model,
---> 88         loss_functions=get_loss_functions_from_config(config),
     89         optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()],
     90         optimizer_args=config.get(KEY.OPTIM_PARAM, {}),
     91         scheduler_cls=scheduler_dict[
     92             config.get(KEY.SCHEDULER, 'exponentiallr').lower()
     93         ],
     94         scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}),
     95         device=config.get(KEY.DEVICE, 'auto'),
     96         distributed=config.get(KEY.IS_DDP, False),
     97         distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'),
     98     )
     99     return trainer

File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py:211](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py#line=210), in get_loss_functions_from_config(config)
    207 from sevenn.train.optim import loss_dict
    209 loss_functions = []  # list of tuples (loss_definition, weight)
--> 211 loss = loss_dict[config[KEY.LOSS].lower()]
    212 loss_param = config.get(KEY.LOSS_PARAM, {})
    214 use_weight = config.get(KEY.USE_WEIGHT, False)

AttributeError: 'dict' object has no attribute 'lower'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions