Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# EmbodiedGen on AMD ROCm (gfx942 / MI300X / MI308X), ROCm 6.4.3 + PyTorch 2.6.
#
# Builds the FULL image-to-3D generation stack on ROCm by swapping the CUDA-only
# libraries for verified ROCm equivalents (rocm-lib-compat / ZJLi2013/rocm3d):
# spconv-cu120 -> spconv_rocm | nvdiffrast -> nvdiffrast@rocm | gsplat -> amd_gsplat
# pytorch3d -> ROCm 6.4/py3.12 wheel | flash-attn -> FA2-Triton | numpy -> <2
# kaolin (CUDA-only) -> sitecustomize stub (texture/mesh-IO stage only)
# Verified on AMD Instinct MI308X: SAM3D + TRELLIS pipelines import & initialize
# (spconv backend + flash_attn; SAM3D attention -> sdpa). See docker/README.md.
#
# Build (from repo root, submodules checked out):
# docker build -f docker/Dockerfile.rocm -t embodiedgen:rocm6.4.3 .
# Run img3d (GPT-free / no texture-bake smoke):
# docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \
# --shm-size 32g -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
# embodiedgen:rocm6.4.3 python -m embodied_gen.models.sam3d

FROM rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0

ENV DEBIAN_FRONTEND=noninteractive \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
PIP_ROOT_USER_ACTION=ignore \
PYTORCH_ROCM_ARCH=gfx942 \
GPU_ARCHS=gfx942 \
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
PYTHONPATH=/workspace/EmbodiedGen

WORKDIR /workspace/EmbodiedGen

# Source tree (incl. thirdparty/TRELLIS, thirdparty/sam3d submodules) is required
# for install_rocm.sh (cleans requirements.txt, builds extensions, installs stub).
COPY . /workspace/EmbodiedGen

# Install the ROCm generation stack. Compiles spconv_rocm / nvdiffrast / flash-attn
# from source (hipcc), so this layer is the slow one.
RUN bash docker/install_rocm.sh

# Smoke: the full img3d import+init chain on ROCm (no model download / no GPT).
CMD ["python", "-c", "import embodied_gen.models.sam3d, thirdparty.TRELLIS.trellis.pipelines; print('EmbodiedGen ROCm generation image OK')"]
78 changes: 78 additions & 0 deletions docker/README.rocm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# EmbodiedGen on AMD ROCm (MI300X / MI308X)

Run the EmbodiedGen **image-to-3D generation** stack on AMD GPUs (gfx942) with
ROCm 6.4.3 + PyTorch 2.6, by swapping the CUDA-only libraries for verified ROCm
equivalents. Verified on **AMD Instinct MI308X**: the SAM3D and TRELLIS pipelines
import and initialize (spconv backend + flash-attn; SAM3D attention auto-selects
`sdpa`).

> Library swaps follow the `rocm-lib-compat` reference
> ([ZJLi2013/rocm3d](https://github.com/ZJLi2013/rocm3d)). The same TRELLIS-v1
> stack is independently verified there via `VAST-AI/AniGen`.

## TL;DR

```bash
# from repo root, with submodules checked out:
git submodule update --init --recursive
docker build -f docker/Dockerfile.rocm -t embodiedgen:rocm6.4.3 .

# import+init smoke (no download / no GPT):
docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \
--shm-size 32g -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
embodiedgen:rocm6.4.3

# full GPT-free image->3D (downloads facebook/sam-3d-objects, ~15GB; saves splat.ply):
docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \
--shm-size 32g -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
-v $PWD:/workspace/EmbodiedGen embodiedgen:rocm6.4.3 \
python -m embodied_gen.models.sam3d
```

To run the swaps directly in a base container instead of building the image:

```bash
docker run -it --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 32g \
-v $PWD:/workspace/EmbodiedGen -w /workspace/EmbodiedGen \
rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0 \
bash docker/install_rocm.sh
```

## CUDA -> ROCm dependency map

| Upstream (CUDA) | ROCm replacement | Status on MI308X |
|---|---|---|
| `spconv-cu120` | [`ZJLi2013/spconv_rocm`](https://github.com/ZJLi2013/spconv_rocm) (source) | ✅ import OK |
| `nvdiffrast` | [`ZJLi2013/nvdiffrast@rocm`](https://github.com/ZJLi2013/nvdiffrast) | ✅ import OK |
| `gsplat` | `amd_gsplat` (`pypi.amd.com/rocm-6.4.3`), import name `gsplat` | ✅ default GS backend |
| `pytorch3d` | ROCm 6.4 / py3.12 prebuilt wheel | ✅ import OK |
| `flash-attn` | FA2-Triton (`FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE` at install **and** runtime) | ✅ import OK |
| `xformers` | not needed — SAM3D attention auto-selects `sdpa` | ✅ skipped |
| `numpy` (base ships 2.x) | pin `numpy<2` (diffusers/transformers requirement) | ✅ |
| `kaolin` (no ROCm wheel; setup.py requires `nvcc`) | `sitecustomize` stub (`docker/kaolin_stub.py`) | ⚠️ texture-stage only |
| `diff-gaussian-rasterization` | optional 'inria' GS backend (gsplat is default) | ⏸ optional |

## The kaolin stub (`docker/kaolin_stub.py`)

`kaolin` is CUDA-only and is imported at the top of `embodied_gen/data/utils.py`,
but is only **used** inside the texture-backprojection / mesh-IO stage
(`kal.io.*.import_mesh`, `kal.render.materials.PBRMaterial`,
`kaolin.render.camera.Camera`) and as type references in `thirdparty/sam3d`.
None of it is on the core geometry+gaussian generation path. The stub (installed
as `sitecustomize.py`) fabricates any `kaolin.*` module so every `import kaolin`
resolves; the texture stage raises a clear error if actually invoked. This mirrors
the proven `ZJLi2013/RealWonder` bypass (~85% pipeline usable on ROCm).

The upstream-friendly long-term fix is to make the kaolin imports in
`data/utils.py` lazy/optional so the stub is unnecessary.

## Known gaps

- **Texture backprojection** (`backproject_v3` / `differentiable_render`) calls
real kaolin mesh-IO and is not available under the stub. Core image-to-3D
(segmentation -> SAM3D geometry + gaussian + mesh export) runs without it.
- **GPT quality-checkers / URDF semantics** (`img3d-cli`) need a GPT key; the
`python -m embodied_gen.models.sam3d` path skips them entirely.
- **`diff-gaussian-rasterization`** (mip-splatting / antialiasing fork) needs
`__trap`->`abort` and a `cooperative_groups/reduce.h` shim to build on ROCm;
it is optional because `gsplat` is the default gaussian backend.
122 changes: 122 additions & 0 deletions docker/install_rocm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/bin/bash
# EmbodiedGen ROCm install (gfx942 / MI300X / MI308X), ROCm 6.4.3 + PyTorch 2.6.
# Swaps the CUDA-only generation stack for ROCm-compatible builds following the
# rocm-lib-compat reference (github.com/ZJLi2013/rocm3d). Intended to run inside
# rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0 with the repo at
# /workspace/EmbodiedGen. Each step reports PASS/FAIL but does not abort, so one
# run yields a full ROCm-compat status map.
set -uo pipefail

export PYTORCH_ROCM_ARCH=gfx942
export GPU_ARCHS=gfx942
export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE
export PIP_ROOT_USER_ACTION=ignore
REPO=${REPO:-/workspace/EmbodiedGen}
cd "$REPO"

PASS=(); FAIL=()
step () { # step "name" cmd...
local name="$1"; shift
echo "==================== STEP: $name ===================="
if "$@"; then echo "[PASS] $name"; PASS+=("$name");
else echo "[FAIL] $name"; FAIL+=("$name"); fi
}
pipi () { pip install --no-cache-dir "$@"; }

# --- 0. keep ROCm torch from base image (do NOT reinstall cu118 torch) ---
python -c "import torch;print('base torch',torch.__version__,'hip',torch.version.hip,'gpu',torch.cuda.is_available())"

# --- 1. requirements.txt minus CUDA-pinned libs (handled below or via base) ---
EXCLUDE='torch|torchvision|torchaudio|xformers|gsplat|flash.attn|flash-attn|triton|spconv|spconv-cu120|pytorch3d'
grep -vEi "^(${EXCLUDE})([<>=!~;[:space:]]|$)" requirements.txt > /tmp/req_clean.txt
echo "--- cleaned requirements (CUDA libs stripped) ---"; cat /tmp/req_clean.txt
step "requirements(clean)" pipi -r /tmp/req_clean.txt --use-deprecated=legacy-resolver
# numpy: EmbodiedGen's diffusers/transformers REQUIRE numpy<2. (The rocm-lib-compat
# "use docker numpy 2.x" guidance does NOT apply here; the base image ships numpy 2.x.)
step "numpy<2" pipi "numpy<2"
# NOTE: xformers is NOT required -- SAM3D attention auto-selects the `sdpa` backend on
# ROCm, and TRELLIS uses spconv+flash_attn. Skipping xformers avoids the torch 2.9.1
# bump that would break the pytorch3d/gsplat ROCm 6.4 wheels.

# --- 2. ROCm replacements for the CUDA-only generation stack (rocm-lib-compat) ---
# All verified on gfx942 / MI300X, ROCm 6.4 (same stack as VAST-AI/AniGen,
# a TRELLIS-v1 image-to-3D repo, in the rocm3d supported-repo list).
# spconv (CUDA spconv-cu120 -> ZJLi2013/spconv_rocm)
step "spconv_rocm" bash -c '
rm -rf /tmp/spconv_rocm &&
git clone --depth 1 -b rocm https://github.com/ZJLi2013/spconv_rocm.git /tmp/spconv_rocm &&
pip install --no-cache-dir -e /tmp/spconv_rocm'
# nvdiffrast (NVlabs -> ZJLi2013/nvdiffrast@rocm)
step "nvdiffrast_rocm" pipi "git+https://github.com/ZJLi2013/nvdiffrast.git@rocm" --no-build-isolation
# gsplat (-> amd_gsplat prebuilt; import name stays `gsplat`; default gaussian backend)
step "amd_gsplat" pipi amd_gsplat --extra-index-url=https://pypi.amd.com/rocm-6.4.3/simple/
# pytorch3d (-> prebuilt ROCm 6.4 / py3.12 wheel)
step "pytorch3d_rocm" pipi https://github.com/ZJLi2013/pytorch3d/releases/download/rocm6.4-py3.12/pytorch3d-0.7.9-cp312-cp312-linux_x86_64.whl
# flash-attn (FA2 Triton on ROCm 6.4). NOTE: requires FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE
# at BOTH install and runtime; otherwise import falls back to the CUDA `flash_attn_2_cuda`.
step "flash_attn(triton)" bash -c 'FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE pip install --no-cache-dir flash-attn --no-build-isolation'

# --- 3. pure git deps from upstream install_basic.sh (CUDA-agnostic) ---
step "utils3d" pipi "utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15"
step "clip" pipi "clip@git+https://github.com/openai/CLIP.git"
step "segment_anything" pipi "segment-anything@git+https://github.com/facebookresearch/segment-anything.git@dca509f"
step "kolors" pipi "kolors@git+https://github.com/HochCC/Kolors.git"
step "MoGe" pipi "MoGe@git+https://github.com/microsoft/MoGe.git@a8c3734"

# --- 4. OPTIONAL: diff-gaussian-rasterization ('inria' gaussian backend) ---
# img3d's default gaussian backend is gsplat (amd_gsplat, above), and both TRELLIS
# and SAM3D guard the diff_gaussian_rasterization import in try/except, so this is
# optional. The CUDA-clean ROCm source is graphdeco-inria built via expenses/
# gaussian-splatting's ROCm branch (rocm3d supported-repo list). EmbodiedGen wires
# TRELLIS to the *mip-splatting* antialiasing fork, whose ROCm build additionally
# needs `__trap`->abort and a cooperative_groups/reduce.h shim (PR candidate).
# Left out of the default install; uncomment to add the non-AA 'inria' backend:
# touch /opt/rocm/include/device_launch_parameters.h
# step "diff_gaussian_rasterization" bash -c '
# rm -rf /tmp/dgr &&
# git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization /tmp/dgr &&
# cd /tmp/dgr && git submodule update --init --recursive &&
# PYTORCH_ROCM_ARCH=gfx942 pip install --no-cache-dir . --no-build-isolation'

# --- 5. ROCm runtime shims, installed as sitecustomize (run at interpreter startup) ---
# (a) kaolin bypass: kaolin is CUDA-only (no ROCm wheel, setup.py hard-requires nvcc).
# Imported at module top of embodied_gen/data/utils.py but only *used* inside the
# texture-backprojection / mesh-IO stage (kal.io.*.import_mesh, render.materials,
# render.camera) + type refs in thirdparty/sam3d -- none on the core image->3D
# geometry+gaussian path. Stub pattern proven on ZJLi2013/RealWonder (same
# SAM-3D-Objects/kaolin dep). check_tensor-style validators return truthy.
# (b) spconv KRSC->Native weight bridge: SAM3D/TRELLIS checkpoints store sparse-conv
# weights in CUDA spconv's ImplicitGemm KRSC layout (5D [out,k,k,k,in]); spconv_rocm
# falls back to the Native algo (3D [Kvol,in,out], 2D when kvol==1), so load_state_dict
# mismatches. The shim converts on load. (Upstream fix belongs in spconv_rocm.)
SITE=$(python -c "import site;print(site.getsitepackages()[0])")
HERE="$(dirname "$0")"
if cp "$HERE/kaolin_stub.py" "$SITE/kaolin_stub.py" \
&& cp "$HERE/spconv_rocm_compat.py" "$SITE/spconv_rocm_compat.py" \
&& printf 'import kaolin_stub\nimport spconv_rocm_compat\n' > "$SITE/sitecustomize.py"; then
echo "[PASS] rocm-shims -> $SITE/sitecustomize.py (kaolin_stub + spconv_rocm_compat)"; PASS+=("rocm-shims")
else
echo "[FAIL] rocm-shims copy"; FAIL+=("rocm-shims")
fi

# --- 6. import smoke: what actually loads on ROCm (flash-attn needs the env var) ---
echo "==================== IMPORT SMOKE ===================="
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE python - <<'PY'
core = ["torch","spconv","nvdiffrast.torch","gsplat","pytorch3d","flash_attn","trimesh","diffusers"]
optional = ["diff_gaussian_rasterization","kaolin"]
import importlib
def check(m):
try:
importlib.import_module(m); print(f"[import OK ] {m}")
except Exception as e:
print(f"[import ERR] {m}: {type(e).__name__}: {str(e)[:160]}")
print("-- core --"); [check(m) for m in core]
print("-- optional --"); [check(m) for m in optional]
import torch
print("torch", torch.__version__, "hip", torch.version.hip, "gpu", torch.cuda.is_available())
PY

echo "==================== SUMMARY ===================="
echo "PASS (${#PASS[@]}): ${PASS[*]}"
echo "FAIL (${#FAIL[@]}): ${FAIL[*]}"
echo "INSTALL_ROCM_DONE"
70 changes: 70 additions & 0 deletions docker/kaolin_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""ROCm kaolin bypass for EmbodiedGen (generalized from ZJLi2013/RealWonder).

kaolin is CUDA-only (no ROCm wheel; setup.py hard-requires nvcc). In EmbodiedGen
it is imported at module top of `embodied_gen/data/utils.py` and used only inside
the texture-backprojection / differentiable-render stage (`kal.io.*.import_mesh`,
`kal.render.materials.PBRMaterial`, `kaolin.render.camera.Camera`), plus type-level
references in thirdparty/sam3d. None of it is on the core image->3D geometry+gaussian
generation path (gsplat is the gaussian backend), so stubbing `kaolin` lets img3d-cli
run on ROCm. The texture-baking stage that actually calls these will surface a clear
error instead of crashing every import.

Activation (must run before any `import kaolin`):
- drop this file's directory on PYTHONPATH as `sitecustomize.py`, or
- `import kaolin_stub` as the very first import of the entrypoint.

This is the ROCm-unblock shim; the upstream-PR-appropriate fix is to make the kaolin
imports in `data/utils.py` lazy/optional (see docs/exp18.md).
"""
import importlib.abc
import importlib.machinery
import sys
import types


class _KaolinStubModule(types.ModuleType):
def __init__(self, name):
super().__init__(name)
self.__file__ = "<kaolin-stub>"
self.__path__ = []
self.__spec__ = None

def __getattr__(self, name):
# Let dunder lookups (e.g. __file__, __wrapped__, __all__) behave normally.
if name.startswith("__") and name.endswith("__"):
raise AttributeError(name)
# Capitalized -> isinstance-safe stub class (Camera, PBRMaterial, ...).
if name and name[0].isupper():
return type(name, (), {})
# Lowercase -> a no-op callable that also behaves like a submodule.
stub = _KaolinCallableStub(f"{self.__name__}.{name}")
return stub


class _KaolinCallableStub(_KaolinStubModule):
def __call__(self, *args, **kwargs):
# Return a truthy no-op. The only kaolin calls on the core geometry/mesh path
# are validators like `kaolin.utils.testing.check_tensor(...)`, used inside
# `assert torch.is_tensor(x) and check_tensor(...)`, which need a truthy return.
# Data-returning kaolin calls (kal.io.*.import_mesh, render.*) live in the
# texture-backprojection stage and will fail fast downstream (documented gap).
return True


class _KaolinFinder(importlib.abc.MetaPathFinder, importlib.abc.Loader):
def find_spec(self, name, path=None, target=None):
if name == "kaolin" or name.startswith("kaolin."):
return importlib.machinery.ModuleSpec(name, self)
return None

def create_module(self, spec):
return _KaolinStubModule(spec.name)

def exec_module(self, module):
pass


if "kaolin" not in sys.modules and not any(
isinstance(f, _KaolinFinder) for f in sys.meta_path
):
sys.meta_path.insert(0, _KaolinFinder())
60 changes: 60 additions & 0 deletions docker/spconv_rocm_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""spconv KRSC->Native weight bridge for ROCm.

CUDA spconv 2.3.x defaults to the ImplicitGemm conv algo, whose `SparseConvolution`
weights are stored in KRSC layout: 5D `[out_channels, kd, kh, kw, in_channels]`.
ROCm `spconv_rocm` (2.3.8+rocm1) lacks the implicit-gemm kernels and falls back to
the Native algo, whose weights are 3D `[kernel_volume, in_channels, out_channels]`.

So checkpoints trained on CUDA (e.g. facebook/sam-3d-objects, SAM3D, TRELLIS) fail to
load on ROCm with errors like:
size mismatch ... copying a param with shape [128, 3, 3, 3, 128]
the shape in current model is [27, 128, 128]

This patches `torch.nn.Module.load_state_dict` to transparently convert any 5D KRSC
spconv weight into the 3D Native layout when (and only when) the destination model
parameter is the matching 3D shape. It is a no-op on CUDA / already-native weights.

Activation: import before loading any spconv checkpoint (e.g. via sitecustomize).
This is the ROCm-unblock shim; the upstream-appropriate fix belongs in spconv_rocm
(accept KRSC checkpoints under the Native algo).
"""
import torch

_orig_load_state_dict = torch.nn.Module.load_state_dict


def _krsc_to_native(w: torch.Tensor):
# KRSC [out, kd, kh, kw, in] -> Native [kd*kh*kw, in, out]; the Native algo squeezes
# kernel_volume==1 (1x1x1 conv) to 2D [in, out].
out, kd, kh, kw, inc = w.shape
kvol = kd * kh * kw
native = w.permute(1, 2, 3, 4, 0).contiguous().reshape(kvol, inc, out)
return native, kvol, inc, out


def _patched_load_state_dict(self, state_dict, strict=True, *args, **kwargs):
try:
own = self.state_dict()
except Exception:
return _orig_load_state_dict(self, state_dict, strict=strict, *args, **kwargs)

converted = 0
fixed = dict(state_dict)
for name, val in state_dict.items():
tgt = own.get(name)
if tgt is None or not hasattr(val, "ndim") or val.ndim != 5:
continue
native, kvol, inc, out = _krsc_to_native(val)
if tgt.ndim == 3 and tuple(tgt.shape) == (kvol, inc, out):
fixed[name] = native
converted += 1
elif tgt.ndim == 2 and kvol == 1 and tuple(tgt.shape) == (inc, out):
fixed[name] = native.reshape(inc, out)
converted += 1
if converted:
print(f"[spconv-rocm-compat] converted {converted} KRSC->Native conv weights")
return _orig_load_state_dict(self, fixed, strict=strict, *args, **kwargs)


if getattr(torch.nn.Module.load_state_dict, "__name__", "") != "_patched_load_state_dict":
torch.nn.Module.load_state_dict = _patched_load_state_dict