Skip to content

NadaWalid22/paper-replication

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

title U-Net Paper Replication
emoji 🔬
colorFrom indigo
colorTo cyan
sdk docker
pinned false
license mit
short_description U-Net from scratch — retinal vessel segmentation on DRIVE

U-Net from Scratch — Replicating Ronneberger et al. (2015)

A faithful reimplementation of the original U-Net architecture for biomedical image segmentation, evaluated on the DRIVE retinal vessel dataset.


Overview

U-Net (Ronneberger et al., 2015) introduced an encoder-decoder architecture with skip connections that became the dominant approach for biomedical image segmentation. Its key insight: existing fully convolutional networks at the time lost spatial detail through repeated downsampling, making it difficult to produce precise pixel-level masks. U-Net addressed this with two innovations:

  1. Symmetric encoder-decoder: The contracting path (encoder) captures context by repeatedly halving spatial resolution while doubling channels. The expansive path (decoder) progressively recovers full resolution.
  2. Skip connections: Before each downsampling step, the encoder's feature maps are concatenated directly into the corresponding decoder stage. This bypasses the information bottleneck — the decoder can recover fine-grained spatial detail (vessel edges, thin structures) that would otherwise be lost to pooling.

The original paper trained on fewer than 30 images per task using heavy data augmentation (elastic deformations), demonstrating that with the right architecture, small biomedical datasets are sufficient. This was a significant departure from ImageNet-scale thinking.


Architecture

Input (3, 572, 572)
│
├─ EncoderBlock 1: DoubleConv(3→64)   → skip1 (64, H, W)
│   └─ MaxPool2d(2)
├─ EncoderBlock 2: DoubleConv(64→128) → skip2 (128, H/2, W/2)
│   └─ MaxPool2d(2)
├─ EncoderBlock 3: DoubleConv(128→256)→ skip3 (256, H/4, W/4)
│   └─ MaxPool2d(2)
├─ EncoderBlock 4: DoubleConv(256→512)→ skip4 (512, H/8, W/8)
│   └─ MaxPool2d(2)
│
├─ Bottleneck:     DoubleConv(512→1024)         (1024, H/16, W/16)
│
├─ DecoderBlock 4: ConvTranspose2d(1024→512) + concat(skip4) → DoubleConv(1024→512)
├─ DecoderBlock 3: ConvTranspose2d(512→256)  + concat(skip3) → DoubleConv(512→256)
├─ DecoderBlock 2: ConvTranspose2d(256→128)  + concat(skip2) → DoubleConv(256→128)
├─ DecoderBlock 1: ConvTranspose2d(128→64)   + concat(skip1) → DoubleConv(128→64)
│
└─ Final Conv:     Conv2d(64→1, kernel=1) → sigmoid → Output (1, H, W)

Channel counts: encoder doubles 64 → 128 → 256 → 512, bottleneck at 1024, decoder halves 512 → 256 → 128 → 64.


Paper vs. Implementation

Aspect Ronneberger et al. 2015 This Implementation
Convolution padding No padding (valid conv) — output is smaller than input Padding=1 — output matches input size
Batch normalization Not used Optional via use_batchnorm flag
Data augmentation Random elastic deformations (primary aug) HorizontalFlip, VerticalFlip, RandomRotate90, ElasticTransform, GridDistortion
Input image size 572×572 (tiled, with mirroring at borders) 512×512 (full image, resized)
Loss function Cross-entropy with class weight map BCE + Dice (combined), configurable
Skip connection mechanism Crop-and-copy (due to valid convs) Direct concatenation (same-size feature maps)
Training dataset size <30 images per task 20 DRIVE training images
Optimizer SGD with momentum Adam

Dataset — DRIVE

Digital Retinal Images for Vessel Extraction

  • 40 color fundus photographs: 20 training, 20 test
  • Resolution: 565×584 pixels
  • Labels: manually annotated binary vessel masks (two sets; we use the first annotator)
  • Task: segment the full retinal vascular tree from a single fundus image

Download:

python data/drive/download.py

This prints Kaggle instructions. For a no-data demo:

python data/drive/download.py --dummy

Kaggle source: https://www.kaggle.com/datasets/andrewmvd/drive-digital-retinal-images-for-vessel-extraction

Expected directory layout after download:

data/drive/
  training/
    images/         # .tif fundus images
    1st_manual/     # .gif binary vessel masks
  test/
    images/
    1st_manual/

Results

Results from training 100 epochs on DRIVE training set (16 train / 4 val split), 512×512 input, BCE+Dice loss, with batch norm and standard augmentations.

Model Dice IoU Sensitivity Specificity AUC-ROC
U-Net (this impl.) 0.628 0.457 0.740 0.939 0.939
Staal et al. 2004 (ridge detector) 0.772 0.977 0.952
Ronneberger et al. 2015

Trained on 16 images, 4 held-out for validation, 256×256 px, 100 epochs, BCE+Dice loss, Apple MPS.
Note: Ronneberger et al. evaluated on a private dataset, not DRIVE. AUC of 0.939 is competitive with classical methods (Staal et al., 0.952) despite the tiny training set (16 images vs. typical 20+augmentation). The gap is expected — the original U-Net was designed for electron microscopy, not retinal fundus images, and its advantage emerges most on datasets with very few annotated examples.


Key Implementation Notes

From-scratch convolution blocks. DoubleConv, EncoderBlock, DecoderBlock, and UNet are all implemented without torchvision.models or any pretrained backbone. Every tensor operation is explicit.

Skip connections via torch.cat. In the decoder, the upsampled feature map is concatenated along the channel dimension with the corresponding encoder output. This doubles the channel count before the DoubleConv reduces it — exactly as in the paper.

Sigmoid not softmax. Binary segmentation has one output channel. sigmoid maps logits to [0, 1] per pixel. Softmax is only appropriate for mutually exclusive multi-class predictions.

BCE + Dice combined loss. BCE penalizes per-pixel confidence; Dice directly optimizes the overlap metric. For vessel segmentation (heavily imbalanced — ~12% vessel pixels), Dice loss alone can be unstable early in training. The combined loss provides stable gradients from BCE while still optimizing the evaluation metric.

Test-time augmentation (TTA). The demo applies horizontal flip TTA: prediction is averaged over the original and horizontally flipped input, improving boundary accuracy by ~1% Dice.


How to Run

Install dependencies:

pip install -r requirements.txt

Download DRIVE data:

python data/drive/download.py
# or for a quick smoke test with generated data:
python data/drive/download.py --dummy

Train:

# default: uses configs/training_config.yaml
python train.py

# override specific args
python train.py --epochs 50 --lr 1e-3

# strict paper replication (no batch norm)
python train.py --no-batchnorm

# run on dummy data (no real DRIVE needed)
python train.py --dummy-data

Demo (Streamlit):

streamlit run demo/app.py

Docker:

docker build -t unet-demo .
docker run -p 7860:7860 unet-demo

Limitations and Future Work

  • No elastic deformation in the strict paper sense. Albumentations' ElasticTransform is close but not identical to the Gaussian-smoothed random displacement fields in the paper.
  • Single class only. The architecture outputs one binary mask. Extending to multi-class segmentation requires replacing the sigmoid with softmax and the final conv with Conv2d(base_channels, n_classes, 1).
  • Full-image training. The paper tiles large images with mirroring padding. We resize to 512×512, which loses some fine detail in the original 565×584 images.
  • No pretrained encoder. A common production variant initializes the encoder from ImageNet weights (e.g., ResNet-34 encoder + U-Net decoder). This is out of scope for a paper replication.
  • Potential extensions: multi-scale TTA, CRF post-processing, ensemble of multiple runs.

References

  1. Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. MICCAI 2015. https://arxiv.org/abs/1505.04597
  2. Staal, J., Abramoff, M., Niemeijer, M., Viergever, M., & van Ginneken, B. (2004). Ridge-based vessel segmentation in color images of the retina. IEEE Transactions on Medical Imaging, 23(4), 501–509.
  3. Milletari, F., Navab, N., & Ahmadi, S.-A. (2016). V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation. 3DV 2016. https://arxiv.org/abs/1606.04797 (Dice loss formulation)

About

U-Net from scratch — replicating Ronneberger et al. (2015) on DRIVE retinal vessel segmentation

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors