Catastrophic forgetting has remained a critical challenge for deep neural networks in Continual Learning (CL) as it undermines consolidated knowledge when learning new tasks. Parameter efficient fine-tuning CL techniques are gaining traction for their effectiveness in addressing catastrophic forgetting with lightweight training schedule while avoiding degradation of consolidated knowledge in pre-trained models. However, low-rank adapters (LoRA) in these approaches are highly sensitive to rank selection as it can lead to suboptimal resource allocation and performance. To this end, we introduce PEARL, a rehearsal-free CL framework that entails dynamic rank allocation for LoRA components during CL training. Specifically, PEARL leverages reference task weights and adaptively determines the rank of task-specific LoRA components based on the current task's proximity to reference task weights in parameter space. To demonstrate the versatility of PEARL, we evaluate PEARL across three vision architectures (ResNet, Separable Convolutional Network, and Vision Transformer) and a multitude of CL scenarios, and show that PEARL outperforms all considered baselines by a large margin.
Before you begin, ensure you have met the following requirements:
- Python: Make sure Python 3.x is installed on your system. This project was developed and tested with Python 3.9.18.
- Virtual Environment: It is recommended to use a virtual environment / conda to manage dependencies.
-
Clone the repository:
git clone https://github.com/NeurAI-Lab/pearl.git
-
Create a virtual environment and activate it:
python3 -m venv myenv source myenv/bin/activate -
Install the required packages:
pip install -r requirements.txt
-
CIFAR-10, CIFAR-100, and TinyImage will be downloaded automatically when the project runs.
-
For DomainNet and ImageNet-R, run the
download.shscript. This will download the datasets and store them under thedatadirectory.
dataset --> Dataset to use (e.g., cifar100).
image_size --> Size of the input images (e.g., 32).
n_classes_per_task --> Number of classes per task (e.g., 20).
n_tasks --> Total number of tasks (e.g., 5).
n_epochs --> Number of training epochs for full model finetuning (e.g., 10).
n_epochs_lora --> Number of training epochs for LoRA (e.g., 10).
model --> Model architecture to use (e.g., resnet, separable_conv, vit).
batch_size --> Size of the batches for training (e.g., 32).
seed --> Random seed for reproducibility (e.g., 20).
lr --> Learning rate for training (e.g., 0.001).
lr_lora --> Learning rate for LoRA (e.g., 0.001).
scheduler --> Type of learning rate scheduler (e.g., step, cosine).
data_root --> Root directory of the dataset (e.g., ../data).
log_dir --> Directory for storing log files (e.g., ./logs_vit).
vit_name --> Pre-trained ViT model name (e.g., vit_base_patch16_224_in21k).
target_modules --> Target modules for LoRA (e.g., ["key"]).
alpha --> Value for alpha (e.g., 2r).
weight_init --> Whether to perform weight initialization (e.g., False).
factor --> Factor for number of filters (e.g., 2).
depth --> Number of layers (e.g., 3).
svd_threshold --> SVD threshold (e.g., 0).
forward_transfer --> Whether to use forward transfer (e.g., True).
weight_renorm --> Whether to use weight renormalization (e.g., False).
verbose --> Whether to print verbose information (e.g., True).
save_model --> Whether to save the model (e.g., False).
use_wandb --> Whether to use Weights and Biases for tracking (e.g., False).
To train the network with the default settings:
cd PEARL/
python main.py --config config/config.ini