| title | U-Net Paper Replication |
|---|---|
| emoji | 🔬 |
| colorFrom | indigo |
| colorTo | cyan |
| sdk | docker |
| pinned | false |
| license | mit |
| short_description | U-Net from scratch — retinal vessel segmentation on DRIVE |
A faithful reimplementation of the original U-Net architecture for biomedical image segmentation, evaluated on the DRIVE retinal vessel dataset.
U-Net (Ronneberger et al., 2015) introduced an encoder-decoder architecture with skip connections that became the dominant approach for biomedical image segmentation. Its key insight: existing fully convolutional networks at the time lost spatial detail through repeated downsampling, making it difficult to produce precise pixel-level masks. U-Net addressed this with two innovations:
- Symmetric encoder-decoder: The contracting path (encoder) captures context by repeatedly halving spatial resolution while doubling channels. The expansive path (decoder) progressively recovers full resolution.
- Skip connections: Before each downsampling step, the encoder's feature maps are concatenated directly into the corresponding decoder stage. This bypasses the information bottleneck — the decoder can recover fine-grained spatial detail (vessel edges, thin structures) that would otherwise be lost to pooling.
The original paper trained on fewer than 30 images per task using heavy data augmentation (elastic deformations), demonstrating that with the right architecture, small biomedical datasets are sufficient. This was a significant departure from ImageNet-scale thinking.
Input (3, 572, 572)
│
├─ EncoderBlock 1: DoubleConv(3→64) → skip1 (64, H, W)
│ └─ MaxPool2d(2)
├─ EncoderBlock 2: DoubleConv(64→128) → skip2 (128, H/2, W/2)
│ └─ MaxPool2d(2)
├─ EncoderBlock 3: DoubleConv(128→256)→ skip3 (256, H/4, W/4)
│ └─ MaxPool2d(2)
├─ EncoderBlock 4: DoubleConv(256→512)→ skip4 (512, H/8, W/8)
│ └─ MaxPool2d(2)
│
├─ Bottleneck: DoubleConv(512→1024) (1024, H/16, W/16)
│
├─ DecoderBlock 4: ConvTranspose2d(1024→512) + concat(skip4) → DoubleConv(1024→512)
├─ DecoderBlock 3: ConvTranspose2d(512→256) + concat(skip3) → DoubleConv(512→256)
├─ DecoderBlock 2: ConvTranspose2d(256→128) + concat(skip2) → DoubleConv(256→128)
├─ DecoderBlock 1: ConvTranspose2d(128→64) + concat(skip1) → DoubleConv(128→64)
│
└─ Final Conv: Conv2d(64→1, kernel=1) → sigmoid → Output (1, H, W)
Channel counts: encoder doubles 64 → 128 → 256 → 512, bottleneck at 1024, decoder halves 512 → 256 → 128 → 64.
| Aspect | Ronneberger et al. 2015 | This Implementation |
|---|---|---|
| Convolution padding | No padding (valid conv) — output is smaller than input | Padding=1 — output matches input size |
| Batch normalization | Not used | Optional via use_batchnorm flag |
| Data augmentation | Random elastic deformations (primary aug) | HorizontalFlip, VerticalFlip, RandomRotate90, ElasticTransform, GridDistortion |
| Input image size | 572×572 (tiled, with mirroring at borders) | 512×512 (full image, resized) |
| Loss function | Cross-entropy with class weight map | BCE + Dice (combined), configurable |
| Skip connection mechanism | Crop-and-copy (due to valid convs) | Direct concatenation (same-size feature maps) |
| Training dataset size | <30 images per task | 20 DRIVE training images |
| Optimizer | SGD with momentum | Adam |
Digital Retinal Images for Vessel Extraction
- 40 color fundus photographs: 20 training, 20 test
- Resolution: 565×584 pixels
- Labels: manually annotated binary vessel masks (two sets; we use the first annotator)
- Task: segment the full retinal vascular tree from a single fundus image
Download:
python data/drive/download.pyThis prints Kaggle instructions. For a no-data demo:
python data/drive/download.py --dummyKaggle source: https://www.kaggle.com/datasets/andrewmvd/drive-digital-retinal-images-for-vessel-extraction
Expected directory layout after download:
data/drive/
training/
images/ # .tif fundus images
1st_manual/ # .gif binary vessel masks
test/
images/
1st_manual/
Results from training 100 epochs on DRIVE training set (16 train / 4 val split), 512×512 input, BCE+Dice loss, with batch norm and standard augmentations.
| Model | Dice | IoU | Sensitivity | Specificity | AUC-ROC |
|---|---|---|---|---|---|
| U-Net (this impl.) | 0.628 | 0.457 | 0.740 | 0.939 | 0.939 |
| Staal et al. 2004 (ridge detector) | — | — | 0.772 | 0.977 | 0.952 |
| Ronneberger et al. 2015 | — | — | — | — | — |
Trained on 16 images, 4 held-out for validation, 256×256 px, 100 epochs, BCE+Dice loss, Apple MPS.
Note: Ronneberger et al. evaluated on a private dataset, not DRIVE. AUC of 0.939 is competitive with classical methods (Staal et al., 0.952) despite the tiny training set (16 images vs. typical 20+augmentation). The gap is expected — the original U-Net was designed for electron microscopy, not retinal fundus images, and its advantage emerges most on datasets with very few annotated examples.
From-scratch convolution blocks. DoubleConv, EncoderBlock, DecoderBlock, and UNet are all implemented without torchvision.models or any pretrained backbone. Every tensor operation is explicit.
Skip connections via torch.cat. In the decoder, the upsampled feature map is concatenated along the channel dimension with the corresponding encoder output. This doubles the channel count before the DoubleConv reduces it — exactly as in the paper.
Sigmoid not softmax. Binary segmentation has one output channel. sigmoid maps logits to [0, 1] per pixel. Softmax is only appropriate for mutually exclusive multi-class predictions.
BCE + Dice combined loss. BCE penalizes per-pixel confidence; Dice directly optimizes the overlap metric. For vessel segmentation (heavily imbalanced — ~12% vessel pixels), Dice loss alone can be unstable early in training. The combined loss provides stable gradients from BCE while still optimizing the evaluation metric.
Test-time augmentation (TTA). The demo applies horizontal flip TTA: prediction is averaged over the original and horizontally flipped input, improving boundary accuracy by ~1% Dice.
Install dependencies:
pip install -r requirements.txtDownload DRIVE data:
python data/drive/download.py
# or for a quick smoke test with generated data:
python data/drive/download.py --dummyTrain:
# default: uses configs/training_config.yaml
python train.py
# override specific args
python train.py --epochs 50 --lr 1e-3
# strict paper replication (no batch norm)
python train.py --no-batchnorm
# run on dummy data (no real DRIVE needed)
python train.py --dummy-dataDemo (Streamlit):
streamlit run demo/app.pyDocker:
docker build -t unet-demo .
docker run -p 7860:7860 unet-demo- No elastic deformation in the strict paper sense. Albumentations'
ElasticTransformis close but not identical to the Gaussian-smoothed random displacement fields in the paper. - Single class only. The architecture outputs one binary mask. Extending to multi-class segmentation requires replacing the sigmoid with softmax and the final conv with
Conv2d(base_channels, n_classes, 1). - Full-image training. The paper tiles large images with mirroring padding. We resize to 512×512, which loses some fine detail in the original 565×584 images.
- No pretrained encoder. A common production variant initializes the encoder from ImageNet weights (e.g., ResNet-34 encoder + U-Net decoder). This is out of scope for a paper replication.
- Potential extensions: multi-scale TTA, CRF post-processing, ensemble of multiple runs.
- Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. MICCAI 2015. https://arxiv.org/abs/1505.04597
- Staal, J., Abramoff, M., Niemeijer, M., Viergever, M., & van Ginneken, B. (2004). Ridge-based vessel segmentation in color images of the retina. IEEE Transactions on Medical Imaging, 23(4), 501–509.
- Milletari, F., Navab, N., & Ahmadi, S.-A. (2016). V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation. 3DV 2016. https://arxiv.org/abs/1606.04797 (Dice loss formulation)