From dad6e8f71635f0e3b2954eeebfad94fff0f4e875 Mon Sep 17 00:00:00 2001 From: Bingyi Cao Date: Wed, 22 Apr 2026 08:19:05 +0000 Subject: [PATCH 1/6] Fix DPT decoder bugs: add ReLU, fix DepthDecoder head, add key remapping - Add F.relu() after DPTHead project conv to match Scenic's output_activation=True default - Fix DepthDecoder to route through parent's nn.Linear head instead of bypassing it - Register bin_centers as a buffer with configurable num_depth_bins - Add weight key remapping for all decoder types --- pytorch/decoders.py | 51 +++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/pytorch/decoders.py b/pytorch/decoders.py index e1dd93f..824e522 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -196,6 +196,7 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) + out = F.relu(out) return out @@ -264,16 +265,26 @@ def __init__(self, num_classes: int = 150, **kwargs) -> None: class DepthDecoder(Decoder): - """Decoder for monocular depth prediction using classification bins.""" + """Decoder for monocular depth prediction using classification bins. - def __init__(self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs) -> None: - # Decoder requires out_channels, we pass 256 as we use channels as bins, - # although we bypass the head in forward(). - super().__init__(out_channels=256, **kwargs) + Predicts depth by classifying each pixel into uniformly-spaced depth bins + and computing the expected depth value. + """ + + def __init__( + self, + num_depth_bins: int = 256, + min_depth: float = 0.001, + max_depth: float = 10.0, + **kwargs, + ) -> None: + super().__init__(out_channels=num_depth_bins, **kwargs) self.min_depth = min_depth self.max_depth = max_depth + self.num_depth_bins = num_depth_bins self.register_buffer( - "bin_centers", torch.linspace(min_depth, max_depth, 256) + "bin_centers", + torch.linspace(min_depth, max_depth, num_depth_bins), ) def forward( @@ -281,20 +292,23 @@ def forward( intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]], image_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: - # Bypass super().forward() to avoid the linear head applied there, - # and use raw DPT features as logits. - logits = self.dpt(intermediate_features) # (B, C, H', W') - # Apply ReLU and shift + # 1. Get DPT features + task head (nn.Linear) via parent class. + # Output shape: (B, num_depth_bins, H', W') + logits = super().forward(intermediate_features) + + # 2. Classification-based depth prediction (following Scenic/AdaBins): + # relu + shift -> linear normalisation -> expectation over bins. logits = torch.relu(logits) + self.min_depth - # Normalize to probabilities along the channel dimension probs = logits / torch.sum(logits, dim=1, keepdim=True) - # Compute expectation: sum(prob * bin_center) - depth_map = torch.einsum( - "bchw,c->bhw", probs, self.bin_centers.to(logits.device) - ) + depth_map = torch.einsum("bchw,c->bhw", probs, self.bin_centers.to(logits.device)) + + # 3. Upsample to target resolution. if image_size is not None: depth_map = F.interpolate( - depth_map.unsqueeze(1), size=image_size, mode="bilinear", align_corners=False + depth_map.unsqueeze(1), + size=image_size, + mode="bilinear", + align_corners=False, ).squeeze(1) return depth_map.unsqueeze(1) @@ -315,7 +329,12 @@ def __init__(self, **kwargs) -> None: "convs.": "dpt.convs.", "fusion_blocks.": "dpt.fusion_blocks.", "project.": "dpt.project.", + # Task-specific head keys (Scenic Dense -> PyTorch head.*) "segmentation_head.": "head.", + "pixel_segmentation.": "head.", + "pixel_depth_classif.": "head.", + "pixel_depth_regress.": "head.", + "pixel_normals.": "head.", } From 7474fc13bde014a84874a82d86e063efd5b06ebb Mon Sep 17 00:00:00 2001 From: bingyic <107590227+bingyic@users.noreply.github.com> Date: Mon, 11 May 2026 02:06:33 -0700 Subject: [PATCH 2/6] Add decoder inference Colab (segmentation, depth, normals) --- pytorch/TIPS_decoder_inference.ipynb | 466 +++++++++++++++++++++++++++ 1 file changed, 466 insertions(+) create mode 100644 pytorch/TIPS_decoder_inference.ipynb diff --git a/pytorch/TIPS_decoder_inference.ipynb b/pytorch/TIPS_decoder_inference.ipynb new file mode 100644 index 0000000..7ac788e --- /dev/null +++ b/pytorch/TIPS_decoder_inference.ipynb @@ -0,0 +1,466 @@ +{ + "cells": [ + { + "id": "94b870ed", + "cell_type": "code", + "source": [ + "# Copyright 2025 Google LLC.\n", + "#\n", + "# SPDX-License-Identifier: Apache-2.0" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "05dc14da", + "cell_type": "markdown", + "source": [ + "# TIPSv2: Segmentation, Depth \u0026 Surface Normals Inference\n", + "\n", + "This notebook demonstrates how to run **dense prediction** inference using\n", + "the TIPSv2 DPT decoders:\n", + "\n", + "1. **Semantic Segmentation** (ADE20K, 150 classes)\n", + "2. **Monocular Depth Estimation** (classification-based, 256 bins)\n", + "3. **Surface Normal Estimation**\n", + "\n", + "The DPT decoder heads take intermediate ViT features from the TIPS vision\n", + "encoder and produce pixel-level predictions.\n", + "\n", + "**Requirements:** GPU runtime recommended for faster inference." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "3dcd30e6", + "cell_type": "markdown", + "source": [ + "## Setup" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "16cac140", + "cell_type": "code", + "source": [ + "# @title Install dependencies and clone TIPS repo.\n", + "import os\n", + "import sys\n", + "\n", + "ROOT_DIR = os.getcwd()\n", + "TIPS_DIR = os.path.join(ROOT_DIR, 'tips')\n", + "\n", + "# Install required packages.\n", + "!pip install -q torch torchvision torchaudio\n", + "!pip install -q tensorflow_text scikit-learn\n", + "\n", + "# Clone the TIPS repository.\n", + "if not os.path.exists(TIPS_DIR):\n", + " !git clone https://github.com/google-deepmind/tips.git {TIPS_DIR}\n", + "\n", + "# Add the root directory to PYTHONPATH so that `tips.*` imports work.\n", + "if ROOT_DIR not in sys.path:\n", + " sys.path.insert(0, ROOT_DIR)\n", + "\n", + "print(f'ROOT_DIR: {ROOT_DIR}')\n", + "print(f'TIPS_DIR: {TIPS_DIR}')\n", + "print('Installation complete!')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "985f995d", + "cell_type": "code", + "source": [ + "# @title Download checkpoints and sample images.\n", + "import urllib.request\n", + "import zipfile\n", + "\n", + "variant = 'L' # @param [\"B\", \"L\", \"So\", \"g\"]\n", + "\n", + "CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/pytorch'\n", + "NYU_URL = 'http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part6.zip'\n", + "ADE20K_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'\n", + "\n", + "CKPT_DIR = os.path.join(ROOT_DIR, 'checkpoints')\n", + "os.makedirs(CKPT_DIR, exist_ok=True)\n", + "\n", + "# Checkpoint naming maps.\n", + "V2_CKPT_BASENAME_MAP = {\n", + " 'B': 'tips_v2_oss_b14', 'L': 'tips_v2_oss_l14',\n", + " 'So': 'tips_v2_oss_so14', 'g': 'tips_v2_oss_g14',\n", + "}\n", + "V2_DPT_BASENAME_MAP = {\n", + " 'B': 'tips_v2_b14', 'L': 'tips_v2_l14',\n", + " 'So': 'tips_v2_so400m14', 'g': 'tips_v2_g14',\n", + "}\n", + "ckpt_basename = V2_CKPT_BASENAME_MAP[variant]\n", + "dpt_basename = V2_DPT_BASENAME_MAP[variant]\n", + "\n", + "# Download vision encoder checkpoint.\n", + "vision_ckpt_name = f'{ckpt_basename}_vision.npz'\n", + "image_encoder_checkpoint = os.path.join(CKPT_DIR, vision_ckpt_name)\n", + "if not os.path.exists(image_encoder_checkpoint):\n", + " print(f'Downloading vision encoder...')\n", + " urllib.request.urlretrieve(f'{CHECKPOINT_BASE_URL}/{vision_ckpt_name}', image_encoder_checkpoint)\n", + "\n", + "# Download DPT checkpoints (Segmentation, Depth, Normals).\n", + "dpt_tasks = ['segmentation', 'depth', 'normals']\n", + "dpt_checkpoint_paths = {}\n", + "for task in dpt_tasks:\n", + " dpt_zip_name = f'{dpt_basename}_{task}_dpt_pytorch.zip'\n", + " dpt_zip_path = os.path.join(CKPT_DIR, dpt_zip_name)\n", + " if not os.path.exists(dpt_zip_path):\n", + " print(f'Downloading DPT {task} checkpoint...')\n", + " try:\n", + " urllib.request.urlretrieve(f'{CHECKPOINT_BASE_URL}/{dpt_zip_name}', dpt_zip_path)\n", + " except Exception as e:\n", + " print(f' Failed: {e}')\n", + " dpt_checkpoint_paths[task] = dpt_zip_path\n", + "\n", + "# Download NYU depth dataset (for depth \u0026 normals demo).\n", + "NYU_IMG_DIR = os.path.join(ROOT_DIR, 'nyu_images')\n", + "if not os.path.isdir(NYU_IMG_DIR):\n", + " print('Downloading NYU dataset...')\n", + " nyu_tmp = os.path.join(ROOT_DIR, 'bedrooms_part6.zip')\n", + " urllib.request.urlretrieve(NYU_URL, nyu_tmp)\n", + " os.makedirs(NYU_IMG_DIR, exist_ok=True)\n", + " with zipfile.ZipFile(nyu_tmp, 'r') as z:\n", + " z.extractall(NYU_IMG_DIR)\n", + " os.remove(nyu_tmp)\n", + "\n", + "# Download ADE20K dataset (for segmentation demo).\n", + "ADE20K_DIR = os.path.join(ROOT_DIR, 'ADEChallengeData2016')\n", + "if not os.path.isdir(ADE20K_DIR):\n", + " print('Downloading ADE20K dataset...')\n", + " ade_tmp = os.path.join(ROOT_DIR, 'ADEChallengeData2016.zip')\n", + " urllib.request.urlretrieve(ADE20K_URL, ade_tmp)\n", + " with zipfile.ZipFile(ade_tmp, 'r') as z:\n", + " z.extractall(ROOT_DIR)\n", + " os.remove(ade_tmp)\n", + "\n", + "print('All downloads complete!')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d1d3260a", + "cell_type": "markdown", + "source": [ + "## Load Models" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "1462be26", + "cell_type": "code", + "source": [ + "# @title Load the TIPS vision encoder and all three DPT decoders.\n", + "import numpy as np\n", + "import torch\n", + "from tips.pytorch import image_encoder\n", + "from tips.pytorch.decoders import (\n", + " Decoder, SegmentationDecoder, DepthDecoder, NormalsDecoder,\n", + " load_decoder_weights,\n", + ")\n", + "\n", + "image_size = 448 # @param {type: \"number\"}\n", + "PATCH_SIZE = 14\n", + "\n", + "# Model configs per variant.\n", + "MODEL_CONSTRUCTOR_MAP = {\n", + " 'B': 'vit_base', 'L': 'vit_large', 'So': 'vit_so400m', 'g': 'vit_giant2',\n", + "}\n", + "EMBED_DIM_MAP = {'B': 768, 'L': 1024, 'So': 1152, 'g': 1536}\n", + "INTERMEDIATE_LAYERS_MAP = {\n", + " 'B': [2, 5, 8, 11], 'L': [5, 11, 17, 23],\n", + " 'So': [6, 13, 20, 26], 'g': [9, 19, 29, 39],\n", + "}\n", + "\n", + "vit_constructor = getattr(image_encoder, MODEL_CONSTRUCTOR_MAP[variant])\n", + "embed_dim = EMBED_DIM_MAP[variant]\n", + "intermediate_layers = INTERMEDIATE_LAYERS_MAP[variant]\n", + "post_process_channels = (embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim)\n", + "ffn_layer = 'swiglu' if variant == 'g' else 'mlp'\n", + "\n", + "# --- Vision Encoder ---\n", + "weights_image = dict(np.load(image_encoder_checkpoint, allow_pickle=False))\n", + "for key in weights_image:\n", + " weights_image[key] = torch.tensor(weights_image[key])\n", + "\n", + "with torch.no_grad():\n", + " model_image = vit_constructor(\n", + " img_size=image_size, patch_size=PATCH_SIZE, ffn_layer=ffn_layer,\n", + " block_chunks=0, init_values=1.0,\n", + " interpolate_antialias=True, interpolate_offset=0.0,\n", + " )\n", + " model_image.load_state_dict(weights_image)\n", + " model_image.eval()\n", + "print(f'✓ Vision encoder loaded ({variant})')\n", + "\n", + "# --- Segmentation Decoder ---\n", + "with torch.no_grad():\n", + " seg_model = SegmentationDecoder(\n", + " num_classes=150, input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(seg_model, dpt_checkpoint_paths['segmentation'])\n", + " seg_model.eval()\n", + "\n", + "# --- Depth Decoder ---\n", + "with torch.no_grad():\n", + " depth_model = DepthDecoder(\n", + " input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(depth_model, dpt_checkpoint_paths['depth'])\n", + " depth_model.eval()\n", + "\n", + "# --- Normals Decoder ---\n", + "with torch.no_grad():\n", + " normals_model = NormalsDecoder(\n", + " input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(normals_model, dpt_checkpoint_paths['normals'])\n", + " normals_model.eval()\n", + "\n", + "print('✓ All decoders loaded')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "06b04c61", + "cell_type": "code", + "source": [ + "# @title Define helper: extract ViT features.\n", + "import torchvision.transforms as TVT\n", + "import PIL.Image\n", + "\n", + "transform = TVT.Compose([TVT.Resize((image_size, image_size)), TVT.ToTensor()])\n", + "\n", + "def extract_features(img_path):\n", + " \"\"\"Load image and extract intermediate ViT features.\"\"\"\n", + " img = PIL.Image.open(img_path).convert(\"RGB\")\n", + " tensor = transform(img).unsqueeze(0)\n", + " device = next(model_image.parameters()).device\n", + " tensor = tensor.to(device)\n", + " with torch.no_grad():\n", + " features = model_image.get_intermediate_layers(\n", + " tensor, n=intermediate_layers, reshape=True,\n", + " return_class_token=True, norm=True,\n", + " )\n", + " # Reorder: (feat, cls) -\u003e (cls, feat)\n", + " features = [(cls, feat) for feat, cls in features]\n", + " return img, features" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "a4f824d8", + "cell_type": "code", + "source": [ + "# @title Define ADE20K class names and color palette.\n", + "import colorsys\n", + "\n", + "ADE20K_CLASSES = (\n", + " 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',\n", + " 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person',\n", + " 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair',\n", + " 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea',\n", + " 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk',\n", + " 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion',\n", + " 'base', 'box', 'column', 'signboard', 'chest of drawers',\n", + " 'counter', 'sand', 'sink', 'skyscraper', 'fireplace',\n", + " 'refrigerator', 'grandstand', 'path', 'stairs', 'runway',\n", + " 'case', 'pool table', 'pillow', 'screen door', 'stairway',\n", + " 'river', 'bridge', 'bookcase', 'blind', 'coffee table',\n", + " 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop',\n", + " 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair',\n", + " 'boat', 'bar', 'arcade machine', 'hovel', 'bus',\n", + " 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning',\n", + " 'streetlight', 'booth', 'television', 'airplane', 'dirt track',\n", + " 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman',\n", + " 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',\n", + " 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool',\n", + " 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag',\n", + " 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',\n", + " 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',\n", + " 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',\n", + " 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier',\n", + " 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower',\n", + " 'radiator', 'glass', 'clock', 'flag',\n", + ")\n", + "\n", + "def _generate_ade20k_palette(n=150):\n", + " palette = np.zeros((n, 3), dtype=np.uint8)\n", + " for i in range(n):\n", + " hue = i / n\n", + " saturation = 0.7 + 0.3 * ((i * 7) % 10) / 10\n", + " value = 0.6 + 0.4 * ((i * 3) % 10) / 10\n", + " r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)\n", + " palette[i] = [int(r * 255), int(g * 255), int(b * 255)]\n", + " return palette\n", + "\n", + "ADE20K_PALETTE = _generate_ade20k_palette()\n", + "print(f'Defined {len(ADE20K_CLASSES)} ADE20K classes')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "9d2e893b", + "cell_type": "markdown", + "source": [ + "## Run Inference\n", + "\n", + "### Semantic Segmentation (ADE20K)" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "51628b05", + "cell_type": "code", + "source": [ + "# @title Run segmentation on an ADE20K sample image.\n", + "import matplotlib.pyplot as plt\n", + "\n", + "ade_img_dir = os.path.join(ADE20K_DIR, 'images', 'validation')\n", + "ade_images = sorted([\n", + " os.path.join(ade_img_dir, f)\n", + " for f in os.listdir(ade_img_dir) if f.endswith('.jpg')\n", + "])\n", + "\n", + "image_path = ade_images[1] # A castle scene\n", + "print(f'Image: {image_path}')\n", + "\n", + "img, features = extract_features(image_path)\n", + "\n", + "with torch.no_grad():\n", + " seg_logits = seg_model(features, image_size=(image_size, image_size))\n", + " seg_map = seg_logits.argmax(dim=1).squeeze(0).cpu().numpy()\n", + "\n", + "colored_seg = ADE20K_PALETTE[seg_map]\n", + "\n", + "# Print top classes found.\n", + "unique_classes, counts = np.unique(seg_map, return_counts=True)\n", + "top_idx = np.argsort(-counts)[:5]\n", + "print('Top classes:')\n", + "for idx in top_idx:\n", + " cls_id = unique_classes[idx]\n", + " pct = 100 * counts[idx] / seg_map.size\n", + " print(f' {ADE20K_CLASSES[cls_id]:20s} ({pct:.1f}%)')\n", + "\n", + "plt.figure(figsize=(12, 5))\n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(img.resize((image_size, image_size)))\n", + "plt.title('Input Image')\n", + "plt.axis('off')\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(colored_seg)\n", + "plt.title('Semantic Segmentation')\n", + "plt.axis('off')\n", + "plt.tight_layout()\n", + "plt.show()" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d296d2e5", + "cell_type": "markdown", + "source": [ + "### Depth Estimation \u0026 Surface Normals (NYU)" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "7796c341", + "cell_type": "code", + "source": [ + "# @title Run depth and normals inference on NYU images.\n", + "import torch.nn.functional as F\n", + "\n", + "# Collect NYU sample images.\n", + "valid_extensions = ('.ppm', '.jpg', '.jpeg', '.png')\n", + "nyu_images = []\n", + "for root, dirs, files in os.walk(NYU_IMG_DIR):\n", + " for file in files:\n", + " if file.lower().endswith(valid_extensions):\n", + " nyu_images.append(os.path.join(root, file))\n", + "\n", + "selected_images = nyu_images[:3]\n", + "\n", + "for i, image_path in enumerate(selected_images):\n", + " print(f'Processing image {i+1}/{len(selected_images)}: {os.path.basename(image_path)}')\n", + " img, features = extract_features(image_path)\n", + "\n", + " with torch.no_grad():\n", + " # --- Depth ---\n", + " depth_map = depth_model(features, image_size=(image_size, image_size))\n", + " depth_np = depth_map.squeeze().cpu().numpy()\n", + " # Normalize for visualization.\n", + " depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8)\n", + "\n", + " # --- Normals ---\n", + " # Get raw low-res output first.\n", + " normals_map = normals_model(features)\n", + " # L2 normalize.\n", + " normals_map = F.normalize(normals_map, dim=1)\n", + " # Upsample with bicubic for smooth results.\n", + " normals_map = F.interpolate(\n", + " normals_map, size=(image_size, image_size),\n", + " mode='bicubic', align_corners=False,\n", + " )\n", + " # Re-normalize after upsampling.\n", + " normals_map = F.normalize(normals_map, dim=1)\n", + " normals_np = normals_map.squeeze(0).cpu().numpy().transpose(1, 2, 0)\n", + " # Map [-1, 1] -\u003e [0, 1] for display.\n", + " normals_np = np.clip((normals_np + 1.0) / 2.0, 0.0, 1.0)\n", + "\n", + " # Visualize.\n", + " plt.figure(figsize=(15, 5))\n", + " plt.subplot(1, 3, 1)\n", + " plt.imshow(img.resize((image_size, image_size)))\n", + " plt.title(f'Input ({i+1})')\n", + " plt.axis('off')\n", + "\n", + " plt.subplot(1, 3, 2)\n", + " plt.imshow(depth_np, cmap='turbo')\n", + " plt.title(f'Depth ({i+1})')\n", + " plt.axis('off')\n", + "\n", + " plt.subplot(1, 3, 3)\n", + " plt.imshow(normals_np)\n", + " plt.title(f'Surface Normals ({i+1})')\n", + " plt.axis('off')\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ], + "metadata": {}, + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat_minor": 5, + "nbformat": 4 +} From 4a4989ee8fdc5cb22e31bd474c5bf24ec8a35edb Mon Sep 17 00:00:00 2001 From: bingyic <107590227+bingyic@users.noreply.github.com> Date: Tue, 12 May 2026 01:35:12 -0700 Subject: [PATCH 3/6] Fix checkpoint URLs: use Scenic checkpoints for DPT decoders --- pytorch/TIPS_decoder_inference.ipynb | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch/TIPS_decoder_inference.ipynb b/pytorch/TIPS_decoder_inference.ipynb index 7ac788e..b2cab29 100644 --- a/pytorch/TIPS_decoder_inference.ipynb +++ b/pytorch/TIPS_decoder_inference.ipynb @@ -81,7 +81,8 @@ "\n", "variant = 'L' # @param [\"B\", \"L\", \"So\", \"g\"]\n", "\n", - "CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/pytorch'\n", + "VISION_CKPT_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/pytorch'\n", + "DPT_CKPT_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/scenic'\n", "NYU_URL = 'http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part6.zip'\n", "ADE20K_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'\n", "\n", @@ -105,18 +106,19 @@ "image_encoder_checkpoint = os.path.join(CKPT_DIR, vision_ckpt_name)\n", "if not os.path.exists(image_encoder_checkpoint):\n", " print(f'Downloading vision encoder...')\n", - " urllib.request.urlretrieve(f'{CHECKPOINT_BASE_URL}/{vision_ckpt_name}', image_encoder_checkpoint)\n", + " urllib.request.urlretrieve(f'{VISION_CKPT_URL}/{vision_ckpt_name}', image_encoder_checkpoint)\n", "\n", "# Download DPT checkpoints (Segmentation, Depth, Normals).\n", + "# These use Scenic-format checkpoints (Flax .npy arrays in a zip).\n", "dpt_tasks = ['segmentation', 'depth', 'normals']\n", "dpt_checkpoint_paths = {}\n", "for task in dpt_tasks:\n", - " dpt_zip_name = f'{dpt_basename}_{task}_dpt_pytorch.zip'\n", + " dpt_zip_name = f'{dpt_basename}_{task}_dpt.zip'\n", " dpt_zip_path = os.path.join(CKPT_DIR, dpt_zip_name)\n", " if not os.path.exists(dpt_zip_path):\n", " print(f'Downloading DPT {task} checkpoint...')\n", " try:\n", - " urllib.request.urlretrieve(f'{CHECKPOINT_BASE_URL}/{dpt_zip_name}', dpt_zip_path)\n", + " urllib.request.urlretrieve(f'{DPT_CKPT_URL}/{dpt_zip_name}', dpt_zip_path)\n", " except Exception as e:\n", " print(f' Failed: {e}')\n", " dpt_checkpoint_paths[task] = dpt_zip_path\n", From b9d44187df9a90735d1eafbe31432b664e25d8d2 Mon Sep 17 00:00:00 2001 From: Bingyi Cao Date: Thu, 14 May 2026 06:20:26 +0000 Subject: [PATCH 4/6] Fix DPT decoder: add output_activation param, GELU tanh approx, ConvTranspose kernel flip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add output_activation parameter to DPTHead and Decoder (defaults to False for Scenic parity where dpt_head_from_config sets it to False) - Use tanh approximation for GELU to match JAX default behavior - Add ConvTranspose kernel spatial flipping in load_decoder_weights for correct Flax→PyTorch weight conversion --- pytorch/decoders.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/pytorch/decoders.py b/pytorch/decoders.py index 824e522..31ff450 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -136,7 +136,8 @@ def forward( x_flat = x.flatten(2).transpose(1, 2) readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) x_cat = torch.cat([x_flat, readout], dim=-1) - x_proj = F.gelu(self.readout_projects[i](x_cat)) + # JAX GELU uses tanh approximation by default. + x_proj = F.gelu(self.readout_projects[i](x_cat), approximate='tanh') x = x_proj.transpose(1, 2).reshape(b, d, h, w) x = self.out_projections[i](x) x = self.resize_layers[i](x) @@ -164,8 +165,10 @@ def __init__( channels: int = 256, post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), readout_type: str = "project", + output_activation: bool = False, ) -> None: super().__init__() + self.output_activation = output_activation self.reassemble = ReassembleBlocks( input_embed_dim=input_embed_dim, out_channels=post_process_channels, @@ -196,7 +199,10 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) - out = F.relu(out) + # NOTE: Scenic's dpt_head_from_config sets output_activation=False, + # so NO ReLU is applied after the project layer by default. + if self.output_activation: + out = F.relu(out) return out @@ -218,6 +224,7 @@ def __init__( channels: int = 256, post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), readout_type: str = "project", + output_activation: bool = False, ) -> None: super().__init__() self.channels = channels @@ -227,6 +234,7 @@ def __init__( channels=channels, post_process_channels=post_process_channels, readout_type=readout_type, + output_activation=output_activation, ) # Common head for all dense prediction tasks self.head = nn.Linear(self.channels, self.out_channels) @@ -356,6 +364,17 @@ def load_decoder_weights( """ weights = dict(np.load(checkpoint_path, allow_pickle=False)) + # ConvTranspose kernel names — these need spatial flipping. + # Flax ConvTranspose uses transpose_kernel=False (no kernel flip), + # while PyTorch ConvTranspose2d always flips. To compensate, we + # rotate the kernel 180° (flip both spatial dims) during loading. + _CONV_TRANSPOSE_KEYS = { + 'dpt.reassemble.resize_layers.0.weight', + 'dpt.reassemble.resize_layers.1.weight', + 'reassemble.resize_layers.0.weight', + 'reassemble.resize_layers.1.weight', + } + sd = {} for key, value in weights.items(): new_key = key @@ -364,7 +383,11 @@ def load_decoder_weights( if key.startswith(old_prefix): new_key = new_prefix + key[len(old_prefix):] break - sd[new_key] = torch.from_numpy(value) + t = torch.from_numpy(value) + # Flip ConvTranspose kernels spatially for Flax→PyTorch parity. + if new_key in _CONV_TRANSPOSE_KEYS and t.ndim == 4: + t = t.flip([2, 3]) # Rotate 180° (flip H and W dims) + sd[new_key] = t model.load_state_dict(sd, strict=True) print(f"Loaded decoder weights from {checkpoint_path} ({len(sd)} tensors)") From eb4a1f98effe00728703842118467e4a5bcc557b Mon Sep 17 00:00:00 2001 From: Bingyi Cao Date: Thu, 14 May 2026 06:59:06 +0000 Subject: [PATCH 5/6] Add Scenic/Flax checkpoint loading to load_decoder_weights Support loading weights directly from Scenic/Flax checkpoints with automatic key mapping and tensor transposition: - Detect Scenic format via '/' separators in keys - Map Flax param tree names to PyTorch state_dict keys - Transpose conv kernels (HWIO -> OIHW), linear (IO -> OI) - Flip ConvTranspose kernels spatially for Flax parity - Map PreActResidualConvUnit_0/1 to residual_unit/main_unit - Auto-populate registered buffers (e.g. bin_centers) --- pytorch/decoders.py | 225 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 197 insertions(+), 28 deletions(-) diff --git a/pytorch/decoders.py b/pytorch/decoders.py index 31ff450..50094f7 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -345,49 +345,218 @@ def __init__(self, **kwargs) -> None: "pixel_normals.": "head.", } +# Scenic/Flax head param names -> PyTorch head.* +_SCENIC_HEAD_NAMES = { + "pixel_segmentation", + "pixel_depth_classif", + "pixel_depth_regress", + "pixel_normals", + "segmentation_head", +} + +# ConvTranspose keys that need spatial flipping. +_CONV_TRANSPOSE_KEYS = { + 'dpt.reassemble.resize_layers.0.weight', + 'dpt.reassemble.resize_layers.1.weight', +} + + +def _is_scenic_format(keys): + """Check if checkpoint keys use Scenic/Flax naming (``/`` separators).""" + return any('/' in k for k in keys) + + +def _convert_scenic_checkpoint(weights): + """Convert Scenic/Flax checkpoint to PyTorch state_dict. + + Scenic checkpoints use Flax parameter tree naming: + decoder/dpt/reassemble_blocks/out_projection_0/kernel + which maps to PyTorch: + dpt.reassemble.out_projections.0.weight + + Weight conversions: + - Conv kernels: (H, W, Cin, Cout) -> (Cout, Cin, H, W) + - ConvTranspose kernels: same + 180-degree spatial flip + - Dense/Linear kernels: (in, out) -> (out, in) + - Biases: direct copy + """ + sd = {} + + # Build a nested dict from flat Scenic keys + tree = {} + for key, value in weights.items(): + # Strip "decoder/" prefix if present. + k = key[len("decoder/"):] if key.startswith("decoder/") else key + parts = k.split("/") + d = tree + for p in parts[:-1]: + d = d.setdefault(p, {}) + d[parts[-1]] = np.array(value) + + dpt_params = tree.get("dpt", tree) + + # --- ReassembleBlocks --- + rb = dpt_params.get("reassemble_blocks", {}) + for i in range(4): + # out_projections (Conv2d 1x1) + op = rb.get(f"out_projection_{i}", {}) + if "kernel" in op: + sd[f"dpt.reassemble.out_projections.{i}.weight"] = torch.from_numpy( + op["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in op: + sd[f"dpt.reassemble.out_projections.{i}.bias"] = torch.from_numpy( + op["bias"].copy() + ) + # readout_projects (Linear) + rp = rb.get(f"readout_projects_{i}", {}) + if "kernel" in rp: + sd[f"dpt.reassemble.readout_projects.{i}.weight"] = torch.from_numpy( + rp["kernel"].T.copy() + ) + if "bias" in rp: + sd[f"dpt.reassemble.readout_projects.{i}.bias"] = torch.from_numpy( + rp["bias"].copy() + ) + + # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv + for idx in [0, 1]: + rl = rb.get(f"resize_layers_{idx}", {}) + if "kernel" in rl: + w = rl["kernel"][::-1, ::-1, :, :].copy() # 180-degree spatial flip + sd[f"dpt.reassemble.resize_layers.{idx}.weight"] = torch.from_numpy( + w.transpose(2, 3, 0, 1).copy() + ) + if "bias" in rl: + sd[f"dpt.reassemble.resize_layers.{idx}.bias"] = torch.from_numpy( + rl["bias"].copy() + ) + # resize_layers_2 = Identity (no weights) + rl3 = rb.get("resize_layers_3", {}) + if "kernel" in rl3: + sd["dpt.reassemble.resize_layers.3.weight"] = torch.from_numpy( + rl3["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in rl3: + sd["dpt.reassemble.resize_layers.3.bias"] = torch.from_numpy( + rl3["bias"].copy() + ) + + # --- Convs (3x3, no bias) --- + for i in range(4): + c = dpt_params.get(f"convs_{i}", {}) + if "kernel" in c: + sd[f"dpt.convs.{i}.weight"] = torch.from_numpy( + c["kernel"].transpose(3, 2, 0, 1).copy() + ) + + # --- Fusion blocks --- + for i in range(4): + fb = dpt_params.get(f"fusion_blocks_{i}", {}) + if i == 0: + # No residual unit, only 1 PreActResidualConvUnit -> main_unit + pacu = fb.get("PreActResidualConvUnit_0", {}) + for cname in ["conv1", "conv2"]: + if cname in pacu and "kernel" in pacu[cname]: + sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + else: + # Residual unit (index 0) + main unit (index 1) + pacu0 = fb.get("PreActResidualConvUnit_0", {}) + pacu1 = fb.get("PreActResidualConvUnit_1", {}) + for cname in ["conv1", "conv2"]: + if cname in pacu0 and "kernel" in pacu0[cname]: + sd[f"dpt.fusion_blocks.{i}.residual_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu0[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + if cname in pacu1 and "kernel" in pacu1[cname]: + sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu1[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + # out_conv (Conv2d 1x1) -- Scenic names it Conv_0 + oc = fb.get("Conv_0", fb.get("out_conv", {})) + if "kernel" in oc: + sd[f"dpt.fusion_blocks.{i}.out_conv.weight"] = torch.from_numpy( + oc["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in oc: + sd[f"dpt.fusion_blocks.{i}.out_conv.bias"] = torch.from_numpy( + oc["bias"].copy() + ) + + # --- Project --- + proj = dpt_params.get("project", {}) + if "kernel" in proj: + sd["dpt.project.weight"] = torch.from_numpy( + proj["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in proj: + sd["dpt.project.bias"] = torch.from_numpy(proj["bias"].copy()) + + # --- Task head (Dense/Linear) --- + for head_name in _SCENIC_HEAD_NAMES: + if head_name in tree: + h = tree[head_name] + if "kernel" in h: + sd["head.weight"] = torch.from_numpy(h["kernel"].T.copy()) + if "bias" in h: + sd["head.bias"] = torch.from_numpy(h["bias"].copy()) + break + + return sd + def load_decoder_weights( model: Decoder, checkpoint_path: str, ) -> Decoder: - """Load pre-converted PyTorch weights into a Decoder. + """Load weights into a Decoder from a checkpoint file. - Supports both the legacy flat key format (e.g. ``reassemble.…``) and the - new hierarchical format (e.g. ``dpt.reassemble.…``). + Supports three checkpoint formats: + 1. Scenic/Flax format: keys with ``/`` separators and ``kernel``/``bias`` + naming (e.g. ``decoder/dpt/reassemble_blocks/out_projection_0/kernel``). + Weights are automatically transposed from Flax layout to PyTorch layout. + 2. Legacy flat format: keys like ``reassemble.…`` that get remapped to + ``dpt.reassemble.…``. + 3. PyTorch hierarchical format: keys already match the model state_dict. Args: model: A Decoder instance (SegmentationDecoder, DepthDecoder, etc.). - checkpoint_path: Path to a checkpoint file. + checkpoint_path: Path to a ``.npz`` checkpoint file. Returns: The model with loaded weights. """ weights = dict(np.load(checkpoint_path, allow_pickle=False)) - # ConvTranspose kernel names — these need spatial flipping. - # Flax ConvTranspose uses transpose_kernel=False (no kernel flip), - # while PyTorch ConvTranspose2d always flips. To compensate, we - # rotate the kernel 180° (flip both spatial dims) during loading. - _CONV_TRANSPOSE_KEYS = { - 'dpt.reassemble.resize_layers.0.weight', - 'dpt.reassemble.resize_layers.1.weight', - 'reassemble.resize_layers.0.weight', - 'reassemble.resize_layers.1.weight', - } - - sd = {} - for key, value in weights.items(): - new_key = key - # Remap legacy flat keys to hierarchical names. - for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items(): - if key.startswith(old_prefix): - new_key = new_prefix + key[len(old_prefix):] - break - t = torch.from_numpy(value) - # Flip ConvTranspose kernels spatially for Flax→PyTorch parity. - if new_key in _CONV_TRANSPOSE_KEYS and t.ndim == 4: - t = t.flip([2, 3]) # Rotate 180° (flip H and W dims) - sd[new_key] = t + if _is_scenic_format(weights): + sd = _convert_scenic_checkpoint(weights) + else: + # Legacy flat or PyTorch hierarchical format. + sd = {} + for key, value in weights.items(): + new_key = key + for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items(): + if key.startswith(old_prefix): + new_key = new_prefix + key[len(old_prefix):] + break + t = torch.from_numpy(value) + # Flip ConvTranspose kernels spatially for Flax->PyTorch parity. + if new_key in _CONV_TRANSPOSE_KEYS and t.ndim == 4: + t = t.flip([2, 3]) + sd[new_key] = t + + # Add registered buffers not in checkpoint (e.g. bin_centers). + for name, buf in model.named_buffers(): + if name not in sd: + sd[name] = buf model.load_state_dict(sd, strict=True) print(f"Loaded decoder weights from {checkpoint_path} ({len(sd)} tensors)") From 00f5fb07a087850fc09d4345b6714c1acd8e5bec Mon Sep 17 00:00:00 2001 From: Bingyi Cao Date: Thu, 14 May 2026 09:08:00 +0000 Subject: [PATCH 6/6] Remove references to Scenic in comments --- pytorch/decoders.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch/decoders.py b/pytorch/decoders.py index 50094f7..72e6595 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -199,7 +199,7 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) - # NOTE: Scenic's dpt_head_from_config sets output_activation=False, + # NOTE: By default, the reference implementation does not apply output activation, # so NO ReLU is applied after the project layer by default. if self.output_activation: out = F.relu(out) @@ -304,7 +304,7 @@ def forward( # Output shape: (B, num_depth_bins, H', W') logits = super().forward(intermediate_features) - # 2. Classification-based depth prediction (following Scenic/AdaBins): + # 2. Classification-based depth prediction: # relu + shift -> linear normalisation -> expectation over bins. logits = torch.relu(logits) + self.min_depth probs = logits / torch.sum(logits, dim=1, keepdim=True) @@ -367,9 +367,9 @@ def _is_scenic_format(keys): def _convert_scenic_checkpoint(weights): - """Convert Scenic/Flax checkpoint to PyTorch state_dict. + """Convert Flax parameter tree checkpoint to PyTorch state_dict. - Scenic checkpoints use Flax parameter tree naming: + These checkpoints use Flax parameter tree naming: decoder/dpt/reassemble_blocks/out_projection_0/kernel which maps to PyTorch: dpt.reassemble.out_projections.0.weight @@ -520,7 +520,7 @@ def load_decoder_weights( """Load weights into a Decoder from a checkpoint file. Supports three checkpoint formats: - 1. Scenic/Flax format: keys with ``/`` separators and ``kernel``/``bias`` + 1. Flax format: keys with ``/`` separators and ``kernel``/``bias`` naming (e.g. ``decoder/dpt/reassemble_blocks/out_projection_0/kernel``). Weights are automatically transposed from Flax layout to PyTorch layout. 2. Legacy flat format: keys like ``reassemble.…`` that get remapped to