diff --git a/pytorch/TIPS_decoder_inference.ipynb b/pytorch/TIPS_decoder_inference.ipynb new file mode 100644 index 0000000..b2cab29 --- /dev/null +++ b/pytorch/TIPS_decoder_inference.ipynb @@ -0,0 +1,468 @@ +{ + "cells": [ + { + "id": "94b870ed", + "cell_type": "code", + "source": [ + "# Copyright 2025 Google LLC.\n", + "#\n", + "# SPDX-License-Identifier: Apache-2.0" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "05dc14da", + "cell_type": "markdown", + "source": [ + "# TIPSv2: Segmentation, Depth \u0026 Surface Normals Inference\n", + "\n", + "This notebook demonstrates how to run **dense prediction** inference using\n", + "the TIPSv2 DPT decoders:\n", + "\n", + "1. **Semantic Segmentation** (ADE20K, 150 classes)\n", + "2. **Monocular Depth Estimation** (classification-based, 256 bins)\n", + "3. **Surface Normal Estimation**\n", + "\n", + "The DPT decoder heads take intermediate ViT features from the TIPS vision\n", + "encoder and produce pixel-level predictions.\n", + "\n", + "**Requirements:** GPU runtime recommended for faster inference." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "3dcd30e6", + "cell_type": "markdown", + "source": [ + "## Setup" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "16cac140", + "cell_type": "code", + "source": [ + "# @title Install dependencies and clone TIPS repo.\n", + "import os\n", + "import sys\n", + "\n", + "ROOT_DIR = os.getcwd()\n", + "TIPS_DIR = os.path.join(ROOT_DIR, 'tips')\n", + "\n", + "# Install required packages.\n", + "!pip install -q torch torchvision torchaudio\n", + "!pip install -q tensorflow_text scikit-learn\n", + "\n", + "# Clone the TIPS repository.\n", + "if not os.path.exists(TIPS_DIR):\n", + " !git clone https://github.com/google-deepmind/tips.git {TIPS_DIR}\n", + "\n", + "# Add the root directory to PYTHONPATH so that `tips.*` imports work.\n", + "if ROOT_DIR not in sys.path:\n", + " sys.path.insert(0, ROOT_DIR)\n", + "\n", + "print(f'ROOT_DIR: {ROOT_DIR}')\n", + "print(f'TIPS_DIR: {TIPS_DIR}')\n", + "print('Installation complete!')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "985f995d", + "cell_type": "code", + "source": [ + "# @title Download checkpoints and sample images.\n", + "import urllib.request\n", + "import zipfile\n", + "\n", + "variant = 'L' # @param [\"B\", \"L\", \"So\", \"g\"]\n", + "\n", + "VISION_CKPT_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/pytorch'\n", + "DPT_CKPT_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/scenic'\n", + "NYU_URL = 'http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part6.zip'\n", + "ADE20K_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'\n", + "\n", + "CKPT_DIR = os.path.join(ROOT_DIR, 'checkpoints')\n", + "os.makedirs(CKPT_DIR, exist_ok=True)\n", + "\n", + "# Checkpoint naming maps.\n", + "V2_CKPT_BASENAME_MAP = {\n", + " 'B': 'tips_v2_oss_b14', 'L': 'tips_v2_oss_l14',\n", + " 'So': 'tips_v2_oss_so14', 'g': 'tips_v2_oss_g14',\n", + "}\n", + "V2_DPT_BASENAME_MAP = {\n", + " 'B': 'tips_v2_b14', 'L': 'tips_v2_l14',\n", + " 'So': 'tips_v2_so400m14', 'g': 'tips_v2_g14',\n", + "}\n", + "ckpt_basename = V2_CKPT_BASENAME_MAP[variant]\n", + "dpt_basename = V2_DPT_BASENAME_MAP[variant]\n", + "\n", + "# Download vision encoder checkpoint.\n", + "vision_ckpt_name = f'{ckpt_basename}_vision.npz'\n", + "image_encoder_checkpoint = os.path.join(CKPT_DIR, vision_ckpt_name)\n", + "if not os.path.exists(image_encoder_checkpoint):\n", + " print(f'Downloading vision encoder...')\n", + " urllib.request.urlretrieve(f'{VISION_CKPT_URL}/{vision_ckpt_name}', image_encoder_checkpoint)\n", + "\n", + "# Download DPT checkpoints (Segmentation, Depth, Normals).\n", + "# These use Scenic-format checkpoints (Flax .npy arrays in a zip).\n", + "dpt_tasks = ['segmentation', 'depth', 'normals']\n", + "dpt_checkpoint_paths = {}\n", + "for task in dpt_tasks:\n", + " dpt_zip_name = f'{dpt_basename}_{task}_dpt.zip'\n", + " dpt_zip_path = os.path.join(CKPT_DIR, dpt_zip_name)\n", + " if not os.path.exists(dpt_zip_path):\n", + " print(f'Downloading DPT {task} checkpoint...')\n", + " try:\n", + " urllib.request.urlretrieve(f'{DPT_CKPT_URL}/{dpt_zip_name}', dpt_zip_path)\n", + " except Exception as e:\n", + " print(f' Failed: {e}')\n", + " dpt_checkpoint_paths[task] = dpt_zip_path\n", + "\n", + "# Download NYU depth dataset (for depth \u0026 normals demo).\n", + "NYU_IMG_DIR = os.path.join(ROOT_DIR, 'nyu_images')\n", + "if not os.path.isdir(NYU_IMG_DIR):\n", + " print('Downloading NYU dataset...')\n", + " nyu_tmp = os.path.join(ROOT_DIR, 'bedrooms_part6.zip')\n", + " urllib.request.urlretrieve(NYU_URL, nyu_tmp)\n", + " os.makedirs(NYU_IMG_DIR, exist_ok=True)\n", + " with zipfile.ZipFile(nyu_tmp, 'r') as z:\n", + " z.extractall(NYU_IMG_DIR)\n", + " os.remove(nyu_tmp)\n", + "\n", + "# Download ADE20K dataset (for segmentation demo).\n", + "ADE20K_DIR = os.path.join(ROOT_DIR, 'ADEChallengeData2016')\n", + "if not os.path.isdir(ADE20K_DIR):\n", + " print('Downloading ADE20K dataset...')\n", + " ade_tmp = os.path.join(ROOT_DIR, 'ADEChallengeData2016.zip')\n", + " urllib.request.urlretrieve(ADE20K_URL, ade_tmp)\n", + " with zipfile.ZipFile(ade_tmp, 'r') as z:\n", + " z.extractall(ROOT_DIR)\n", + " os.remove(ade_tmp)\n", + "\n", + "print('All downloads complete!')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d1d3260a", + "cell_type": "markdown", + "source": [ + "## Load Models" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "1462be26", + "cell_type": "code", + "source": [ + "# @title Load the TIPS vision encoder and all three DPT decoders.\n", + "import numpy as np\n", + "import torch\n", + "from tips.pytorch import image_encoder\n", + "from tips.pytorch.decoders import (\n", + " Decoder, SegmentationDecoder, DepthDecoder, NormalsDecoder,\n", + " load_decoder_weights,\n", + ")\n", + "\n", + "image_size = 448 # @param {type: \"number\"}\n", + "PATCH_SIZE = 14\n", + "\n", + "# Model configs per variant.\n", + "MODEL_CONSTRUCTOR_MAP = {\n", + " 'B': 'vit_base', 'L': 'vit_large', 'So': 'vit_so400m', 'g': 'vit_giant2',\n", + "}\n", + "EMBED_DIM_MAP = {'B': 768, 'L': 1024, 'So': 1152, 'g': 1536}\n", + "INTERMEDIATE_LAYERS_MAP = {\n", + " 'B': [2, 5, 8, 11], 'L': [5, 11, 17, 23],\n", + " 'So': [6, 13, 20, 26], 'g': [9, 19, 29, 39],\n", + "}\n", + "\n", + "vit_constructor = getattr(image_encoder, MODEL_CONSTRUCTOR_MAP[variant])\n", + "embed_dim = EMBED_DIM_MAP[variant]\n", + "intermediate_layers = INTERMEDIATE_LAYERS_MAP[variant]\n", + "post_process_channels = (embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim)\n", + "ffn_layer = 'swiglu' if variant == 'g' else 'mlp'\n", + "\n", + "# --- Vision Encoder ---\n", + "weights_image = dict(np.load(image_encoder_checkpoint, allow_pickle=False))\n", + "for key in weights_image:\n", + " weights_image[key] = torch.tensor(weights_image[key])\n", + "\n", + "with torch.no_grad():\n", + " model_image = vit_constructor(\n", + " img_size=image_size, patch_size=PATCH_SIZE, ffn_layer=ffn_layer,\n", + " block_chunks=0, init_values=1.0,\n", + " interpolate_antialias=True, interpolate_offset=0.0,\n", + " )\n", + " model_image.load_state_dict(weights_image)\n", + " model_image.eval()\n", + "print(f'✓ Vision encoder loaded ({variant})')\n", + "\n", + "# --- Segmentation Decoder ---\n", + "with torch.no_grad():\n", + " seg_model = SegmentationDecoder(\n", + " num_classes=150, input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(seg_model, dpt_checkpoint_paths['segmentation'])\n", + " seg_model.eval()\n", + "\n", + "# --- Depth Decoder ---\n", + "with torch.no_grad():\n", + " depth_model = DepthDecoder(\n", + " input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(depth_model, dpt_checkpoint_paths['depth'])\n", + " depth_model.eval()\n", + "\n", + "# --- Normals Decoder ---\n", + "with torch.no_grad():\n", + " normals_model = NormalsDecoder(\n", + " input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(normals_model, dpt_checkpoint_paths['normals'])\n", + " normals_model.eval()\n", + "\n", + "print('✓ All decoders loaded')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "06b04c61", + "cell_type": "code", + "source": [ + "# @title Define helper: extract ViT features.\n", + "import torchvision.transforms as TVT\n", + "import PIL.Image\n", + "\n", + "transform = TVT.Compose([TVT.Resize((image_size, image_size)), TVT.ToTensor()])\n", + "\n", + "def extract_features(img_path):\n", + " \"\"\"Load image and extract intermediate ViT features.\"\"\"\n", + " img = PIL.Image.open(img_path).convert(\"RGB\")\n", + " tensor = transform(img).unsqueeze(0)\n", + " device = next(model_image.parameters()).device\n", + " tensor = tensor.to(device)\n", + " with torch.no_grad():\n", + " features = model_image.get_intermediate_layers(\n", + " tensor, n=intermediate_layers, reshape=True,\n", + " return_class_token=True, norm=True,\n", + " )\n", + " # Reorder: (feat, cls) -\u003e (cls, feat)\n", + " features = [(cls, feat) for feat, cls in features]\n", + " return img, features" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "a4f824d8", + "cell_type": "code", + "source": [ + "# @title Define ADE20K class names and color palette.\n", + "import colorsys\n", + "\n", + "ADE20K_CLASSES = (\n", + " 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',\n", + " 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person',\n", + " 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair',\n", + " 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea',\n", + " 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk',\n", + " 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion',\n", + " 'base', 'box', 'column', 'signboard', 'chest of drawers',\n", + " 'counter', 'sand', 'sink', 'skyscraper', 'fireplace',\n", + " 'refrigerator', 'grandstand', 'path', 'stairs', 'runway',\n", + " 'case', 'pool table', 'pillow', 'screen door', 'stairway',\n", + " 'river', 'bridge', 'bookcase', 'blind', 'coffee table',\n", + " 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop',\n", + " 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair',\n", + " 'boat', 'bar', 'arcade machine', 'hovel', 'bus',\n", + " 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning',\n", + " 'streetlight', 'booth', 'television', 'airplane', 'dirt track',\n", + " 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman',\n", + " 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',\n", + " 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool',\n", + " 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag',\n", + " 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',\n", + " 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',\n", + " 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',\n", + " 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier',\n", + " 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower',\n", + " 'radiator', 'glass', 'clock', 'flag',\n", + ")\n", + "\n", + "def _generate_ade20k_palette(n=150):\n", + " palette = np.zeros((n, 3), dtype=np.uint8)\n", + " for i in range(n):\n", + " hue = i / n\n", + " saturation = 0.7 + 0.3 * ((i * 7) % 10) / 10\n", + " value = 0.6 + 0.4 * ((i * 3) % 10) / 10\n", + " r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)\n", + " palette[i] = [int(r * 255), int(g * 255), int(b * 255)]\n", + " return palette\n", + "\n", + "ADE20K_PALETTE = _generate_ade20k_palette()\n", + "print(f'Defined {len(ADE20K_CLASSES)} ADE20K classes')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "9d2e893b", + "cell_type": "markdown", + "source": [ + "## Run Inference\n", + "\n", + "### Semantic Segmentation (ADE20K)" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "51628b05", + "cell_type": "code", + "source": [ + "# @title Run segmentation on an ADE20K sample image.\n", + "import matplotlib.pyplot as plt\n", + "\n", + "ade_img_dir = os.path.join(ADE20K_DIR, 'images', 'validation')\n", + "ade_images = sorted([\n", + " os.path.join(ade_img_dir, f)\n", + " for f in os.listdir(ade_img_dir) if f.endswith('.jpg')\n", + "])\n", + "\n", + "image_path = ade_images[1] # A castle scene\n", + "print(f'Image: {image_path}')\n", + "\n", + "img, features = extract_features(image_path)\n", + "\n", + "with torch.no_grad():\n", + " seg_logits = seg_model(features, image_size=(image_size, image_size))\n", + " seg_map = seg_logits.argmax(dim=1).squeeze(0).cpu().numpy()\n", + "\n", + "colored_seg = ADE20K_PALETTE[seg_map]\n", + "\n", + "# Print top classes found.\n", + "unique_classes, counts = np.unique(seg_map, return_counts=True)\n", + "top_idx = np.argsort(-counts)[:5]\n", + "print('Top classes:')\n", + "for idx in top_idx:\n", + " cls_id = unique_classes[idx]\n", + " pct = 100 * counts[idx] / seg_map.size\n", + " print(f' {ADE20K_CLASSES[cls_id]:20s} ({pct:.1f}%)')\n", + "\n", + "plt.figure(figsize=(12, 5))\n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(img.resize((image_size, image_size)))\n", + "plt.title('Input Image')\n", + "plt.axis('off')\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(colored_seg)\n", + "plt.title('Semantic Segmentation')\n", + "plt.axis('off')\n", + "plt.tight_layout()\n", + "plt.show()" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d296d2e5", + "cell_type": "markdown", + "source": [ + "### Depth Estimation \u0026 Surface Normals (NYU)" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "7796c341", + "cell_type": "code", + "source": [ + "# @title Run depth and normals inference on NYU images.\n", + "import torch.nn.functional as F\n", + "\n", + "# Collect NYU sample images.\n", + "valid_extensions = ('.ppm', '.jpg', '.jpeg', '.png')\n", + "nyu_images = []\n", + "for root, dirs, files in os.walk(NYU_IMG_DIR):\n", + " for file in files:\n", + " if file.lower().endswith(valid_extensions):\n", + " nyu_images.append(os.path.join(root, file))\n", + "\n", + "selected_images = nyu_images[:3]\n", + "\n", + "for i, image_path in enumerate(selected_images):\n", + " print(f'Processing image {i+1}/{len(selected_images)}: {os.path.basename(image_path)}')\n", + " img, features = extract_features(image_path)\n", + "\n", + " with torch.no_grad():\n", + " # --- Depth ---\n", + " depth_map = depth_model(features, image_size=(image_size, image_size))\n", + " depth_np = depth_map.squeeze().cpu().numpy()\n", + " # Normalize for visualization.\n", + " depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8)\n", + "\n", + " # --- Normals ---\n", + " # Get raw low-res output first.\n", + " normals_map = normals_model(features)\n", + " # L2 normalize.\n", + " normals_map = F.normalize(normals_map, dim=1)\n", + " # Upsample with bicubic for smooth results.\n", + " normals_map = F.interpolate(\n", + " normals_map, size=(image_size, image_size),\n", + " mode='bicubic', align_corners=False,\n", + " )\n", + " # Re-normalize after upsampling.\n", + " normals_map = F.normalize(normals_map, dim=1)\n", + " normals_np = normals_map.squeeze(0).cpu().numpy().transpose(1, 2, 0)\n", + " # Map [-1, 1] -\u003e [0, 1] for display.\n", + " normals_np = np.clip((normals_np + 1.0) / 2.0, 0.0, 1.0)\n", + "\n", + " # Visualize.\n", + " plt.figure(figsize=(15, 5))\n", + " plt.subplot(1, 3, 1)\n", + " plt.imshow(img.resize((image_size, image_size)))\n", + " plt.title(f'Input ({i+1})')\n", + " plt.axis('off')\n", + "\n", + " plt.subplot(1, 3, 2)\n", + " plt.imshow(depth_np, cmap='turbo')\n", + " plt.title(f'Depth ({i+1})')\n", + " plt.axis('off')\n", + "\n", + " plt.subplot(1, 3, 3)\n", + " plt.imshow(normals_np)\n", + " plt.title(f'Surface Normals ({i+1})')\n", + " plt.axis('off')\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ], + "metadata": {}, + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat_minor": 5, + "nbformat": 4 +} diff --git a/pytorch/decoders.py b/pytorch/decoders.py index 4bbc1b1..72e6595 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -136,7 +136,8 @@ def forward( x_flat = x.flatten(2).transpose(1, 2) readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) x_cat = torch.cat([x_flat, readout], dim=-1) - x_proj = F.gelu(self.readout_projects[i](x_cat)) + # JAX GELU uses tanh approximation by default. + x_proj = F.gelu(self.readout_projects[i](x_cat), approximate='tanh') x = x_proj.transpose(1, 2).reshape(b, d, h, w) x = self.out_projections[i](x) x = self.resize_layers[i](x) @@ -164,8 +165,10 @@ def __init__( channels: int = 256, post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), readout_type: str = "project", + output_activation: bool = False, ) -> None: super().__init__() + self.output_activation = output_activation self.reassemble = ReassembleBlocks( input_embed_dim=input_embed_dim, out_channels=post_process_channels, @@ -196,6 +199,10 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) + # NOTE: By default, the reference implementation does not apply output activation, + # so NO ReLU is applied after the project layer by default. + if self.output_activation: + out = F.relu(out) return out @@ -217,6 +224,7 @@ def __init__( channels: int = 256, post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), readout_type: str = "project", + output_activation: bool = False, ) -> None: super().__init__() self.channels = channels @@ -226,6 +234,7 @@ def __init__( channels=channels, post_process_channels=post_process_channels, readout_type=readout_type, + output_activation=output_activation, ) # Common head for all dense prediction tasks self.head = nn.Linear(self.channels, self.out_channels) @@ -264,36 +273,50 @@ def __init__(self, num_classes: int = 150, **kwargs) -> None: class DepthDecoder(Decoder): - """Decoder for monocular depth prediction using classification bins.""" + """Decoder for monocular depth prediction using classification bins. - def __init__(self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs) -> None: - # Decoder requires out_channels, we pass 256 as we use channels as bins, - # although we bypass the head in forward(). - super().__init__(out_channels=256, **kwargs) + Predicts depth by classifying each pixel into uniformly-spaced depth bins + and computing the expected depth value. + """ + + def __init__( + self, + num_depth_bins: int = 256, + min_depth: float = 0.001, + max_depth: float = 10.0, + **kwargs, + ) -> None: + super().__init__(out_channels=num_depth_bins, **kwargs) self.min_depth = min_depth self.max_depth = max_depth - + self.num_depth_bins = num_depth_bins + self.register_buffer( + "bin_centers", + torch.linspace(min_depth, max_depth, num_depth_bins), + ) def forward( self, intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]], image_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: - # Bypass super().forward() to avoid the linear head applied there, - # and use raw DPT features as logits. - logits = self.dpt(intermediate_features) # (B, C, H', W') - # Apply ReLU and shift + # 1. Get DPT features + task head (nn.Linear) via parent class. + # Output shape: (B, num_depth_bins, H', W') + logits = super().forward(intermediate_features) + + # 2. Classification-based depth prediction: + # relu + shift -> linear normalisation -> expectation over bins. logits = torch.relu(logits) + self.min_depth - # Normalize to probabilities along the channel dimension probs = logits / torch.sum(logits, dim=1, keepdim=True) - # Compute expectation: sum(prob * bin_center) - bin_centers = torch.linspace( - self.min_depth, self.max_depth, 256, device=logits.device, dtype=logits.dtype - ) - depth_map = torch.einsum("bchw,c->bhw", probs, bin_centers) + depth_map = torch.einsum("bchw,c->bhw", probs, self.bin_centers.to(logits.device)) + + # 3. Upsample to target resolution. if image_size is not None: depth_map = F.interpolate( - depth_map.unsqueeze(1), size=image_size, mode="bilinear", align_corners=False + depth_map.unsqueeze(1), + size=image_size, + mode="bilinear", + align_corners=False, ).squeeze(1) return depth_map.unsqueeze(1) @@ -314,37 +337,226 @@ def __init__(self, **kwargs) -> None: "convs.": "dpt.convs.", "fusion_blocks.": "dpt.fusion_blocks.", "project.": "dpt.project.", + # Task-specific head keys (Scenic Dense -> PyTorch head.*) "segmentation_head.": "head.", + "pixel_segmentation.": "head.", + "pixel_depth_classif.": "head.", + "pixel_depth_regress.": "head.", + "pixel_normals.": "head.", +} + +# Scenic/Flax head param names -> PyTorch head.* +_SCENIC_HEAD_NAMES = { + "pixel_segmentation", + "pixel_depth_classif", + "pixel_depth_regress", + "pixel_normals", + "segmentation_head", +} + +# ConvTranspose keys that need spatial flipping. +_CONV_TRANSPOSE_KEYS = { + 'dpt.reassemble.resize_layers.0.weight', + 'dpt.reassemble.resize_layers.1.weight', } +def _is_scenic_format(keys): + """Check if checkpoint keys use Scenic/Flax naming (``/`` separators).""" + return any('/' in k for k in keys) + + +def _convert_scenic_checkpoint(weights): + """Convert Flax parameter tree checkpoint to PyTorch state_dict. + + These checkpoints use Flax parameter tree naming: + decoder/dpt/reassemble_blocks/out_projection_0/kernel + which maps to PyTorch: + dpt.reassemble.out_projections.0.weight + + Weight conversions: + - Conv kernels: (H, W, Cin, Cout) -> (Cout, Cin, H, W) + - ConvTranspose kernels: same + 180-degree spatial flip + - Dense/Linear kernels: (in, out) -> (out, in) + - Biases: direct copy + """ + sd = {} + + # Build a nested dict from flat Scenic keys + tree = {} + for key, value in weights.items(): + # Strip "decoder/" prefix if present. + k = key[len("decoder/"):] if key.startswith("decoder/") else key + parts = k.split("/") + d = tree + for p in parts[:-1]: + d = d.setdefault(p, {}) + d[parts[-1]] = np.array(value) + + dpt_params = tree.get("dpt", tree) + + # --- ReassembleBlocks --- + rb = dpt_params.get("reassemble_blocks", {}) + for i in range(4): + # out_projections (Conv2d 1x1) + op = rb.get(f"out_projection_{i}", {}) + if "kernel" in op: + sd[f"dpt.reassemble.out_projections.{i}.weight"] = torch.from_numpy( + op["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in op: + sd[f"dpt.reassemble.out_projections.{i}.bias"] = torch.from_numpy( + op["bias"].copy() + ) + # readout_projects (Linear) + rp = rb.get(f"readout_projects_{i}", {}) + if "kernel" in rp: + sd[f"dpt.reassemble.readout_projects.{i}.weight"] = torch.from_numpy( + rp["kernel"].T.copy() + ) + if "bias" in rp: + sd[f"dpt.reassemble.readout_projects.{i}.bias"] = torch.from_numpy( + rp["bias"].copy() + ) + + # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv + for idx in [0, 1]: + rl = rb.get(f"resize_layers_{idx}", {}) + if "kernel" in rl: + w = rl["kernel"][::-1, ::-1, :, :].copy() # 180-degree spatial flip + sd[f"dpt.reassemble.resize_layers.{idx}.weight"] = torch.from_numpy( + w.transpose(2, 3, 0, 1).copy() + ) + if "bias" in rl: + sd[f"dpt.reassemble.resize_layers.{idx}.bias"] = torch.from_numpy( + rl["bias"].copy() + ) + # resize_layers_2 = Identity (no weights) + rl3 = rb.get("resize_layers_3", {}) + if "kernel" in rl3: + sd["dpt.reassemble.resize_layers.3.weight"] = torch.from_numpy( + rl3["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in rl3: + sd["dpt.reassemble.resize_layers.3.bias"] = torch.from_numpy( + rl3["bias"].copy() + ) + + # --- Convs (3x3, no bias) --- + for i in range(4): + c = dpt_params.get(f"convs_{i}", {}) + if "kernel" in c: + sd[f"dpt.convs.{i}.weight"] = torch.from_numpy( + c["kernel"].transpose(3, 2, 0, 1).copy() + ) + + # --- Fusion blocks --- + for i in range(4): + fb = dpt_params.get(f"fusion_blocks_{i}", {}) + if i == 0: + # No residual unit, only 1 PreActResidualConvUnit -> main_unit + pacu = fb.get("PreActResidualConvUnit_0", {}) + for cname in ["conv1", "conv2"]: + if cname in pacu and "kernel" in pacu[cname]: + sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + else: + # Residual unit (index 0) + main unit (index 1) + pacu0 = fb.get("PreActResidualConvUnit_0", {}) + pacu1 = fb.get("PreActResidualConvUnit_1", {}) + for cname in ["conv1", "conv2"]: + if cname in pacu0 and "kernel" in pacu0[cname]: + sd[f"dpt.fusion_blocks.{i}.residual_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu0[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + if cname in pacu1 and "kernel" in pacu1[cname]: + sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu1[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + # out_conv (Conv2d 1x1) -- Scenic names it Conv_0 + oc = fb.get("Conv_0", fb.get("out_conv", {})) + if "kernel" in oc: + sd[f"dpt.fusion_blocks.{i}.out_conv.weight"] = torch.from_numpy( + oc["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in oc: + sd[f"dpt.fusion_blocks.{i}.out_conv.bias"] = torch.from_numpy( + oc["bias"].copy() + ) + + # --- Project --- + proj = dpt_params.get("project", {}) + if "kernel" in proj: + sd["dpt.project.weight"] = torch.from_numpy( + proj["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in proj: + sd["dpt.project.bias"] = torch.from_numpy(proj["bias"].copy()) + + # --- Task head (Dense/Linear) --- + for head_name in _SCENIC_HEAD_NAMES: + if head_name in tree: + h = tree[head_name] + if "kernel" in h: + sd["head.weight"] = torch.from_numpy(h["kernel"].T.copy()) + if "bias" in h: + sd["head.bias"] = torch.from_numpy(h["bias"].copy()) + break + + return sd + + def load_decoder_weights( model: Decoder, checkpoint_path: str, ) -> Decoder: - """Load pre-converted PyTorch weights into a Decoder. + """Load weights into a Decoder from a checkpoint file. - Supports both the legacy flat key format (e.g. ``reassemble.…``) and the - new hierarchical format (e.g. ``dpt.reassemble.…``). + Supports three checkpoint formats: + 1. Flax format: keys with ``/`` separators and ``kernel``/``bias`` + naming (e.g. ``decoder/dpt/reassemble_blocks/out_projection_0/kernel``). + Weights are automatically transposed from Flax layout to PyTorch layout. + 2. Legacy flat format: keys like ``reassemble.…`` that get remapped to + ``dpt.reassemble.…``. + 3. PyTorch hierarchical format: keys already match the model state_dict. Args: model: A Decoder instance (SegmentationDecoder, DepthDecoder, etc.). - checkpoint_path: Path to a checkpoint file. + checkpoint_path: Path to a ``.npz`` checkpoint file. Returns: The model with loaded weights. """ weights = dict(np.load(checkpoint_path, allow_pickle=False)) - sd = {} - for key, value in weights.items(): - new_key = key - # Remap legacy flat keys to hierarchical names. - for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items(): - if key.startswith(old_prefix): - new_key = new_prefix + key[len(old_prefix):] - break - sd[new_key] = torch.from_numpy(value) + if _is_scenic_format(weights): + sd = _convert_scenic_checkpoint(weights) + else: + # Legacy flat or PyTorch hierarchical format. + sd = {} + for key, value in weights.items(): + new_key = key + for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items(): + if key.startswith(old_prefix): + new_key = new_prefix + key[len(old_prefix):] + break + t = torch.from_numpy(value) + # Flip ConvTranspose kernels spatially for Flax->PyTorch parity. + if new_key in _CONV_TRANSPOSE_KEYS and t.ndim == 4: + t = t.flip([2, 3]) + sd[new_key] = t + + # Add registered buffers not in checkpoint (e.g. bin_centers). + for name, buf in model.named_buffers(): + if name not in sd: + sd[name] = buf model.load_state_dict(sd, strict=True) print(f"Loaded decoder weights from {checkpoint_path} ({len(sd)} tensors)")