An experimental approach using pathology-guided hybrid CNN–Transformer models
University of Hertfordshire — Department of Computer Science BSc Artificial Intelligence Project, 6COM2017
Author: Riya Basak; Supervised by: Dr Kheng Lee Koay
Overview • Contributions • Reproducibility • Dataset • Training • Results • Demo • Library • Citation • Live
Brain tumour MRI classifiers can appear accurate while relying on non-tumour shortcuts such as artefacts, skull boundaries, or acquisition bias.
This project develops hybrid CNN–Transformer models with experimental guidance modules and explainability tools to encourage tumour-centred reasoning and more interpretable outputs.
The repository contains:
-
offline preprocessing with leakage-safe SHA1 deduplication, tight-crop preprocessing, and split generation,
-
four training variants:
- Hybrid A (PFD-A and GSTE-A),
- Hybrid B (PFD-B and GSTE-B),
- ablation without PFD-A / GSTE-A,
- ablation without PFD-B / GSTE-B,
-
post-hoc explainability outputs using Grad-CAM++ and attention rollout,
-
a local Flask demo web app for qualitative inspection on a single uploaded image,
-
a reusable PFD-GSTE guidance library in
pfd_gste/for importing the guidance modules into other CNN, Transformer, or hybrid medical image classifiers.
| Field | Detail |
|---|---|
| Project title | Mitigating Shortcut Learning in Brain Tumour MRI Classification |
| Programme | Modular BSc (Hons) Computer Science (Artificial Intelligence) |
| Module | 6COM2017 – Artificial Intelligence Project |
| Institution | University of Hertfordshire |
| Author | Riya Basak |
| Supervisor | Dr Kheng Lee Koay |
| Main repository | HybridResNet50V2-RViT |
| Main task | Four-class brain tumour MRI classification |
| Classes | glioma, meningioma, pituitary, notumor |
| Core methods | PFD, GSTE, hybrid CNN–Transformer classification, Grad-CAM++, attention rollout, MC Dropout |
This project introduces two experimental guidance modules inside a hybrid CNN–Transformer pipeline.
| Component | Description |
|---|---|
| PFD — Pathology-Focused Disentanglement | A learnable soft spatial mask over the CNN feature map. |
| GSTE — Guided Semantic Token Evolution | Reuses the same mask to guide transformer tokens. |
Two guidance strengths are implemented for matched comparison:
| Variant | Guidance design |
|---|---|
| Hybrid A | PFD-A gates only the transformer token pathway; GSTE-A reweights 49 CNN tokens. |
| Hybrid B | PFD-B influences both the CNN descriptor and transformer guidance pathway; GSTE-B weights 196 patch tokens and can optionally shrink them toward highlighted regions. |
.
├── README.md
├── requirements.txt
├── .gitattributes # Git LFS tracking rules for model checkpoints
├── .gitignore # ignored files and folders
├── CITATION.bib # BibTeX citation
├── CITATION.cff # machine-readable citation metadata
├── CODE_OF_CONDUCT.md # contributor and responsible-use guidelines
├── LICENSE
├── Research_Note.pdf # research note
├── additional.txt # additional submitted project text
├── data/
│ ├── raw/brain-tumor-mri-dataset/ # downloaded Kaggle dataset
│ ├── processed/tightcrop/ # generated 224x224 tight-cropped images
│ └── splits/tightcrop/ # generated train.csv / val.csv / test.csv
├── docs/
│ ├── dataset_prep.md # preprocessing notes
│ └── images/
│ ├── hybrid-a-architecture.png
│ ├── hybrid-b-architecture.png
│ └── demo-app.png
├── pfd_gste/ # reusable PFD-GSTE guidance modules
│ ├── __init__.py # public package interface
│ └── guidance.py # PFD, GSTE-A, GSTE-B, patch guidance, and MC-dropout helpers
├── results/ # preprocessing audit outputs, CSV summaries, plots, and evaluation outputs
├── Misclassified-results/ # saved misclassification examples and related analysis outputs
├── scripts/
│ ├── data.py # BrainMRICSV and build_transforms
│ ├── dataset_prep.py # offline preprocessing, leakage-safe deduplication, and split generation
│ ├── dataset_plots.py # dataset plots used by preprocessing
│ └── Confusion_metrics_plot_generator.py # helper plotting utilities
├── Hybrid-model-with-pfdA-gsteA/
│ ├── models/
│ │ └── hybrid_model.py
│ ├── train-A.py
│ ├── Xai-A.py
│ └── best_model.pt # trained checkpoint stored through Git LFS, if pulled correctly
├── Hybrid-model-with-pfdB-gsteB/
│ ├── models/
│ │ └── hybrid_model.py
│ ├── train-B.py
│ ├── Xai-B.py
│ └── best_model.pt # trained checkpoint stored through Git LFS, if pulled correctly
├── Hybrid-model-without-pfdA-gsteA/
│ ├── models/
│ │ └── hybrid_model.py
│ ├── train-without-A.py
│ ├── Xai-without-A.py
│ └── best_model.pt # trained checkpoint stored through Git LFS, if pulled correctly
├── Hybrid-model-without-pfdB-gsteB/
│ ├── models/
│ │ └── hybrid_model.py
│ ├── train-without-B.py
│ ├── Xai-without-B.py
│ └── best_model.pt # trained checkpoint stored through Git LFS, if pulled correctly
└── webapp/
├── app.py # local Flask demo app
├── models_registry.json # checkpoint/model registry used by the demo
├── requirements.txt # separate dependencies for the web app
└── templates/
└── index.html # demo web interface
There are two ways to use this repository:
- Reproduce the full project from raw data by running preprocessing, training, and XAI scripts.
- Run the demo web app locally using trained checkpoints stored with Git LFS.
The demo web app is a supporting qualitative inspection tool. It is not the main training or evaluation pipeline.
Run commands from the main project folder unless a step tells you to change into another folder such as webapp/.
- Download and extract the Kaggle dataset into
data/raw/brain-tumor-mri-dataset/. - Run preprocessing:
python scripts/dataset_prep.py- Run one of the training scripts for Hybrid A, Hybrid B, Ablation A, or Ablation B.
- Optionally run the corresponding XAI script.
- Optionally run the demo app from
webapp/.
Dataset used: Masoud Nickparvar (2021), Brain Tumor MRI Dataset
Classes: glioma, meningioma, pituitary, and no tumor
Code label for no tumor: notumor
Reference: Masoud Nickparvar. (2021). Brain Tumor MRI Dataset [Data set]. Kaggle. https://doi.org/10.34740/KAGGLE/DSV/2645886
The benchmark combines multiple public sources, including figshare, SARTAJ, and Br35H.
In this project’s curation notes, SARTAJ glioma-class issues were handled by using figshare images instead.
To reproduce the work from scratch:
- Download the Kaggle dataset.
- Extract it into:
data/raw/brain-tumor-mri-dataset/
The preprocessing script expects the raw dataset to be present in that location before it is run.
Preprocessing is performed by:
python scripts/dataset_prep.py- Remove duplicates using SHA1-based leakage-safe deduplication.
- Create a clean 224×224 RGB dataset using tight-crop preprocessing.
- Keep Kaggle Testing as the held-out test set and create train/val from Kaggle Training.
| Step | Detail |
|---|---|
| Raw images scanned | 7023 |
| Unique images after deduplication | 6726 |
| Duplicates removed | 297 |
| Total suspect images | 0 |
| Image correction | EXIF transpose where needed |
| Crop method | tight crop |
| Crop threshold | 5 |
| Crop margin | 10 |
| Colour format | RGB |
| Final image size | 224×224 |
Processed images:
data/processed/tightcrop/{train,val,test}/{class}/
CSV split files:
data/splits/tightcrop/train.csv
data/splits/tightcrop/val.csv
data/splits/tightcrop/test.csv
Audit and summary outputs:
results/
| Split | Count |
|---|---|
| Train | 4353 |
| Val | 1089 |
| Test | 1284 |
| Split | glioma | meningioma | pituitary | notumor |
|---|---|---|---|---|
| Train | 1057 | 1064 | 1152 | 1080 |
| Val | 264 | 267 | 288 | 270 |
| Test | 299 | 304 | 300 | 381 |
Hybrid A uses pathology-focused guidance only on the transformer token pathway.
- input: RGB 224×224,
- ResNet50V2 backbone produces
featwith shape(B, 2048, 7, 7), - CNN descriptor is computed from ungated features,
- PFD-A gates features only for transformer token formation,
- the 7×7 grid becomes 49 tokens,
- GSTE-A reweights those tokens while keeping token count fixed,
- four internal rotations are used: 0°, 90°, 180°, 270°,
- a rotation-aware transformer encoder produces the token descriptor,
- CNN and transformer descriptors are fused for 4-class classification.
Hybrid B applies stronger pathology-focused guidance across the CNN descriptor and transformer guidance pathway.
- uses the same backbone and learned pathology mask,
- CNN descriptor is computed from gated features,
- the transformer branch uses raw-image patch tokens (14×14 = 196 tokens),
- the guidance mask is upsampled and pooled to the patch grid,
- optional token-grid shrinking can reduce computation toward highlighted regions,
- the fusion and classification pattern remains the same.
| Item | Detail |
|---|---|
| Python | 3.12.2 |
| Training platform | Kaggle |
| GPU | Tesla P100 |
| Local development | macOS |
| Platform string | macOS-26.2-arm64-arm-64bit |
For preprocessing, training, evaluation, and XAI:
torchtorchvisiontimmnumpypandasscikit-learnmatplotlibpillow
flasktorchtorchvisiontimmnumpypillowmatplotlib
For preprocessing, training, and XAI:
python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txtFor the demo web app, use the separate environment in webapp/requirements.txt.
Training scripts do not call preprocessing automatically.
They read only:
data/splits/tightcrop/train.csv,data/splits/tightcrop/val.csv,data/splits/tightcrop/test.csv,
using the shared loader and transforms from:
scripts/data.py→BrainMRICSV,build_transforms.
| Setting | Value |
|---|---|
| Epochs | 100 |
| Batch size | 32 |
| Seed | 42 |
| Device | "cuda" if available, else "cpu" |
| DataLoader workers | 2 |
| Pin memory | True |
| Hybrid B-style training loader | drop_last=True |
Training only:
RandomRotation(±15°),RandomHorizontalFlip(p=0.5),RandomAffine(translate=0.05),- optional Gaussian noise
(std=0.02, p=0.5), - normalisation with mean = std =
(0.5, 0.5, 0.5).
Validation and test use tensor conversion and normalisation only.
| Component | Detail |
|---|---|
| Optimiser | AdamW |
| CNN learning rate | 1e-4 |
| Transformer / fusion learning rate | 5e-4 |
| Weight decay | 0.01 |
| No decay applied to | bias, norm, or 1D parameters |
| Training loss | CrossEntropy(label_smoothing=0.05) |
| Evaluation loss | plain CrossEntropy |
| Scheduler | CosineAnnealingLR |
| Minimum LR | eta_min = 1e-6 |
| Warmup | freeze CNN for 5 epochs, then unfreeze and rebuild optimiser and scheduler |
| Gradient clipping | max_norm = 1.0 |
--ampfor mixed precision,--freeze_cnn_bnto freeze CNN BatchNorm statistics.
| Item | Detail |
|---|---|
| Monitored metric | validation macro-F1 |
| Patience | 10 |
| Best checkpoint | best_model.pt |
| Checkpoint contents | weights, class names, normalisation values, training arguments, and model configuration |
Each run saves:
best_model.pt,history.csv,loss_curves.png,acc_curves.png,confusion_matrix.png,metrics.json.
python Hybrid-model-with-pfdA-gsteA/train-A.pypython Hybrid-model-with-pfdB-gsteB/train-B.pypython Hybrid-model-without-pfdA-gsteA/train-without-A.pypython Hybrid-model-without-pfdB-gsteB/train-without-B.pypython Hybrid-model-with-pfdA-gsteA/train-A.py --ampEach variant has a corresponding XAI script for qualitative inspection of model focus.
python Hybrid-model-with-pfdA-gsteA/Xai-A.pypython Hybrid-model-with-pfdB-gsteB/Xai-B.pypython Hybrid-model-without-pfdA-gsteA/Xai-without-A.pypython Hybrid-model-without-pfdB-gsteB/Xai-without-B.pyThese scripts generate qualitative outputs such as Grad-CAM++ and attention rollout visualisations.
To support fair comparison across variants:
- preprocessing is performed once offline,
- all training variants use the same leakage-aware CSV splits,
- the default random seed is 42,
- the default training length is 100 epochs,
- the best checkpoint is selected using validation macro-F1,
- test performance is reported on the held-out test split,
- Hybrid B-style runs use
drop_last=True; Hybrid A-style runs do not.
Because package versions, CUDA availability, drivers, and hardware can differ across systems, exact runtime behaviour may vary even when the same code and seed are used.
data/
├── raw/brain-tumor-mri-dataset/
├── processed/tightcrop/
└── splits/tightcrop/
├── train.csv
├── val.csv
└── test.csv
If these outputs do not exist, the training scripts will not run correctly.
The table below summarises the final test-set performance recorded in each run’s metrics.json.
| Model | Test Acc | Macro F1 (test) | Cohen’s Kappa | MCC | Macro Specificity | Best Epoch (val macro-F1) |
|---|---|---|---|---|---|---|
| Hybrid A (PFD-A + GSTE-A) | 0.9875 | 0.9875 | 0.9833 | 0.9833 | 0.9959 | 43 |
| Hybrid B (PFD-B + GSTE-B) | 0.9852 | 0.9849 | 0.9802 | 0.9803 | 0.9952 | 14 |
| Without A (ablation) | 0.9875 | 0.9873 | 0.9833 | 0.9834 | 0.9959 | 30 |
| Without B (ablation) | 0.9922 | 0.9920 | 0.9896 | 0.9896 | 0.9975 | 42 |
The project supports post-hoc explainability and uncertainty analysis:
- Grad-CAM++ on the CNN branch,
- attention rollout on the transformer branch,
- MC Dropout at inference to estimate predictive mean and variance.
These tools are used for qualitative inspection of tumour-centred evidence rather than as replacements for quantitative evaluation.
Trained checkpoints are stored using Git LFS.
Do not use GitHub “Download ZIP”, because ZIP downloads may contain pointer files instead of the real .pt checkpoints.
If checkpoint files are missing or unexpectedly small after cloning, run:
git lfs pullThe demo app was tested on Python 3.12.2 on macOS.
On some Windows systems, Python 3.11.x may be more reliable depending on available PyTorch wheels.
Torch installation varies by OS, CPU/GPU, and Python version.
If installing torch or torchvision fails, use the official PyTorch install command for your platform first, then install the remaining packages from webapp/requirements.txt.
brew install git-lfs
git lfs installcd ~/Downloads
git clone https://github.com/AnnyaB/HybridResNet50V2-RViT.git
cd HybridResNet50V2-RViTls -lh Hybrid-model-with-pfdA-gsteA/best_model.ptcd webapp
python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
python app.pysudo apt update
sudo apt install -y git-lfs
git lfs installcd ~/Downloads
git clone https://github.com/AnnyaB/HybridResNet50V2-RViT.git
cd HybridResNet50V2-RViTls -lh Hybrid-model-with-pfdA-gsteA/best_model.ptcd webapp
python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
python app.pywinget install --id Python.Python.3.11 -e
winget upgrade --id Python.Python.3.11winget install --id Git.Git -ewinget install --id GitHub.GitLFS -e
git lfs installgit --version
git lfs version
py -3.11 --versionmkdir $env:USERPROFILE\ai_project
cd $env:USERPROFILE\ai_project
git clone https://github.com/AnnyaB/HybridResNet50V2-RViT.git
cd HybridResNet50V2-RViTdir Hybrid-model-with-pfdA-gsteA\best_model.ptcd webapp
py -3.11 -m venv .venv
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
.\.venv\Scripts\Activate.ps1
python -m pip install --upgrade pip
pip install -r requirements.txt
python app.pyAfter startup, open:
http://127.0.0.1:5000
app.py starts the local Flask demo web application for model inference and visualisation. It does not train or test the models.
A small test set of 16 images can be used for quick demo checking if it is included alongside the web application materials.
If you see a timm warning about a deprecated model-name mapping, this is usually not an error.
It means a model alias name was remapped internally and the model can still load normally.
For the out-of-distribution OOD demo, a small external sample was taken from Fernando Feltrin’s Brain Tumor MRI Images 44 Classes Kaggle dataset.
The dataset is described as a collection of T1, contrast-enhanced T1, and T2 brain MRI images grouped by tumour type, and the class list includes meningioma, together with many other specific tumour categories.
Only five randomly selected T1ce meningioma images were used for the demo.
This was done because meningioma is explicitly provided as a named class, whereas the project’s other four-class categories do not map cleanly to this dataset:
- pituitary and no tumour are not listed as classes on the dataset page,
- glioma is not presented as one single class but is split across multiple, more specific tumour labels.
Therefore, this dataset was used only for a small qualitative OOD demonstration, not for a formal benchmark evaluation.
The folder pfd_gste/ contains reusable PyTorch modules for pathology-focused feature gating and guided token reweighting.
These modules are not tied to the original ResNet50V2-RViT model. They can be imported into other CNN, Transformer, or hybrid classifiers for related medical image classification tasks, such as brain, breast, lung, retinal, or other tumour-classification problems.
| Module | Purpose |
|---|---|
PFDGSTEVariantA |
Feature-token guidance for models where transformer tokens are produced from CNN feature maps. |
PFDGSTEVariantB |
Patch-token guidance for models where raw-image patch tokens are guided by a CNN-derived pathology mask. |
PathologyFocusedGate |
Standalone soft spatial feature gating. |
mc_dropout_predict |
Helper for MC-dropout uncertainty estimation. |
from pfd_gste import PFDGSTEVariantA, PFDGSTEVariantBUse Variant A when a CNN backbone returns a spatial feature map and the transformer operates on feature tokens.
import torch.nn as nn
from pfd_gste import PFDGSTEVariantA
class GuidedFeatureClassifier(nn.Module):
def __init__(self, backbone, feature_channels, num_classes, embed_dim=128):
super().__init__()
self.backbone = backbone
self.guidance = PFDGSTEVariantA(feature_channels, embed_dim)
layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=8,
dim_feedforward=embed_dim * 4,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(layer, num_layers=4)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, images):
features = self.backbone(images)
tokens, mask, alpha = self.guidance(features)
tokens = self.encoder(tokens)
logits = self.classifier(tokens.mean(dim=1))
return logitsUse Variant B when PFD should guide the CNN feature pathway and GSTE should guide raw-image patch tokens.
import torch
import torch.nn as nn
import torch.nn.functional as F
from pfd_gste import PFDGSTEVariantB
class GuidedPatchClassifier(nn.Module):
def __init__(self, backbone, feature_channels, num_classes, embed_dim=128):
super().__init__()
self.backbone = backbone
self.guidance = PFDGSTEVariantB(
in_channels=feature_channels,
embed_dim=embed_dim,
patch_size=16,
min_side=7,
max_shrink=0.50,
)
self.feature_pool = nn.AdaptiveAvgPool2d(1)
self.feature_proj = nn.Linear(feature_channels, embed_dim)
layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=8,
dim_feedforward=embed_dim * 4,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(layer, num_layers=4)
self.classifier = nn.Linear(embed_dim * 2, num_classes)
def forward(self, images):
features = self.backbone(images)
gated_features, tokens, mask, alpha, token_hw = self.guidance(
images,
features,
shrink=True,
)
cnn_vec = self.feature_pool(gated_features).flatten(1)
cnn_vec = F.relu(self.feature_proj(cnn_vec), inplace=True)
tokens = self.encoder(tokens)
token_vec = tokens.mean(dim=1)
logits = self.classifier(torch.cat([cnn_vec, token_vec], dim=1))
return logitsThe modules can be trained normally as part of a PyTorch model.
import torch
import torch.nn as nn
def train_one_epoch(model, loader, optimiser, device):
model.train()
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
total_loss = 0.0
total_correct = 0
total_count = 0
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
optimiser.zero_grad(set_to_none=True)
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimiser.step()
total_loss += loss.item() * images.size(0)
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_count += images.size(0)
return total_loss / total_count, total_correct / total_countFrom the repository root:
python - <<'PY'
from pfd_gste import PFDGSTEVariantA, PFDGSTEVariantB
print("PFD-GSTE library imports correctly.")
PYThe reusable guidance components are available as the pfd-gste Python package for Python 3.12:
pip install pfd-gsteThe PyPI package contains only the reusable PFD-GSTE guidance modules. It does not include the dataset, trained checkpoints, complete classifiers, experimental results, or Flask application.
The modules can also be imported locally from a cloned copy of this repository when commands are run from the repository root or when the repository root is included in PYTHONPATH.
Bolya, D., Fu, C., Dai, X., Zhang, P., Feichtenhofer, C. and Hoffman, J. (2022) Token merging: Your ViT but faster. arXiv preprint. https://doi.org/10.48550/arXiv.2210.09461
da Costa-Luis, C.O. (2019) tqdm: a fast, extensible progress meter for Python and CLI, Journal of Open Source Software, 4(37), 1277. https://doi.org/10.21105/joss.01277
Feltrin, F. (2023) Brain Tumor MRI Images 44 Classes [dataset]. Kaggle. Available at: https://www.kaggle.com/datasets/fernando2rad/brain-tumor-mri-images-44c (Accessed: 15 April 2026).
Harris, C.R., Millman, K.J., van der Walt, S.J., Gommers, R., Virtanen, P., Cournapeau, D., Wieser, E., Taylor, J., Berg, S., Smith, N.J., Kern, R., Picus, M., Hoyer, S., van Kerkwijk, M.H., Brett, M., Haldane, A., del Río, J.F., Wiebe, M., Peterson, P., Gérard-Marchant, P., Sheppard, K., Reddy, T., Weckesser, W., Abbasi, H., Gohlke, C. and Oliphant, T.E. (2020) Array programming with NumPy, Nature, 585, pp. 357–362. https://doi.org/10.1038/s41586-020-2649-2
He, K., Zhang, X., Ren, S. and Sun, J. (2016) ‘Identity mappings in deep residual networks’, European Conference on Computer Vision, pp. 630–645. https://doi.org/10.1007/978-3-319-46493-038
Hugging Face (2019) timm/resnetv2_50x1_bit.goog_in21k_ft_in1k [Pretrained model weights]. Available at: https://huggingface.co/timm/resnetv2_50x1_bit.goog_in21k_ft_in1k (Accessed: 14 February 2026).
Hunter, J.D. (2007) Matplotlib: a 2D graphics environment, Computing in Science & Engineering, 9(3), pp. 90–95. https://doi.org/10.1109/MCSE.2007.55
Kleinberg, J. and Tardos, E. (2006) Algorithm design. 1st edn. Boston, MA: Pearson Education / Addison-Wesley.
Kolesnikov, A. et al. (2020) ‘Big Transfer (BiT): General visual representation learning’, European Conference on Computer Vision. https://doi.org/10.48550/arXiv.1912.11370
Krishnan, P.T., Krishnadoss, P., Khandelwal, M., Gupta, D., Nihaal, A. and Kumar, T.S. (2024) ‘Enhancing brain tumor detection in MRI with a rotation invariant Vision Transformer’, Frontiers in Neuroinformatics, 18, 1414925. https://doi.org/10.3389/fninf.2024.1414925
McKinney, W. (2010) Data structures for statistical computing in Python, in van der Walt, S. and Millman, J. (eds.) Proceedings of the 9th Python in Science Conference, pp. 56–61. https://doi.org/10.25080/Majora-92bf1922-00a
Pallets (2024) Flask documentation. Available at: https://flask.palletsprojects.com/ (Accessed: 12 February 2026).
Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Köpf, A., Yang, E., DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., Bai, J. and Chintala, S. (2019) PyTorch: an imperative style, high-performance deep learning library, in Wallach, H., Larochelle, H., Beygelzimer, A., d’Alché-Buc, F., Fox, E. and Garnett, R. (eds.) Advances in Neural Information Processing Systems, 32, pp. 8024–8035. https://doi.org/10.48550/arXiv.1912.01703
PyTorch (2024) torchvision documentation. Available at: https://docs.pytorch.org/vision/main/index.html (Accessed: 22 October 2025).
Rao, Y. et al. (2021) ‘DynamicViT: Efficient vision transformers with dynamic token sparsification’, Advances in Neural Information Processing Systems. https://doi.org/10.48550/arXiv.2106.02034
Sarada, B., Reddy, K.N., Muktisingh, R., Babu, R. and Babu, B.S.S.V.R. (2025) ‘Brain tumor classification using modified ResNet50V2 deep learning model’, International Journal of Computing and Digital Systems, 17(1), pp. 1–11. https://doi.org/10.12785/ijcds/1571021750
Xia, T., Chartsias, A. and Tsaftaris, S.A. (2020) ‘Pseudo-healthy synthesis with pathology disentanglement and adversarial learning’, Medical Image Analysis, 64, 101719. https://doi.org/10.1016/j.media.2020.101719
This project is released under the MIT License.
This means the code may be used, copied, modified, merged, published, distributed, sublicensed, and reused in future research or software projects, provided that the original copyright notice and MIT License text are included.
If this repository, code, trained models, or PFD-GSTE guidance modules are useful in your work, please cite:
Basak, R. (2026) Mitigating Shortcut Learning in Brain Tumour MRI Classification. BSc Artificial Intelligence Project, University of Hertfordshire. Available at: https://github.com/AnnyaB/HybridResNet50V2-RViT
A public Hugging Face Space is available for browser-based testing of the deployed research demo:
Live app: https://huggingface.co/spaces/AnnyaaB/brain-tumour-pfd-gste-demo
The deployed app can be opened from a phone, tablet, laptop, or desktop browser. You do not need to clone this GitHub repository or install Python locally to try the interface.
Current deployed version:
- runs as a Dockerised Flask application on Hugging Face Spaces;
- loads the four trained PyTorch model variants;
- accepts a single uploaded MRI image;
- returns model predictions, confidence values, tumour probability, and probability plots;
- runs the project’s qualitative XAI workflow, including Grad-CAM++ and attention-rollout overlays where the corresponding model exposes the required activations and attention information;
- uses the same model-specific XAI helper files as the local demo workflow:
Xai-A.py,Xai-B.py,Xai-without-A.py, andXai-without-B.py.
The public Hugging Face Space is intended as a convenient browser-based research demo. Because it runs in a different Docker/Linux/CPU environment from the original local development machine, and because MC-dropout and gradient-based XAI can be sensitive to package versions, hardware kernels, and stochastic forward passes, the deployed output may not be pixel-identical to the local-machine output.
For the most controlled check of the original research workflow, you should clone the repository, pull the Git LFS checkpoints, install the documented dependencies, and run the local Flask app or the model-specific XAI scripts on your own machine. The local workflow should be treated as the reference path for full reproducibility inspection.
This software is for research and educational use only.
It is not a certified medical device and must not be used for clinical diagnosis, patient management, or treatment decisions.
Any outputs produced by this code are experimental and may be incorrect.
For questions, reproducibility issues, or suggested improvements, please open a GitHub issue.


