From ab7f7fe300ac78b967dee1789ca12295f49c69c6 Mon Sep 17 00:00:00 2001 From: Ruthwik Date: Mon, 27 Apr 2026 14:22:53 -0700 Subject: [PATCH 01/45] initial commit: dataprep step spec --- dimos/learning/dataprep.py | 195 ++++++++++++++++++++++++++++ dimos/learning/dataset.example.yaml | 81 ++++++++++++ dimos/learning/formats/hdf5.py | 35 +++++ dimos/learning/formats/lerobot.py | 35 +++++ dimos/learning/formats/rlds.py | 34 +++++ dimos/learning/spec.py | 147 +++++++++++++++++++++ 6 files changed, 527 insertions(+) create mode 100644 dimos/learning/dataprep.py create mode 100644 dimos/learning/dataset.example.yaml create mode 100644 dimos/learning/formats/hdf5.py create mode 100644 dimos/learning/formats/lerobot.py create mode 100644 dimos/learning/formats/rlds.py create mode 100644 dimos/learning/spec.py diff --git a/dimos/learning/dataprep.py b/dimos/learning/dataprep.py new file mode 100644 index 0000000000..53b1e7c411 --- /dev/null +++ b/dimos/learning/dataprep.py @@ -0,0 +1,195 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset builder/loader for the DimOS Learning Framework. + +Reads a `DatasetSpec` (see `dimos.learning.spec`) and either: + - builds a training-ready dataset on disk in HDF5/RLDS/LeRobot, or + - returns a PyTorch Dataset for training. + +The same spec also drives inference observation construction. + +Workflow: + # 1. Record a teleop session (Sam's PR #1708) + dimos --blueprint quest_teleop_xarm7 --record-path session.db + + # 2. Build a training-ready dataset + python -m dimos.learning.dataprep build dataset.yaml + + # 3. Train using the same spec + from dimos.learning.dataprep import load_dataset, load_spec + spec = load_spec("dataset.yaml") + ds = load_dataset(spec) +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np + +from dimos.learning.spec import ( + DatasetSpec, + Episode, + EpisodeConfig, + FieldRef, + FilterConfig, + OutputConfig, + Sample, +) + +Writer = Callable[[Iterator[Sample], OutputConfig], Path] + +if TYPE_CHECKING: + import torch + + from dimos.memory2.store.sqlite import SqliteStore + + +# ───────────────────────────────────────────────────────────────────────────── +# Spec I/O +# ───────────────────────────────────────────────────────────────────────────── + + +def load_spec(path: str | Path) -> DatasetSpec: + """Load a DatasetSpec from .yaml/.yml/.json (dispatch by extension).""" + raise NotImplementedError + + +def save_spec(spec: DatasetSpec, path: str | Path) -> None: + """Write a DatasetSpec back to .yaml/.yml/.json (round-trip safe).""" + raise NotImplementedError + + +# ───────────────────────────────────────────────────────────────────────────── +# Episode extraction +# ───────────────────────────────────────────────────────────────────────────── + + +def extract_episodes(store: SqliteStore, cfg: EpisodeConfig) -> list[Episode]: + """Extract episode boundaries from the recording per the configured strategy. + + BUTTONS: scan cfg.button_stream for rising edges on cfg.start/save/discard. + State machine: + IDLE --start press--> RECORDING (begin episode) + RECORDING --save press--> IDLE (commit, success=True) + RECORDING --discard press--> IDLE (drop) + RECORDING --start press--> RECORDING (auto-commit, begin new) + session ends mid-episode: always discard + + RANGES: emit one Episode per (start_ts, end_ts) tuple in cfg.ranges. + + WHOLE: emit a single Episode covering the entire recording's time range. + """ + raise NotImplementedError + + +def filter_episodes(eps: list[Episode], cfg: FilterConfig | None) -> list[Episode]: + """Apply success/duration/label whitelist filters. None = pass-through.""" + raise NotImplementedError + + +# ───────────────────────────────────────────────────────────────────────────── +# Stream synchronization (build per-timestep samples) +# ───────────────────────────────────────────────────────────────────────────── + + +def iter_samples( + store: SqliteStore, + episode: Episode, + spec: DatasetSpec, +) -> Iterator[Sample]: + """Yield synced (obs, action) Samples for one episode. + + Walks the anchor stream at sync.rate_hz between episode.start_ts and + episode.end_ts. For each anchor timestamp, pulls the nearest observation/ + action from each configured stream within sync.tolerance_ms. Applies any + declared preprocess (e.g. jpeg_decode for Image, field projection for + JointState). Skips frames where any required stream lacks a sample within + tolerance. + """ + raise NotImplementedError + + +def _resolve_field(msg: Any, ref: FieldRef) -> np.ndarray: + """Pull a single field from a stream message and convert to np.ndarray. + + Applies ref.field projection (attribute access) and ref.preprocess hook + (named transform like jpeg_decode). Returns a numpy array suitable for + inclusion in a Sample. + """ + raise NotImplementedError + + +# ───────────────────────────────────────────────────────────────────────────── +# Public API +# ───────────────────────────────────────────────────────────────────────────── + + +def _get_writer(format_name: str) -> Writer: + """Lazy-import the `write` function for a given format. Avoids loading + heavy deps (h5py, tfds, lerobot) for unused formats.""" + if format_name == "lerobot": + from dimos.learning.formats.lerobot import write + elif format_name == "hdf5": + from dimos.learning.formats.hdf5 import write + elif format_name == "rlds": + from dimos.learning.formats.rlds import write + else: + raise ValueError( + f"Unknown dataset format: {format_name!r}. Supported: lerobot, hdf5, rlds." + ) + return write + + +def build_dataset(spec: DatasetSpec) -> Path: + """End-to-end: raw session.db -> on-disk dataset in spec.output.format. + + Returns the path written. Requires spec.output to be set. Dispatches to + the appropriate writer in `dimos.learning.formats` via `_get_writer`. + """ + raise NotImplementedError + + +def load_dataset(spec: DatasetSpec) -> torch.utils.data.Dataset[Sample]: + """Training-time loader: returns a PyTorch Dataset over the source recording. + + Materializes Samples on the fly (lazy). Does not require spec.output. + Pre-extracts episodes once and indexes anchor timestamps for O(1) __getitem__. + """ + raise NotImplementedError + + +def inspect(spec: DatasetSpec) -> dict[str, Any]: + """Stats for a session: episode count, duration distribution, per-stream counts. + + Used by `python -m dimos.learning.dataset inspect`. + """ + raise NotImplementedError + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────────────────────── + + +def main() -> None: + """CLI entrypoint: build / inspect a dataset spec.""" + raise NotImplementedError + + +if __name__ == "__main__": + main() diff --git a/dimos/learning/dataset.example.yaml b/dimos/learning/dataset.example.yaml new file mode 100644 index 0000000000..df01cb03f7 --- /dev/null +++ b/dimos/learning/dataset.example.yaml @@ -0,0 +1,81 @@ +# DimOS Learning — DatasetSpec template +# +# This file is the contract between data collection and training. Same spec is +# used by `python -m dimos.learning.dataset build` (export to disk) and by +# `load_dataset(spec)` (training-time PyTorch Dataset). +# +# Stream names below must match the topic names recorded by RecordReplay. To +# discover them: `python -m dimos.learning.dataset inspect dataset.yaml` +# (or look in the SQLite registry of session.db). + +# ─── Source recording ──────────────────────────────────────────────────────── +source: ./session.db + +# ─── How to slice the session into episodes ────────────────────────────────── +episodes: + extractor: buttons # buttons | ranges | whole_session + + # BUTTONS extractor: state machine over the recorded button stream + button_stream: buttons # stream name (matches LCM topic, sanitized) + start: A # rising edge -> begin episode + save: B # rising edge -> end + save + discard: X # rising edge -> end + drop + # Note: if recording stops mid-episode without an explicit save/discard, + # the in-progress episode is always discarded. + + # RANGES extractor: explicit absolute timestamps (only used when extractor: ranges) + # ranges: + # - [1730000000.0, 1730000045.5] + # - [1730000060.0, 1730000110.2] + + default_task_label: pick_red_cube # optional; applied to every extracted episode + +# ─── What goes into each timestep ──────────────────────────────────────────── +# Each entry: dataset_key -> { stream, type?, field?, preprocess? } +# stream — recorded stream name (LCM topic, sanitized) +# type — optional dotted message type (for codec dispatch) +# field — attribute on the message; omit to keep the whole message +# preprocess — named transform applied after field projection + +observation: + image: + stream: camera_color_image + type: sensor_msgs.Image + preprocess: jpeg_decode # raw JPEG bytes -> HxWx3 uint8 + joint_pos: + stream: coordinator_joint_state + type: sensor_msgs.JointState + field: position + joint_vel: + stream: coordinator_joint_state + type: sensor_msgs.JointState + field: velocity + +action: + target_pos: + stream: coordinator_joint_command + type: sensor_msgs.JointState + field: position + +# ─── Synchronization (build per-timestep samples) ──────────────────────────── +sync: + anchor: image # which observation key drives the timeline + rate_hz: 30 # downsample anchor to this rate; 0 = native + tolerance_ms: 50 # max time delta when picking nearest sample + strategy: nearest # nearest | interp + +# ─── Per-episode filters (optional) ────────────────────────────────────────── +filters: + success_only: true + min_duration_s: 1.0 + # max_duration_s: 60.0 + # task_labels: [pick_red_cube, pick_blue_cube] + +# ─── Output (only required when calling build_dataset / `... build`) ───────── +output: + format: lerobot # lerobot | hdf5 | rlds + path: ./datasets/pick_red/ + metadata: + fps: 30 + robot: xarm7 + task_label: pick_red_cube diff --git a/dimos/learning/formats/hdf5.py b/dimos/learning/formats/hdf5.py new file mode 100644 index 0000000000..b40324fd02 --- /dev/null +++ b/dimos/learning/formats/hdf5.py @@ -0,0 +1,35 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HDF5 dataset writer. + +Produces a single .hdf5 file with one group per episode: + /data/episode_000000/ + observation/ # one dataset per observation key (T, ...) + action/ # one dataset per action key (T, ...) + ts # timestamps (T,) + /metadata # JSON-encoded spec + per-episode tags +""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path + +from dimos.learning.spec import OutputConfig, Sample + + +def write(samples: Iterator[Sample], output: OutputConfig) -> Path: + """Write samples to a single HDF5 file. Returns the file path.""" + raise NotImplementedError diff --git a/dimos/learning/formats/lerobot.py b/dimos/learning/formats/lerobot.py new file mode 100644 index 0000000000..bb82161e3c --- /dev/null +++ b/dimos/learning/formats/lerobot.py @@ -0,0 +1,35 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LeRobot v2 dataset writer. + +Produces a directory layout compatible with HuggingFace LeRobot: + / + meta/info.json # schema, fps, total episodes/frames + meta/episodes.jsonl # per-episode metadata (length, task) + data/chunk-000/episode_000000.parquet # tabular obs+action + videos/chunk-000//episode_000000.mp4 # encoded image streams +""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path + +from dimos.learning.spec import OutputConfig, Sample + + +def write(samples: Iterator[Sample], output: OutputConfig) -> Path: + """Write samples in LeRobot v2 layout. Returns the dataset root path.""" + raise NotImplementedError diff --git a/dimos/learning/formats/rlds.py b/dimos/learning/formats/rlds.py new file mode 100644 index 0000000000..3fbf2d40db --- /dev/null +++ b/dimos/learning/formats/rlds.py @@ -0,0 +1,34 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RLDS / TFDS dataset writer. + +Produces a TFDS-on-disk layout (TFRecord shards + dataset_info.json) following +the RLDS Episode/Step protocol used by Open X-Embodiment, RT-X, etc. + +Each TF Example encodes one Episode as a sequence of Steps with: + observation/, action/, reward, discount, is_first, is_last, is_terminal +""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path + +from dimos.learning.spec import OutputConfig, Sample + + +def write(samples: Iterator[Sample], output: OutputConfig) -> Path: + """Write samples as TFDS/RLDS shards. Returns the dataset directory path.""" + raise NotImplementedError diff --git a/dimos/learning/spec.py b/dimos/learning/spec.py new file mode 100644 index 0000000000..740de364e3 --- /dev/null +++ b/dimos/learning/spec.py @@ -0,0 +1,147 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data definitions for the DimOS Learning Framework. + +Contains the YAML/JSON-backed `DatasetSpec` schema and the runtime data +classes (`Episode`, `Sample`) shared between collection, training, and +inference. No logic — just typed records and constants. Safe to import +from anywhere (no circular dependencies). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field + +# ───────────────────────────────────────────────────────────────────────────── +# DatasetSpec — the YAML/JSON schema +# ───────────────────────────────────────────────────────────────────────────── + + +class EpisodeConfig(BaseModel): + """How to slice the continuous recording into episodes.""" + + extractor: Literal["buttons", "ranges", "whole_session"] = "buttons" + + # BUTTONS extractor: friendly names map to Quest Buttons attrs via BUTTON_ALIASES. + # The state machine always discards an in-progress episode if the recording ends + # without an explicit save/discard press. + button_stream: str = "buttons" + start: str = "A" # rising edge -> begin episode + save: str = "B" # rising edge -> end + save + discard: str = "X" # rising edge -> end + drop + + # RANGES extractor: explicit absolute timestamps + ranges: list[tuple[float, float]] | None = None + + # Optional default label applied to every extracted episode + default_task_label: str | None = None + + +class FieldRef(BaseModel): + """Pointer to a field in a recorded stream.""" + + stream: str # LCM stream / topic name as recorded in session.db + type: str | None = None # optional dotted type (e.g. "sensor_msgs.Image"); for codec dispatch + field: str | None = None # attribute on the message; None = whole message + preprocess: str | None = None # named preprocess hook (e.g. "jpeg_decode", "normalize_image") + + +class SyncConfig(BaseModel): + """How to build per-timestep samples by aligning multiple streams.""" + + anchor: str # key in `observation` that drives the timeline + rate_hz: float = 30.0 # downsample anchor to this rate; 0 = use anchor's native rate + tolerance_ms: float = 50.0 # max allowed time delta when picking nearest sample + strategy: Literal["nearest", "interp"] = "nearest" + + +class FilterConfig(BaseModel): + """Per-episode filters applied after extraction.""" + + success_only: bool = True + min_duration_s: float = 0.0 + max_duration_s: float | None = None + task_labels: list[str] | None = None # whitelist; None = all + + +class OutputConfig(BaseModel): + """Where and how to write the built dataset.""" + + format: Literal["lerobot", "hdf5", "rlds"] + path: Path + metadata: dict[str, Any] = Field(default_factory=dict) + + +class DatasetSpec(BaseModel): + """Top-level spec. Same instance used at build, load, and inference time.""" + + source: Path # path to session.db produced by RecordReplay + episodes: EpisodeConfig + observation: dict[str, FieldRef] # obs key -> stream field + action: dict[str, FieldRef] # action key -> stream field + sync: SyncConfig + filters: FilterConfig | None = None + output: OutputConfig | None = None # only required by build_dataset() + + +# ───────────────────────────────────────────────────────────────────────────── +# Runtime data +# ───────────────────────────────────────────────────────────────────────────── + + +# Friendly Quest controller names -> Buttons attribute names. +# Override by supplying an attribute name directly in the spec. +BUTTON_ALIASES: dict[str, str] = { + "A": "right_primary", + "B": "right_secondary", + "X": "left_primary", + "Y": "left_secondary", + "LT": "left_trigger", + "RT": "right_trigger", + "LG": "left_grip", + "RG": "right_grip", + "MENU_L": "left_menu", + "MENU_R": "right_menu", +} + + +class Episode(BaseModel): + """A single demonstration carved from a session.""" + + id: str + start_ts: float + end_ts: float + task_label: str | None = None + success: bool = True + metadata: dict[str, Any] = Field(default_factory=dict) + + @property + def duration(self) -> float: + return self.end_ts - self.start_ts + + +class Sample(BaseModel): + """One synchronized timestep: aligned obs + action at ts.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + ts: float + episode_id: str + observation: dict[str, np.ndarray] + action: dict[str, np.ndarray] From 7c9f53540d87f1f2c348bdaf7f80ede5aff65f30 Mon Sep 17 00:00:00 2001 From: Ruthwik Date: Wed, 29 Apr 2026 20:04:18 -0700 Subject: [PATCH 02/45] temp: learning spec files --- dimos/learning/PLAN.md | 645 ++++++++++++++++++ dimos/learning/collection/blueprint.py | 117 ++++ dimos/learning/collection/episode_monitor.py | 116 ++++ dimos/learning/dataprep.py | 289 ++++---- dimos/learning/dataset.example.yaml | 13 +- dimos/learning/inference/action_replayer.py | 134 ++++ dimos/learning/inference/blueprint.py | 136 ++++ .../learning/inference/chunk_policy_module.py | 156 +++++ dimos/learning/inference/obs_builder.py | 72 ++ dimos/learning/policy/base.py | 93 +++ dimos/learning/policy/lerobot_policy.py | 106 +++ dimos/learning/spec.py | 64 +- dimos/learning/training/blueprint.py | 103 +++ dimos/learning/training/configs.py | 94 +++ dimos/learning/training/monitor_module.py | 89 +++ dimos/learning/training/split.py | 41 ++ dimos/learning/training/stats.py | 108 +++ dimos/learning/training/train.py | 129 ++++ dimos/learning/training/trainer_module.py | 263 +++++++ 19 files changed, 2630 insertions(+), 138 deletions(-) create mode 100644 dimos/learning/PLAN.md create mode 100644 dimos/learning/collection/blueprint.py create mode 100644 dimos/learning/collection/episode_monitor.py create mode 100644 dimos/learning/inference/action_replayer.py create mode 100644 dimos/learning/inference/blueprint.py create mode 100644 dimos/learning/inference/chunk_policy_module.py create mode 100644 dimos/learning/inference/obs_builder.py create mode 100644 dimos/learning/policy/base.py create mode 100644 dimos/learning/policy/lerobot_policy.py create mode 100644 dimos/learning/training/blueprint.py create mode 100644 dimos/learning/training/configs.py create mode 100644 dimos/learning/training/monitor_module.py create mode 100644 dimos/learning/training/split.py create mode 100644 dimos/learning/training/stats.py create mode 100644 dimos/learning/training/train.py create mode 100644 dimos/learning/training/trainer_module.py diff --git a/dimos/learning/PLAN.md b/dimos/learning/PLAN.md new file mode 100644 index 0000000000..0dab9571df --- /dev/null +++ b/dimos/learning/PLAN.md @@ -0,0 +1,645 @@ +# DimOS Learning Framework — v1 Plan + +## v1 Scope + +**Goal:** end-to-end pipeline that lets a DimOS user collect teleop demos, train a policy, and run it on a real arm — for two concrete targets: + +1. **BC / ACT** — train ACT (Action Chunking Transformer) on a pick-and-place demo set on xArm7. +2. **VLA finetune** — finetune a pretrained π₀ / π₀.₅ checkpoint on the same demo set. + +Both targets share a single `DatasetSpec`, a single LeRobot dataset on disk, and a single inference module. The choice between ACT and a VLA is just a different training entry point and a different policy class at inference time. + +**v1 architectural mandate — fully DimOS-native:** + +Every phase of the pipeline (collection, training, inference) is exposed as a **Module + Blueprint** with RPC surfaces. There is **one user-facing UX**: `dimos --blueprint ` for everything. Agent skills can drive any phase via @rpc. Composition between phases is just port wiring (`builder.done → trainer.builder_done`). + +For collection and training — where the actual work is offline batch processing — the Module is an **orchestrator over a subprocess**: it spawns `python -m dimos.learning.dataprep build` or `python -m ...training.train`, parses its progress lines, and republishes them as typed events. The work itself stays in the subprocess (heavy deps isolated, process-cancellable, separately testable). Inference Modules do real live work because there is real live data flow. + +Every Module exposes: +- `@rpc start()` / `@rpc stop()` — lifecycle +- `@rpc (...)` — at least one agent-callable action +- `@rpc get_status()` — observability +- typed `In[...]` / `Out[...]` ports for blueprint composition + +**Out of v1 (deferred to v2 — see bottom of file):** +- RL (online + offline) +- Pure proprioceptive policies in the 100 Hz tick loop (`PolicyControlTask`) +- Multi-embodiment / cross-task training +- Distributed / multi-GPU training +- Live recording of episode boundaries (we keep post-hoc button extraction) + +**Design principle:** lean on existing infrastructure — `RecordReplay` (PR #1708), memory2, the Module/Blueprint system, **and** the `lerobot` library which already implements ACT and π₀/π₀.₅ end-to-end. We add only the DimOS glue. + +--- + +## Architecture Overview + +``` + COLLECT TRAIN INFER + ─────── ───── ───── + Teleop + Camera load_dataset(spec) ChunkPolicyModule (1–30 Hz) + ↓ ↓ ↓ (action chunks) + RecordReplay --record-path train_bc / finetune_vla ActionReplayer (100 Hz) + ↓ ↓ ↓ + session.db checkpoint (.safetensors) Coordinator joint_command + ↓ + stats.json ↓ + dataset.yaml + dataprep.py Hardware + ↓ + LeRobot v2 dataset on disk +``` + +The **same `dataset.yaml`** is the contract between collection (export), training (load), and inference (live obs construction). It is the single source of truth for what counts as observation/action, episode boundaries, sync strategy, and feature shapes. + +--- + +## 1. Data Collection — STATUS: IMPLEMENTED, NEEDS POLISH + +### 1.1 Recording (no new code) + +Use Sam's `RecordReplay` from PR #1708: + +```bash +dimos --blueprint quest_teleop_xarm7 --record-path session.db +``` + +Captures every LCM topic — joint states, joint commands, camera images, controller `Buttons`, IMU, etc. One stream per topic in a single `SqliteStore`. Episode boundaries are not marked at record time; they are recovered offline from button presses. + +### 1.2 The dataset spec + +Implemented in `dimos/learning/spec.py`. Schema is pydantic v2 → round-trips YAML/JSON. Eight typed classes: `EpisodeConfig`, `FieldRef`, `SyncConfig`, `FilterConfig`, `OutputConfig`, `DatasetSpec`, `Episode`, `Sample`. Friendly Quest button names (`A`/`B`/`X`/...) resolve to `Buttons` bit fields via `BUTTON_ALIASES`. + +Example (see `dimos/learning/dataset.example.yaml` for the live template): + +```yaml +source: session.db + +episodes: + extractor: buttons # buttons | ranges | whole_session + start: A # press to begin + save: B # press to commit + discard: X # press to drop + default_task_label: pick_red_cube + +observation: + cam_high: + stream: camera_color_image + preprocess: jpeg_decode + cam_wrist: + stream: camera_wrist_color_image + preprocess: jpeg_decode + joint_pos: + stream: coordinator_joint_state + field: position + +action: + joint_target: + stream: coordinator_joint_command + field: position + +sync: + anchor: cam_high + rate_hz: 30 + tolerance_ms: 50 + strategy: nearest + +filters: + success_only: true + min_duration_s: 1.0 + task_labels: [pick_red_cube] + +output: + format: lerobot # primary v1 target + path: datasets/pick_red/ + metadata: + fps: 30 + robot: xarm7 +``` + +### 1.3 The pipeline file: `dimos/learning/dataprep.py` + +Implemented. Does everything: read raw `session.db`, extract episodes from button events, sync streams, dispatch to the chosen format writer. Same module exposes `load_dataset(spec)` for training. + +Public functions (all done): +- `load_spec(path)` / `save_spec(spec, path)` — YAML/JSON I/O +- `extract_episodes(store, cfg)` — three strategies (buttons/ranges/whole_session) +- `filter_episodes(eps, cfg)` — success / duration / label whitelist +- `iter_samples(store, episode, spec)` — anchor-rate timestep walker w/ bisect nearest-search +- `build_dataset(spec)` — full session.db → on-disk dataset +- `load_dataset(spec)` — returns a `torch.utils.data.Dataset[Sample]` +- `inspect(spec)` — episode/duration/per-stream stats +- `main()` — CLI: `build` / `inspect` / `review` (review is a stub) + +### 1.4 Format writers (in `dimos/learning/formats/`) + +| Format | v1 priority | Status | Why | +|-----------|-------------|--------|-----| +| `lerobot` | **primary** | done | Native input for both ACT and π₀/π₀.₅ via `lerobot` lib | +| `hdf5` | secondary | done | ACT-original codebase, debugging, smaller deps | +| `rlds` | v2 | done (gated on TF) | RT-X / OpenX-Embodiment compat — not needed for v1 | + +### 1.5 Gaps to close in v1 + +The collection pipeline is functional, but a few things are needed before LeRobot training works cleanly. Each is small. + +**(a) Per-episode task description** — π₀ is language-conditioned; LeRobot v2 has a `tasks.jsonl` table. Currently `Episode.task_label: str | None` is a tag; we need a free-form string per episode. Extend with: +```python +class Episode: + task_description: str | None = None # e.g. "pick up the red cube and place it on the blue plate" +``` +The LeRobot writer already emits `tasks.jsonl` keyed on `task_label` — switch it to use `task_description` (fall back to `task_label`). Population: `EpisodeConfig.default_task_description` for single-task sessions, or set per episode in the `review` CLI. + +**(b) Dataset statistics** — LeRobot training requires `meta/stats.json` (per-feature mean/std/min/max/q01/q99). Add a streaming stats accumulator inside `formats/lerobot.py::write` so we don't need a second pass over the data. Image stats are computed on a subsample (every Nth frame) to bound cost. + +**(c) Train/val split** — LeRobot v2 supports filtering by episode index at training time, so we don't need to materialize two datasets. Add `FilterConfig.val_episode_ids: list[int] | None` and `FilterConfig.val_ratio: float | None` (deterministic seeded split). Trainer reads these. + +**(d) Image format on disk** — LeRobot v2 stores images as MP4 videos by default (`videos/chunk-NNN//episode_NNNNNN.mp4`). Current writer writes them as parquet tensor columns, which works but inflates disk size. Switch to MP4 encoding via `imageio[ffmpeg]` for image streams ≥2D + uint8. Parquet cells then store frame indices, not pixels. + +**(e) `review` CLI** — currently a stub. Implement a minimal non-interactive form first: load spec, list episodes with metadata, allow batch retag via `--set-label PICK_RED --episode-ids 0,1,2,5`. Interactive TUI is v2. + +These five items + the existing skeleton complete the collection side for v1. + +--- + +## 2. Training — NEW v1 WORK + +### 2.1 Strategy: thin wrappers around `lerobot` + +The `lerobot` library (HuggingFace + Tesla-PI fork) already implements ACT, Diffusion Policy, π₀, π₀.₅ — including dataloaders for the LeRobot v2 format, normalization, action chunking, language tokenization, training loop, checkpointing, and ONNX export. + +**We do NOT reimplement these.** The v1 training pipeline is two thin Python wrappers that: +1. Take a DimOS `DatasetSpec`, +2. Translate it into a LeRobot config, +3. Invoke `lerobot.scripts.train.train()`, +4. Save the resulting checkpoint to a path that the inference module knows how to read. + +This keeps `dimos/learning/training/` short, rides on a maintained upstream, and means a `pi0.5` upgrade is a config bump rather than a code change. + +### 2.2 File layout: `dimos/learning/training/` + +``` +dimos/learning/training/ + train.py # train_bc, finetune_vla — public entry points + configs.py # BCConfig, VLAConfig + stats.py # compute_stats(spec) -> dict (used by build_dataset too) + split.py # train/val episode split helper +``` + +### 2.3 The two entry points + +```python +def train_bc(spec: DatasetSpec, cfg: BCConfig, output_dir: Path) -> Path: + """Train an ACT (or other BC) policy on `spec`. Returns checkpoint path.""" + +def finetune_vla(spec: DatasetSpec, cfg: VLAConfig, output_dir: Path) -> Path: + """Finetune a pretrained π₀ / π₀.₅ on `spec`. Returns checkpoint path.""" +``` + +Both: +- Materialize the dataset via `build_dataset(spec)` if `spec.output.path` doesn't already exist (idempotent). +- Build a `lerobot.LeRobotDataset(spec.output.path)`. +- Build a LeRobot policy from `cfg`. +- Call the LeRobot training loop with `cfg.steps`, `cfg.batch_size`, `cfg.lr`, etc. +- Save final checkpoint + a sidecar `dimos_meta.json` with `{spec_path, dataset_path, dimos_version}` so inference can recover everything. + +### 2.4 `BCConfig` (ACT-focused for v1) + +```python +class BCConfig(BaseModel): + policy_type: Literal["act", "diffusion"] = "act" + + # ACT model arch — defaults match the original ACT pick-and-place setup + chunk_size: int = 50 # action_horizon + n_obs_steps: int = 1 + hidden_dim: int = 512 + n_layers: int = 4 + n_heads: int = 8 + use_vae: bool = True + kl_weight: float = 10.0 + + # Vision backbone + vision_backbone: str = "resnet18" + pretrained: bool = True + + # Optim + steps: int = 100_000 + batch_size: int = 8 + lr: float = 1e-5 + lr_backbone: float = 1e-5 + weight_decay: float = 1e-4 + + # Eval + val_ratio: float = 0.1 + save_every: int = 10_000 +``` + +### 2.5 `VLAConfig` (π₀ / π₀.₅ finetune) + +```python +class VLAConfig(BaseModel): + policy_type: Literal["pi0", "pi0_5"] = "pi0_5" + pretrained_path: str # HF hub id or local path + finetune_mode: Literal["full", "lora"] = "lora" + lora_rank: int = 16 + freeze_vision: bool = True + freeze_language: bool = True + + chunk_size: int = 50 # default π₀ action horizon + + steps: int = 30_000 + batch_size: int = 4 + lr: float = 5e-5 + weight_decay: float = 1e-4 + save_every: int = 5_000 + + # The spec's task_description per episode is the language conditioning at train time. + # No additional config needed. +``` + +### 2.6 Stats and split — pulled out so they're reusable + +`stats.compute_stats(spec)` walks the materialized dataset once, accumulating Welford mean/std for joint vectors and per-channel image stats on a subsample. Writes `meta/stats.json`. Called from both `build_dataset` (so the disk-resident dataset is self-describing) and `train_bc` / `finetune_vla` (idempotent — skip if `stats.json` already exists). + +`split.train_val_split(spec, val_ratio, seed=0)` returns two episode-id lists. Deterministic. Trainer passes these to LeRobot via its episode filter. + +### 2.7 CLI + +```bash +# Train ACT on a built dataset +python -m dimos.learning.training.train bc dataset.yaml \ + --output runs/act_pick_red \ + --steps 100000 --batch-size 8 + +# Finetune π₀.₅ +python -m dimos.learning.training.train vla dataset.yaml \ + --output runs/pi05_pick_red \ + --pretrained lerobot/pi0_5 \ + --finetune-mode lora --lora-rank 16 +``` + +The CLI is a tiny argparse wrapper that builds `BCConfig`/`VLAConfig` and calls the function. + +### 2.8 Dependencies + +Adds to `pyproject.toml` (under a `[project.optional-dependencies]` `learning` extra so default installs aren't bloated): +- `lerobot >= 0.2` +- `torch >= 2.3` (already implied by `lerobot`) +- `imageio[ffmpeg]` (MP4 image encoding) +- For VLA only: `transformers`, `accelerate`, `peft` (LoRA) + +User installs with `pip install -e .[learning]`. + +--- + +## 3. Inference — NEW v1 WORK + +### 3.1 The two paths, simplified for v1 + +For v1 we need exactly **one** inference module: `ChunkPolicyModule`. Both ACT and π₀/π₀.₅ produce action chunks (sequences of length `chunk_size`), so the same module handles them. The model runs slow (1–30 Hz depending on whether it's ACT or VLA); a separate `ActionReplayer` plays the chunk back at the coordinator's 100 Hz tick rate. + +``` + ┌──────────────────────────┐ + │ ChunkPolicyModule │ ← runs in its own thread/process at policy rate + │ In: color_image │ + │ joint_state │ + │ language_text │ + │ Out: action_chunk │ + └────────────┬─────────────┘ + │ (T, action_dim) + ▼ + ┌──────────────────────────┐ + │ ActionReplayer │ ← part of the ControlTask graph + │ ControlTask @ 100 Hz │ + │ pops next action, │ + │ emits JointCommand │ + └──────────────────────────┘ +``` + +`PolicyControlTask` (joint-only, in-tick-loop) is **deferred to v2** — it's only useful for proprioceptive policies, which we don't train in v1. + +### 3.2 File layout: `dimos/learning/inference/` + +``` +dimos/learning/inference/ + chunk_policy_module.py # ChunkPolicyModule + action_replayer.py # ActionReplayer (subclass of BaseControlTask) + obs_builder.py # spec.observation -> live obs dict, decoupled + blueprints.py # autoconnect helpers +``` + +### 3.3 `dimos/learning/policy/` — Policy protocol + +```python +class Policy(Protocol): + @classmethod + def load(cls, path: Path, device: str = "cuda") -> "Policy": ... + def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: + """Return shape (chunk_size, action_dim).""" + +class LeRobotPolicy(Policy): + """Wraps any lerobot.PreTrainedPolicy (ACT, π₀, π₀.₅, Diffusion). + Detects policy_type from the checkpoint metadata.""" +``` + +v1 ships `LeRobotPolicy` only. `OnnxPolicy`, custom `TorchPolicy` are v2. + +### 3.4 `ChunkPolicyModule` skeleton + +```python +class ChunkPolicyModule(Module): + color_image: In[Image] + joint_state: In[JointState] + language_text: In[str] # optional; ignored if policy doesn't use it + + action_chunk: Out[ActionChunk] # new typed message: ts, joint_names, positions[T, N] + + def __init__(self, *, spec_path: str, policy_path: str, + inference_rate_hz: float, device: str = "cuda"): + self._spec = load_spec(spec_path) + self._policy = LeRobotPolicy.load(Path(policy_path), device) + self._obs_builder = ObsBuilder(self._spec) + self._stats = load_stats(Path(policy_path).parent / "stats.json") + + @rate_limited(inference_rate_hz) + def on_tick(self): + obs = self._obs_builder.build( + color_image=self.color_image.latest(), + joint_state=self.joint_state.latest(), + language=self.language_text.latest_or("default task"), + ) + chunk = self._policy.predict_chunk(self._stats.normalize(obs)) + chunk = self._stats.unnormalize_actions(chunk) + self.action_chunk.publish(ActionChunk(positions=chunk, ts=time.time(), ...)) +``` + +### 3.5 `ActionReplayer` — a `BaseControlTask` + +Lives in the tick loop. Subscribes to `action_chunk`. Maintains a buffer of pending actions with their relative timestamps; `compute(state)` interpolates to the current tick time. + +```python +class ActionReplayer(BaseControlTask): + name = "policy_replay" + + def __init__(self, joint_names: list[str], chunk_topic: str, policy_dt: float): + self._joint_names = joint_names + self._buffer: deque[tuple[float, np.ndarray]] = deque() # (target_ts, positions) + ... + + def on_action_chunk(self, msg: ActionChunk) -> None: + # Push new chunk, drop overlap with current buffer (latest chunk wins) + ... + + def compute(self, state: CoordinatorState) -> JointCommandOutput | None: + if not self._buffer: + return None + target = self._lookup_or_interp(state.now) + return JointCommandOutput(joint_names=self._joint_names, positions=target) +``` + +Key behaviors: +- **Latest chunk wins**: when a new chunk arrives, drop any buffered actions ≥ the new chunk's start time. Smooth (no gap) because the new chunk's first action is conditioned on near-current obs. +- **Lookback safety**: if the policy module stalls and the buffer empties, hold last commanded position (don't fall to zero). Log a warning. +- **Temporal ensembling (optional v1 nice-to-have)**: when ACT publishes overlapping chunks, exponentially weight predictions for the same target timestamp. Off by default. + +### 3.6 Live observation construction + +`ObsBuilder` reads `spec.observation` and exposes `build(**live_streams) -> dict[str, np.ndarray]`. Reuses the same `_resolve_field` and `preprocess` registry from `dataprep.py` so train and infer share normalization. This is the single most important consistency guarantee in the framework. + +### 3.7 Inference blueprint + +```python +# Vision policy (ACT or π₀.₅) → Coordinator +pick_red_cube = autoconnect( + RealSenseCamera.blueprint(camera_id="cam_high"), + ChunkPolicyModule.blueprint( + spec_path="datasets/pick_red/dataset.yaml", + policy_path="runs/pi05_pick_red/checkpoint", + inference_rate_hz=5.0, + ), + ControlCoordinator.blueprint( + hardware=[xarm7], + tasks=[ + TaskConfig(name="policy_replay", type="action_replayer", + chunk_topic="action_chunk", policy_dt=0.2), + ], + ), +) +``` + +### 3.8 Coordinator change (one line) + +In `dimos/control/coordinator.py::_create_task_from_config` add a case for `type == "action_replayer"` that constructs the `ActionReplayer`. Plus a new `ActionReplayerConfig` in `dimos/control/task.py`'s task config union. Same pattern as existing tasks. + +--- + +## 4. File Structure (v1) + +``` +dimos/learning/ + spec.py # ✅ DatasetSpec (DataPrep-callable types) + dataprep.py # ✅ DataPrep class + CLI + dataset.example.yaml # ✅ + + formats/ + lerobot.py # ✅ (needs MP4 + stats — §1.5) + hdf5.py # ✅ + rlds.py # ✅ (v2 priority) + + collection/ # 🆕 v1 — collection blueprint + episode_monitor.py # EpisodeMonitorModule (live counters via @rpc) + blueprint.py # collection_blueprint(...) + + training/ # 🆕 v1 — training scripts + orchestrator Modules + configs.py # BCConfig, VLAConfig [script] + stats.py # Stats class + compute_stats [script] + split.py # train_val_split [script] + train.py # train_bc, finetune_vla, CLI [script] + trainer_module.py # TrainerModule — wraps build+train [Module] + monitor_module.py # LearningMonitorModule (rerun + JSONL) [Module] + blueprint.py # learning_train_{act,vla,idle} + + inference/ # 🆕 v1 — inference blueprint (real live Modules) + obs_builder.py # ObsBuilder (uses DataPrep.resolve_field) + chunk_policy_module.py # ChunkPolicyModule (real Module) + action_replayer.py # ActionReplayer (BaseControlTask) + blueprint.py # policy_blueprint(...) + + policy/ # 🆕 v1 — Policy abstraction + base.py # Policy protocol + ActionChunk message + lerobot_policy.py # LeRobotPolicy (ACT, π₀, π₀.₅) +``` + +**Note on script-vs-Module split inside `training/`:** The four `[script]` +files (`configs.py`, `stats.py`, `split.py`, `train.py`) hold the actual +training logic and are independently usable from notebooks/CI/tests. The +three `[Module]` files wrap them as DimOS Modules that spawn the scripts +as subprocesses — that's the dual-surface UX the v1 mandate requires. + +Critical files outside `dimos/learning/`: + +| File | Change | +|------|--------| +| `dimos/control/coordinator.py` | Add `"action_replayer"` case in `_create_task_from_config` | +| `dimos/control/task.py` | Add `ActionReplayerConfig` to task config union | +| `pyproject.toml` | Add `[project.optional-dependencies].learning` extra | +| `dimos/messages/` (or wherever DimOS LCM types live) | New `ActionChunk` type: `(joint_names, positions[T,N], ts, dt)` | + +--- + +## 5. End-to-End Demo Recipe + +Two flows, one command list each. Both assume `pip install -e .[learning]` is done. + +### 5.1 ACT pick-and-place on xArm7 — blueprint-first UX + +Each phase is `dimos --blueprint `. Underlying scripts (`python -m +dimos.learning.dataprep`, `python -m dimos.learning.training.train`) are +still callable directly for CI / notebooks / debugging — but the +default flow is the blueprint surface. + +```bash +# 1. Collect — teleop + camera + RecordReplay + EpisodeMonitorModule +dimos --blueprint learning_collect_quest_xarm7 --record-path data/pick_red.db +# (operator presses A=start / B=save / X=discard; +# EpisodeMonitorModule.status streams "episodes_saved: N" live) + +# 2. Train — DatasetBuilderModule + TrainerModule + LearningMonitorModule +dimos --blueprint learning_train_act \ + --spec dataset.yaml --output runs/act_pick_red +# Inside: builder runs first (subprocess: dataprep build), trainer +# auto-fires on builder.done (subprocess: train bc), monitor logs to rerun. + +# 3. Infer — Camera + ChunkPolicyModule + ActionReplayer + Coordinator +dimos --blueprint learning_infer_pick_red \ + --policy-path runs/act_pick_red +``` + +### 5.2 π₀.₅ finetune on the same data + +Steps 1 + 3 unchanged. Step 2 is the same `learning_train_*` blueprint +with `--kind vla` and a `--pretrained` flag — agent or human just changes +the trigger payload, not the blueprint. + +```bash +dimos --blueprint learning_train_vla \ + --spec dataset.yaml --output runs/pi05_pick_red \ + --pretrained lerobot/pi0_5 --finetune-mode lora --lora-rank 16 +``` + +### 5.3 Agent-driven flow (same Modules, no `auto_run`) + +Demonstrates the @rpc surface. Run a single training blueprint with auto-run +disabled; a chat agent then drives every phase: + +```bash +dimos --blueprint learning_train_idle # builder + trainer + monitor, all idle +``` + +``` +agent: "build the dataset for pick_red" + → DatasetBuilderModule.build(spec_path="dataset.yaml") + ← BuildProgress events stream back to the chat + ← BuildDone(success=True, dataset_path="datasets/pick_red/") + +agent: "train ACT on it for 100k steps" + → TrainerModule.train( + spec_path="dataset.yaml", + output_dir="runs/act_pick_red", + config_kind="bc", + config_overrides={"steps": 100_000}, + ) + ← TrainProgress events stream loss/step + ← TrainDone(success=True, checkpoint_dir="runs/act_pick_red/...") + +agent: "deploy it on the xarm" + → launches `dimos --blueprint learning_infer_pick_red --policy-path ...` +``` + +The fact that the same Modules drive both the "everything-auto" CLI flow +(§5.1) and the "agent-driven" flow (§5.3) is the v1 architectural payoff. + +--- + +## 6. Verification (what we test before declaring v1 done) + +1. **Recording** — teleop blueprint with `--record-path` produces a session.db whose stream listing matches the spec. +2. **Build** — `python -m dimos.learning.dataprep build dataset.yaml` against a real session, then `lerobot.LeRobotDataset(path)` opens it without error and `len(ds) > 0`. +3. **Stats** — `meta/stats.json` exists and has finite, non-degenerate values for every observation/action key. +4. **Train** — `train_bc` runs ≥1k steps end-to-end on a real session; loss decreases; checkpoint loads back via `LeRobotPolicy.load`. +5. **VLA finetune** — `finetune_vla` runs ≥500 steps with LoRA on top of a downloaded π₀.₅ checkpoint; no OOM at batch=4 on a 24 GB GPU; loss decreases. +6. **Live obs parity** — `ObsBuilder.build(...)` on a fake live stream and `iter_samples(...)` on the same data give bit-identical observation dicts. +7. **Inference (sim)** — `ChunkPolicyModule` + `ActionReplayer` + `ControlCoordinator` with MuJoCo xArm7 produces non-NaN joint commands at 100 Hz, replays a 50-step chunk smoothly, recovers when policy module stalls. +8. **Inference (hw)** — same blueprint on real xArm7 produces a successful pick-and-place at ≥30% success rate after 50 demos. (Success rate is informational; the test is "no crashes, no jerks, no diverging commands.") + +--- + +## 7. Key Design Decisions (v1) + +| Decision | Rationale | +|----------|-----------| +| LeRobot v2 is the canonical on-disk format | Both ACT and π₀/π₀.₅ train from it natively; no custom dataloaders | +| `lerobot` library does the heavy lifting | We don't reimplement ACT, π₀, dataloader, normalization, or training loop | +| One inference module (`ChunkPolicyModule`), not three | ACT and VLA both produce chunks; only the model class differs | +| `ActionReplayer` lives in the tick loop, model lives in a Module | Decouple slow inference (1–5 Hz VLA) from fast control (100 Hz) | +| `ObsBuilder` reused between train and infer | Single source of truth for observation construction — eliminates train/serve skew | +| Episode metadata carries `task_description` | π₀/π₀.₅ are language-conditioned; `task_label` alone is too narrow | +| Stats computed at build time, written to disk | Trainers and inference both read from `meta/stats.json` — no recompute | +| Coordinator change is one new task type (`action_replayer`) | Minimal, additive | +| RLDS, ONNX, RL, proprio-only policy task → v2 | Deliberately not in scope | + +--- + +## 8. Risks & Open Questions for v1 + +- **`lerobot` API stability.** We're pinning to a specific minor version. If their training entry point changes, our wrapper breaks. Mitigation: pin tightly in `pyproject.toml`, add an integration test that exercises the wrapper. +- **π₀.₅ checkpoint availability.** Depends on the public release. Fallback: ship v1 with π₀ only, add π₀.₅ when it lands (config bump). +- **Action space match.** π₀ assumes a 7-DoF EEF action by default; xArm7 joint-position control is 7-DoF joint. Need to either (a) keep π₀'s action head and use joint targets in its expected layout, or (b) retrain the action head. v1 chooses (a) via the spec's action key naming. +- **Real-time perf of chunk replay.** First action of a chunk is conditioned on `t = chunk_arrival_time`, but it executes some ms later. With 50-step chunks at 30 Hz this is ~1.6 s of buffer; if the policy module stalls, the replayer drifts. Mitigation: replayer rejects stale chunks (`now − chunk.ts > policy_dt × 1.5`) and re-requests. +- **Camera calibration in the spec.** ACT/π₀ are sensitive to camera placement. The spec doesn't currently encode camera intrinsics/extrinsics. v1 punt: rely on the operator to record from the same physical setup at infer time. Add `metadata.cameras` schema in v1.5 if it bites. + +--- + +--- + +# v2 Considerations (deferred) + +These are explicitly out of v1 scope. Pulled here so we don't lose track. Anything from this list that becomes cheap during v1 implementation gets promoted up. + +### Training +- **RL** — `train_rl(env_cfg, model_cfg)` for online (PPO/SAC) and offline (CQL/IQL/AWAC) RL. Needs an env wrapper around DimOS (sim primarily — MuJoCo via the existing `MujocoCamera` work). +- **Multi-task / multi-embodiment training** — train one policy on demos from xArm7 + Piper + Mock. Needs URDF retargeting, embodiment-id conditioning. +- **Distributed training** — `accelerate` / FSDP for VLA full-finetune on multi-GPU. +- **Curriculum / dataset weighting** — sample harder episodes more often, weight by reward, etc. +- **Diffusion Policy** as a first-class BC option (lerobot supports it; just a config). +- **Active data collection** — uncertainty-based suggestion of which demos to collect next. + +### Inference +- **`PolicyControlTask`** — joint-only proprioceptive policies in the tick loop (100 Hz, no Module overhead). Useful for residual policies, locomotion. +- **`OnnxPolicy` / `TorchScriptPolicy`** — alternate Policy backends for deployment without `lerobot` runtime dep. +- **Cross-embodiment retargeting at inference** — train on xArm7 demos, deploy on Piper. +- **Temporal ensembling on by default** — currently nice-to-have in v1, make it the default after measuring its effect on jerk. +- **Async chunk pipelining** — request chunk N+1 while replaying chunk N to hide policy latency completely. +- **Real-time safety layer** — collision check + joint-limit clamp downstream of `ActionReplayer`. + +### Data Collection +- **Live `EpisodeManagerModule`** — annotate episode boundaries at record time instead of post-hoc. Useful when the operator wants to pause/resume and the recording length is huge. Currently overkill. +- **RLDS / TFDS writer** — needed for OpenX-Embodiment contributions and RT-X-style training. Skeleton already exists. +- **Interactive `review` TUI** — scrub through episodes, watch the camera stream, retag/discard. v1 ships a non-interactive batch retag CLI only. +- **Camera intrinsics/extrinsics in spec metadata** — required for any cross-embodiment or sim-real work. +- **Force/torque streams as observation** — schema already supports it; need preprocess hooks for FT data. +- **Imitation-from-observation** — episodes without action streams (just video). Needs an inverse dynamics model. +- **Custom transports for streaming** — record directly to a remote SqliteStore over network, bypass local disk. Probably never needed. + +### Training Module / DimOS-native +- **`TrainingModule` agent skill** — expose `train_bc` / `finetune_vla` as RPCs so the LLM agent can trigger training from a chat session. Requires sandboxing (long-running subprocess management, GPU resource gating). +- **`PolicyRegistry` Module** — track all trained policies, their specs, eval results. A "model zoo" served as a Module for blueprint composition. + +### Tooling +- **W&B / TensorBoard integration** standardized across all trainers. +- **Eval harness** — replay validation episodes through the trained policy in sim, report success rate. Needed for any kind of automated training pipeline. +- **HuggingFace Hub upload** — `dimos.learning.training.publish(checkpoint, repo_id)` so trained policies are sharable. + +### Possibly to promote into v1 if cheap +- **Stats subsampling for image features** — needed in v1 anyway; making it configurable is a 5-line addition. +- **Episode-level train/val split** — already needed in v1; might as well expose `--split-by hash(episode_id)` as a third strategy. +- **Diffusion Policy via the same `train_bc` entry point** — `lerobot` already has it, the only additional work is one more policy_type literal in `BCConfig`. Probably ship it. +- **`OnnxPolicy`** — only worth promoting if a deployment target without `lerobot` install exists. None in v1; defer. diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py new file mode 100644 index 0000000000..3a4233c045 --- /dev/null +++ b/dimos/learning/collection/blueprint.py @@ -0,0 +1,117 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collection blueprints for the DimOS Learning Framework. + +Each blueprint composes a teleop session + a camera + the +EpisodeMonitorModule (for live operator feedback). RecordReplay is NOT a +Module — it intercepts at the transport layer and is enabled via the CLI +flag `--record-path session.db`. + +Usage: + dimos run learning-collect-quest-xarm7 --record-path data/pick_red.db +""" + +from __future__ import annotations + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera +from dimos.learning.collection.episode_monitor import ( + EpisodeMonitorModule, + EpisodeStatus, +) +from dimos.msgs.sensor_msgs.Image import Image +from dimos.teleop.quest.blueprints import ( + teleop_quest_dual, + teleop_quest_piper, + teleop_quest_xarm6, + teleop_quest_xarm7, +) +from dimos.teleop.quest.quest_types import Buttons + +# ── XArm7 + Quest ──────────────────────────────────────────────────────────── + +learning_collect_quest_xarm7 = autoconnect( + teleop_quest_xarm7, + RealSenseCamera.blueprint(enable_pointcloud=False), + EpisodeMonitorModule.blueprint(), +).transports( + { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport( + "/learning/episode_status", EpisodeStatus + ), + } +) + + +# ── Piper + Quest ──────────────────────────────────────────────────────────── + +learning_collect_quest_piper = autoconnect( + teleop_quest_piper, + RealSenseCamera.blueprint(enable_pointcloud=False), + EpisodeMonitorModule.blueprint(), +).transports( + { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport( + "/learning/episode_status", EpisodeStatus + ), + } +) + + +# ── XArm6 + Quest ──────────────────────────────────────────────────────────── + +learning_collect_quest_xarm6 = autoconnect( + teleop_quest_xarm6, + RealSenseCamera.blueprint(enable_pointcloud=False), + EpisodeMonitorModule.blueprint(), +).transports( + { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport( + "/learning/episode_status", EpisodeStatus + ), + } +) + + +# ── Dual arm (XArm6 + Piper) + Quest ───────────────────────────────────────── + +learning_collect_quest_dual = autoconnect( + teleop_quest_dual, + RealSenseCamera.blueprint(enable_pointcloud=False), + EpisodeMonitorModule.blueprint(), +).transports( + { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport( + "/learning/episode_status", EpisodeStatus + ), + } +) + + +__all__ = [ + "learning_collect_quest_dual", + "learning_collect_quest_piper", + "learning_collect_quest_xarm6", + "learning_collect_quest_xarm7", +] diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py new file mode 100644 index 0000000000..daa79f0251 --- /dev/null +++ b/dimos/learning/collection/episode_monitor.py @@ -0,0 +1,116 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Live episode-status feedback during teleop recording. + +Watches the buttons stream and runs the same start/save/discard state +machine that `DataPrep.extract_episodes` runs offline — but here it runs +live so the operator can see counters update in real time. Pure observability: +this module does NOT write anything. The recording itself is RecordReplay's +job; episode boundary extraction still happens post-hoc inside DataPrep. + +Why a separate live state-machine instead of just consuming DataPrep's offline +output? Because the operator wants feedback *during* the session ("episodes +saved: 12") to know when to stop, retry a bad demo, etc. + +Agent surface: `get_status()` returns the latest counters; `reset_counters()` +zeroes them between recording sessions without restarting the blueprint. +""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.teleop.quest.quest_types import Buttons + + +class EpisodeStatus(BaseModel): + """Live counters published every state transition.""" + + state: Literal["idle", "recording"] + episodes_saved: int + episodes_discarded: int + current_episode_start_ts: float | None # None when state == "idle" + last_event: Literal["start", "save", "discard", "init"] = "init" + + +class EpisodeMonitorModuleConfig(ModuleConfig): + """Match the same fields used by `EpisodeConfig` in the dataset spec + so the live monitor and the offline extractor agree on what each button + means. Friendly names ("A", "B", "X") resolve via BUTTON_ALIASES. + """ + + button_stream: str = "buttons" + start: str = "A" + save: str = "B" + discard: str = "X" + + +class EpisodeMonitorModule(Module): + """Live operator feedback for teleop recording sessions.""" + + config: EpisodeMonitorModuleConfig + + buttons: In[Buttons] + + status: Out[EpisodeStatus] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._state: Literal["idle", "recording"] = "idle" + self._saved: int = 0 + self._discarded: int = 0 + self._current_start_ts: float | None = None + # Previous bit-state of each watched button, for rising-edge detection. + self._prev_bits: dict[str, bool] = {} + + @rpc + def start(self) -> None: + """Subscribe to `buttons` and emit an initial idle status.""" + raise NotImplementedError + + @rpc + def stop(self) -> None: + """Unsubscribe and call super().stop().""" + raise NotImplementedError + + @rpc + def reset_counters(self) -> EpisodeStatus: + """Zero the saved/discarded counters and force state back to idle. + Returns the new status.""" + raise NotImplementedError + + @rpc + def get_status(self) -> EpisodeStatus: + """Return the current EpisodeStatus snapshot.""" + raise NotImplementedError + + # ── internals ──────────────────────────────────────────────────────────── + + def _on_buttons(self, msg: Buttons) -> None: + """Detect rising edges on start/save/discard buttons; advance state + machine; publish EpisodeStatus on every transition. + + State machine — must mirror DataPrep.extract_episodes in BUTTONS mode: + IDLE --start press--> RECORDING (begin) + RECORDING --save press---> IDLE (saved += 1) + RECORDING --discard -----> IDLE (discarded += 1) + RECORDING --start press--> RECORDING (auto-commit prev, begin new) + """ + raise NotImplementedError diff --git a/dimos/learning/dataprep.py b/dimos/learning/dataprep.py index 53b1e7c411..c531302f67 100644 --- a/dimos/learning/dataprep.py +++ b/dimos/learning/dataprep.py @@ -14,11 +14,15 @@ """Dataset builder/loader for the DimOS Learning Framework. -Reads a `DatasetSpec` (see `dimos.learning.spec`) and either: +`DataPrep` is the single user-facing entry point. It reads a `DatasetSpec` +(see `dimos.learning.spec`) and either: - builds a training-ready dataset on disk in HDF5/RLDS/LeRobot, or - returns a PyTorch Dataset for training. -The same spec also drives inference observation construction. +Stateless helpers (episode extraction, sample iteration, field resolution) +live as `@staticmethod`s on `DataPrep` so they share one namespace and are +callable without an instance — the live `ObsBuilder` at inference time +reuses `DataPrep.resolve_field` for that reason. Workflow: # 1. Record a teleop session (Sam's PR #1708) @@ -28,9 +32,9 @@ python -m dimos.learning.dataprep build dataset.yaml # 3. Train using the same spec - from dimos.learning.dataprep import load_dataset, load_spec - spec = load_spec("dataset.yaml") - ds = load_dataset(spec) + from dimos.learning.dataprep import DataPrep + dp = DataPrep.from_file("dataset.yaml") + ds = dp.load() """ from __future__ import annotations @@ -45,10 +49,10 @@ DatasetSpec, Episode, EpisodeConfig, - FieldRef, FilterConfig, OutputConfig, Sample, + StreamField, ) Writer = Callable[[Iterator[Sample], OutputConfig], Path] @@ -60,125 +64,172 @@ # ───────────────────────────────────────────────────────────────────────────── -# Spec I/O +# DataPrep — the only thing this module exports besides `main()` # ───────────────────────────────────────────────────────────────────────────── -def load_spec(path: str | Path) -> DatasetSpec: - """Load a DatasetSpec from .yaml/.yml/.json (dispatch by extension).""" - raise NotImplementedError - - -def save_spec(spec: DatasetSpec, path: str | Path) -> None: - """Write a DatasetSpec back to .yaml/.yml/.json (round-trip safe).""" - raise NotImplementedError +class DataPrep: + """Build / load / inspect a dataset from a `DatasetSpec`. + Holds the open `SqliteStore` and the cached, filtered episode list so + repeated operations on the same spec (e.g. `inspect()` then `build()`) + don't redo work. Construction is cheap — the store and episodes are + computed lazily on first access. -# ───────────────────────────────────────────────────────────────────────────── -# Episode extraction -# ───────────────────────────────────────────────────────────────────────────── - - -def extract_episodes(store: SqliteStore, cfg: EpisodeConfig) -> list[Episode]: - """Extract episode boundaries from the recording per the configured strategy. - - BUTTONS: scan cfg.button_stream for rising edges on cfg.start/save/discard. - State machine: - IDLE --start press--> RECORDING (begin episode) - RECORDING --save press--> IDLE (commit, success=True) - RECORDING --discard press--> IDLE (drop) - RECORDING --start press--> RECORDING (auto-commit, begin new) - session ends mid-episode: always discard - - RANGES: emit one Episode per (start_ts, end_ts) tuple in cfg.ranges. - - WHOLE: emit a single Episode covering the entire recording's time range. + Not a DimOS Module: no ports, no runtime lifecycle. It's a stateful + façade over the static helpers below. """ - raise NotImplementedError - - -def filter_episodes(eps: list[Episode], cfg: FilterConfig | None) -> list[Episode]: - """Apply success/duration/label whitelist filters. None = pass-through.""" - raise NotImplementedError - - -# ───────────────────────────────────────────────────────────────────────────── -# Stream synchronization (build per-timestep samples) -# ───────────────────────────────────────────────────────────────────────────── - - -def iter_samples( - store: SqliteStore, - episode: Episode, - spec: DatasetSpec, -) -> Iterator[Sample]: - """Yield synced (obs, action) Samples for one episode. - - Walks the anchor stream at sync.rate_hz between episode.start_ts and - episode.end_ts. For each anchor timestamp, pulls the nearest observation/ - action from each configured stream within sync.tolerance_ms. Applies any - declared preprocess (e.g. jpeg_decode for Image, field projection for - JointState). Skips frames where any required stream lacks a sample within - tolerance. - """ - raise NotImplementedError - -def _resolve_field(msg: Any, ref: FieldRef) -> np.ndarray: - """Pull a single field from a stream message and convert to np.ndarray. - - Applies ref.field projection (attribute access) and ref.preprocess hook - (named transform like jpeg_decode). Returns a numpy array suitable for - inclusion in a Sample. - """ - raise NotImplementedError - - -# ───────────────────────────────────────────────────────────────────────────── -# Public API -# ───────────────────────────────────────────────────────────────────────────── - - -def _get_writer(format_name: str) -> Writer: - """Lazy-import the `write` function for a given format. Avoids loading - heavy deps (h5py, tfds, lerobot) for unused formats.""" - if format_name == "lerobot": - from dimos.learning.formats.lerobot import write - elif format_name == "hdf5": - from dimos.learning.formats.hdf5 import write - elif format_name == "rlds": - from dimos.learning.formats.rlds import write - else: - raise ValueError( - f"Unknown dataset format: {format_name!r}. Supported: lerobot, hdf5, rlds." - ) - return write - - -def build_dataset(spec: DatasetSpec) -> Path: - """End-to-end: raw session.db -> on-disk dataset in spec.output.format. - - Returns the path written. Requires spec.output to be set. Dispatches to - the appropriate writer in `dimos.learning.formats` via `_get_writer`. - """ - raise NotImplementedError - - -def load_dataset(spec: DatasetSpec) -> torch.utils.data.Dataset[Sample]: - """Training-time loader: returns a PyTorch Dataset over the source recording. - - Materializes Samples on the fly (lazy). Does not require spec.output. - Pre-extracts episodes once and indexes anchor timestamps for O(1) __getitem__. - """ - raise NotImplementedError - - -def inspect(spec: DatasetSpec) -> dict[str, Any]: - """Stats for a session: episode count, duration distribution, per-stream counts. - - Used by `python -m dimos.learning.dataset inspect`. - """ - raise NotImplementedError + # ── construction ───────────────────────────────────────────────────────── + + def __init__(self, spec: DatasetSpec) -> None: + """Bind to a spec. Does not open the store or extract episodes yet.""" + raise NotImplementedError + + @classmethod + def from_file(cls, path: str | Path) -> DataPrep: + """Convenience: `DataPrep.from_file("dataset.yaml")`.""" + raise NotImplementedError + + # ── lazy-cached state ──────────────────────────────────────────────────── + + @property + def store(self) -> SqliteStore: + """Open the recording's SqliteStore on first access; cached thereafter.""" + raise NotImplementedError + + @property + def episodes(self) -> list[Episode]: + """Extract + filter episodes on first access; cached thereafter. + + Equivalent to: + DataPrep.filter_episodes( + DataPrep.extract_episodes(store, spec.episodes), + spec.filters, + ) + """ + raise NotImplementedError + + # ── operations ─────────────────────────────────────────────────────────── + + def iter_samples(self) -> Iterator[Sample]: + """Yield synced Samples across every episode, in episode order.""" + raise NotImplementedError + + def build(self) -> Path: + """End-to-end: source session.db -> on-disk dataset in spec.output.format. + + Returns the path written. Requires `spec.output` to be set. Dispatches + to the appropriate writer in `dimos.learning.formats` via `_get_writer`. + """ + raise NotImplementedError + + def load(self) -> torch.utils.data.Dataset[Sample]: + """Training-time loader: returns a PyTorch Dataset over the source recording. + + Materializes Samples on-the-fly (lazy). Does not require `spec.output`. + Pre-extracts episodes once and indexes anchor timestamps for O(1) + `__getitem__`. + """ + raise NotImplementedError + + def inspect(self) -> dict[str, Any]: + """Stats for a session: episode count, duration distribution, + per-stream sample counts. Used by `python -m dimos.learning.dataprep inspect`. + """ + raise NotImplementedError + + def close(self) -> None: + """Close the underlying SqliteStore. Safe to call multiple times.""" + raise NotImplementedError + + def __enter__(self) -> DataPrep: + return self + + def __exit__(self, *exc: object) -> None: + self.close() + + # ── stateless helpers ──────────────────────────────────────────────────── + # + # Static so they're callable without an instance. `resolve_field` in + # particular is reused by `dimos.learning.inference.obs_builder` to build + # live observations, so train and infer share exactly one code path for + # field projection + preprocess. + + @staticmethod + def extract_episodes(store: SqliteStore, cfg: EpisodeConfig) -> list[Episode]: + """Extract episode boundaries per the configured strategy. + + BUTTONS: scan cfg.button_stream for rising edges on cfg.start/save/discard. + State machine: + IDLE --start press--> RECORDING (begin episode) + RECORDING --save press--> IDLE (commit, success=True) + RECORDING --discard press--> IDLE (drop) + RECORDING --start press--> RECORDING (auto-commit, begin new) + session ends mid-episode: always discard + + RANGES: emit one Episode per (start_ts, end_ts) tuple in cfg.ranges. + + WHOLE: emit a single Episode covering the entire recording's time range. + """ + raise NotImplementedError + + @staticmethod + def filter_episodes(eps: list[Episode], cfg: FilterConfig | None) -> list[Episode]: + """Apply success / duration / label whitelist filters. `None` = pass-through. + + Note: train/val split fields on FilterConfig (`val_episode_ids`, + `val_ratio`) are *not* applied here — they're consumed by the trainer, + which needs the full episode list to materialize both splits. + """ + raise NotImplementedError + + @staticmethod + def iter_episode_samples( + store: SqliteStore, + episode: Episode, + spec: DatasetSpec, + ) -> Iterator[Sample]: + """Yield synced (obs, action) Samples for a single episode. + + Walks the anchor stream at sync.rate_hz between episode.start_ts and + episode.end_ts. For each anchor timestamp, pulls the nearest observation/ + action from each configured stream within sync.tolerance_ms. Applies any + declared preprocess (e.g. jpeg_decode for Image, field projection for + JointState). Skips frames where any required stream lacks a sample + within tolerance. + """ + raise NotImplementedError + + @staticmethod + def resolve_field(msg: Any, ref: StreamField) -> np.ndarray: + """Pull a single field from a stream message and convert to np.ndarray. + + Applies ref.field projection (attribute access) and ref.preprocess hook + (named transform like jpeg_decode). Returns a numpy array suitable for + inclusion in a Sample. + + Reused by the live ObsBuilder at inference time — single source of + truth for observation construction across train and infer. + """ + raise NotImplementedError + + @staticmethod + def _get_writer(format_name: str) -> Writer: + """Lazy-import the `write` function for a given format. Avoids loading + heavy deps (h5py, tfds, lerobot) for unused formats. + """ + if format_name == "lerobot": + from dimos.learning.formats.lerobot import write + elif format_name == "hdf5": + from dimos.learning.formats.hdf5 import write + elif format_name == "rlds": + from dimos.learning.formats.rlds import write + else: + raise ValueError( + f"Unknown dataset format: {format_name!r}. Supported: lerobot, hdf5, rlds." + ) + return write # ───────────────────────────────────────────────────────────────────────────── @@ -187,7 +238,7 @@ def inspect(spec: DatasetSpec) -> dict[str, Any]: def main() -> None: - """CLI entrypoint: build / inspect a dataset spec.""" + """CLI entrypoint: `build` / `inspect` / `review` a dataset spec.""" raise NotImplementedError diff --git a/dimos/learning/dataset.example.yaml b/dimos/learning/dataset.example.yaml index df01cb03f7..83a774fd68 100644 --- a/dimos/learning/dataset.example.yaml +++ b/dimos/learning/dataset.example.yaml @@ -1,11 +1,11 @@ # DimOS Learning — DatasetSpec template # # This file is the contract between data collection and training. Same spec is -# used by `python -m dimos.learning.dataset build` (export to disk) and by +# used by `python -m dimos.learning.dataprep build` (export to disk) and by # `load_dataset(spec)` (training-time PyTorch Dataset). # # Stream names below must match the topic names recorded by RecordReplay. To -# discover them: `python -m dimos.learning.dataset inspect dataset.yaml` +# discover them: `python -m dimos.learning.dataprep inspect dataset.yaml` # (or look in the SQLite registry of session.db). # ─── Source recording ──────────────────────────────────────────────────────── @@ -28,7 +28,9 @@ episodes: # - [1730000000.0, 1730000045.5] # - [1730000060.0, 1730000110.2] - default_task_label: pick_red_cube # optional; applied to every extracted episode + default_task_label: pick_red_cube # optional; short categorical tag + # Free-form natural-language task string used as language conditioning for VLAs (π₀, π₀.₅). + default_task_description: "pick up the red cube and place it on the blue plate" # ─── What goes into each timestep ──────────────────────────────────────────── # Each entry: dataset_key -> { stream, type?, field?, preprocess? } @@ -71,6 +73,11 @@ filters: # max_duration_s: 60.0 # task_labels: [pick_red_cube, pick_blue_cube] + # Train/val split — both optional. val_episode_ids takes precedence over val_ratio. + # val_episode_ids: [0, 7, 13] + # val_ratio: 0.1 + # val_split_seed: 0 + # ─── Output (only required when calling build_dataset / `... build`) ───────── output: format: lerobot # lerobot | hdf5 | rlds diff --git a/dimos/learning/inference/action_replayer.py b/dimos/learning/inference/action_replayer.py new file mode 100644 index 0000000000..1d7e07ca59 --- /dev/null +++ b/dimos/learning/inference/action_replayer.py @@ -0,0 +1,134 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Replay policy-emitted ActionChunks at the coordinator's tick rate. + +`ChunkPolicyModule` runs slow (1–30 Hz) and emits sequences of future actions +(ActionChunks). The coordinator runs at 100 Hz. This task bridges them: +subscribe to the chunk topic, maintain a small buffer of pending (target_ts, +positions) entries, and on each tick interpolate to the current time. + +Lives in the tick loop because hardware writes happen there. Designed so a +slow / stalled policy doesn't crash the controller — see "fault behavior" +below. +""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass + +from dimos.control.task import ( + BaseControlTask, + CoordinatorState, + JointCommandOutput, + ResourceClaim, +) +from dimos.learning.policy.base import ActionChunk + + +@dataclass +class ActionReplayerConfig: + """Configuration for ActionReplayer. + + Attributes: + joint_names: joints this task commands. Must match the policy's + `joint_names` (caller is responsible — typically wired from the + checkpoint's `dimos_meta.json`). + chunk_topic: name of the topic ChunkPolicyModule publishes on. + ActionReplayer subscribes via the coordinator's transport. + priority: tick-loop arbitration priority. + max_chunk_age_s: drop any incoming chunk whose `ts` is more than this + many seconds old at receive time. Guards against stalls. + hold_on_stall: if the buffer empties (policy fell behind / died), + hold the last commanded position instead of returning None + (which would let lower-priority tasks take over). + temporal_ensemble: when overlapping chunks arrive, exponentially + weight predictions for the same target time (ACT trick). + Off by default; v1 nice-to-have. + """ + + joint_names: list[str] + chunk_topic: str = "action_chunk" + priority: int = 10 + max_chunk_age_s: float = 0.5 + hold_on_stall: bool = True + temporal_ensemble: bool = False + + +class ActionReplayer(BaseControlTask): + """ControlTask that replays policy chunks into joint commands at tick rate. + + Behavior: + - On each new chunk, drop any buffered targets at or after the new + chunk's first target_ts (latest chunk wins). + - On each tick, interpolate (or look up nearest) target for `state.now`. + - If `state.now` is past the buffer end: + - hold last position if `hold_on_stall=True` + - else go inactive (return None) + - Stale chunks (`now - chunk.ts > max_chunk_age_s`) are dropped. + + Fault behavior: + - Policy dies / module crashes: buffer drains, behavior degrades to + "hold last position" (or inactive). Hardware never sees zero or NaN. + """ + + def __init__(self, name: str, config: ActionReplayerConfig) -> None: + """Initialize. Subscription to `chunk_topic` is set up by the coordinator + when the task is registered (we expose `on_action_chunk` for it to call). + """ + raise NotImplementedError + + # ── ControlTask interface ──────────────────────────────────────────────── + + @property + def name(self) -> str: + raise NotImplementedError + + def claim(self) -> ResourceClaim: + """Claim `config.joint_names` at `config.priority`.""" + raise NotImplementedError + + def is_active(self) -> bool: + """Active iff the buffer has a non-stale target for `now` (or + `hold_on_stall` is true and we've ever received a chunk).""" + raise NotImplementedError + + def compute(self, state: CoordinatorState) -> JointCommandOutput | None: + """Return interpolated joint targets for `state.now`. + + Pure lookup over the buffered chunk; no model inference happens here. + Must complete in well under 10 ms to not jeopardize the 100 Hz loop. + """ + raise NotImplementedError + + # ── chunk handling ─────────────────────────────────────────────────────── + + def on_action_chunk(self, msg: ActionChunk) -> None: + """Push a new chunk's actions into the buffer. + + Steps: + 1. If `time_now - msg.ts > max_chunk_age_s`: drop and log. + 2. Compute target_ts for each action: `msg.ts + i * msg.dt`. + 3. Drop any buffered entries with target_ts >= msg.ts + msg.dt. + 4. Append the new (target_ts, positions) pairs in order. + """ + raise NotImplementedError + + # ── internals ──────────────────────────────────────────────────────────── + + def _interpolate(self, t: float) -> JointCommandOutput | None: + """Look up or linearly interpolate the buffer at time `t`. Returns + None if `t` is outside the buffered range and `hold_on_stall=False`.""" + raise NotImplementedError diff --git a/dimos/learning/inference/blueprint.py b/dimos/learning/inference/blueprint.py new file mode 100644 index 0000000000..507124ec75 --- /dev/null +++ b/dimos/learning/inference/blueprint.py @@ -0,0 +1,136 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference blueprints for the DimOS Learning Framework. + +Each blueprint composes: + Camera (publishes color_image) + ChunkPolicyModule (consumes obs, publishes ActionChunk at policy rate) + ControlCoordinator with ActionReplayer task (replays chunks at 100 Hz) + Hardware (consumes joint_command) + +The same blueprint serves both ACT (vision) and pi0/pi0.5 (vision + +language) — `ChunkPolicyModule` auto-detects from the loaded checkpoint +whether the policy expects language. For VLA, an LLM agent skill or a +language-source Module publishes to the `language_text` topic. + +Note: ActionReplayer is a ControlTask, not a Module. It runs inside the +ControlCoordinator and is registered via `task_type="action_replayer"` in +the coordinator's task config. The coordinator variants below currently +reference the existing teleop-IK coordinator blueprints; v1 implementation +adds learning-specific coordinator blueprints under +`dimos/control/blueprints/learning.py` that swap teleop_ik for action_replayer +(see plan §7 critical files). + +Usage: + dimos run learning-infer-xarm7 \\ + --ChunkPolicyModule.config.spec_path dataset.yaml \\ + --ChunkPolicyModule.config.policy_path runs/act_pick_red \\ + --ChunkPolicyModule.config.inference_rate_hz 30 +""" + +from __future__ import annotations + +from dimos.control.blueprints.teleop import ( + coordinator_teleop_piper, + coordinator_teleop_xarm6, + coordinator_teleop_xarm7, +) +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera +from dimos.learning.inference.chunk_policy_module import ChunkPolicyModule +from dimos.learning.policy.base import ActionChunk +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.JointState import JointState + +# Topics shared across variants. +_T_COLOR_IMAGE = "/camera/color_image" +_T_JOINT_STATE = "/coordinator/joint_state" +_T_LANGUAGE = "/learning/language_text" +_T_ACTION_CHUNK = "/learning/action_chunk" + + +# ── XArm7 (ACT-rate, 30 Hz) ────────────────────────────────────────────────── + +learning_infer_xarm7 = autoconnect( + RealSenseCamera.blueprint(enable_pointcloud=False), + ChunkPolicyModule.blueprint(inference_rate_hz=30.0), + coordinator_teleop_xarm7, # TODO: replace with coordinator_action_replayer_xarm7 +).transports( + { + ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), + ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), + ("language_text", str): LCMTransport(_T_LANGUAGE, str), + ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), + } +) + + +# ── Piper (ACT-rate) ───────────────────────────────────────────────────────── + +learning_infer_piper = autoconnect( + RealSenseCamera.blueprint(enable_pointcloud=False), + ChunkPolicyModule.blueprint(inference_rate_hz=30.0), + coordinator_teleop_piper, +).transports( + { + ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), + ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), + ("language_text", str): LCMTransport(_T_LANGUAGE, str), + ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), + } +) + + +# ── XArm6 (ACT-rate) ───────────────────────────────────────────────────────── + +learning_infer_xarm6 = autoconnect( + RealSenseCamera.blueprint(enable_pointcloud=False), + ChunkPolicyModule.blueprint(inference_rate_hz=30.0), + coordinator_teleop_xarm6, +).transports( + { + ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), + ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), + ("language_text", str): LCMTransport(_T_LANGUAGE, str), + ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), + } +) + + +# ── XArm7 (VLA-rate, 5 Hz) ─────────────────────────────────────────────────── +# Same wiring; only the policy thread rate differs. pi0/pi0.5 are slow +# enough that running them at 30 Hz wastes GPU. + +learning_infer_vla_xarm7 = autoconnect( + RealSenseCamera.blueprint(enable_pointcloud=False), + ChunkPolicyModule.blueprint(inference_rate_hz=5.0), + coordinator_teleop_xarm7, +).transports( + { + ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), + ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), + ("language_text", str): LCMTransport(_T_LANGUAGE, str), + ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), + } +) + + +__all__ = [ + "learning_infer_piper", + "learning_infer_vla_xarm7", + "learning_infer_xarm6", + "learning_infer_xarm7", +] diff --git a/dimos/learning/inference/chunk_policy_module.py b/dimos/learning/inference/chunk_policy_module.py new file mode 100644 index 0000000000..734f77ac19 --- /dev/null +++ b/dimos/learning/inference/chunk_policy_module.py @@ -0,0 +1,156 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Vision/VLA policy as a DimOS Module — produces action chunks at policy rate. + +One module covers both v1 inference targets: + - ACT (10–30 Hz, vision + joint state) + - pi0/pi0.5 (1–5 Hz, vision + joint state + language) + +`ChunkPolicyModule` runs the policy in a background thread at `inference_rate_hz`, +publishes each output as an `ActionChunk` message, and is consumed by +`ActionReplayer` (in the coordinator's tick loop) which interpolates to 100 Hz. + +Heavy ML deps (`lerobot`, `torch`) are imported lazily via `LeRobotPolicy.load`, +not at module import time — so just having this in a blueprint doesn't pull +CUDA into every install. +""" + +from __future__ import annotations + +import threading +from pathlib import Path +from typing import Any + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.learning.inference.obs_builder import ObsBuilder +from dimos.learning.policy.base import ActionChunk, Policy +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.JointState import JointState + + +class ChunkPolicyModuleConfig(ModuleConfig): + """Config for ChunkPolicyModule.""" + + spec_path: str # path to dataset.yaml — supplies obs construction + policy_path: str # path to lerobot checkpoint dir + inference_rate_hz: float = 5.0 # 5 Hz default for VLA; 30 Hz for ACT + device: str = "cuda" + default_language: str = "" # used when `language_text` port has no value yet + + +class ChunkPolicyModule(Module): + """Runs a Policy at `inference_rate_hz`, publishes ActionChunks. + + Live message latching: + - `color_image` and `joint_state` are cached on every receive; the + policy thread reads the latest cached value at each tick. + - `language_text` is optional; if the policy doesn't expect language + (`policy.expects_language is False`) the port is ignored. + + The thread loop is best-effort wrt `inference_rate_hz`: if a forward pass + takes longer than the period, the next tick fires immediately; we never + queue stale work. + """ + + config: ChunkPolicyModuleConfig + + color_image: In[Image] + joint_state: In[JointState] + language_text: In[str] + + action_chunk: Out[ActionChunk] + + def __init__(self, **kwargs: Any) -> None: + """Defer all heavy init to `start()`.""" + super().__init__(**kwargs) + # Latched live messages — written by port callbacks, read by policy thread. + self._latest_image: Image | None = None + self._latest_joint_state: JointState | None = None + self._latest_language: str | None = None + self._latch_lock = threading.Lock() + + # Filled in start(): + self._policy: Policy | None = None + self._obs_builder: ObsBuilder | None = None + self._chunk_id: int = 0 + + # Thread control + self._thread: threading.Thread | None = None + self._stop = threading.Event() + + # ── lifecycle ──────────────────────────────────────────────────────────── + + @rpc + def start(self) -> None: + """Load spec + policy, subscribe to ports, spawn the inference thread. + + Steps: + 1. `spec = DatasetSpec.from_file(config.spec_path)` + 2. `self._policy = LeRobotPolicy.load(config.policy_path, device=config.device)` + 3. `self._obs_builder = ObsBuilder(spec)` + 4. Subscribe color_image / joint_state / language_text -> latch handlers. + 5. Start the policy thread targeting `_run_loop`. + """ + raise NotImplementedError + + @rpc + def stop(self) -> None: + """Stop the inference thread and call `super().stop()`.""" + raise NotImplementedError + + # ── agent surface ──────────────────────────────────────────────────────── + + @rpc + def set_language(self, text: str) -> None: + """Override the language conditioning text without touching the + upstream `language_text` port. Useful when an LLM agent skill drives + VLA task switching.""" + raise NotImplementedError + + @rpc + def reload_policy(self, policy_path: str, device: str | None = None) -> None: + """Hot-swap the policy checkpoint without restarting the blueprint. + Stops the inference thread, loads the new checkpoint, restarts.""" + raise NotImplementedError + + @rpc + def get_status(self) -> dict[str, Any]: + """Return {'running': bool, 'chunk_count': int, 'policy_path': str, + 'expects_language': bool, 'last_chunk_ts': float | None}.""" + raise NotImplementedError + + # ── inference loop ─────────────────────────────────────────────────────── + + def _run_loop(self) -> None: + """Background thread. Sleep to next deadline, build obs, call policy, + publish chunk. Logs and continues on any per-tick error so a single + bad observation doesn't kill inference. + """ + raise NotImplementedError + + def _build_live_obs(self) -> dict[str, Any] | None: + """Snapshot the latched messages and assemble the dict the ObsBuilder wants. + + Returns None if any required stream hasn't received a message yet + (the loop will skip this tick and try again). + """ + raise NotImplementedError + + def _next_chunk_id(self) -> int: + cid = self._chunk_id + self._chunk_id += 1 + return cid diff --git a/dimos/learning/inference/obs_builder.py b/dimos/learning/inference/obs_builder.py new file mode 100644 index 0000000000..f08db973d6 --- /dev/null +++ b/dimos/learning/inference/obs_builder.py @@ -0,0 +1,72 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Live observation construction for inference. + +`ObsBuilder` is the inference-time counterpart to `DataPrep.iter_episode_samples`. +At training time, samples are built by walking recorded streams; at inference +time we have the *latest* message on each live stream. The transformation +from per-stream messages to a model-ready obs dict must be identical between +the two paths or we get train/serve skew. + +To guarantee that, `ObsBuilder` reuses `DataPrep.resolve_field` for field +projection + preprocess. The only thing it adds is a mapping from the spec's +`stream:` names to live message objects supplied by the caller (the +ChunkPolicyModule, which has the actual `In[Image]` / `In[JointState]` ports). +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from dimos.learning.spec import DatasetSpec + + +class ObsBuilder: + """Builds the model-input dict from the latest live messages. + + Construction takes a `DatasetSpec`. `build()` takes a `{stream_name: msg}` + dict where keys are the recorded stream names referenced by + `spec.observation[*].stream`, and values are the live LCM messages. + + The caller (ChunkPolicyModule) is responsible for resolving its In ports + to those stream names — that's a small static mapping it sets up once. + """ + + def __init__(self, spec: DatasetSpec) -> None: + """Cache the observation StreamFields for fast lookup at tick rate.""" + raise NotImplementedError + + def build(self, live_messages: dict[str, Any]) -> dict[str, np.ndarray]: + """Project + preprocess the latest message on each obs stream. + + Args: + live_messages: stream_name -> latest message object (e.g. + {"camera_color_image": , "coordinator_joint_state": }). + Every stream referenced by `spec.observation` must be present; + missing streams raise. + + Returns: + obs dict keyed by `spec.observation` keys (e.g. "cam_high", + "joint_pos"). Values are np.ndarrays whose shapes/dtypes match + what `iter_episode_samples` produced at training time. + """ + raise NotImplementedError + + def required_streams(self) -> set[str]: + """Stream names this builder reads from. Used by ChunkPolicyModule + to wire its In ports + assert the live_messages dict is complete.""" + raise NotImplementedError diff --git a/dimos/learning/policy/base.py b/dimos/learning/policy/base.py new file mode 100644 index 0000000000..021f95229c --- /dev/null +++ b/dimos/learning/policy/base.py @@ -0,0 +1,93 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy abstraction — what `ChunkPolicyModule` calls every inference tick. + +The Policy protocol decouples model format (lerobot PreTrainedPolicy in v1, +ONNX/TorchScript in v2) from the inference module. Anything that satisfies +this protocol is droppable into a blueprint. + +`ActionChunk` is the typed message published by `ChunkPolicyModule` and +consumed by `ActionReplayer`. v1 uses a pydantic model; v2 will replace it +with a generated LCM type so it can flow over the wire — the field layout +here matches what that LCM type will look like. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Protocol, runtime_checkable + +import numpy as np +from pydantic import BaseModel, ConfigDict + + +class ActionChunk(BaseModel): + """A predicted sequence of joint targets, plus the metadata to replay it. + + Fields: + ts: wall-clock time the chunk was produced (seconds). + joint_names: names matching the action key ordering used at training. + positions: shape (T, N) — T future steps, N = len(joint_names). + dt: expected interval between successive actions (seconds). + Replayer uses ts + i*dt as the target time for action i. + chunk_id: monotonic id for ordering / dedup at the replayer. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + ts: float + joint_names: list[str] + positions: np.ndarray # (T, N) + dt: float + chunk_id: int + + +@runtime_checkable +class Policy(Protocol): + """What ChunkPolicyModule needs from any policy implementation.""" + + @classmethod + def load(cls, path: str | Path, device: str = "cuda") -> Policy: + """Load a checkpoint directory. `path` is a lerobot checkpoint dir in v1. + + Implementations should also load the sidecar `dimos_meta.json` and + `meta/stats.json` so `predict_chunk` can normalize/unnormalize without + the caller doing it. + """ + ... + + @property + def chunk_size(self) -> int: + """Number of actions emitted per `predict_chunk` call (T).""" + ... + + @property + def joint_names(self) -> list[str]: + """Action joint names, matching the spec's action key ordering.""" + ... + + @property + def expects_language(self) -> bool: + """True if the policy reads `obs['language_text']` (VLAs); False otherwise.""" + ... + + def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: + """Return shape (chunk_size, action_dim) — already unnormalized to joint space. + + `obs` keys must match `spec.observation`. The policy applies its own + input normalization internally (using the stats it loaded with the + checkpoint). + """ + ... diff --git a/dimos/learning/policy/lerobot_policy.py b/dimos/learning/policy/lerobot_policy.py new file mode 100644 index 0000000000..3222116b54 --- /dev/null +++ b/dimos/learning/policy/lerobot_policy.py @@ -0,0 +1,106 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LeRobot policy wrapper. + +Wraps any `lerobot.PreTrainedPolicy` (ACT, Diffusion, pi0, pi0.5) behind the +`Policy` protocol. This is the only Policy implementation in v1 — both +training entry points produce checkpoints loadable by this class. + +Heavy deps (`lerobot`, `torch`) are imported lazily inside `load()` so simply +importing this module does not require a CUDA install. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np + +from dimos.learning.policy.base import Policy + +if TYPE_CHECKING: + pass # lerobot / torch deferred to load() + + +class LeRobotPolicy: + """Adapter for lerobot's PreTrainedPolicy → DimOS Policy protocol.""" + + # Type-erased to keep this file import-light. Concrete type: + # _model: lerobot.policies.pretrained.PreTrainedPolicy + _model: Any + _stats: dict[str, Any] + _chunk_size: int + _joint_names: list[str] + _expects_language: bool + _device: str + + def __init__( + self, + model: Any, + stats: dict[str, Any], + chunk_size: int, + joint_names: list[str], + expects_language: bool, + device: str, + ) -> None: + """Direct constructor — prefer `LeRobotPolicy.load(path)` in user code.""" + raise NotImplementedError + + @classmethod + def load(cls, path: str | Path, device: str = "cuda") -> LeRobotPolicy: + """Load a lerobot checkpoint directory. + + Expected layout under `path`: + config.json / model.safetensors - the lerobot checkpoint + meta/stats.json - normalization stats + dimos_meta.json - DimOS sidecar (spec + provenance) + + Auto-detects the policy class (act / diffusion / pi0 / pi0_5) from + the lerobot config and sets `expects_language` accordingly. + """ + raise NotImplementedError + + @property + def chunk_size(self) -> int: + return self._chunk_size + + @property + def joint_names(self) -> list[str]: + return self._joint_names + + @property + def expects_language(self) -> bool: + return self._expects_language + + def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: + """Run one forward pass; return (chunk_size, action_dim). + + Steps: + 1. Normalize obs via `self._stats` (image: /255 + per-channel norm; + vector: (x - mean) / std). + 2. Convert to torch tensors on `self._device`, add batch dim. + 3. Call `self._model.select_action_chunk(obs)` (or equivalent). + 4. Move back to numpy, drop batch dim. + 5. Unnormalize actions via `self._stats`. + + Matches the pipeline used inside lerobot's training loop, so live + inference sees the same numerics as training-time evaluation. + """ + raise NotImplementedError + + +# Sanity check: make the protocol relationship explicit at import time. +_: type[Policy] = LeRobotPolicy diff --git a/dimos/learning/spec.py b/dimos/learning/spec.py index 740de364e3..da5ac79028 100644 --- a/dimos/learning/spec.py +++ b/dimos/learning/spec.py @@ -33,6 +33,33 @@ # ───────────────────────────────────────────────────────────────────────────── +class DatasetSpec(BaseModel): + """Top-level spec. Same instance used at build, load, and inference time. + + A `DatasetSpec` (loaded from YAML/JSON) is the contract between data + collection (raw RecordReplay session -> on-disk dataset) and training + (loading the same spec to feed a model). The same spec also drives + inference observation construction. + """ + + source: Path # path to session.db produced by RecordReplay + episodes: EpisodeConfig + observation: dict[str, StreamField] # obs key -> stream field + action: dict[str, StreamField] # action key -> stream field + sync: SyncConfig + filters: FilterConfig | None = None + output: OutputConfig | None = None # only required by DataPrep.build() + + @classmethod + def from_file(cls, path: str | Path) -> DatasetSpec: + """Load from .yaml/.yml/.json (dispatch by extension).""" + raise NotImplementedError + + def save(self, path: str | Path) -> None: + """Write to .yaml/.yml/.json (round-trip safe).""" + raise NotImplementedError + + class EpisodeConfig(BaseModel): """How to slice the continuous recording into episodes.""" @@ -49,12 +76,15 @@ class EpisodeConfig(BaseModel): # RANGES extractor: explicit absolute timestamps ranges: list[tuple[float, float]] | None = None - # Optional default label applied to every extracted episode + # Default label/description applied to every extracted episode unless overridden. + # task_description is the free-form natural-language string used as language + # conditioning for VLA policies (e.g. "pick up the red cube and place it on the blue plate"). default_task_label: str | None = None + default_task_description: str | None = None -class FieldRef(BaseModel): - """Pointer to a field in a recorded stream.""" +class StreamField(BaseModel): + """Pointer to a field in a recorded stream — one (obs|action) key's data source.""" stream: str # LCM stream / topic name as recorded in session.db type: str | None = None # optional dotted type (e.g. "sensor_msgs.Image"); for codec dispatch @@ -79,6 +109,13 @@ class FilterConfig(BaseModel): max_duration_s: float | None = None task_labels: list[str] | None = None # whitelist; None = all + # Train/val split. Episodes whose index lands in val become the validation set + # at training time; everything else is train. `val_episode_ids` takes precedence + # over `val_ratio`. Both None = no split (everything is train). + val_episode_ids: list[int] | None = None + val_ratio: float | None = None + val_split_seed: int = 0 + class OutputConfig(BaseModel): """Where and how to write the built dataset.""" @@ -88,18 +125,6 @@ class OutputConfig(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict) -class DatasetSpec(BaseModel): - """Top-level spec. Same instance used at build, load, and inference time.""" - - source: Path # path to session.db produced by RecordReplay - episodes: EpisodeConfig - observation: dict[str, FieldRef] # obs key -> stream field - action: dict[str, FieldRef] # action key -> stream field - sync: SyncConfig - filters: FilterConfig | None = None - output: OutputConfig | None = None # only required by build_dataset() - - # ───────────────────────────────────────────────────────────────────────────── # Runtime data # ───────────────────────────────────────────────────────────────────────────── @@ -127,7 +152,8 @@ class Episode(BaseModel): id: str start_ts: float end_ts: float - task_label: str | None = None + task_label: str | None = None # short categorical tag (e.g. "pick_red_cube") + task_description: str | None = None # free-form natural-language string for VLA conditioning success: bool = True metadata: dict[str, Any] = Field(default_factory=dict) @@ -145,3 +171,9 @@ class Sample(BaseModel): episode_id: str observation: dict[str, np.ndarray] action: dict[str, np.ndarray] + + +# DatasetSpec is defined before its referenced subclasses so it reads as the +# top-of-file entry point. Resolve those forward references now that every +# referenced class exists in the module namespace. +DatasetSpec.model_rebuild() diff --git a/dimos/learning/training/blueprint.py b/dimos/learning/training/blueprint.py new file mode 100644 index 0000000000..7c7c97852b --- /dev/null +++ b/dimos/learning/training/blueprint.py @@ -0,0 +1,103 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training blueprints for the DimOS Learning Framework. + +Each blueprint composes TrainerModule + LearningMonitorModule. TrainerModule +handles both the dataset-build and the training subprocesses internally +(see its docstring) — there is no separate builder Module in v1. + +Variants: + learning_train_act - auto build (if needed) then train ACT (BC) + learning_train_vla - auto build (if needed) then finetune pi0/pi0.5 + learning_train_idle - module idle, agent drives via @rpc + +Defaults (spec_path, output_dir, ...) are placeholders; override at run +time via CLI flags or @rpc calls. Per-job overrides on the trigger payload +take precedence over module config. + +Usage: + dimos run learning-train-act \\ + --TrainerModule.config.spec_path dataset.yaml \\ + --TrainerModule.config.output_dir runs/act_pick_red +""" + +from __future__ import annotations + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.learning.training.monitor_module import LearningMonitorModule +from dimos.learning.training.trainer_module import ( + TrainDone, + TrainerModule, + TrainProgress, +) + +# Topic names — shared across all variants so monitors / agents subscribe once. +_T_TRAIN_PROGRESS = "/learning/train/progress" +_T_TRAIN_DONE = "/learning/train/done" + + +# ── ACT (BC) — auto build (if needed) then train ───────────────────────────── + +learning_train_act = autoconnect( + TrainerModule.blueprint(config_kind="bc", auto_run=True), + LearningMonitorModule.blueprint(log_to_rerun=True), +).transports( + { + ("progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), + ("done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), + ("train_progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), + ("train_done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), + } +) + + +# ── pi0 / pi0.5 (VLA finetune) — auto build (if needed) then train ─────────── + +learning_train_vla = autoconnect( + TrainerModule.blueprint(config_kind="vla", auto_run=True), + LearningMonitorModule.blueprint(log_to_rerun=True), +).transports( + { + ("progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), + ("done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), + ("train_progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), + ("train_done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), + } +) + + +# ── Idle — TrainerModule waits for explicit @rpc / external trigger ────────── +# Agent-driven: agent skill calls TrainerModule.train(...) (or .build_only(...)) +# over RPC. Same module, no auto behavior. + +learning_train_idle = autoconnect( + TrainerModule.blueprint(auto_run=False), + LearningMonitorModule.blueprint(log_to_rerun=True), +).transports( + { + ("progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), + ("done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), + ("train_progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), + ("train_done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), + } +) + + +__all__ = [ + "learning_train_act", + "learning_train_idle", + "learning_train_vla", +] diff --git a/dimos/learning/training/configs.py b/dimos/learning/training/configs.py new file mode 100644 index 0000000000..1630d925dd --- /dev/null +++ b/dimos/learning/training/configs.py @@ -0,0 +1,94 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trainer configs for v1. + +Two pydantic configs, one per training entry point: + - BCConfig -> consumed by train_bc (ACT, optionally Diffusion) + - VLAConfig -> consumed by finetune_vla (pi0, pi0.5) + +Both are translated into a `lerobot` training config inside the trainer; +fields here are the small, opinionated subset DimOS users actually need to +tune. Anything not exposed falls back to the lerobot default. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel + + +class BCConfig(BaseModel): + """Behavior-cloning trainer config (v1: ACT, with Diffusion as a flag).""" + + policy_type: Literal["act", "diffusion"] = "act" + + # Action chunking + chunk_size: int = 50 # number of future actions predicted per inference call + n_obs_steps: int = 1 # observation history length passed to the policy + + # ACT model arch (ignored for Diffusion) + hidden_dim: int = 512 + n_layers: int = 4 + n_heads: int = 8 + use_vae: bool = True + kl_weight: float = 10.0 + + # Vision backbone + vision_backbone: str = "resnet18" + pretrained: bool = True + + # Optim + steps: int = 100_000 + batch_size: int = 8 + lr: float = 1e-5 + lr_backbone: float = 1e-5 + weight_decay: float = 1e-4 + + # Eval / checkpointing + save_every: int = 10_000 + eval_every: int = 5_000 + seed: int = 0 + device: str = "cuda" + + +class VLAConfig(BaseModel): + """VLA finetune config (v1: pi0, pi0.5).""" + + policy_type: Literal["pi0", "pi0_5"] = "pi0_5" + + # Pretrained checkpoint — HF hub id or local path + pretrained_path: str + + # Finetune mode + finetune_mode: Literal["full", "lora"] = "lora" + lora_rank: int = 16 + freeze_vision: bool = True + freeze_language: bool = True + + # Action chunking — pi0/pi0.5 default + chunk_size: int = 50 + + # Optim + steps: int = 30_000 + batch_size: int = 4 + lr: float = 5e-5 + weight_decay: float = 1e-4 + + # Eval / checkpointing + save_every: int = 5_000 + eval_every: int = 2_500 + seed: int = 0 + device: str = "cuda" diff --git a/dimos/learning/training/monitor_module.py b/dimos/learning/training/monitor_module.py new file mode 100644 index 0000000000..c2f5d2cc42 --- /dev/null +++ b/dimos/learning/training/monitor_module.py @@ -0,0 +1,89 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualize / log training progress. + +Subscribes to the unified `TrainProgress` + `TrainDone` streams from +`TrainerModule` (which covers both build and train phases via `phase` field); +logs to: + - rerun (if the rerun bridge is available — already a DimOS dep) + - JSONL file (structured, post-hoc analysis) + - stdout (always, terse summary line per event) + +Optional in any blueprint. Sits passively on the bus so the same training +session can have multiple monitors (one in dev, one writing to a server). +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In +from dimos.learning.training.trainer_module import TrainDone, TrainProgress + + +class LearningMonitorModuleConfig(ModuleConfig): + """Where to send progress events. + + Attributes: + log_to_rerun: forward every event to the rerun bridge if importable. + log_to_stdout: print one terse summary line per event. + jsonl_path: if set, append JSON-per-line to this file. + train_loss_smoothing: EMA smoothing factor for the rerun loss curve. + """ + + log_to_rerun: bool = True + log_to_stdout: bool = True + jsonl_path: str | None = None + train_loss_smoothing: float = 0.9 + + +class LearningMonitorModule(Module): + """Pure subscriber. Owns no work, just visualization fan-out.""" + + config: LearningMonitorModuleConfig + + train_progress: In[TrainProgress] + train_done: In[TrainDone] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._jsonl_handle: Any = None # opened in start() + self._train_loss_ema: float | None = None + + @rpc + def start(self) -> None: + """Open the JSONL file (if configured), subscribe to both ports.""" + raise NotImplementedError + + @rpc + def stop(self) -> None: + """Flush + close the JSONL file; super().stop().""" + raise NotImplementedError + + # ── handlers (called from port subscriptions) ──────────────────────────── + + def _on_train_progress(self, msg: TrainProgress) -> None: + """Forward to enabled sinks. Routes by `msg.phase`: + - phase == "build": log dataset progress (episodes, samples) + - phase in {"train","eval"}: log loss curves (with EMA smoothing for rerun) + - other phases: log message line only. + """ + raise NotImplementedError + + def _on_train_done(self, msg: TrainDone) -> None: + """Final summary line; close JSONL section.""" + raise NotImplementedError diff --git a/dimos/learning/training/split.py b/dimos/learning/training/split.py new file mode 100644 index 0000000000..b4a6f27b0f --- /dev/null +++ b/dimos/learning/training/split.py @@ -0,0 +1,41 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Episode-level train/val split. + +LeRobot v2 supports filtering by episode index at training time, so we don't +materialize two datasets. We compute the partition once and pass the index +lists to the trainer. + +Resolution order (first non-None wins): + 1. `cfg.val_episode_ids` — explicit whitelist + 2. `cfg.val_ratio` — deterministic random split via cfg.val_split_seed + 3. neither set — empty val (everything is train) +""" + +from __future__ import annotations + +from dimos.learning.spec import Episode, FilterConfig + + +def train_val_split( + episodes: list[Episode], + cfg: FilterConfig | None, +) -> tuple[list[int], list[int]]: + """Partition `episodes` (already filtered) into (train_ids, val_ids). + + Returns lists of episode *indices* into `episodes`, not Episode objects. + LeRobot consumes index lists. Determinism is via `cfg.val_split_seed`. + """ + raise NotImplementedError diff --git a/dimos/learning/training/stats.py b/dimos/learning/training/stats.py new file mode 100644 index 0000000000..23ca337c24 --- /dev/null +++ b/dimos/learning/training/stats.py @@ -0,0 +1,108 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-feature dataset statistics — written once, read by trainers + inference. + +LeRobot expects `meta/stats.json` next to the dataset with mean/std/min/max/q01/q99 +for every observation and action key. The same dict is consumed by: + - the trainer (normalize inputs / unnormalize predicted actions) + - the inference `ObsBuilder` and `ActionReplayer` (same normalization, live) + +`Stats` is a streaming Welford accumulator: feed it `Sample` instances one at a +time and call `.result()` at the end. Used both inside `formats.lerobot.write` +(so the dataset is self-describing on first build) and as a standalone pass via +`compute_stats(dp)` if a dataset on disk needs stats recomputed. + +Image stats are computed on a subsample (every Nth frame) to bound cost. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from dimos.learning.spec import Sample + +if TYPE_CHECKING: + from dimos.learning.dataprep import DataPrep + + +class Stats: + """Streaming Welford accumulator for per-feature mean/std/min/max/q01/q99. + + Update with one Sample at a time; call `.result()` at the end to get the + serializable dict written to `meta/stats.json`. + + Quantiles (q01, q99) are computed from a reservoir sample of size + `quantile_reservoir` per feature — bounded memory for unbounded streams. + """ + + def __init__( + self, + image_subsample: int = 10, + quantile_reservoir: int = 10_000, + seed: int = 0, + ) -> None: + """Configure cost knobs. + + Args: + image_subsample: include every Nth image frame in stats; N=1 for full + accuracy, larger N for faster builds on long sessions. + quantile_reservoir: reservoir size per feature for q01/q99. + seed: for the reservoir sampler. + """ + raise NotImplementedError + + def update(self, sample: Sample) -> None: + """Fold one Sample into the running statistics for every obs/action key.""" + raise NotImplementedError + + def result(self) -> dict[str, Any]: + """Return the LeRobot-compatible stats dict. + + Schema: + { + "observation.": {"mean": [...], "std": [...], "min": [...], + "max": [...], "q01": [...], "q99": [...]}, + "action.": {... same keys ...}, + ... + } + """ + raise NotImplementedError + + def save(self, path: str | Path) -> None: + """Write `result()` to `path` as JSON.""" + raise NotImplementedError + + @classmethod + def load(cls, path: str | Path) -> dict[str, Any]: + """Read a stats JSON from disk. Returns the raw dict, not a Stats instance.""" + raise NotImplementedError + + +def compute_stats(samples: Iterator[Sample], **kw: Any) -> dict[str, Any]: + """One-shot helper: drain `samples`, return the stats dict. + + Equivalent to `s = Stats(**kw); for x in samples: s.update(x); return s.result()`. + """ + raise NotImplementedError + + +def compute_stats_from_prep(dp: DataPrep, **kw: Any) -> dict[str, Any]: + """One-shot helper that pulls samples from a DataPrep instance. + + Convenience for "I have a built dataset on disk and need to recompute stats." + """ + raise NotImplementedError diff --git a/dimos/learning/training/train.py b/dimos/learning/training/train.py new file mode 100644 index 0000000000..2730bdd715 --- /dev/null +++ b/dimos/learning/training/train.py @@ -0,0 +1,129 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training entry points for v1. + +Two functions, both thin wrappers around `lerobot`: + - train_bc(spec, cfg, output_dir) -> ACT (or Diffusion) BC training + - finetune_vla(spec, cfg, output_dir) -> pi0 / pi0.5 finetune + +Both: + 1. Materialize the dataset via `DataPrep.build()` if `spec.output.path` + doesn't already exist (idempotent). + 2. Open the materialized dataset as a `lerobot.LeRobotDataset`. + 3. Translate the DimOS config to a lerobot config, build the policy. + 4. Compute / load `meta/stats.json`. + 5. Compute the train/val split. + 6. Call lerobot's training loop. + 7. Write the checkpoint + a sidecar `dimos_meta.json` so inference can + reconstruct everything from `output_dir` alone. + +We do NOT reimplement the training loop, optimizer schedule, normalization, +action chunking, language tokenization, or checkpoint format. Riding on +lerobot keeps this file small and means a `pi0.5` upgrade is a config bump. +""" + +from __future__ import annotations + +from pathlib import Path + +from dimos.learning.spec import DatasetSpec +from dimos.learning.training.configs import BCConfig, VLAConfig + +# Sidecar written next to the lerobot checkpoint so inference can recover +# the spec + dataset path that produced this policy. +DIMOS_META_FILENAME = "dimos_meta.json" + + +def train_bc(spec: DatasetSpec, cfg: BCConfig, output_dir: str | Path) -> Path: + """Train an ACT (or Diffusion) BC policy on `spec`. + + Returns the path to the final checkpoint directory. The returned dir + contains the lerobot checkpoint + `dimos_meta.json` linking back to the + spec and dataset used. + """ + raise NotImplementedError + + +def finetune_vla(spec: DatasetSpec, cfg: VLAConfig, output_dir: str | Path) -> Path: + """Finetune a pretrained pi0 / pi0.5 on `spec`. + + Loads `cfg.pretrained_path` (HF hub id or local), wraps it for the + requested `finetune_mode` (full or LoRA), runs lerobot's training loop, + and writes the resulting checkpoint to `output_dir`. + + Returns the checkpoint directory path. + """ + raise NotImplementedError + + +# ───────────────────────────────────────────────────────────────────────────── +# Internals — translate DimOS configs into the lerobot training entry point. +# ───────────────────────────────────────────────────────────────────────────── + + +def _ensure_dataset(spec: DatasetSpec) -> Path: + """If `spec.output.path` doesn't exist on disk yet, run `DataPrep.build()`. + + Returns the resolved dataset path. Raises if `spec.output` is None + (training requires a materialized dataset). + """ + raise NotImplementedError + + +def _build_lerobot_config_bc(spec: DatasetSpec, cfg: BCConfig, dataset_path: Path) -> object: + """Translate a DimOS BCConfig + spec into a lerobot training config. + + Returns the lerobot config object opaque to the rest of this file — + everything lerobot-specific stays inside the implementation. + """ + raise NotImplementedError + + +def _build_lerobot_config_vla(spec: DatasetSpec, cfg: VLAConfig, dataset_path: Path) -> object: + """Translate a DimOS VLAConfig + spec into a lerobot training config.""" + raise NotImplementedError + + +def _write_dimos_meta(output_dir: Path, spec: DatasetSpec, dataset_path: Path) -> None: + """Write `dimos_meta.json` next to the checkpoint. + + Schema: + { + "dimos_version": "...", + "spec": , + "dataset_path": "", + "lerobot_version": "..." + } + Used by inference to rehydrate the spec without a separate yaml. + """ + raise NotImplementedError + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────────────────────── + + +def main() -> None: + """CLI entrypoint: + + python -m dimos.learning.training.train bc --output [...] + python -m dimos.learning.training.train vla --output --pretrained [...] + """ + raise NotImplementedError + + +if __name__ == "__main__": + main() diff --git a/dimos/learning/training/trainer_module.py b/dimos/learning/training/trainer_module.py new file mode 100644 index 0000000000..490b1b73ee --- /dev/null +++ b/dimos/learning/training/trainer_module.py @@ -0,0 +1,263 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DimOS Module wrapper around the v1 training pipeline. + +A single training job is two subprocesses run in sequence: + 1. `python -m dimos.learning.dataprep build` (skipped if output exists) + 2. `python -m dimos.learning.training.train ...` + +`TrainerModule` runs both, parses their progress lines, and republishes them +under one unified `TrainProgress` stream with a `phase` field. There is no +separate builder Module — building is always a precursor to training in v1, +so the wiring tax of two Modules + a chain port wasn't worth it. + +Why subprocess: keeps `lerobot`, `torch`, CUDA out of the runtime's import +graph. Process isolation also means a CUDA OOM doesn't poison the runtime. + +Wiring patterns: + - Default blueprint: `auto_run=True` -> module fires on start() + - Agent skill: agent calls `@rpc train(...)` directly + - Build-only (rare): agent calls `@rpc build_only(spec_path)` + - External trigger: publish `TrainTrigger` on the trigger port +""" + +from __future__ import annotations + +import subprocess +import threading +from pathlib import Path +from typing import Any, Literal + +from pydantic import BaseModel + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out + + +# ───────────────────────────────────────────────────────────────────────────── +# Message types +# ───────────────────────────────────────────────────────────────────────────── + + +class TrainTrigger(BaseModel): + """Start a training job. Empty trigger uses module config defaults.""" + + spec_path: str | None = None + output_dir: str | None = None + config_kind: Literal["bc", "vla"] | None = None + config_overrides: dict[str, Any] = {} # merged onto BCConfig/VLAConfig + skip_build: bool = False # set when caller knows the dataset is already built + job_id: str | None = None + + +class TrainProgress(BaseModel): + """Unified progress event covering both build and train phases. + + `phase` indicates which subprocess the event came from. For phase=="build" + the train-specific fields (loss, val_loss, step counts) are zero/None. + For phase=="train" the build-specific fields are zero. + """ + + job_id: str + phase: Literal["build", "load", "train", "eval", "save", "done", "failed"] + message: str = "" + + # Build-phase fields (meaningful only when phase == "build") + samples_written: int = 0 + current_episode: int = 0 + total_episodes: int = 0 + + # Train-phase fields (meaningful only when phase in {"load","train","eval","save"}) + step: int = 0 + total_steps: int = 0 + loss: float | None = None + val_loss: float | None = None + eta_s: float | None = None + + +class TrainDone(BaseModel): + """Terminal event with the final checkpoint dir or an error.""" + + job_id: str + success: bool + dataset_path: Path | None = None # the (possibly newly-built) dataset + checkpoint_dir: Path | None = None # None on failure + error: str | None = None + + +# ───────────────────────────────────────────────────────────────────────────── +# Module +# ───────────────────────────────────────────────────────────────────────────── + + +class TrainerModuleConfig(ModuleConfig): + """Trainer module config. + + Attributes: + spec_path: default spec path (used for both build and train). + output_dir: default checkpoint output directory. + config_kind: "bc" (ACT/Diffusion) or "vla" (pi0/pi0.5). + config_path: optional BCConfig/VLAConfig YAML override. + python_executable: subprocess python; "" = current sys.executable. + skip_build_if_exists: if the dataset path already exists on disk, + skip the build phase. Default True (idempotent). + auto_run: if True, start a job on `start()`. + max_concurrent: cap on simultaneous jobs. v1 uses 1. + """ + + spec_path: str = "" + output_dir: str = "" + config_kind: Literal["bc", "vla"] = "bc" + config_path: str | None = None + python_executable: str = "" + skip_build_if_exists: bool = True + auto_run: bool = False + max_concurrent: int = 1 + + +class TrainerModule(Module): + """Spawns dataprep build (if needed) then train; reports unified progress.""" + + config: TrainerModuleConfig + + trigger: In[TrainTrigger] + + progress: Out[TrainProgress] + done: Out[TrainDone] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + # job_id -> (build_proc, train_proc, watcher_thread). Each proc may be None. + self._jobs: dict[str, dict[str, Any]] = {} + self._jobs_lock = threading.Lock() + self._next_job_id = 0 + + # ── lifecycle ──────────────────────────────────────────────────────────── + + @rpc + def start(self) -> None: + """Subscribe to `trigger`. If `auto_run`, kick off one training job.""" + raise NotImplementedError + + @rpc + def stop(self) -> None: + """Cancel all in-flight jobs, then super().stop().""" + raise NotImplementedError + + # ── agent / external surface ───────────────────────────────────────────── + + @rpc + def train( + self, + spec_path: str | None = None, + output_dir: str | None = None, + config_kind: Literal["bc", "vla"] | None = None, + config_overrides: dict[str, Any] | None = None, + skip_build: bool = False, + ) -> str: + """Start a build-then-train job. Returns job_id. + + All arguments override `config` for this job only. If `skip_build` or + `config.skip_build_if_exists` and the dataset is on disk, the build + phase is skipped. + """ + raise NotImplementedError + + @rpc + def build_only(self, spec_path: str | None = None) -> str: + """Run only the dataset-build subprocess; do not train. + + Convenience for the rare standalone case (CI dataset bake, debugging + a new spec). Returns job_id; emits TrainProgress with phase=="build" + events and a TrainDone with checkpoint_dir=None on completion. + """ + raise NotImplementedError + + @rpc + def cancel(self, job_id: str) -> bool: + """SIGTERM the active subprocess (build or train); True if cancelled.""" + raise NotImplementedError + + @rpc + def list_jobs(self) -> list[str]: + """Return active job ids.""" + raise NotImplementedError + + @rpc + def list_checkpoints(self, output_dir: str | None = None) -> list[str]: + """Scan `output_dir` (defaults to config.output_dir) and return paths + to checkpoint subdirectories. Useful for agent flows like 'train then + deploy the latest checkpoint'.""" + raise NotImplementedError + + # ── internals ──────────────────────────────────────────────────────────── + + def _on_trigger(self, msg: TrainTrigger) -> None: + """Port handler — calls `self.train(...)`.""" + raise NotImplementedError + + def _run_job( + self, + job_id: str, + spec_path: str, + output_dir: str, + config_kind: Literal["bc", "vla"], + config_overrides: dict[str, Any], + skip_build: bool, + train: bool, + ) -> None: + """Background thread driving one job through its phases. + + Sequence: + 1. Resolve dataset path from spec. + 2. If `skip_build` is False and dataset doesn't exist (or + `skip_build_if_exists` is False), spawn `dataprep build` and + stream its progress as phase=="build". + 3. If `train` is True, spawn `train` and stream progress as + phase in {"load","train","eval","save"}. + 4. Emit terminal TrainDone. + """ + raise NotImplementedError + + def _spawn_build(self, spec_path: str, job_id: str) -> subprocess.Popen[str]: + """Build argv for `python -m dimos.learning.dataprep build --progress-json`.""" + raise NotImplementedError + + def _spawn_train( + self, + spec_path: str, + output_dir: str, + config_kind: Literal["bc", "vla"], + config_overrides: dict[str, Any], + job_id: str, + ) -> subprocess.Popen[str]: + """Build argv for `python -m dimos.learning.training.train --progress-json`.""" + raise NotImplementedError + + def _stream_build_progress(self, job_id: str, proc: subprocess.Popen[str]) -> int: + """Read stdout JSON-per-line from the build subprocess; publish each + as TrainProgress(phase="build"). Returns subprocess exit code.""" + raise NotImplementedError + + def _stream_train_progress(self, job_id: str, proc: subprocess.Popen[str]) -> int: + """Read stdout JSON-per-line from the train subprocess; publish each + as TrainProgress(phase in {"load","train","eval","save"}). Returns exit code.""" + raise NotImplementedError + + def _allocate_job_id(self) -> str: + jid = f"train-{self._next_job_id}" + self._next_job_id += 1 + return jid From a10dce103bfd4e7a6c6772ab00f4cf595d79da85 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 30 Apr 2026 17:08:46 -0700 Subject: [PATCH 03/45] temp spec update --- dimos/learning/PLAN.md | 645 ------------------ dimos/learning/collection/blueprint.py | 16 +- dimos/learning/inference/action_replayer.py | 1 - .../learning/inference/chunk_policy_module.py | 1 - dimos/learning/learning_spec.md | 164 +++++ dimos/learning/policy/lerobot_policy.py | 5 +- dimos/learning/specs/datacollection.md | 247 +++++++ dimos/learning/specs/inference.md | 184 +++++ dimos/learning/specs/structure.md | 141 ++++ dimos/learning/specs/training.md | 60 ++ dimos/learning/training/monitor_module.py | 6 +- dimos/learning/training/train.py | 4 +- dimos/learning/training/trainer_module.py | 3 +- 13 files changed, 807 insertions(+), 670 deletions(-) delete mode 100644 dimos/learning/PLAN.md create mode 100644 dimos/learning/learning_spec.md create mode 100644 dimos/learning/specs/datacollection.md create mode 100644 dimos/learning/specs/inference.md create mode 100644 dimos/learning/specs/structure.md create mode 100644 dimos/learning/specs/training.md diff --git a/dimos/learning/PLAN.md b/dimos/learning/PLAN.md deleted file mode 100644 index 0dab9571df..0000000000 --- a/dimos/learning/PLAN.md +++ /dev/null @@ -1,645 +0,0 @@ -# DimOS Learning Framework — v1 Plan - -## v1 Scope - -**Goal:** end-to-end pipeline that lets a DimOS user collect teleop demos, train a policy, and run it on a real arm — for two concrete targets: - -1. **BC / ACT** — train ACT (Action Chunking Transformer) on a pick-and-place demo set on xArm7. -2. **VLA finetune** — finetune a pretrained π₀ / π₀.₅ checkpoint on the same demo set. - -Both targets share a single `DatasetSpec`, a single LeRobot dataset on disk, and a single inference module. The choice between ACT and a VLA is just a different training entry point and a different policy class at inference time. - -**v1 architectural mandate — fully DimOS-native:** - -Every phase of the pipeline (collection, training, inference) is exposed as a **Module + Blueprint** with RPC surfaces. There is **one user-facing UX**: `dimos --blueprint ` for everything. Agent skills can drive any phase via @rpc. Composition between phases is just port wiring (`builder.done → trainer.builder_done`). - -For collection and training — where the actual work is offline batch processing — the Module is an **orchestrator over a subprocess**: it spawns `python -m dimos.learning.dataprep build` or `python -m ...training.train`, parses its progress lines, and republishes them as typed events. The work itself stays in the subprocess (heavy deps isolated, process-cancellable, separately testable). Inference Modules do real live work because there is real live data flow. - -Every Module exposes: -- `@rpc start()` / `@rpc stop()` — lifecycle -- `@rpc (...)` — at least one agent-callable action -- `@rpc get_status()` — observability -- typed `In[...]` / `Out[...]` ports for blueprint composition - -**Out of v1 (deferred to v2 — see bottom of file):** -- RL (online + offline) -- Pure proprioceptive policies in the 100 Hz tick loop (`PolicyControlTask`) -- Multi-embodiment / cross-task training -- Distributed / multi-GPU training -- Live recording of episode boundaries (we keep post-hoc button extraction) - -**Design principle:** lean on existing infrastructure — `RecordReplay` (PR #1708), memory2, the Module/Blueprint system, **and** the `lerobot` library which already implements ACT and π₀/π₀.₅ end-to-end. We add only the DimOS glue. - ---- - -## Architecture Overview - -``` - COLLECT TRAIN INFER - ─────── ───── ───── - Teleop + Camera load_dataset(spec) ChunkPolicyModule (1–30 Hz) - ↓ ↓ ↓ (action chunks) - RecordReplay --record-path train_bc / finetune_vla ActionReplayer (100 Hz) - ↓ ↓ ↓ - session.db checkpoint (.safetensors) Coordinator joint_command - ↓ + stats.json ↓ - dataset.yaml + dataprep.py Hardware - ↓ - LeRobot v2 dataset on disk -``` - -The **same `dataset.yaml`** is the contract between collection (export), training (load), and inference (live obs construction). It is the single source of truth for what counts as observation/action, episode boundaries, sync strategy, and feature shapes. - ---- - -## 1. Data Collection — STATUS: IMPLEMENTED, NEEDS POLISH - -### 1.1 Recording (no new code) - -Use Sam's `RecordReplay` from PR #1708: - -```bash -dimos --blueprint quest_teleop_xarm7 --record-path session.db -``` - -Captures every LCM topic — joint states, joint commands, camera images, controller `Buttons`, IMU, etc. One stream per topic in a single `SqliteStore`. Episode boundaries are not marked at record time; they are recovered offline from button presses. - -### 1.2 The dataset spec - -Implemented in `dimos/learning/spec.py`. Schema is pydantic v2 → round-trips YAML/JSON. Eight typed classes: `EpisodeConfig`, `FieldRef`, `SyncConfig`, `FilterConfig`, `OutputConfig`, `DatasetSpec`, `Episode`, `Sample`. Friendly Quest button names (`A`/`B`/`X`/...) resolve to `Buttons` bit fields via `BUTTON_ALIASES`. - -Example (see `dimos/learning/dataset.example.yaml` for the live template): - -```yaml -source: session.db - -episodes: - extractor: buttons # buttons | ranges | whole_session - start: A # press to begin - save: B # press to commit - discard: X # press to drop - default_task_label: pick_red_cube - -observation: - cam_high: - stream: camera_color_image - preprocess: jpeg_decode - cam_wrist: - stream: camera_wrist_color_image - preprocess: jpeg_decode - joint_pos: - stream: coordinator_joint_state - field: position - -action: - joint_target: - stream: coordinator_joint_command - field: position - -sync: - anchor: cam_high - rate_hz: 30 - tolerance_ms: 50 - strategy: nearest - -filters: - success_only: true - min_duration_s: 1.0 - task_labels: [pick_red_cube] - -output: - format: lerobot # primary v1 target - path: datasets/pick_red/ - metadata: - fps: 30 - robot: xarm7 -``` - -### 1.3 The pipeline file: `dimos/learning/dataprep.py` - -Implemented. Does everything: read raw `session.db`, extract episodes from button events, sync streams, dispatch to the chosen format writer. Same module exposes `load_dataset(spec)` for training. - -Public functions (all done): -- `load_spec(path)` / `save_spec(spec, path)` — YAML/JSON I/O -- `extract_episodes(store, cfg)` — three strategies (buttons/ranges/whole_session) -- `filter_episodes(eps, cfg)` — success / duration / label whitelist -- `iter_samples(store, episode, spec)` — anchor-rate timestep walker w/ bisect nearest-search -- `build_dataset(spec)` — full session.db → on-disk dataset -- `load_dataset(spec)` — returns a `torch.utils.data.Dataset[Sample]` -- `inspect(spec)` — episode/duration/per-stream stats -- `main()` — CLI: `build` / `inspect` / `review` (review is a stub) - -### 1.4 Format writers (in `dimos/learning/formats/`) - -| Format | v1 priority | Status | Why | -|-----------|-------------|--------|-----| -| `lerobot` | **primary** | done | Native input for both ACT and π₀/π₀.₅ via `lerobot` lib | -| `hdf5` | secondary | done | ACT-original codebase, debugging, smaller deps | -| `rlds` | v2 | done (gated on TF) | RT-X / OpenX-Embodiment compat — not needed for v1 | - -### 1.5 Gaps to close in v1 - -The collection pipeline is functional, but a few things are needed before LeRobot training works cleanly. Each is small. - -**(a) Per-episode task description** — π₀ is language-conditioned; LeRobot v2 has a `tasks.jsonl` table. Currently `Episode.task_label: str | None` is a tag; we need a free-form string per episode. Extend with: -```python -class Episode: - task_description: str | None = None # e.g. "pick up the red cube and place it on the blue plate" -``` -The LeRobot writer already emits `tasks.jsonl` keyed on `task_label` — switch it to use `task_description` (fall back to `task_label`). Population: `EpisodeConfig.default_task_description` for single-task sessions, or set per episode in the `review` CLI. - -**(b) Dataset statistics** — LeRobot training requires `meta/stats.json` (per-feature mean/std/min/max/q01/q99). Add a streaming stats accumulator inside `formats/lerobot.py::write` so we don't need a second pass over the data. Image stats are computed on a subsample (every Nth frame) to bound cost. - -**(c) Train/val split** — LeRobot v2 supports filtering by episode index at training time, so we don't need to materialize two datasets. Add `FilterConfig.val_episode_ids: list[int] | None` and `FilterConfig.val_ratio: float | None` (deterministic seeded split). Trainer reads these. - -**(d) Image format on disk** — LeRobot v2 stores images as MP4 videos by default (`videos/chunk-NNN//episode_NNNNNN.mp4`). Current writer writes them as parquet tensor columns, which works but inflates disk size. Switch to MP4 encoding via `imageio[ffmpeg]` for image streams ≥2D + uint8. Parquet cells then store frame indices, not pixels. - -**(e) `review` CLI** — currently a stub. Implement a minimal non-interactive form first: load spec, list episodes with metadata, allow batch retag via `--set-label PICK_RED --episode-ids 0,1,2,5`. Interactive TUI is v2. - -These five items + the existing skeleton complete the collection side for v1. - ---- - -## 2. Training — NEW v1 WORK - -### 2.1 Strategy: thin wrappers around `lerobot` - -The `lerobot` library (HuggingFace + Tesla-PI fork) already implements ACT, Diffusion Policy, π₀, π₀.₅ — including dataloaders for the LeRobot v2 format, normalization, action chunking, language tokenization, training loop, checkpointing, and ONNX export. - -**We do NOT reimplement these.** The v1 training pipeline is two thin Python wrappers that: -1. Take a DimOS `DatasetSpec`, -2. Translate it into a LeRobot config, -3. Invoke `lerobot.scripts.train.train()`, -4. Save the resulting checkpoint to a path that the inference module knows how to read. - -This keeps `dimos/learning/training/` short, rides on a maintained upstream, and means a `pi0.5` upgrade is a config bump rather than a code change. - -### 2.2 File layout: `dimos/learning/training/` - -``` -dimos/learning/training/ - train.py # train_bc, finetune_vla — public entry points - configs.py # BCConfig, VLAConfig - stats.py # compute_stats(spec) -> dict (used by build_dataset too) - split.py # train/val episode split helper -``` - -### 2.3 The two entry points - -```python -def train_bc(spec: DatasetSpec, cfg: BCConfig, output_dir: Path) -> Path: - """Train an ACT (or other BC) policy on `spec`. Returns checkpoint path.""" - -def finetune_vla(spec: DatasetSpec, cfg: VLAConfig, output_dir: Path) -> Path: - """Finetune a pretrained π₀ / π₀.₅ on `spec`. Returns checkpoint path.""" -``` - -Both: -- Materialize the dataset via `build_dataset(spec)` if `spec.output.path` doesn't already exist (idempotent). -- Build a `lerobot.LeRobotDataset(spec.output.path)`. -- Build a LeRobot policy from `cfg`. -- Call the LeRobot training loop with `cfg.steps`, `cfg.batch_size`, `cfg.lr`, etc. -- Save final checkpoint + a sidecar `dimos_meta.json` with `{spec_path, dataset_path, dimos_version}` so inference can recover everything. - -### 2.4 `BCConfig` (ACT-focused for v1) - -```python -class BCConfig(BaseModel): - policy_type: Literal["act", "diffusion"] = "act" - - # ACT model arch — defaults match the original ACT pick-and-place setup - chunk_size: int = 50 # action_horizon - n_obs_steps: int = 1 - hidden_dim: int = 512 - n_layers: int = 4 - n_heads: int = 8 - use_vae: bool = True - kl_weight: float = 10.0 - - # Vision backbone - vision_backbone: str = "resnet18" - pretrained: bool = True - - # Optim - steps: int = 100_000 - batch_size: int = 8 - lr: float = 1e-5 - lr_backbone: float = 1e-5 - weight_decay: float = 1e-4 - - # Eval - val_ratio: float = 0.1 - save_every: int = 10_000 -``` - -### 2.5 `VLAConfig` (π₀ / π₀.₅ finetune) - -```python -class VLAConfig(BaseModel): - policy_type: Literal["pi0", "pi0_5"] = "pi0_5" - pretrained_path: str # HF hub id or local path - finetune_mode: Literal["full", "lora"] = "lora" - lora_rank: int = 16 - freeze_vision: bool = True - freeze_language: bool = True - - chunk_size: int = 50 # default π₀ action horizon - - steps: int = 30_000 - batch_size: int = 4 - lr: float = 5e-5 - weight_decay: float = 1e-4 - save_every: int = 5_000 - - # The spec's task_description per episode is the language conditioning at train time. - # No additional config needed. -``` - -### 2.6 Stats and split — pulled out so they're reusable - -`stats.compute_stats(spec)` walks the materialized dataset once, accumulating Welford mean/std for joint vectors and per-channel image stats on a subsample. Writes `meta/stats.json`. Called from both `build_dataset` (so the disk-resident dataset is self-describing) and `train_bc` / `finetune_vla` (idempotent — skip if `stats.json` already exists). - -`split.train_val_split(spec, val_ratio, seed=0)` returns two episode-id lists. Deterministic. Trainer passes these to LeRobot via its episode filter. - -### 2.7 CLI - -```bash -# Train ACT on a built dataset -python -m dimos.learning.training.train bc dataset.yaml \ - --output runs/act_pick_red \ - --steps 100000 --batch-size 8 - -# Finetune π₀.₅ -python -m dimos.learning.training.train vla dataset.yaml \ - --output runs/pi05_pick_red \ - --pretrained lerobot/pi0_5 \ - --finetune-mode lora --lora-rank 16 -``` - -The CLI is a tiny argparse wrapper that builds `BCConfig`/`VLAConfig` and calls the function. - -### 2.8 Dependencies - -Adds to `pyproject.toml` (under a `[project.optional-dependencies]` `learning` extra so default installs aren't bloated): -- `lerobot >= 0.2` -- `torch >= 2.3` (already implied by `lerobot`) -- `imageio[ffmpeg]` (MP4 image encoding) -- For VLA only: `transformers`, `accelerate`, `peft` (LoRA) - -User installs with `pip install -e .[learning]`. - ---- - -## 3. Inference — NEW v1 WORK - -### 3.1 The two paths, simplified for v1 - -For v1 we need exactly **one** inference module: `ChunkPolicyModule`. Both ACT and π₀/π₀.₅ produce action chunks (sequences of length `chunk_size`), so the same module handles them. The model runs slow (1–30 Hz depending on whether it's ACT or VLA); a separate `ActionReplayer` plays the chunk back at the coordinator's 100 Hz tick rate. - -``` - ┌──────────────────────────┐ - │ ChunkPolicyModule │ ← runs in its own thread/process at policy rate - │ In: color_image │ - │ joint_state │ - │ language_text │ - │ Out: action_chunk │ - └────────────┬─────────────┘ - │ (T, action_dim) - ▼ - ┌──────────────────────────┐ - │ ActionReplayer │ ← part of the ControlTask graph - │ ControlTask @ 100 Hz │ - │ pops next action, │ - │ emits JointCommand │ - └──────────────────────────┘ -``` - -`PolicyControlTask` (joint-only, in-tick-loop) is **deferred to v2** — it's only useful for proprioceptive policies, which we don't train in v1. - -### 3.2 File layout: `dimos/learning/inference/` - -``` -dimos/learning/inference/ - chunk_policy_module.py # ChunkPolicyModule - action_replayer.py # ActionReplayer (subclass of BaseControlTask) - obs_builder.py # spec.observation -> live obs dict, decoupled - blueprints.py # autoconnect helpers -``` - -### 3.3 `dimos/learning/policy/` — Policy protocol - -```python -class Policy(Protocol): - @classmethod - def load(cls, path: Path, device: str = "cuda") -> "Policy": ... - def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: - """Return shape (chunk_size, action_dim).""" - -class LeRobotPolicy(Policy): - """Wraps any lerobot.PreTrainedPolicy (ACT, π₀, π₀.₅, Diffusion). - Detects policy_type from the checkpoint metadata.""" -``` - -v1 ships `LeRobotPolicy` only. `OnnxPolicy`, custom `TorchPolicy` are v2. - -### 3.4 `ChunkPolicyModule` skeleton - -```python -class ChunkPolicyModule(Module): - color_image: In[Image] - joint_state: In[JointState] - language_text: In[str] # optional; ignored if policy doesn't use it - - action_chunk: Out[ActionChunk] # new typed message: ts, joint_names, positions[T, N] - - def __init__(self, *, spec_path: str, policy_path: str, - inference_rate_hz: float, device: str = "cuda"): - self._spec = load_spec(spec_path) - self._policy = LeRobotPolicy.load(Path(policy_path), device) - self._obs_builder = ObsBuilder(self._spec) - self._stats = load_stats(Path(policy_path).parent / "stats.json") - - @rate_limited(inference_rate_hz) - def on_tick(self): - obs = self._obs_builder.build( - color_image=self.color_image.latest(), - joint_state=self.joint_state.latest(), - language=self.language_text.latest_or("default task"), - ) - chunk = self._policy.predict_chunk(self._stats.normalize(obs)) - chunk = self._stats.unnormalize_actions(chunk) - self.action_chunk.publish(ActionChunk(positions=chunk, ts=time.time(), ...)) -``` - -### 3.5 `ActionReplayer` — a `BaseControlTask` - -Lives in the tick loop. Subscribes to `action_chunk`. Maintains a buffer of pending actions with their relative timestamps; `compute(state)` interpolates to the current tick time. - -```python -class ActionReplayer(BaseControlTask): - name = "policy_replay" - - def __init__(self, joint_names: list[str], chunk_topic: str, policy_dt: float): - self._joint_names = joint_names - self._buffer: deque[tuple[float, np.ndarray]] = deque() # (target_ts, positions) - ... - - def on_action_chunk(self, msg: ActionChunk) -> None: - # Push new chunk, drop overlap with current buffer (latest chunk wins) - ... - - def compute(self, state: CoordinatorState) -> JointCommandOutput | None: - if not self._buffer: - return None - target = self._lookup_or_interp(state.now) - return JointCommandOutput(joint_names=self._joint_names, positions=target) -``` - -Key behaviors: -- **Latest chunk wins**: when a new chunk arrives, drop any buffered actions ≥ the new chunk's start time. Smooth (no gap) because the new chunk's first action is conditioned on near-current obs. -- **Lookback safety**: if the policy module stalls and the buffer empties, hold last commanded position (don't fall to zero). Log a warning. -- **Temporal ensembling (optional v1 nice-to-have)**: when ACT publishes overlapping chunks, exponentially weight predictions for the same target timestamp. Off by default. - -### 3.6 Live observation construction - -`ObsBuilder` reads `spec.observation` and exposes `build(**live_streams) -> dict[str, np.ndarray]`. Reuses the same `_resolve_field` and `preprocess` registry from `dataprep.py` so train and infer share normalization. This is the single most important consistency guarantee in the framework. - -### 3.7 Inference blueprint - -```python -# Vision policy (ACT or π₀.₅) → Coordinator -pick_red_cube = autoconnect( - RealSenseCamera.blueprint(camera_id="cam_high"), - ChunkPolicyModule.blueprint( - spec_path="datasets/pick_red/dataset.yaml", - policy_path="runs/pi05_pick_red/checkpoint", - inference_rate_hz=5.0, - ), - ControlCoordinator.blueprint( - hardware=[xarm7], - tasks=[ - TaskConfig(name="policy_replay", type="action_replayer", - chunk_topic="action_chunk", policy_dt=0.2), - ], - ), -) -``` - -### 3.8 Coordinator change (one line) - -In `dimos/control/coordinator.py::_create_task_from_config` add a case for `type == "action_replayer"` that constructs the `ActionReplayer`. Plus a new `ActionReplayerConfig` in `dimos/control/task.py`'s task config union. Same pattern as existing tasks. - ---- - -## 4. File Structure (v1) - -``` -dimos/learning/ - spec.py # ✅ DatasetSpec (DataPrep-callable types) - dataprep.py # ✅ DataPrep class + CLI - dataset.example.yaml # ✅ - - formats/ - lerobot.py # ✅ (needs MP4 + stats — §1.5) - hdf5.py # ✅ - rlds.py # ✅ (v2 priority) - - collection/ # 🆕 v1 — collection blueprint - episode_monitor.py # EpisodeMonitorModule (live counters via @rpc) - blueprint.py # collection_blueprint(...) - - training/ # 🆕 v1 — training scripts + orchestrator Modules - configs.py # BCConfig, VLAConfig [script] - stats.py # Stats class + compute_stats [script] - split.py # train_val_split [script] - train.py # train_bc, finetune_vla, CLI [script] - trainer_module.py # TrainerModule — wraps build+train [Module] - monitor_module.py # LearningMonitorModule (rerun + JSONL) [Module] - blueprint.py # learning_train_{act,vla,idle} - - inference/ # 🆕 v1 — inference blueprint (real live Modules) - obs_builder.py # ObsBuilder (uses DataPrep.resolve_field) - chunk_policy_module.py # ChunkPolicyModule (real Module) - action_replayer.py # ActionReplayer (BaseControlTask) - blueprint.py # policy_blueprint(...) - - policy/ # 🆕 v1 — Policy abstraction - base.py # Policy protocol + ActionChunk message - lerobot_policy.py # LeRobotPolicy (ACT, π₀, π₀.₅) -``` - -**Note on script-vs-Module split inside `training/`:** The four `[script]` -files (`configs.py`, `stats.py`, `split.py`, `train.py`) hold the actual -training logic and are independently usable from notebooks/CI/tests. The -three `[Module]` files wrap them as DimOS Modules that spawn the scripts -as subprocesses — that's the dual-surface UX the v1 mandate requires. - -Critical files outside `dimos/learning/`: - -| File | Change | -|------|--------| -| `dimos/control/coordinator.py` | Add `"action_replayer"` case in `_create_task_from_config` | -| `dimos/control/task.py` | Add `ActionReplayerConfig` to task config union | -| `pyproject.toml` | Add `[project.optional-dependencies].learning` extra | -| `dimos/messages/` (or wherever DimOS LCM types live) | New `ActionChunk` type: `(joint_names, positions[T,N], ts, dt)` | - ---- - -## 5. End-to-End Demo Recipe - -Two flows, one command list each. Both assume `pip install -e .[learning]` is done. - -### 5.1 ACT pick-and-place on xArm7 — blueprint-first UX - -Each phase is `dimos --blueprint `. Underlying scripts (`python -m -dimos.learning.dataprep`, `python -m dimos.learning.training.train`) are -still callable directly for CI / notebooks / debugging — but the -default flow is the blueprint surface. - -```bash -# 1. Collect — teleop + camera + RecordReplay + EpisodeMonitorModule -dimos --blueprint learning_collect_quest_xarm7 --record-path data/pick_red.db -# (operator presses A=start / B=save / X=discard; -# EpisodeMonitorModule.status streams "episodes_saved: N" live) - -# 2. Train — DatasetBuilderModule + TrainerModule + LearningMonitorModule -dimos --blueprint learning_train_act \ - --spec dataset.yaml --output runs/act_pick_red -# Inside: builder runs first (subprocess: dataprep build), trainer -# auto-fires on builder.done (subprocess: train bc), monitor logs to rerun. - -# 3. Infer — Camera + ChunkPolicyModule + ActionReplayer + Coordinator -dimos --blueprint learning_infer_pick_red \ - --policy-path runs/act_pick_red -``` - -### 5.2 π₀.₅ finetune on the same data - -Steps 1 + 3 unchanged. Step 2 is the same `learning_train_*` blueprint -with `--kind vla` and a `--pretrained` flag — agent or human just changes -the trigger payload, not the blueprint. - -```bash -dimos --blueprint learning_train_vla \ - --spec dataset.yaml --output runs/pi05_pick_red \ - --pretrained lerobot/pi0_5 --finetune-mode lora --lora-rank 16 -``` - -### 5.3 Agent-driven flow (same Modules, no `auto_run`) - -Demonstrates the @rpc surface. Run a single training blueprint with auto-run -disabled; a chat agent then drives every phase: - -```bash -dimos --blueprint learning_train_idle # builder + trainer + monitor, all idle -``` - -``` -agent: "build the dataset for pick_red" - → DatasetBuilderModule.build(spec_path="dataset.yaml") - ← BuildProgress events stream back to the chat - ← BuildDone(success=True, dataset_path="datasets/pick_red/") - -agent: "train ACT on it for 100k steps" - → TrainerModule.train( - spec_path="dataset.yaml", - output_dir="runs/act_pick_red", - config_kind="bc", - config_overrides={"steps": 100_000}, - ) - ← TrainProgress events stream loss/step - ← TrainDone(success=True, checkpoint_dir="runs/act_pick_red/...") - -agent: "deploy it on the xarm" - → launches `dimos --blueprint learning_infer_pick_red --policy-path ...` -``` - -The fact that the same Modules drive both the "everything-auto" CLI flow -(§5.1) and the "agent-driven" flow (§5.3) is the v1 architectural payoff. - ---- - -## 6. Verification (what we test before declaring v1 done) - -1. **Recording** — teleop blueprint with `--record-path` produces a session.db whose stream listing matches the spec. -2. **Build** — `python -m dimos.learning.dataprep build dataset.yaml` against a real session, then `lerobot.LeRobotDataset(path)` opens it without error and `len(ds) > 0`. -3. **Stats** — `meta/stats.json` exists and has finite, non-degenerate values for every observation/action key. -4. **Train** — `train_bc` runs ≥1k steps end-to-end on a real session; loss decreases; checkpoint loads back via `LeRobotPolicy.load`. -5. **VLA finetune** — `finetune_vla` runs ≥500 steps with LoRA on top of a downloaded π₀.₅ checkpoint; no OOM at batch=4 on a 24 GB GPU; loss decreases. -6. **Live obs parity** — `ObsBuilder.build(...)` on a fake live stream and `iter_samples(...)` on the same data give bit-identical observation dicts. -7. **Inference (sim)** — `ChunkPolicyModule` + `ActionReplayer` + `ControlCoordinator` with MuJoCo xArm7 produces non-NaN joint commands at 100 Hz, replays a 50-step chunk smoothly, recovers when policy module stalls. -8. **Inference (hw)** — same blueprint on real xArm7 produces a successful pick-and-place at ≥30% success rate after 50 demos. (Success rate is informational; the test is "no crashes, no jerks, no diverging commands.") - ---- - -## 7. Key Design Decisions (v1) - -| Decision | Rationale | -|----------|-----------| -| LeRobot v2 is the canonical on-disk format | Both ACT and π₀/π₀.₅ train from it natively; no custom dataloaders | -| `lerobot` library does the heavy lifting | We don't reimplement ACT, π₀, dataloader, normalization, or training loop | -| One inference module (`ChunkPolicyModule`), not three | ACT and VLA both produce chunks; only the model class differs | -| `ActionReplayer` lives in the tick loop, model lives in a Module | Decouple slow inference (1–5 Hz VLA) from fast control (100 Hz) | -| `ObsBuilder` reused between train and infer | Single source of truth for observation construction — eliminates train/serve skew | -| Episode metadata carries `task_description` | π₀/π₀.₅ are language-conditioned; `task_label` alone is too narrow | -| Stats computed at build time, written to disk | Trainers and inference both read from `meta/stats.json` — no recompute | -| Coordinator change is one new task type (`action_replayer`) | Minimal, additive | -| RLDS, ONNX, RL, proprio-only policy task → v2 | Deliberately not in scope | - ---- - -## 8. Risks & Open Questions for v1 - -- **`lerobot` API stability.** We're pinning to a specific minor version. If their training entry point changes, our wrapper breaks. Mitigation: pin tightly in `pyproject.toml`, add an integration test that exercises the wrapper. -- **π₀.₅ checkpoint availability.** Depends on the public release. Fallback: ship v1 with π₀ only, add π₀.₅ when it lands (config bump). -- **Action space match.** π₀ assumes a 7-DoF EEF action by default; xArm7 joint-position control is 7-DoF joint. Need to either (a) keep π₀'s action head and use joint targets in its expected layout, or (b) retrain the action head. v1 chooses (a) via the spec's action key naming. -- **Real-time perf of chunk replay.** First action of a chunk is conditioned on `t = chunk_arrival_time`, but it executes some ms later. With 50-step chunks at 30 Hz this is ~1.6 s of buffer; if the policy module stalls, the replayer drifts. Mitigation: replayer rejects stale chunks (`now − chunk.ts > policy_dt × 1.5`) and re-requests. -- **Camera calibration in the spec.** ACT/π₀ are sensitive to camera placement. The spec doesn't currently encode camera intrinsics/extrinsics. v1 punt: rely on the operator to record from the same physical setup at infer time. Add `metadata.cameras` schema in v1.5 if it bites. - ---- - ---- - -# v2 Considerations (deferred) - -These are explicitly out of v1 scope. Pulled here so we don't lose track. Anything from this list that becomes cheap during v1 implementation gets promoted up. - -### Training -- **RL** — `train_rl(env_cfg, model_cfg)` for online (PPO/SAC) and offline (CQL/IQL/AWAC) RL. Needs an env wrapper around DimOS (sim primarily — MuJoCo via the existing `MujocoCamera` work). -- **Multi-task / multi-embodiment training** — train one policy on demos from xArm7 + Piper + Mock. Needs URDF retargeting, embodiment-id conditioning. -- **Distributed training** — `accelerate` / FSDP for VLA full-finetune on multi-GPU. -- **Curriculum / dataset weighting** — sample harder episodes more often, weight by reward, etc. -- **Diffusion Policy** as a first-class BC option (lerobot supports it; just a config). -- **Active data collection** — uncertainty-based suggestion of which demos to collect next. - -### Inference -- **`PolicyControlTask`** — joint-only proprioceptive policies in the tick loop (100 Hz, no Module overhead). Useful for residual policies, locomotion. -- **`OnnxPolicy` / `TorchScriptPolicy`** — alternate Policy backends for deployment without `lerobot` runtime dep. -- **Cross-embodiment retargeting at inference** — train on xArm7 demos, deploy on Piper. -- **Temporal ensembling on by default** — currently nice-to-have in v1, make it the default after measuring its effect on jerk. -- **Async chunk pipelining** — request chunk N+1 while replaying chunk N to hide policy latency completely. -- **Real-time safety layer** — collision check + joint-limit clamp downstream of `ActionReplayer`. - -### Data Collection -- **Live `EpisodeManagerModule`** — annotate episode boundaries at record time instead of post-hoc. Useful when the operator wants to pause/resume and the recording length is huge. Currently overkill. -- **RLDS / TFDS writer** — needed for OpenX-Embodiment contributions and RT-X-style training. Skeleton already exists. -- **Interactive `review` TUI** — scrub through episodes, watch the camera stream, retag/discard. v1 ships a non-interactive batch retag CLI only. -- **Camera intrinsics/extrinsics in spec metadata** — required for any cross-embodiment or sim-real work. -- **Force/torque streams as observation** — schema already supports it; need preprocess hooks for FT data. -- **Imitation-from-observation** — episodes without action streams (just video). Needs an inverse dynamics model. -- **Custom transports for streaming** — record directly to a remote SqliteStore over network, bypass local disk. Probably never needed. - -### Training Module / DimOS-native -- **`TrainingModule` agent skill** — expose `train_bc` / `finetune_vla` as RPCs so the LLM agent can trigger training from a chat session. Requires sandboxing (long-running subprocess management, GPU resource gating). -- **`PolicyRegistry` Module** — track all trained policies, their specs, eval results. A "model zoo" served as a Module for blueprint composition. - -### Tooling -- **W&B / TensorBoard integration** standardized across all trainers. -- **Eval harness** — replay validation episodes through the trained policy in sim, report success rate. Needed for any kind of automated training pipeline. -- **HuggingFace Hub upload** — `dimos.learning.training.publish(checkpoint, repo_id)` so trained policies are sharable. - -### Possibly to promote into v1 if cheap -- **Stats subsampling for image features** — needed in v1 anyway; making it configurable is a 5-line addition. -- **Episode-level train/val split** — already needed in v1; might as well expose `--split-by hash(episode_id)` as a third strategy. -- **Diffusion Policy via the same `train_bc` entry point** — `lerobot` already has it, the only additional work is one more policy_type literal in `BCConfig`. Probably ship it. -- **`OnnxPolicy`** — only worth promoting if a deployment target without `lerobot` install exists. None in v1; defer. diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index 3a4233c045..40cb39b032 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -51,9 +51,7 @@ { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport( - "/learning/episode_status", EpisodeStatus - ), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), } ) @@ -68,9 +66,7 @@ { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport( - "/learning/episode_status", EpisodeStatus - ), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), } ) @@ -85,9 +81,7 @@ { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport( - "/learning/episode_status", EpisodeStatus - ), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), } ) @@ -102,9 +96,7 @@ { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport( - "/learning/episode_status", EpisodeStatus - ), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), } ) diff --git a/dimos/learning/inference/action_replayer.py b/dimos/learning/inference/action_replayer.py index 1d7e07ca59..3903ea26f7 100644 --- a/dimos/learning/inference/action_replayer.py +++ b/dimos/learning/inference/action_replayer.py @@ -26,7 +26,6 @@ from __future__ import annotations -from collections import deque from dataclasses import dataclass from dimos.control.task import ( diff --git a/dimos/learning/inference/chunk_policy_module.py b/dimos/learning/inference/chunk_policy_module.py index 734f77ac19..4cccdfbf55 100644 --- a/dimos/learning/inference/chunk_policy_module.py +++ b/dimos/learning/inference/chunk_policy_module.py @@ -30,7 +30,6 @@ from __future__ import annotations import threading -from pathlib import Path from typing import Any from dimos.core.core import rpc diff --git a/dimos/learning/learning_spec.md b/dimos/learning/learning_spec.md new file mode 100644 index 0000000000..469a565c28 --- /dev/null +++ b/dimos/learning/learning_spec.md @@ -0,0 +1,164 @@ +### Data collection + +Two phases: +A. Recording - teleop/drive robot, streams recorded at `session.db` +B. DataPrep - convert `session.db` to `dataset/` + +Phase A - Recording +``` python + +spec = DatasetConfig.from_file("datasets/pick_cube.yaml") + +learning_collect_quest_xarm7 = autoconnect( + teleop_quest_xarm7, + RealSenseCamera.blueprint(enable_pointcloud=False), + EpisodeMonitorModule.blueprint(spec=spec), +) + +# hardware +# sim +``` + +`RecordReplay` is a transport-layer hook (`--record-path`); captures every +transport in the blueprint (including `episode_status`) into `session.db`. + +Open: feedback to operators — display episode number / status in the VR +headset. Unify episode-boundary declaration across VR / keyboard / active+passive +/ future inputs into one `episode_status` stream in the `.db` file. + +``` python +from dimos.learning.config import DatasetConfig, EpisodeStatus + +class EpisodeMonitorModuleConfig(ModuleConfig): + spec: DatasetConfig + +class EpisodeMonitorModule(Module): + config: EpisodeMonitorModuleConfig + + buttons: In[Buttons] + keyboard: In[KeyPress] + status: Out[EpisodeStatus] + + @rpc + def reset_counters(self) -> EpisodeStatus: ... + @rpc + def get_status(self) -> EpisodeStatus: ... + + def _on_buttons(self, msg: Buttons) -> None: ... + def _on_keyboard(self, msg: KeyPress) -> None: ... +``` + +Config +```yaml +source: session.db + +episodes: + extractor: episode_status + default_task_label: pick_red_cube + button_map: {start: A, save: B, discard: X} + keyboard_map: {start: space, save: s, discard: d} + +observation: + cam: + stream: camera_color_image + field: image + joint_pos: + stream: coordinator_joint_state + field: position + +action: + joint_target: + stream: coordinator_joint_command + field: position + +sync: + anchor: cam + rate_hz: 30 + tolerance_ms: 50 + strategy: nearest + +output: + format: lerobot + path: datasets/pick_red/ + metadata: {fps: 30, robot: xarm7} +``` + +`learning/config.py` + +``` python +from dimos.protocol.service.spec import BaseConfig + + +class EpisodeStatus(BaseModel): # runtime message (not BaseConfig — built in code, not from YAML) + state: Literal["idle", "recording"] + episodes_saved: int + episodes_discarded: int + current_episode_start_ts: float | None + last_event: Literal["start", "save", "discard", "init"] = "init" + task_label: str | None = None + + +class EpisodeConfig(BaseConfig): + extractor: Literal["episode_status", "ranges", "whole_session"] + ranges: list[tuple[float, float]] | None = None + default_task_label: str | None = None + button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} + keyboard_map: dict[Literal["start", "save", "discard"], str] = {} + +class StreamField(BaseConfig): + stream: str + field: str | None = None + +class SyncConfig(BaseConfig): + anchor: str + rate_hz: float + tolerance_ms: float + strategy: Literal["nearest", "interp"] = "nearest" + +class OutputConfig(BaseConfig): + format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" + path: Path + metadata: dict[str, Any] = {} + + +class DatasetConfig(BaseConfig): + source: str + episodes: EpisodeConfig + observation: dict[str, StreamField] + action: dict[str, StreamField] + sync: SyncConfig + output: OutputConfig + + @classmethod + def from_file(cls, path: str | Path) -> DatasetConfig: ... +``` + +Phase B - DataPrep + +``` +spec = DatasetConfig.from_file("datasets/pick_red.yaml") + +learning_dataprep = autoconnect( + DataPrepModule.blueprint(spec=spec), +).transports({}) +``` + +- DataPrepModule + +``` python +class DataPrepModuleConfig(ModuleConfig): + spec: DatasetConfig + output_dir: str | None = None + +class DataPrepModule(Module): + config: DataPrepModuleConfig + + @rpc + def build(self, output_dir: str | None = None) -> None: ... + @rpc + def cancel(self) -> bool: ... + @rpc + def get_status(self) -> dict[str, Any]: ... + @rpc + def inspect(self) -> dict[str, Any]: ... +``` diff --git a/dimos/learning/policy/lerobot_policy.py b/dimos/learning/policy/lerobot_policy.py index 3222116b54..6d9cbd3d8c 100644 --- a/dimos/learning/policy/lerobot_policy.py +++ b/dimos/learning/policy/lerobot_policy.py @@ -25,15 +25,12 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np from dimos.learning.policy.base import Policy -if TYPE_CHECKING: - pass # lerobot / torch deferred to load() - class LeRobotPolicy: """Adapter for lerobot's PreTrainedPolicy → DimOS Policy protocol.""" diff --git a/dimos/learning/specs/datacollection.md b/dimos/learning/specs/datacollection.md new file mode 100644 index 0000000000..e348f1454b --- /dev/null +++ b/dimos/learning/specs/datacollection.md @@ -0,0 +1,247 @@ +# Stage 1 — Data + +Two phases: + +1. **Recording** — live; operator drives the robot. `RecordReplay` writes streams to `session.db`. +2. **DataPrep** — offline; convert `session.db` → `dataset/`. + +One `dataset.yaml` drives both, parsed once into a `DatasetConfig`. + +--- + +## Phase A — Recording + +### Blueprint + +```python +# dimos/learning/collection/blueprint.py +spec = DatasetConfig.from_file("datasets/pick_red.yaml") + +learning_collect_quest_xarm7 = autoconnect( + teleop_quest_xarm7, + RealSenseCamera.blueprint(enable_pointcloud=False), + EpisodeMonitorModule.blueprint(spec=spec), +).transports({ + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), +}) +``` + +`RecordReplay` (`--record-path`) captures every transport above, including `episode_status`. + +--- + +### Dataset spec — YAML + +```yaml +# datasets/pick_red.yaml +source: session.db + +episodes: + extractor: episode_status + status_stream: episode_status + default_task_label: pick_red_cube + button_map: {start: A, save: B, discard: X} + keyboard_map: {start: space, save: s, discard: d} + +observation: + cam: + stream: camera_color_image + field: image + joint_pos: + stream: coordinator_joint_state + field: position + +action: + joint_target: + stream: coordinator_joint_command + field: position + +sync: + anchor: cam + rate_hz: 30 + tolerance_ms: 50 + strategy: nearest + +output: + format: lerobot + path: datasets/pick_red/ + metadata: {fps: 30, robot: xarm7} +``` + +--- + +### Dataset spec — pydantic classes + +```python +# dimos/learning/config.py +from dimos.protocol.service.spec import BaseConfig # extra="forbid" + + +class DatasetConfig(BaseConfig): + source: str + episodes: EpisodeConfig + observation: dict[str, StreamField] + action: dict[str, StreamField] + sync: SyncConfig + output: OutputConfig + + @classmethod + def from_file(cls, path: str | Path) -> DatasetConfig: ... + + +class EpisodeConfig(BaseConfig): + extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" + status_stream: str = "episode_status" + ranges: list[tuple[float, float]] | None = None + default_task_label: str | None = None + button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} + keyboard_map: dict[Literal["start", "save", "discard"], str] = {} + + +class StreamField(BaseConfig): + stream: str + field: str | None = None + + +class SyncConfig(BaseConfig): + anchor: str + rate_hz: float + tolerance_ms: float + strategy: Literal["nearest", "interp"] = "nearest" + + +class OutputConfig(BaseConfig): + format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" + path: Path + metadata: dict[str, Any] = {} +``` + +| Module | Reads | +|---|---| +| `EpisodeMonitorModule` | `spec.episodes.button_map` / `keyboard_map` / `default_task_label` | +| `DataPrepModule` | full spec | +| `ChunkPolicyModule` | `spec.observation`, `spec.sync` | + +--- + +### EpisodeMonitorModule + +Translates teleop input (buttons, keyboard, future inputs) into a canonical +`EpisodeStatus` stream. `DataPrep` reads only that stream — never raw inputs. + +```python +# dimos/learning/collection/episode_monitor.py + +class EpisodeStatus(BaseModel): + state: Literal["idle", "recording"] + episodes_saved: int + episodes_discarded: int + current_episode_start_ts: float | None + last_event: Literal["start", "save", "discard", "init"] = "init" + task_label: str | None = None + + +class EpisodeMonitorModuleConfig(ModuleConfig): + spec: DatasetConfig + + +class EpisodeMonitorModule(Module): + config: EpisodeMonitorModuleConfig + + buttons: In[Buttons] + keyboard: In[KeyPress] + status: Out[EpisodeStatus] + + @rpc + def reset_counters(self) -> EpisodeStatus: ... + @rpc + def get_status(self) -> EpisodeStatus: ... + + def _on_buttons(self, msg: Buttons) -> None: ... + def _on_keyboard(self, msg: KeyPress) -> None: ... +``` + +State machine: + +``` +IDLE --start--> RECORDING +RECORDING --save--> IDLE (commit) +RECORDING --discard--> IDLE (drop) +RECORDING --start--> RECORDING (auto-commit prev) +session end mid-episode: always discard +``` + +--- + +### Run + +```bash +dimos run learning-collect-quest-xarm7 \ + --spec-path datasets/pick_red.yaml \ + --record-path data/pick_red.db +``` + +--- + +## Phase B — DataPrep + +Reads `session.db`, slices on `episode_status`, syncs streams, writes +`dataset/`. Heavy deps run in a subprocess. + +### Blueprint + +```python +# dimos/learning/dataprep/blueprint.py +spec = DatasetConfig.from_file("datasets/pick_red.yaml") + +learning_dataprep = autoconnect( + DataPrepModule.blueprint(spec=spec, auto_run=True), +).transports({}) +``` + +### DataPrepModule + +```python +# dimos/learning/dataprep_module.py + +class DataPrepModuleConfig(ModuleConfig): + spec: DatasetConfig + output_dir: str | None = None + auto_run: bool = False + + +class DataPrepModule(Module): + config: DataPrepModuleConfig + + @rpc + def build(self, output_dir: str | None = None) -> None: ... + @rpc + def cancel(self) -> bool: ... + @rpc + def get_status(self) -> dict[str, Any]: ... + @rpc + def inspect(self) -> dict[str, Any]: ... +``` + +### Run + +```bash +dimos run learning-dataprep --spec-path datasets/pick_red.yaml +``` + +--- + +## End-to-end + +```bash +SPEC=datasets/pick_red.yaml + +dimos run learning-collect-quest-xarm7 --spec-path $SPEC --record-path data/pick_red.db +dimos run learning-dataprep --spec-path $SPEC +``` + +``` +session.db ─► dataset/ + meta/stats.json +``` diff --git a/dimos/learning/specs/inference.md b/dimos/learning/specs/inference.md new file mode 100644 index 0000000000..10e48b7701 --- /dev/null +++ b/dimos/learning/specs/inference.md @@ -0,0 +1,184 @@ +# Stage 3 — Inference + +- **`ChunkPolicyModule`** — Module @ 1–30 Hz. Builds obs via `ObsBuilder`, + calls `policy.predict_chunk(obs)`, emits `ActionChunk`. +- **`ActionReplayer`** — `BaseControlTask` in the 100 Hz `ControlCoordinator` + tick loop. Buffers chunks (latest-wins), interpolates to `state.now`, + emits `JointCommandOutput`. Holds last position on stall. + +--- + +## Blueprint + +```python +# dimos/learning/inference/blueprint.py +from dimos.learning.config import ActionChunk + +learning_infer_xarm7 = autoconnect( + RealSenseCamera.blueprint(enable_pointcloud=False), + ChunkPolicyModule.blueprint( + policy_path="runs/act_pick_red", + inference_rate_hz=30.0, + ), + coordinator_action_replayer_xarm7, +).transports({ + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("language_text", str): LCMTransport("/learning/language_text", str), + ("action_chunk", ActionChunk): LCMTransport("/learning/action_chunk", ActionChunk), +}) +``` + +## Message types + +`ActionChunk` lives in `dimos/learning/config.py` next to `EpisodeStatus` +and `DatasetConfig` — single import for all cross-stage contracts. +`Policy` is the backend Protocol; lives in `policy/base.py`. + +```python +# dimos/learning/config.py + +class ActionChunk(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + ts: float + joint_names: list[str] + positions: np.ndarray # (T, N) + dt: float + chunk_id: int +``` + +```python +# dimos/learning/policy/base.py + +@runtime_checkable +class Policy(Protocol): + @classmethod + def load(cls, path: str | Path, device: str = "cuda") -> Policy: ... + + @property + def chunk_size(self) -> int: ... + @property + def joint_names(self) -> list[str]: ... + @property + def expects_language(self) -> bool: ... + + def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: ... +``` + +## ChunkPolicyModule + +```python +# dimos/learning/inference/chunk_policy_module.py + +class ChunkPolicyModuleConfig(ModuleConfig): + policy_path: str # spec read from /dimos_meta.json + inference_rate_hz: float = 5.0 + device: str = "cuda" + default_language: str = "" + + +class ChunkPolicyModule(Module): + config: ChunkPolicyModuleConfig + + color_image: In[Image] + joint_state: In[JointState] + language_text: In[str] + action_chunk: Out[ActionChunk] + + @rpc + def set_language(self, text: str) -> None: ... + @rpc + def reload_policy(self, policy_path: str, device: str | None = None) -> None: ... + @rpc + def get_status(self) -> dict[str, Any]: ... + + # Lifecycle: start() loads policy + spawns the loop thread; stop() joins it. + def _run_loop(self) -> None: + period = 1.0 / self.config.inference_rate_hz + while not self._stop.is_set(): + t0 = time.monotonic() + obs = self._build_live_obs() + if obs is None: # waiting for first frames + time.sleep(period); continue + + positions = self.policy.predict_chunk(obs) # (T, action_dim) + self.action_chunk.publish(ActionChunk( + ts=time.time(), + joint_names=self.policy.joint_names, + positions=positions, + dt=period, + chunk_id=self._next_chunk_id(), + )) + time.sleep(max(0.0, period - (time.monotonic() - t0))) + + def _build_live_obs(self) -> dict[str, np.ndarray] | None: + # snapshot latched In[Image] / In[JointState] / In[str] under a lock, + # hand to ObsBuilder.build(...) → returns obs dict or None if not ready + ... +``` + +## ObsBuilder + +`ChunkPolicyModule.start()` reads the embedded spec from +`/dimos_meta.json` and constructs the `ObsBuilder` from it +— no `--spec-path` needed at inference. + +```python +# dimos/learning/inference/obs_builder.py + +class ObsBuilder: + def __init__(self, spec: DatasetConfig) -> None: ... + def build(self, live_messages: dict[str, Any]) -> dict[str, np.ndarray]: ... + def required_streams(self) -> set[str]: ... +``` + +## ActionReplayer + +```python +# dimos/learning/inference/action_replayer.py + +@dataclass +class ActionReplayerConfig: + joint_names: list[str] + chunk_topic: str = "action_chunk" + priority: int = 10 + max_chunk_age_s: float = 0.5 + hold_on_stall: bool = True + temporal_ensemble: bool = False + + +class ActionReplayer(BaseControlTask): + def __init__(self, name: str, config: ActionReplayerConfig) -> None: ... + + @property + def name(self) -> str: ... + def claim(self) -> ResourceClaim: ... + def is_active(self) -> bool: ... + def compute(self, state: CoordinatorState) -> JointCommandOutput | None: ... + def on_action_chunk(self, msg: ActionChunk) -> None: ... +``` + +--- + +## Run + +```bash +dimos run learning-infer-xarm7 --policy-path runs/act_pick_red +``` + +--- + +## End-to-end + +```bash +SPEC=datasets/pick_red.yaml + +dimos run learning-collect-quest-xarm7 --spec-path $SPEC --record-path data/pick_red.db +dimos run learning-dataprep --spec-path $SPEC +dimos run learning-train --dataset-path dataset/ --output-dir runs/act_pick_red +dimos run learning-infer-xarm7 --policy-path runs/act_pick_red +``` + +``` +session.db ─► dataset/ ─► checkpoint/ ─► live policy +``` diff --git a/dimos/learning/specs/structure.md b/dimos/learning/specs/structure.md new file mode 100644 index 0000000000..0f8ce55da8 --- /dev/null +++ b/dimos/learning/specs/structure.md @@ -0,0 +1,141 @@ +# Folder Structure + +The four spec docs in this directory are the source of truth. The code +tree below is the implementation layout — each file maps to a section in +one of the three stage docs. + +``` +dimos/learning/ +│ +├── specs/ # ← spec docs (you are here) +│ ├── structure.md # this file — folder layout +│ ├── datacollection.md # Stage 1 — recording + dataprep + inspect +│ ├── training.md # Stage 2 — TrainerModule +│ └── inference.md # Stage 3 — ChunkPolicyModule + ActionReplayer +│ +├── __init__.py +├── config.py # DatasetConfig + sub-configs (pydantic BaseConfig) +├── dataset.example.yaml # annotated example spec +│ +├── dataprep.py # DataPrep façade + resolve_field staticmethod +│ # `python -m dimos.learning.dataprep build|inspect` +├── dataprep_module.py # DataPrepModule (wraps the subprocess for blueprint UX) +│ +├── collection/ # ── Stage 1 / Phase A: live recording ── +│ ├── __init__.py +│ ├── episode_monitor.py # EpisodeStatus, EpisodeMonitorModule(Config) +│ └── blueprint.py # learning_collect_quest_{xarm7,xarm6,piper,dual} +│ +├── formats/ # ── dataset writers (DataPrep._get_writer dispatches) ── +│ ├── __init__.py +│ ├── lerobot.py # LeRobot v2 (parquet + MP4 + meta/stats.json) +│ ├── hdf5.py # flat HDF5 +│ └── rlds.py # RLDS / TFDS +│ +├── training/ # ── Stage 2: offline training ── +│ ├── __init__.py +│ ├── trainer_module.py # TrainProgress, TrainDone, TrainerModule(Config) +│ ├── train.py # subprocess CLI +│ # `python -m dimos.learning.training.train {bc|vla}` +│ ├── configs.py # bc / vla training configs +│ ├── split.py # train/val episode-level split +│ ├── stats.py # meta/stats.json computation (norm/unnorm) +│ └── blueprint.py # learning_train +│ +├── policy/ # ── policy backends (live + checkpoint loading) ── +│ ├── __init__.py +│ ├── base.py # ActionChunk pydantic + Policy Protocol +│ └── lerobot_policy.py # LeRobotPolicy.load → reads dimos_meta.json + stats.json +│ +└── inference/ # ── Stage 3: live policy serving ── + ├── __init__.py + ├── chunk_policy_module.py # ChunkPolicyModule(Config); slow Module @ 1–30 Hz + ├── obs_builder.py # ObsBuilder; calls DataPrep.resolve_field + ├── action_replayer.py # ActionReplayer (BaseControlTask, NOT a Module) + └── blueprint.py # learning_infer_{xarm7,xarm6,piper} + # + learning_infer_vla_{xarm7,...} +``` + +--- + +## Where each artifact is produced / consumed + +| Artifact | Producer | Consumer | +|---|---|---| +| `dataset.yaml` | human (operator) | `DataPrep`, `ObsBuilder` | +| `session.db` | `RecordReplay` (transport hook, `--record-path`) | `DataPrep` | +| `dataset/` + stats | `dataprep build` → `formats/.py` | `lerobot.LeRobotDataset`, `train.py` | +| `checkpoint/` + meta | `train.py` | `LeRobotPolicy.load`, `ChunkPolicyModule` | +| `ActionChunk` (live) | `ChunkPolicyModule` (Module, LCM) | `ActionReplayer` (BaseControlTask) | +| `JointCommandOutput` | `ActionReplayer` (in 100 Hz tick loop) | `ControlCoordinator` → hardware | + +--- + +## `DatasetConfig` as the single source of truth + +`DatasetConfig` (loaded once from `dataset.yaml`) drives module configs +across stages — same instance, no drift between train and serve. + +```python +# Top-level, in each blueprint factory: +spec = DatasetConfig.from_file(spec_path) + +# Passed as a typed field on each module's config: +EpisodeMonitorModule.blueprint(spec=spec) # Stage 1: spec.episodes +DataPrepModule.blueprint(spec=spec) # Stage 1: full spec +ChunkPolicyModule.blueprint(spec=spec, ...) # Stage 3: spec.observation, spec.sync +``` + +| Stage | Module | How it gets the spec | +|---|---|---| +| 1A | `EpisodeMonitorModule` | passed in via blueprint (`spec=spec`); reads `spec.episodes` for button maps | +| 1B | `DataPrepModule` | passed in via blueprint; reads full spec. **DataPrep snapshots the spec into `dataset/dataset.yaml`** so downstream stages don't need the YAML. | +| 2 | `TrainerModule` | reads `dataset/dataset.yaml` + LeRobot `info.json`; copies spec snapshot into `checkpoint/dimos_meta.json` | +| 3 | `ChunkPolicyModule` | reads `/dimos_meta.json` at `start()`; constructs `ObsBuilder` from the embedded spec. **No `--spec-path` flag needed at inference.** | + +The operator only ever passes `--spec-path` for Recording and DataPrep +(stages where the spec is the input). After DataPrep, the spec rides +with the data. + +Same `resolve_field` is invoked from `DataPrep.iter_episode_samples` +(Stage 1B) and `ObsBuilder.build` (Stage 3). One source of truth → +no train/serve skew. + +--- + +## What's deliberately not in this tree + +- **`RecordReplay`** — transport-layer hook (in `dimos/core/`), not a + `learning/` Module. Enabled by `--record-path` at the CLI; unaware of + what's recording. +- **`coordinator_action_replayer_`** — per-robot coordinator + blueprints that register the `ActionReplayer` task. These live next + to the rest of the per-robot wiring (likely + `dimos/robot//blueprints.py`), not under `learning/`. +- **A second `ControlCoordinator`** — the existing one is reused. We add + one task type (`ActionReplayer`), not a parallel control stack. +- **New transports** — v1 is LCM-only on the wire. +- **New LCM message types** — `ActionChunk` is local-only pydantic in v1. + Promote to a generated LCM type in v2 only if cross-language consumers + need it. + +--- + +## Module / non-Module split (one rule) + +A class becomes a **Module** when it: +- has long-lived state worth `start()/stop()` lifecycle, **and** +- needs typed I/O ports across process boundaries. + +Otherwise it stays a plain class or a `BaseControlTask`: + +| Class | Type | Why | +|---|---|---| +| `EpisodeMonitorModule` | Module | Long-lived; subscribes to buttons; publishes status | +| `DataPrepModule` | Module | Wraps subprocess; agent-callable via `@skill` | +| `TrainerModule` | Module | Wraps subprocess; long-running; agent-callable | +| `ChunkPolicyModule` | Module | Long-lived inference thread; latched In ports | +| `DataPrep` | plain class | Stateless façade over static helpers; no ports | +| `ObsBuilder` | plain class | Pure function over latched messages | +| `ActionReplayer` | `BaseControlTask` | Must run in coordinator's 100 Hz thread, not via transport | +| `RecordReplay` | transport hook | Captures every stream uniformly; not a Module | diff --git a/dimos/learning/specs/training.md b/dimos/learning/specs/training.md new file mode 100644 index 0000000000..e732507f41 --- /dev/null +++ b/dimos/learning/specs/training.md @@ -0,0 +1,60 @@ +# Stage 2 — Training + +Offline. Reads `dataset/`, writes `checkpoint/` + `dimos_meta.json`. + +`TrainerModule` is an RPC façade over a training subprocess. Metrics → +TensorBoard. Lifecycle → `get_status()`. + +--- + +## Blueprint + +```python +# dimos/learning/training/blueprint.py +learning_train = autoconnect( + TrainerModule.blueprint(auto_run=True), +).transports({}) +``` + +## Module + +```python +class TrainerModuleConfig(ModuleConfig): + dataset_path: str = "" + output_dir: str = "" + config_kind: Literal["bc", "vla"] = "bc" + config_path: str | None = None + auto_run: bool = False + tensorboard_port: int = 6006 + + +class TrainerModule(Module): + config: TrainerModuleConfig + + @rpc + def train( + self, + dataset_path: str | None = None, + output_dir: str | None = None, + config_kind: Literal["bc", "vla"] | None = None, + config_overrides: dict[str, Any] | None = None, + ) -> None: ... + @rpc + def cancel(self) -> bool: ... + @rpc + def get_status(self) -> dict[str, Any]: ... +``` + +## Run + +```bash +dimos run learning-train \ + --dataset-path dataset/ \ + --output-dir runs/act_pick_red \ + --config-kind bc + +tensorboard --logdir runs/act_pick_red +``` + +Artifact: `checkpoint/` = `*.safetensors` + `dimos_meta.json` (spec snapshot, +`joint_names`, `chunk_size`, `policy_type`, `expects_language`). diff --git a/dimos/learning/training/monitor_module.py b/dimos/learning/training/monitor_module.py index c2f5d2cc42..e771933df6 100644 --- a/dimos/learning/training/monitor_module.py +++ b/dimos/learning/training/monitor_module.py @@ -78,9 +78,9 @@ def stop(self) -> None: def _on_train_progress(self, msg: TrainProgress) -> None: """Forward to enabled sinks. Routes by `msg.phase`: - - phase == "build": log dataset progress (episodes, samples) - - phase in {"train","eval"}: log loss curves (with EMA smoothing for rerun) - - other phases: log message line only. + - phase == "build": log dataset progress (episodes, samples) + - phase in {"train","eval"}: log loss curves (with EMA smoothing for rerun) + - other phases: log message line only. """ raise NotImplementedError diff --git a/dimos/learning/training/train.py b/dimos/learning/training/train.py index 2730bdd715..8d0545288a 100644 --- a/dimos/learning/training/train.py +++ b/dimos/learning/training/train.py @@ -119,8 +119,8 @@ def _write_dimos_meta(output_dir: Path, spec: DatasetSpec, dataset_path: Path) - def main() -> None: """CLI entrypoint: - python -m dimos.learning.training.train bc --output [...] - python -m dimos.learning.training.train vla --output --pretrained [...] + python -m dimos.learning.training.train bc --output [...] + python -m dimos.learning.training.train vla --output --pretrained [...] """ raise NotImplementedError diff --git a/dimos/learning/training/trainer_module.py b/dimos/learning/training/trainer_module.py index 490b1b73ee..0d49509097 100644 --- a/dimos/learning/training/trainer_module.py +++ b/dimos/learning/training/trainer_module.py @@ -35,9 +35,9 @@ from __future__ import annotations +from pathlib import Path import subprocess import threading -from pathlib import Path from typing import Any, Literal from pydantic import BaseModel @@ -46,7 +46,6 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out - # ───────────────────────────────────────────────────────────────────────────── # Message types # ───────────────────────────────────────────────────────────────────────────── From 779d100f62574bcabd029e0d19befd649c6bb958 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Fri, 1 May 2026 15:47:22 -0700 Subject: [PATCH 04/45] learning pipeline spec --- dimos/control/tasks/action_replayer_task.py | 73 +++++ dimos/learning/collection/blueprint.py | 65 +--- dimos/learning/collection/episode_monitor.py | 97 +++--- dimos/learning/dataprep.py | 289 ++++++------------ dimos/learning/dataprep_module.py | 92 ++++++ dimos/learning/dataset.example.yaml | 88 ------ dimos/learning/formats/hdf5.py | 15 +- dimos/learning/formats/lerobot.py | 16 +- dimos/learning/formats/rlds.py | 15 +- dimos/learning/inference/action_replayer.py | 133 -------- dimos/learning/inference/blueprint.py | 102 +------ .../learning/inference/chunk_policy_module.py | 120 +++----- dimos/learning/inference/obs_builder.py | 72 ----- dimos/learning/learning_spec.md | 164 ---------- dimos/learning/policy/base.py | 53 +--- dimos/learning/policy/lerobot_policy.py | 50 +-- dimos/learning/spec.py | 179 ----------- dimos/learning/specs/datacollection.md | 205 +++++-------- dimos/learning/specs/inference.md | 131 ++++---- dimos/learning/specs/structure.md | 193 ++++++------ dimos/learning/specs/training.md | 52 ++-- dimos/learning/training/blueprint.py | 92 +----- dimos/learning/training/configs.py | 52 +--- dimos/learning/training/monitor_module.py | 89 ------ dimos/learning/training/split.py | 41 --- dimos/learning/training/stats.py | 108 ------- dimos/learning/training/train.py | 127 +++----- dimos/learning/training/trainer_module.py | 227 ++------------ 28 files changed, 751 insertions(+), 2189 deletions(-) create mode 100644 dimos/control/tasks/action_replayer_task.py create mode 100644 dimos/learning/dataprep_module.py delete mode 100644 dimos/learning/dataset.example.yaml delete mode 100644 dimos/learning/inference/action_replayer.py delete mode 100644 dimos/learning/inference/obs_builder.py delete mode 100644 dimos/learning/learning_spec.md delete mode 100644 dimos/learning/spec.py delete mode 100644 dimos/learning/training/monitor_module.py delete mode 100644 dimos/learning/training/split.py delete mode 100644 dimos/learning/training/stats.py diff --git a/dimos/control/tasks/action_replayer_task.py b/dimos/control/tasks/action_replayer_task.py new file mode 100644 index 0000000000..a56c5fb267 --- /dev/null +++ b/dimos/control/tasks/action_replayer_task.py @@ -0,0 +1,73 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Replay policy ActionChunks at coordinator tick rate (100 Hz). + +Slow producer (ChunkPolicyModule @ ~30 Hz, jittery) → buffer → +deterministic 100 Hz JointCommandOutput. Stale-chunk and policy-stall +handling keep hardware safe when the policy falls behind or dies. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from dimos.control.task import ( + BaseControlTask, + CoordinatorState, + JointCommandOutput, + ResourceClaim, +) +from dimos.learning.policy.base import ActionChunk + + +@dataclass +class ActionReplayerConfig: + joint_names: list[str] + priority: int = 10 + max_chunk_age_s: float = 0.5 # drop chunks older than this at receive time + hold_on_stall: bool = True # hold last position if buffer drains + temporal_ensemble: bool = False # ACT trick; off in v1 + + +class ActionReplayer(BaseControlTask): + """Buffer latest chunk; interpolate per tick; emit JointCommandOutput.""" + + def __init__(self, name: str, config: ActionReplayerConfig) -> None: + raise NotImplementedError + + @property + def name(self) -> str: + raise NotImplementedError + + def claim(self) -> ResourceClaim: + """Claim `config.joint_names` at `config.priority`.""" + raise NotImplementedError + + def is_active(self) -> bool: + """True iff buffer has a non-stale target for now (or hold_on_stall).""" + raise NotImplementedError + + def compute(self, state: CoordinatorState) -> JointCommandOutput | None: + """Pure lookup / interpolate over the buffer at `state.now`. + Must complete in << 10 ms.""" + raise NotImplementedError + + def on_action_chunk(self, msg: ActionChunk) -> None: + """Latest-wins push. Drop if msg too old; drop buffered entries + at/after msg's first target_ts; append new (target_ts, positions).""" + raise NotImplementedError + + def _interpolate(self, t: float) -> JointCommandOutput | None: + raise NotImplementedError diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index 40cb39b032..0ee92a298a 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Collection blueprints for the DimOS Learning Framework. - -Each blueprint composes a teleop session + a camera + the -EpisodeMonitorModule (for live operator feedback). RecordReplay is NOT a -Module — it intercepts at the transport layer and is enabled via the CLI -flag `--record-path session.db`. - -Usage: - dimos run learning-collect-quest-xarm7 --record-path data/pick_red.db -""" +"""Recording blueprints. RecordReplay is enabled via `--record-path`.""" from __future__ import annotations @@ -41,64 +32,40 @@ ) from dimos.teleop.quest.quest_types import Buttons -# ── XArm7 + Quest ──────────────────────────────────────────────────────────── +_DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} +_TRANSPORTS = { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), +} + learning_collect_quest_xarm7 = autoconnect( teleop_quest_xarm7, RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(), -).transports( - { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), - } -) - + EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), +).transports(_TRANSPORTS) -# ── Piper + Quest ──────────────────────────────────────────────────────────── learning_collect_quest_piper = autoconnect( teleop_quest_piper, RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(), -).transports( - { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), - } -) - + EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), +).transports(_TRANSPORTS) -# ── XArm6 + Quest ──────────────────────────────────────────────────────────── learning_collect_quest_xarm6 = autoconnect( teleop_quest_xarm6, RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(), -).transports( - { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), - } -) + EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), +).transports(_TRANSPORTS) -# ── Dual arm (XArm6 + Piper) + Quest ───────────────────────────────────────── - learning_collect_quest_dual = autoconnect( teleop_quest_dual, RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(), -).transports( - { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), - } -) + EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), +).transports(_TRANSPORTS) __all__ = [ diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index daa79f0251..437c844336 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -12,20 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Live episode-status feedback during teleop recording. +"""Single point of teleop-input → EpisodeStatus translation. -Watches the buttons stream and runs the same start/save/discard state -machine that `DataPrep.extract_episodes` runs offline — but here it runs -live so the operator can see counters update in real time. Pure observability: -this module does NOT write anything. The recording itself is RecordReplay's -job; episode boundary extraction still happens post-hoc inside DataPrep. - -Why a separate live state-machine instead of just consuming DataPrep's offline -output? Because the operator wants feedback *during* the session ("episodes -saved: 12") to know when to stop, retry a bad demo, etc. - -Agent surface: `get_status()` returns the latest counters; `reset_counters()` -zeroes them between recording sessions without restarting the blueprint. +Watches buttons / keyboard, runs the start/save/discard state machine, +publishes EpisodeStatus on every transition. RecordReplay captures that +stream into session.db; DataPrep reads only the recorded EpisodeStatus +events offline — never raw buttons or keypresses. """ from __future__ import annotations @@ -39,36 +31,53 @@ from dimos.core.stream import In, Out from dimos.teleop.quest.quest_types import Buttons +# Friendly names → Quest Buttons attribute names. Override by supplying an +# attribute name directly in `button_map`. +BUTTON_ALIASES: dict[str, str] = { + "A": "right_primary", + "B": "right_secondary", + "X": "left_primary", + "Y": "left_secondary", + "LT": "left_trigger", + "RT": "right_trigger", + "LG": "left_grip", + "RG": "right_grip", + "MENU_L": "left_menu", + "MENU_R": "right_menu", +} -class EpisodeStatus(BaseModel): - """Live counters published every state transition.""" +class EpisodeStatus(BaseModel): state: Literal["idle", "recording"] episodes_saved: int episodes_discarded: int - current_episode_start_ts: float | None # None when state == "idle" + current_episode_start_ts: float | None last_event: Literal["start", "save", "discard", "init"] = "init" + task_label: str | None = None -class EpisodeMonitorModuleConfig(ModuleConfig): - """Match the same fields used by `EpisodeConfig` in the dataset spec - so the live monitor and the offline extractor agree on what each button - means. Friendly names ("A", "B", "X") resolve via BUTTON_ALIASES. - """ +class KeyPress(BaseModel): + """Single keypress event from a keyboard input source.""" - button_stream: str = "buttons" - start: str = "A" - save: str = "B" - discard: str = "X" + key: str + ts: float -class EpisodeMonitorModule(Module): - """Live operator feedback for teleop recording sessions.""" +class EpisodeMonitorModuleConfig(ModuleConfig): + button_map: dict[Literal["start", "save", "discard"], str] = { + "start": "A", + "save": "B", + "discard": "X", + } + keyboard_map: dict[Literal["start", "save", "discard"], str] = {} + default_task_label: str | None = None + +class EpisodeMonitorModule(Module): config: EpisodeMonitorModuleConfig buttons: In[Buttons] - + keyboard: In[KeyPress] status: Out[EpisodeStatus] def __init__(self, **kwargs: Any) -> None: @@ -77,40 +86,38 @@ def __init__(self, **kwargs: Any) -> None: self._saved: int = 0 self._discarded: int = 0 self._current_start_ts: float | None = None - # Previous bit-state of each watched button, for rising-edge detection. - self._prev_bits: dict[str, bool] = {} + self._prev_bits: dict[str, bool] = {} # rising-edge detection for buttons @rpc def start(self) -> None: - """Subscribe to `buttons` and emit an initial idle status.""" raise NotImplementedError @rpc def stop(self) -> None: - """Unsubscribe and call super().stop().""" raise NotImplementedError @rpc def reset_counters(self) -> EpisodeStatus: - """Zero the saved/discarded counters and force state back to idle. - Returns the new status.""" raise NotImplementedError @rpc def get_status(self) -> EpisodeStatus: - """Return the current EpisodeStatus snapshot.""" raise NotImplementedError - # ── internals ──────────────────────────────────────────────────────────── - def _on_buttons(self, msg: Buttons) -> None: - """Detect rising edges on start/save/discard buttons; advance state - machine; publish EpisodeStatus on every transition. - - State machine — must mirror DataPrep.extract_episodes in BUTTONS mode: - IDLE --start press--> RECORDING (begin) - RECORDING --save press---> IDLE (saved += 1) - RECORDING --discard -----> IDLE (discarded += 1) - RECORDING --start press--> RECORDING (auto-commit prev, begin new) + """Rising-edge detect against `config.button_map`; advance state machine.""" + raise NotImplementedError + + def _on_keyboard(self, msg: KeyPress) -> None: + """Match `msg.key` against `config.keyboard_map`; advance state machine.""" + raise NotImplementedError + + def _transition(self, event: Literal["start", "save", "discard"], ts: float) -> None: + """Apply the state-machine transition and publish EpisodeStatus. + + IDLE --start--> RECORDING + RECORDING --save--> IDLE (commit, saved += 1) + RECORDING --discard--> IDLE (drop, discarded += 1) + RECORDING --start--> RECORDING (auto-commit prev, begin new) """ raise NotImplementedError diff --git a/dimos/learning/dataprep.py b/dimos/learning/dataprep.py index c531302f67..9ba4dbb0bf 100644 --- a/dimos/learning/dataprep.py +++ b/dimos/learning/dataprep.py @@ -12,235 +12,138 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Dataset builder/loader for the DimOS Learning Framework. +"""Dataset-shape types + pure helpers. -`DataPrep` is the single user-facing entry point. It reads a `DatasetSpec` -(see `dimos.learning.spec`) and either: - - builds a training-ready dataset on disk in HDF5/RLDS/LeRobot, or - - returns a PyTorch Dataset for training. +Sub-configs (StreamField, SyncConfig, OutputConfig, EpisodeExtractor) and +data records (Episode, Sample) live here. So do the stateless functions +that walk samples — `resolve_field`, `compute_stats`, `extract_episodes`, +`iter_episode_samples`. Importable without booting a Module. -Stateless helpers (episode extraction, sample iteration, field resolution) -live as `@staticmethod`s on `DataPrep` so they share one namespace and are -callable without an instance — the live `ObsBuilder` at inference time -reuses `DataPrep.resolve_field` for that reason. - -Workflow: - # 1. Record a teleop session (Sam's PR #1708) - dimos --blueprint quest_teleop_xarm7 --record-path session.db - - # 2. Build a training-ready dataset - python -m dimos.learning.dataprep build dataset.yaml - - # 3. Train using the same spec - from dimos.learning.dataprep import DataPrep - dp = DataPrep.from_file("dataset.yaml") - ds = dp.load() +`DataPrepModule` (in `dataprep_module.py`) is a thin wrapper that runs +these helpers on a thread. """ from __future__ import annotations from collections.abc import Callable, Iterator from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import numpy as np +from pydantic import BaseModel, ConfigDict, Field -from dimos.learning.spec import ( - DatasetSpec, - Episode, - EpisodeConfig, - FilterConfig, - OutputConfig, - Sample, - StreamField, -) +from dimos.protocol.service.spec import BaseConfig -Writer = Callable[[Iterator[Sample], OutputConfig], Path] +Writer = Callable[[Iterator["Sample"], "OutputConfig"], Path] if TYPE_CHECKING: - import torch - from dimos.memory2.store.sqlite import SqliteStore # ───────────────────────────────────────────────────────────────────────────── -# DataPrep — the only thing this module exports besides `main()` +# Sub-configs # ───────────────────────────────────────────────────────────────────────────── -class DataPrep: - """Build / load / inspect a dataset from a `DatasetSpec`. +class EpisodeExtractor(BaseConfig): + extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" + status_stream: str = "episode_status" + ranges: list[tuple[float, float]] | None = None - Holds the open `SqliteStore` and the cached, filtered episode list so - repeated operations on the same spec (e.g. `inspect()` then `build()`) - don't redo work. Construction is cheap — the store and episodes are - computed lazily on first access. - Not a DimOS Module: no ports, no runtime lifecycle. It's a stateful - façade over the static helpers below. - """ +class StreamField(BaseConfig): + stream: str + field: str | None = None - # ── construction ───────────────────────────────────────────────────────── - def __init__(self, spec: DatasetSpec) -> None: - """Bind to a spec. Does not open the store or extract episodes yet.""" - raise NotImplementedError +class SyncConfig(BaseConfig): + anchor: str + rate_hz: float + tolerance_ms: float + strategy: Literal["nearest", "interp"] = "nearest" - @classmethod - def from_file(cls, path: str | Path) -> DataPrep: - """Convenience: `DataPrep.from_file("dataset.yaml")`.""" - raise NotImplementedError - # ── lazy-cached state ──────────────────────────────────────────────────── +class OutputConfig(BaseConfig): + format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" + path: Path + metadata: dict[str, Any] = {} - @property - def store(self) -> SqliteStore: - """Open the recording's SqliteStore on first access; cached thereafter.""" - raise NotImplementedError + +# ───────────────────────────────────────────────────────────────────────────── +# Data records +# ───────────────────────────────────────────────────────────────────────────── + + +class Episode(BaseModel): + id: str + start_ts: float + end_ts: float + task_label: str | None = None + success: bool = True + metadata: dict[str, Any] = Field(default_factory=dict) @property - def episodes(self) -> list[Episode]: - """Extract + filter episodes on first access; cached thereafter. - - Equivalent to: - DataPrep.filter_episodes( - DataPrep.extract_episodes(store, spec.episodes), - spec.filters, - ) - """ - raise NotImplementedError - - # ── operations ─────────────────────────────────────────────────────────── - - def iter_samples(self) -> Iterator[Sample]: - """Yield synced Samples across every episode, in episode order.""" - raise NotImplementedError - - def build(self) -> Path: - """End-to-end: source session.db -> on-disk dataset in spec.output.format. - - Returns the path written. Requires `spec.output` to be set. Dispatches - to the appropriate writer in `dimos.learning.formats` via `_get_writer`. - """ - raise NotImplementedError - - def load(self) -> torch.utils.data.Dataset[Sample]: - """Training-time loader: returns a PyTorch Dataset over the source recording. - - Materializes Samples on-the-fly (lazy). Does not require `spec.output`. - Pre-extracts episodes once and indexes anchor timestamps for O(1) - `__getitem__`. - """ - raise NotImplementedError - - def inspect(self) -> dict[str, Any]: - """Stats for a session: episode count, duration distribution, - per-stream sample counts. Used by `python -m dimos.learning.dataprep inspect`. - """ - raise NotImplementedError - - def close(self) -> None: - """Close the underlying SqliteStore. Safe to call multiple times.""" - raise NotImplementedError - - def __enter__(self) -> DataPrep: - return self - - def __exit__(self, *exc: object) -> None: - self.close() - - # ── stateless helpers ──────────────────────────────────────────────────── - # - # Static so they're callable without an instance. `resolve_field` in - # particular is reused by `dimos.learning.inference.obs_builder` to build - # live observations, so train and infer share exactly one code path for - # field projection + preprocess. - - @staticmethod - def extract_episodes(store: SqliteStore, cfg: EpisodeConfig) -> list[Episode]: - """Extract episode boundaries per the configured strategy. - - BUTTONS: scan cfg.button_stream for rising edges on cfg.start/save/discard. - State machine: - IDLE --start press--> RECORDING (begin episode) - RECORDING --save press--> IDLE (commit, success=True) - RECORDING --discard press--> IDLE (drop) - RECORDING --start press--> RECORDING (auto-commit, begin new) - session ends mid-episode: always discard - - RANGES: emit one Episode per (start_ts, end_ts) tuple in cfg.ranges. - - WHOLE: emit a single Episode covering the entire recording's time range. - """ - raise NotImplementedError - - @staticmethod - def filter_episodes(eps: list[Episode], cfg: FilterConfig | None) -> list[Episode]: - """Apply success / duration / label whitelist filters. `None` = pass-through. - - Note: train/val split fields on FilterConfig (`val_episode_ids`, - `val_ratio`) are *not* applied here — they're consumed by the trainer, - which needs the full episode list to materialize both splits. - """ - raise NotImplementedError - - @staticmethod - def iter_episode_samples( - store: SqliteStore, - episode: Episode, - spec: DatasetSpec, - ) -> Iterator[Sample]: - """Yield synced (obs, action) Samples for a single episode. - - Walks the anchor stream at sync.rate_hz between episode.start_ts and - episode.end_ts. For each anchor timestamp, pulls the nearest observation/ - action from each configured stream within sync.tolerance_ms. Applies any - declared preprocess (e.g. jpeg_decode for Image, field projection for - JointState). Skips frames where any required stream lacks a sample - within tolerance. - """ - raise NotImplementedError - - @staticmethod - def resolve_field(msg: Any, ref: StreamField) -> np.ndarray: - """Pull a single field from a stream message and convert to np.ndarray. - - Applies ref.field projection (attribute access) and ref.preprocess hook - (named transform like jpeg_decode). Returns a numpy array suitable for - inclusion in a Sample. - - Reused by the live ObsBuilder at inference time — single source of - truth for observation construction across train and infer. - """ - raise NotImplementedError - - @staticmethod - def _get_writer(format_name: str) -> Writer: - """Lazy-import the `write` function for a given format. Avoids loading - heavy deps (h5py, tfds, lerobot) for unused formats. - """ - if format_name == "lerobot": - from dimos.learning.formats.lerobot import write - elif format_name == "hdf5": - from dimos.learning.formats.hdf5 import write - elif format_name == "rlds": - from dimos.learning.formats.rlds import write - else: - raise ValueError( - f"Unknown dataset format: {format_name!r}. Supported: lerobot, hdf5, rlds." - ) - return write + def duration(self) -> float: + return self.end_ts - self.start_ts + + +class Sample(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + ts: float + episode_id: str + observation: dict[str, np.ndarray] + action: dict[str, np.ndarray] # ───────────────────────────────────────────────────────────────────────────── -# CLI +# Pure helpers — used by ChunkPolicyModule, format writers, DataPrepModule # ───────────────────────────────────────────────────────────────────────────── -def main() -> None: - """CLI entrypoint: `build` / `inspect` / `review` a dataset spec.""" +def resolve_field(msg: Any, ref: StreamField) -> np.ndarray: + """Project `msg` through `ref` (attribute access). Single source of + truth for obs/action construction across train and live.""" + raise NotImplementedError + + +def extract_episodes(store: SqliteStore, cfg: EpisodeExtractor) -> list[Episode]: + """Walk recorded EpisodeStatus events (or ranges/whole_session) into Episodes.""" + raise NotImplementedError + + +def iter_episode_samples( + store: SqliteStore, + episode: Episode, + streams: dict[str, StreamField], # observation ∪ action + sync: SyncConfig, +) -> Iterator[Sample]: + """Yield synced (obs, action) Samples for one episode.""" + raise NotImplementedError + + +def compute_stats( + samples: Iterator[Sample], + image_subsample: int = 10, + quantile_reservoir: int = 10_000, + seed: int = 0, +) -> dict[str, Any]: + """Per-feature mean/std/min/max/q01/q99 in one pass. + + Welford for mean/std; reservoir sample for quantiles. Image features + subsampled (every Nth frame) to bound cost. + """ raise NotImplementedError -if __name__ == "__main__": - main() +def get_writer(format_name: str) -> Writer: + """Lazy-import the format writer's `write` function.""" + if format_name == "lerobot": + from dimos.learning.formats.lerobot import write + elif format_name == "hdf5": + from dimos.learning.formats.hdf5 import write + elif format_name == "rlds": + from dimos.learning.formats.rlds import write + else: + raise ValueError(f"Unknown format: {format_name!r}") + return write diff --git a/dimos/learning/dataprep_module.py b/dimos/learning/dataprep_module.py new file mode 100644 index 0000000000..e08b660907 --- /dev/null +++ b/dimos/learning/dataprep_module.py @@ -0,0 +1,92 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataPrepModule — wraps the dataprep pipeline as a Module with RPC surface. + +All dataset-shape types and pure helpers live in `dataprep.py`. This file +just adds the Module lifecycle + thread + status tracking. +""" + +from __future__ import annotations + +import threading +from typing import Any + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.learning.dataprep import ( + EpisodeExtractor, + OutputConfig, + StreamField, + SyncConfig, +) + + +class DataPrepModuleConfig(ModuleConfig): + source: str + episodes: EpisodeExtractor + observation: dict[str, StreamField] + action: dict[str, StreamField] + sync: SyncConfig + output: OutputConfig + auto_run: bool = False + + +class DataPrepModule(Module): + """Wraps a long-running dataset build job.""" + + config: DataPrepModuleConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._thread: threading.Thread | None = None + self._lock = threading.Lock() + self._status: dict[str, Any] = { + "state": "idle", # idle | running | succeeded | failed + "current_phase": None, # scan_episodes | write | done + "progress_pct": 0.0, + "dataset_path": None, + "error": None, + } + + @rpc + def start(self) -> None: + raise NotImplementedError + + @rpc + def stop(self) -> None: + raise NotImplementedError + + @rpc + def build(self) -> None: + """Spawn a daemon thread running the build pipeline. Returns immediately.""" + raise NotImplementedError + + @rpc + def get_status(self) -> dict[str, Any]: + raise NotImplementedError + + @rpc + def inspect(self) -> dict[str, Any]: + """Read-only summary: episode count, drop rates, joint names, stats presence.""" + raise NotImplementedError + + def _run_build(self) -> None: + """Thread target. Opens session.db, calls extract_episodes / + iter_episode_samples / format writer, snapshots config to + /dimos_meta.json. Updates _status under _lock.""" + raise NotImplementedError + + +__all__ = ["DataPrepModule", "DataPrepModuleConfig"] diff --git a/dimos/learning/dataset.example.yaml b/dimos/learning/dataset.example.yaml deleted file mode 100644 index 83a774fd68..0000000000 --- a/dimos/learning/dataset.example.yaml +++ /dev/null @@ -1,88 +0,0 @@ -# DimOS Learning — DatasetSpec template -# -# This file is the contract between data collection and training. Same spec is -# used by `python -m dimos.learning.dataprep build` (export to disk) and by -# `load_dataset(spec)` (training-time PyTorch Dataset). -# -# Stream names below must match the topic names recorded by RecordReplay. To -# discover them: `python -m dimos.learning.dataprep inspect dataset.yaml` -# (or look in the SQLite registry of session.db). - -# ─── Source recording ──────────────────────────────────────────────────────── -source: ./session.db - -# ─── How to slice the session into episodes ────────────────────────────────── -episodes: - extractor: buttons # buttons | ranges | whole_session - - # BUTTONS extractor: state machine over the recorded button stream - button_stream: buttons # stream name (matches LCM topic, sanitized) - start: A # rising edge -> begin episode - save: B # rising edge -> end + save - discard: X # rising edge -> end + drop - # Note: if recording stops mid-episode without an explicit save/discard, - # the in-progress episode is always discarded. - - # RANGES extractor: explicit absolute timestamps (only used when extractor: ranges) - # ranges: - # - [1730000000.0, 1730000045.5] - # - [1730000060.0, 1730000110.2] - - default_task_label: pick_red_cube # optional; short categorical tag - # Free-form natural-language task string used as language conditioning for VLAs (π₀, π₀.₅). - default_task_description: "pick up the red cube and place it on the blue plate" - -# ─── What goes into each timestep ──────────────────────────────────────────── -# Each entry: dataset_key -> { stream, type?, field?, preprocess? } -# stream — recorded stream name (LCM topic, sanitized) -# type — optional dotted message type (for codec dispatch) -# field — attribute on the message; omit to keep the whole message -# preprocess — named transform applied after field projection - -observation: - image: - stream: camera_color_image - type: sensor_msgs.Image - preprocess: jpeg_decode # raw JPEG bytes -> HxWx3 uint8 - joint_pos: - stream: coordinator_joint_state - type: sensor_msgs.JointState - field: position - joint_vel: - stream: coordinator_joint_state - type: sensor_msgs.JointState - field: velocity - -action: - target_pos: - stream: coordinator_joint_command - type: sensor_msgs.JointState - field: position - -# ─── Synchronization (build per-timestep samples) ──────────────────────────── -sync: - anchor: image # which observation key drives the timeline - rate_hz: 30 # downsample anchor to this rate; 0 = native - tolerance_ms: 50 # max time delta when picking nearest sample - strategy: nearest # nearest | interp - -# ─── Per-episode filters (optional) ────────────────────────────────────────── -filters: - success_only: true - min_duration_s: 1.0 - # max_duration_s: 60.0 - # task_labels: [pick_red_cube, pick_blue_cube] - - # Train/val split — both optional. val_episode_ids takes precedence over val_ratio. - # val_episode_ids: [0, 7, 13] - # val_ratio: 0.1 - # val_split_seed: 0 - -# ─── Output (only required when calling build_dataset / `... build`) ───────── -output: - format: lerobot # lerobot | hdf5 | rlds - path: ./datasets/pick_red/ - metadata: - fps: 30 - robot: xarm7 - task_label: pick_red_cube diff --git a/dimos/learning/formats/hdf5.py b/dimos/learning/formats/hdf5.py index b40324fd02..57ddb096a0 100644 --- a/dimos/learning/formats/hdf5.py +++ b/dimos/learning/formats/hdf5.py @@ -12,24 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""HDF5 dataset writer. - -Produces a single .hdf5 file with one group per episode: - /data/episode_000000/ - observation/ # one dataset per observation key (T, ...) - action/ # one dataset per action key (T, ...) - ts # timestamps (T,) - /metadata # JSON-encoded spec + per-episode tags -""" +"""HDF5 dataset writer. Single .hdf5 with one group per episode + stats group.""" from __future__ import annotations from collections.abc import Iterator from pathlib import Path -from dimos.learning.spec import OutputConfig, Sample +from dimos.learning.dataprep import OutputConfig, Sample def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Write samples to a single HDF5 file. Returns the file path.""" + """Write samples to a single HDF5 file (stats as group attrs). + Returns the file path.""" raise NotImplementedError diff --git a/dimos/learning/formats/lerobot.py b/dimos/learning/formats/lerobot.py index bb82161e3c..e5e348af44 100644 --- a/dimos/learning/formats/lerobot.py +++ b/dimos/learning/formats/lerobot.py @@ -14,12 +14,13 @@ """LeRobot v2 dataset writer. -Produces a directory layout compatible with HuggingFace LeRobot: +Layout: / - meta/info.json # schema, fps, total episodes/frames - meta/episodes.jsonl # per-episode metadata (length, task) - data/chunk-000/episode_000000.parquet # tabular obs+action - videos/chunk-000//episode_000000.mp4 # encoded image streams + meta/info.json schema, fps, total episodes/frames + meta/episodes.jsonl per-episode metadata + meta/stats.json per-feature stats (from DataPrep.compute_stats) + data/chunk-000/episode_*.parquet + videos/chunk-000//episode_*.mp4 """ from __future__ import annotations @@ -27,9 +28,10 @@ from collections.abc import Iterator from pathlib import Path -from dimos.learning.spec import OutputConfig, Sample +from dimos.learning.dataprep import OutputConfig, Sample def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Write samples in LeRobot v2 layout. Returns the dataset root path.""" + """Drain samples, write parquet+MP4, call DataPrep.compute_stats, + serialize stats to meta/stats.json. Return the dataset root path.""" raise NotImplementedError diff --git a/dimos/learning/formats/rlds.py b/dimos/learning/formats/rlds.py index 3fbf2d40db..23126279dc 100644 --- a/dimos/learning/formats/rlds.py +++ b/dimos/learning/formats/rlds.py @@ -12,23 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""RLDS / TFDS dataset writer. - -Produces a TFDS-on-disk layout (TFRecord shards + dataset_info.json) following -the RLDS Episode/Step protocol used by Open X-Embodiment, RT-X, etc. - -Each TF Example encodes one Episode as a sequence of Steps with: - observation/, action/, reward, discount, is_first, is_last, is_terminal -""" +"""RLDS / TFDS dataset writer. TFRecord shards + dataset_info.json +following the RLDS Episode/Step protocol.""" from __future__ import annotations from collections.abc import Iterator from pathlib import Path -from dimos.learning.spec import OutputConfig, Sample +from dimos.learning.dataprep import OutputConfig, Sample def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Write samples as TFDS/RLDS shards. Returns the dataset directory path.""" + """Write samples as TFDS/RLDS shards (stats in features metadata). + Returns the dataset directory path.""" raise NotImplementedError diff --git a/dimos/learning/inference/action_replayer.py b/dimos/learning/inference/action_replayer.py deleted file mode 100644 index 3903ea26f7..0000000000 --- a/dimos/learning/inference/action_replayer.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Replay policy-emitted ActionChunks at the coordinator's tick rate. - -`ChunkPolicyModule` runs slow (1–30 Hz) and emits sequences of future actions -(ActionChunks). The coordinator runs at 100 Hz. This task bridges them: -subscribe to the chunk topic, maintain a small buffer of pending (target_ts, -positions) entries, and on each tick interpolate to the current time. - -Lives in the tick loop because hardware writes happen there. Designed so a -slow / stalled policy doesn't crash the controller — see "fault behavior" -below. -""" - -from __future__ import annotations - -from dataclasses import dataclass - -from dimos.control.task import ( - BaseControlTask, - CoordinatorState, - JointCommandOutput, - ResourceClaim, -) -from dimos.learning.policy.base import ActionChunk - - -@dataclass -class ActionReplayerConfig: - """Configuration for ActionReplayer. - - Attributes: - joint_names: joints this task commands. Must match the policy's - `joint_names` (caller is responsible — typically wired from the - checkpoint's `dimos_meta.json`). - chunk_topic: name of the topic ChunkPolicyModule publishes on. - ActionReplayer subscribes via the coordinator's transport. - priority: tick-loop arbitration priority. - max_chunk_age_s: drop any incoming chunk whose `ts` is more than this - many seconds old at receive time. Guards against stalls. - hold_on_stall: if the buffer empties (policy fell behind / died), - hold the last commanded position instead of returning None - (which would let lower-priority tasks take over). - temporal_ensemble: when overlapping chunks arrive, exponentially - weight predictions for the same target time (ACT trick). - Off by default; v1 nice-to-have. - """ - - joint_names: list[str] - chunk_topic: str = "action_chunk" - priority: int = 10 - max_chunk_age_s: float = 0.5 - hold_on_stall: bool = True - temporal_ensemble: bool = False - - -class ActionReplayer(BaseControlTask): - """ControlTask that replays policy chunks into joint commands at tick rate. - - Behavior: - - On each new chunk, drop any buffered targets at or after the new - chunk's first target_ts (latest chunk wins). - - On each tick, interpolate (or look up nearest) target for `state.now`. - - If `state.now` is past the buffer end: - - hold last position if `hold_on_stall=True` - - else go inactive (return None) - - Stale chunks (`now - chunk.ts > max_chunk_age_s`) are dropped. - - Fault behavior: - - Policy dies / module crashes: buffer drains, behavior degrades to - "hold last position" (or inactive). Hardware never sees zero or NaN. - """ - - def __init__(self, name: str, config: ActionReplayerConfig) -> None: - """Initialize. Subscription to `chunk_topic` is set up by the coordinator - when the task is registered (we expose `on_action_chunk` for it to call). - """ - raise NotImplementedError - - # ── ControlTask interface ──────────────────────────────────────────────── - - @property - def name(self) -> str: - raise NotImplementedError - - def claim(self) -> ResourceClaim: - """Claim `config.joint_names` at `config.priority`.""" - raise NotImplementedError - - def is_active(self) -> bool: - """Active iff the buffer has a non-stale target for `now` (or - `hold_on_stall` is true and we've ever received a chunk).""" - raise NotImplementedError - - def compute(self, state: CoordinatorState) -> JointCommandOutput | None: - """Return interpolated joint targets for `state.now`. - - Pure lookup over the buffered chunk; no model inference happens here. - Must complete in well under 10 ms to not jeopardize the 100 Hz loop. - """ - raise NotImplementedError - - # ── chunk handling ─────────────────────────────────────────────────────── - - def on_action_chunk(self, msg: ActionChunk) -> None: - """Push a new chunk's actions into the buffer. - - Steps: - 1. If `time_now - msg.ts > max_chunk_age_s`: drop and log. - 2. Compute target_ts for each action: `msg.ts + i * msg.dt`. - 3. Drop any buffered entries with target_ts >= msg.ts + msg.dt. - 4. Append the new (target_ts, positions) pairs in order. - """ - raise NotImplementedError - - # ── internals ──────────────────────────────────────────────────────────── - - def _interpolate(self, t: float) -> JointCommandOutput | None: - """Look up or linearly interpolate the buffer at time `t`. Returns - None if `t` is outside the buffered range and `hold_on_stall=False`.""" - raise NotImplementedError diff --git a/dimos/learning/inference/blueprint.py b/dimos/learning/inference/blueprint.py index 507124ec75..485ec7a2d5 100644 --- a/dimos/learning/inference/blueprint.py +++ b/dimos/learning/inference/blueprint.py @@ -12,41 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Inference blueprints for the DimOS Learning Framework. - -Each blueprint composes: - Camera (publishes color_image) - ChunkPolicyModule (consumes obs, publishes ActionChunk at policy rate) - ControlCoordinator with ActionReplayer task (replays chunks at 100 Hz) - Hardware (consumes joint_command) - -The same blueprint serves both ACT (vision) and pi0/pi0.5 (vision + -language) — `ChunkPolicyModule` auto-detects from the loaded checkpoint -whether the policy expects language. For VLA, an LLM agent skill or a -language-source Module publishes to the `language_text` topic. - -Note: ActionReplayer is a ControlTask, not a Module. It runs inside the -ControlCoordinator and is registered via `task_type="action_replayer"` in -the coordinator's task config. The coordinator variants below currently -reference the existing teleop-IK coordinator blueprints; v1 implementation -adds learning-specific coordinator blueprints under -`dimos/control/blueprints/learning.py` that swap teleop_ik for action_replayer -(see plan §7 critical files). - -Usage: - dimos run learning-infer-xarm7 \\ - --ChunkPolicyModule.config.spec_path dataset.yaml \\ - --ChunkPolicyModule.config.policy_path runs/act_pick_red \\ - --ChunkPolicyModule.config.inference_rate_hz 30 +"""ACT inference blueprint. ActionReplayer is registered by the per-robot +coordinator blueprint (passed in below). v1 placeholder uses the existing +teleop coordinator; replace with a coordinator that registers ActionReplayer. """ from __future__ import annotations -from dimos.control.blueprints.teleop import ( - coordinator_teleop_piper, - coordinator_teleop_xarm6, - coordinator_teleop_xarm7, -) +from dimos.control.blueprints.teleop import coordinator_teleop_xarm7 from dimos.core.coordination.blueprints import autoconnect from dimos.core.transport import LCMTransport from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera @@ -55,82 +28,25 @@ from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.JointState import JointState -# Topics shared across variants. _T_COLOR_IMAGE = "/camera/color_image" _T_JOINT_STATE = "/coordinator/joint_state" -_T_LANGUAGE = "/learning/language_text" _T_ACTION_CHUNK = "/learning/action_chunk" -# ── XArm7 (ACT-rate, 30 Hz) ────────────────────────────────────────────────── - learning_infer_xarm7 = autoconnect( RealSenseCamera.blueprint(enable_pointcloud=False), - ChunkPolicyModule.blueprint(inference_rate_hz=30.0), + ChunkPolicyModule.blueprint( + policy_path="data/runs/act_pick_red", + inference_rate_hz=30.0, + ), coordinator_teleop_xarm7, # TODO: replace with coordinator_action_replayer_xarm7 ).transports( { ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), - ("language_text", str): LCMTransport(_T_LANGUAGE, str), - ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), - } -) - - -# ── Piper (ACT-rate) ───────────────────────────────────────────────────────── - -learning_infer_piper = autoconnect( - RealSenseCamera.blueprint(enable_pointcloud=False), - ChunkPolicyModule.blueprint(inference_rate_hz=30.0), - coordinator_teleop_piper, -).transports( - { - ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), - ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), - ("language_text", str): LCMTransport(_T_LANGUAGE, str), - ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), - } -) - - -# ── XArm6 (ACT-rate) ───────────────────────────────────────────────────────── - -learning_infer_xarm6 = autoconnect( - RealSenseCamera.blueprint(enable_pointcloud=False), - ChunkPolicyModule.blueprint(inference_rate_hz=30.0), - coordinator_teleop_xarm6, -).transports( - { - ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), - ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), - ("language_text", str): LCMTransport(_T_LANGUAGE, str), - ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), - } -) - - -# ── XArm7 (VLA-rate, 5 Hz) ─────────────────────────────────────────────────── -# Same wiring; only the policy thread rate differs. pi0/pi0.5 are slow -# enough that running them at 30 Hz wastes GPU. - -learning_infer_vla_xarm7 = autoconnect( - RealSenseCamera.blueprint(enable_pointcloud=False), - ChunkPolicyModule.blueprint(inference_rate_hz=5.0), - coordinator_teleop_xarm7, -).transports( - { - ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), - ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), - ("language_text", str): LCMTransport(_T_LANGUAGE, str), ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), } ) -__all__ = [ - "learning_infer_piper", - "learning_infer_vla_xarm7", - "learning_infer_xarm6", - "learning_infer_xarm7", -] +__all__ = ["learning_infer_xarm7"] diff --git a/dimos/learning/inference/chunk_policy_module.py b/dimos/learning/inference/chunk_policy_module.py index 4cccdfbf55..811dd136f7 100644 --- a/dimos/learning/inference/chunk_policy_module.py +++ b/dimos/learning/inference/chunk_policy_module.py @@ -12,141 +12,107 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Vision/VLA policy as a DimOS Module — produces action chunks at policy rate. +"""ACT inference Module @ ~30 Hz. -One module covers both v1 inference targets: - - ACT (10–30 Hz, vision + joint state) - - pi0/pi0.5 (1–5 Hz, vision + joint state + language) +Reads `/dimos_meta.json` at start() to recover obs schema +(StreamField map + sync). Latches In ports; calls predict_chunk on the +freshest snapshot every tick; publishes ActionChunk over LCM. -`ChunkPolicyModule` runs the policy in a background thread at `inference_rate_hz`, -publishes each output as an `ActionChunk` message, and is consumed by -`ActionReplayer` (in the coordinator's tick loop) which interpolates to 100 Hz. - -Heavy ML deps (`lerobot`, `torch`) are imported lazily via `LeRobotPolicy.load`, -not at module import time — so just having this in a blueprint doesn't pull -CUDA into every install. +Heavy ML deps (`lerobot`, `torch`) imported lazily inside `start()` — +this file is import-light. """ from __future__ import annotations import threading +import time from typing import Any +import numpy as np + from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.learning.inference.obs_builder import ObsBuilder +from dimos.learning.dataprep import StreamField, SyncConfig from dimos.learning.policy.base import ActionChunk, Policy from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.JointState import JointState class ChunkPolicyModuleConfig(ModuleConfig): - """Config for ChunkPolicyModule.""" - - spec_path: str # path to dataset.yaml — supplies obs construction - policy_path: str # path to lerobot checkpoint dir - inference_rate_hz: float = 5.0 # 5 Hz default for VLA; 30 Hz for ACT + policy_path: str + inference_rate_hz: float = 30.0 device: str = "cuda" - default_language: str = "" # used when `language_text` port has no value yet class ChunkPolicyModule(Module): - """Runs a Policy at `inference_rate_hz`, publishes ActionChunks. - - Live message latching: - - `color_image` and `joint_state` are cached on every receive; the - policy thread reads the latest cached value at each tick. - - `language_text` is optional; if the policy doesn't expect language - (`policy.expects_language is False`) the port is ignored. - - The thread loop is best-effort wrt `inference_rate_hz`: if a forward pass - takes longer than the period, the next tick fires immediately; we never - queue stale work. - """ - config: ChunkPolicyModuleConfig color_image: In[Image] joint_state: In[JointState] - language_text: In[str] - action_chunk: Out[ActionChunk] def __init__(self, **kwargs: Any) -> None: - """Defer all heavy init to `start()`.""" super().__init__(**kwargs) - # Latched live messages — written by port callbacks, read by policy thread. + # Latched live messages self._latest_image: Image | None = None self._latest_joint_state: JointState | None = None - self._latest_language: str | None = None self._latch_lock = threading.Lock() # Filled in start(): self._policy: Policy | None = None - self._obs_builder: ObsBuilder | None = None + self._observation: dict[str, StreamField] = {} + self._sync: SyncConfig | None = None self._chunk_id: int = 0 - # Thread control self._thread: threading.Thread | None = None self._stop = threading.Event() - # ── lifecycle ──────────────────────────────────────────────────────────── - @rpc def start(self) -> None: - """Load spec + policy, subscribe to ports, spawn the inference thread. - - Steps: - 1. `spec = DatasetSpec.from_file(config.spec_path)` - 2. `self._policy = LeRobotPolicy.load(config.policy_path, device=config.device)` - 3. `self._obs_builder = ObsBuilder(spec)` - 4. Subscribe color_image / joint_state / language_text -> latch handlers. - 5. Start the policy thread targeting `_run_loop`. - """ + """Lazy-import LeRobotPolicy; load checkpoint; read dimos_meta.json + for observation/sync; subscribe to ports; spawn the loop thread.""" raise NotImplementedError @rpc def stop(self) -> None: - """Stop the inference thread and call `super().stop()`.""" - raise NotImplementedError - - # ── agent surface ──────────────────────────────────────────────────────── - - @rpc - def set_language(self, text: str) -> None: - """Override the language conditioning text without touching the - upstream `language_text` port. Useful when an LLM agent skill drives - VLA task switching.""" raise NotImplementedError @rpc def reload_policy(self, policy_path: str, device: str | None = None) -> None: - """Hot-swap the policy checkpoint without restarting the blueprint. - Stops the inference thread, loads the new checkpoint, restarts.""" + """Hot-swap the checkpoint without restarting the blueprint.""" raise NotImplementedError @rpc def get_status(self) -> dict[str, Any]: - """Return {'running': bool, 'chunk_count': int, 'policy_path': str, - 'expects_language': bool, 'last_chunk_ts': float | None}.""" + """{'running', 'chunk_count', 'policy_path', 'last_chunk_ts'}.""" raise NotImplementedError - # ── inference loop ─────────────────────────────────────────────────────── - def _run_loop(self) -> None: - """Background thread. Sleep to next deadline, build obs, call policy, - publish chunk. Logs and continues on any per-tick error so a single - bad observation doesn't kill inference. - """ - raise NotImplementedError - - def _build_live_obs(self) -> dict[str, Any] | None: - """Snapshot the latched messages and assemble the dict the ObsBuilder wants. - - Returns None if any required stream hasn't received a message yet - (the loop will skip this tick and try again). - """ + period = 1.0 / self.config.inference_rate_hz + while not self._stop.is_set(): + t0 = time.monotonic() + obs = self._build_live_obs() + if obs is None: + time.sleep(period) + continue + + positions = self._policy.predict_chunk(obs) # (T, action_dim) + self.action_chunk.publish( + ActionChunk( + ts=time.time(), + joint_names=self._policy.joint_names, + positions=positions, + dt=period, + chunk_id=self._next_chunk_id(), + ) + ) + time.sleep(max(0.0, period - (time.monotonic() - t0))) + + def _build_live_obs(self) -> dict[str, np.ndarray] | None: + """Snapshot latched messages under a lock, project each obs key + through `resolve_field` using `self._observation`. Returns + None if any required stream hasn't received a message yet.""" raise NotImplementedError def _next_chunk_id(self) -> int: diff --git a/dimos/learning/inference/obs_builder.py b/dimos/learning/inference/obs_builder.py deleted file mode 100644 index f08db973d6..0000000000 --- a/dimos/learning/inference/obs_builder.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Live observation construction for inference. - -`ObsBuilder` is the inference-time counterpart to `DataPrep.iter_episode_samples`. -At training time, samples are built by walking recorded streams; at inference -time we have the *latest* message on each live stream. The transformation -from per-stream messages to a model-ready obs dict must be identical between -the two paths or we get train/serve skew. - -To guarantee that, `ObsBuilder` reuses `DataPrep.resolve_field` for field -projection + preprocess. The only thing it adds is a mapping from the spec's -`stream:` names to live message objects supplied by the caller (the -ChunkPolicyModule, which has the actual `In[Image]` / `In[JointState]` ports). -""" - -from __future__ import annotations - -from typing import Any - -import numpy as np - -from dimos.learning.spec import DatasetSpec - - -class ObsBuilder: - """Builds the model-input dict from the latest live messages. - - Construction takes a `DatasetSpec`. `build()` takes a `{stream_name: msg}` - dict where keys are the recorded stream names referenced by - `spec.observation[*].stream`, and values are the live LCM messages. - - The caller (ChunkPolicyModule) is responsible for resolving its In ports - to those stream names — that's a small static mapping it sets up once. - """ - - def __init__(self, spec: DatasetSpec) -> None: - """Cache the observation StreamFields for fast lookup at tick rate.""" - raise NotImplementedError - - def build(self, live_messages: dict[str, Any]) -> dict[str, np.ndarray]: - """Project + preprocess the latest message on each obs stream. - - Args: - live_messages: stream_name -> latest message object (e.g. - {"camera_color_image": , "coordinator_joint_state": }). - Every stream referenced by `spec.observation` must be present; - missing streams raise. - - Returns: - obs dict keyed by `spec.observation` keys (e.g. "cam_high", - "joint_pos"). Values are np.ndarrays whose shapes/dtypes match - what `iter_episode_samples` produced at training time. - """ - raise NotImplementedError - - def required_streams(self) -> set[str]: - """Stream names this builder reads from. Used by ChunkPolicyModule - to wire its In ports + assert the live_messages dict is complete.""" - raise NotImplementedError diff --git a/dimos/learning/learning_spec.md b/dimos/learning/learning_spec.md deleted file mode 100644 index 469a565c28..0000000000 --- a/dimos/learning/learning_spec.md +++ /dev/null @@ -1,164 +0,0 @@ -### Data collection - -Two phases: -A. Recording - teleop/drive robot, streams recorded at `session.db` -B. DataPrep - convert `session.db` to `dataset/` - -Phase A - Recording -``` python - -spec = DatasetConfig.from_file("datasets/pick_cube.yaml") - -learning_collect_quest_xarm7 = autoconnect( - teleop_quest_xarm7, - RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(spec=spec), -) - -# hardware -# sim -``` - -`RecordReplay` is a transport-layer hook (`--record-path`); captures every -transport in the blueprint (including `episode_status`) into `session.db`. - -Open: feedback to operators — display episode number / status in the VR -headset. Unify episode-boundary declaration across VR / keyboard / active+passive -/ future inputs into one `episode_status` stream in the `.db` file. - -``` python -from dimos.learning.config import DatasetConfig, EpisodeStatus - -class EpisodeMonitorModuleConfig(ModuleConfig): - spec: DatasetConfig - -class EpisodeMonitorModule(Module): - config: EpisodeMonitorModuleConfig - - buttons: In[Buttons] - keyboard: In[KeyPress] - status: Out[EpisodeStatus] - - @rpc - def reset_counters(self) -> EpisodeStatus: ... - @rpc - def get_status(self) -> EpisodeStatus: ... - - def _on_buttons(self, msg: Buttons) -> None: ... - def _on_keyboard(self, msg: KeyPress) -> None: ... -``` - -Config -```yaml -source: session.db - -episodes: - extractor: episode_status - default_task_label: pick_red_cube - button_map: {start: A, save: B, discard: X} - keyboard_map: {start: space, save: s, discard: d} - -observation: - cam: - stream: camera_color_image - field: image - joint_pos: - stream: coordinator_joint_state - field: position - -action: - joint_target: - stream: coordinator_joint_command - field: position - -sync: - anchor: cam - rate_hz: 30 - tolerance_ms: 50 - strategy: nearest - -output: - format: lerobot - path: datasets/pick_red/ - metadata: {fps: 30, robot: xarm7} -``` - -`learning/config.py` - -``` python -from dimos.protocol.service.spec import BaseConfig - - -class EpisodeStatus(BaseModel): # runtime message (not BaseConfig — built in code, not from YAML) - state: Literal["idle", "recording"] - episodes_saved: int - episodes_discarded: int - current_episode_start_ts: float | None - last_event: Literal["start", "save", "discard", "init"] = "init" - task_label: str | None = None - - -class EpisodeConfig(BaseConfig): - extractor: Literal["episode_status", "ranges", "whole_session"] - ranges: list[tuple[float, float]] | None = None - default_task_label: str | None = None - button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} - keyboard_map: dict[Literal["start", "save", "discard"], str] = {} - -class StreamField(BaseConfig): - stream: str - field: str | None = None - -class SyncConfig(BaseConfig): - anchor: str - rate_hz: float - tolerance_ms: float - strategy: Literal["nearest", "interp"] = "nearest" - -class OutputConfig(BaseConfig): - format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" - path: Path - metadata: dict[str, Any] = {} - - -class DatasetConfig(BaseConfig): - source: str - episodes: EpisodeConfig - observation: dict[str, StreamField] - action: dict[str, StreamField] - sync: SyncConfig - output: OutputConfig - - @classmethod - def from_file(cls, path: str | Path) -> DatasetConfig: ... -``` - -Phase B - DataPrep - -``` -spec = DatasetConfig.from_file("datasets/pick_red.yaml") - -learning_dataprep = autoconnect( - DataPrepModule.blueprint(spec=spec), -).transports({}) -``` - -- DataPrepModule - -``` python -class DataPrepModuleConfig(ModuleConfig): - spec: DatasetConfig - output_dir: str | None = None - -class DataPrepModule(Module): - config: DataPrepModuleConfig - - @rpc - def build(self, output_dir: str | None = None) -> None: ... - @rpc - def cancel(self) -> bool: ... - @rpc - def get_status(self) -> dict[str, Any]: ... - @rpc - def inspect(self) -> dict[str, Any]: ... -``` diff --git a/dimos/learning/policy/base.py b/dimos/learning/policy/base.py index 021f95229c..dc31a4fa10 100644 --- a/dimos/learning/policy/base.py +++ b/dimos/learning/policy/base.py @@ -12,16 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Policy abstraction — what `ChunkPolicyModule` calls every inference tick. +"""ActionChunk message + Policy backend Protocol. -The Policy protocol decouples model format (lerobot PreTrainedPolicy in v1, -ONNX/TorchScript in v2) from the inference module. Anything that satisfies -this protocol is droppable into a blueprint. - -`ActionChunk` is the typed message published by `ChunkPolicyModule` and -consumed by `ActionReplayer`. v1 uses a pydantic model; v2 will replace it -with a generated LCM type so it can flow over the wire — the field layout -here matches what that LCM type will look like. +`ChunkPolicyModule` produces ActionChunks; `ActionReplayer` consumes them. +Any policy backend (lerobot in v1) just needs to satisfy `Policy`. """ from __future__ import annotations @@ -34,15 +28,10 @@ class ActionChunk(BaseModel): - """A predicted sequence of joint targets, plus the metadata to replay it. + """T future joint targets + the metadata to replay them. - Fields: - ts: wall-clock time the chunk was produced (seconds). - joint_names: names matching the action key ordering used at training. - positions: shape (T, N) — T future steps, N = len(joint_names). - dt: expected interval between successive actions (seconds). - Replayer uses ts + i*dt as the target time for action i. - chunk_id: monotonic id for ordering / dedup at the replayer. + positions: shape (T, N), N = len(joint_names). + Replayer uses ts + i*dt as the target time for positions[i]. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -59,35 +48,13 @@ class Policy(Protocol): """What ChunkPolicyModule needs from any policy implementation.""" @classmethod - def load(cls, path: str | Path, device: str = "cuda") -> Policy: - """Load a checkpoint directory. `path` is a lerobot checkpoint dir in v1. - - Implementations should also load the sidecar `dimos_meta.json` and - `meta/stats.json` so `predict_chunk` can normalize/unnormalize without - the caller doing it. - """ - ... - - @property - def chunk_size(self) -> int: - """Number of actions emitted per `predict_chunk` call (T).""" - ... + def load(cls, path: str | Path, device: str = "cuda") -> Policy: ... @property - def joint_names(self) -> list[str]: - """Action joint names, matching the spec's action key ordering.""" - ... - + def chunk_size(self) -> int: ... @property - def expects_language(self) -> bool: - """True if the policy reads `obs['language_text']` (VLAs); False otherwise.""" - ... + def joint_names(self) -> list[str]: ... def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: - """Return shape (chunk_size, action_dim) — already unnormalized to joint space. - - `obs` keys must match `spec.observation`. The policy applies its own - input normalization internally (using the stats it loaded with the - checkpoint). - """ + """Return shape (chunk_size, action_dim), unnormalized to joint space.""" ... diff --git a/dimos/learning/policy/lerobot_policy.py b/dimos/learning/policy/lerobot_policy.py index 6d9cbd3d8c..c6d764ed23 100644 --- a/dimos/learning/policy/lerobot_policy.py +++ b/dimos/learning/policy/lerobot_policy.py @@ -12,15 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LeRobot policy wrapper. - -Wraps any `lerobot.PreTrainedPolicy` (ACT, Diffusion, pi0, pi0.5) behind the -`Policy` protocol. This is the only Policy implementation in v1 — both -training entry points produce checkpoints loadable by this class. - -Heavy deps (`lerobot`, `torch`) are imported lazily inside `load()` so simply -importing this module does not require a CUDA install. -""" +"""LeRobot ACT policy wrapper. Lazy-imports lerobot/torch in load().""" from __future__ import annotations @@ -33,15 +25,10 @@ class LeRobotPolicy: - """Adapter for lerobot's PreTrainedPolicy → DimOS Policy protocol.""" - - # Type-erased to keep this file import-light. Concrete type: - # _model: lerobot.policies.pretrained.PreTrainedPolicy - _model: Any + _model: Any # lerobot.policies.pretrained.PreTrainedPolicy _stats: dict[str, Any] _chunk_size: int _joint_names: list[str] - _expects_language: bool _device: str def __init__( @@ -50,24 +37,13 @@ def __init__( stats: dict[str, Any], chunk_size: int, joint_names: list[str], - expects_language: bool, device: str, ) -> None: - """Direct constructor — prefer `LeRobotPolicy.load(path)` in user code.""" raise NotImplementedError @classmethod def load(cls, path: str | Path, device: str = "cuda") -> LeRobotPolicy: - """Load a lerobot checkpoint directory. - - Expected layout under `path`: - config.json / model.safetensors - the lerobot checkpoint - meta/stats.json - normalization stats - dimos_meta.json - DimOS sidecar (spec + provenance) - - Auto-detects the policy class (act / diffusion / pi0 / pi0_5) from - the lerobot config and sets `expects_language` accordingly. - """ + """Load checkpoint dir: model.safetensors + meta/stats.json + dimos_meta.json.""" raise NotImplementedError @property @@ -78,26 +54,10 @@ def chunk_size(self) -> int: def joint_names(self) -> list[str]: return self._joint_names - @property - def expects_language(self) -> bool: - return self._expects_language - def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: - """Run one forward pass; return (chunk_size, action_dim). - - Steps: - 1. Normalize obs via `self._stats` (image: /255 + per-channel norm; - vector: (x - mean) / std). - 2. Convert to torch tensors on `self._device`, add batch dim. - 3. Call `self._model.select_action_chunk(obs)` (or equivalent). - 4. Move back to numpy, drop batch dim. - 5. Unnormalize actions via `self._stats`. - - Matches the pipeline used inside lerobot's training loop, so live - inference sees the same numerics as training-time evaluation. - """ + """Normalize obs → forward pass → unnormalize → (chunk_size, action_dim).""" raise NotImplementedError -# Sanity check: make the protocol relationship explicit at import time. +# Protocol conformance check at import time. _: type[Policy] = LeRobotPolicy diff --git a/dimos/learning/spec.py b/dimos/learning/spec.py deleted file mode 100644 index da5ac79028..0000000000 --- a/dimos/learning/spec.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Data definitions for the DimOS Learning Framework. - -Contains the YAML/JSON-backed `DatasetSpec` schema and the runtime data -classes (`Episode`, `Sample`) shared between collection, training, and -inference. No logic — just typed records and constants. Safe to import -from anywhere (no circular dependencies). -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any, Literal - -import numpy as np -from pydantic import BaseModel, ConfigDict, Field - -# ───────────────────────────────────────────────────────────────────────────── -# DatasetSpec — the YAML/JSON schema -# ───────────────────────────────────────────────────────────────────────────── - - -class DatasetSpec(BaseModel): - """Top-level spec. Same instance used at build, load, and inference time. - - A `DatasetSpec` (loaded from YAML/JSON) is the contract between data - collection (raw RecordReplay session -> on-disk dataset) and training - (loading the same spec to feed a model). The same spec also drives - inference observation construction. - """ - - source: Path # path to session.db produced by RecordReplay - episodes: EpisodeConfig - observation: dict[str, StreamField] # obs key -> stream field - action: dict[str, StreamField] # action key -> stream field - sync: SyncConfig - filters: FilterConfig | None = None - output: OutputConfig | None = None # only required by DataPrep.build() - - @classmethod - def from_file(cls, path: str | Path) -> DatasetSpec: - """Load from .yaml/.yml/.json (dispatch by extension).""" - raise NotImplementedError - - def save(self, path: str | Path) -> None: - """Write to .yaml/.yml/.json (round-trip safe).""" - raise NotImplementedError - - -class EpisodeConfig(BaseModel): - """How to slice the continuous recording into episodes.""" - - extractor: Literal["buttons", "ranges", "whole_session"] = "buttons" - - # BUTTONS extractor: friendly names map to Quest Buttons attrs via BUTTON_ALIASES. - # The state machine always discards an in-progress episode if the recording ends - # without an explicit save/discard press. - button_stream: str = "buttons" - start: str = "A" # rising edge -> begin episode - save: str = "B" # rising edge -> end + save - discard: str = "X" # rising edge -> end + drop - - # RANGES extractor: explicit absolute timestamps - ranges: list[tuple[float, float]] | None = None - - # Default label/description applied to every extracted episode unless overridden. - # task_description is the free-form natural-language string used as language - # conditioning for VLA policies (e.g. "pick up the red cube and place it on the blue plate"). - default_task_label: str | None = None - default_task_description: str | None = None - - -class StreamField(BaseModel): - """Pointer to a field in a recorded stream — one (obs|action) key's data source.""" - - stream: str # LCM stream / topic name as recorded in session.db - type: str | None = None # optional dotted type (e.g. "sensor_msgs.Image"); for codec dispatch - field: str | None = None # attribute on the message; None = whole message - preprocess: str | None = None # named preprocess hook (e.g. "jpeg_decode", "normalize_image") - - -class SyncConfig(BaseModel): - """How to build per-timestep samples by aligning multiple streams.""" - - anchor: str # key in `observation` that drives the timeline - rate_hz: float = 30.0 # downsample anchor to this rate; 0 = use anchor's native rate - tolerance_ms: float = 50.0 # max allowed time delta when picking nearest sample - strategy: Literal["nearest", "interp"] = "nearest" - - -class FilterConfig(BaseModel): - """Per-episode filters applied after extraction.""" - - success_only: bool = True - min_duration_s: float = 0.0 - max_duration_s: float | None = None - task_labels: list[str] | None = None # whitelist; None = all - - # Train/val split. Episodes whose index lands in val become the validation set - # at training time; everything else is train. `val_episode_ids` takes precedence - # over `val_ratio`. Both None = no split (everything is train). - val_episode_ids: list[int] | None = None - val_ratio: float | None = None - val_split_seed: int = 0 - - -class OutputConfig(BaseModel): - """Where and how to write the built dataset.""" - - format: Literal["lerobot", "hdf5", "rlds"] - path: Path - metadata: dict[str, Any] = Field(default_factory=dict) - - -# ───────────────────────────────────────────────────────────────────────────── -# Runtime data -# ───────────────────────────────────────────────────────────────────────────── - - -# Friendly Quest controller names -> Buttons attribute names. -# Override by supplying an attribute name directly in the spec. -BUTTON_ALIASES: dict[str, str] = { - "A": "right_primary", - "B": "right_secondary", - "X": "left_primary", - "Y": "left_secondary", - "LT": "left_trigger", - "RT": "right_trigger", - "LG": "left_grip", - "RG": "right_grip", - "MENU_L": "left_menu", - "MENU_R": "right_menu", -} - - -class Episode(BaseModel): - """A single demonstration carved from a session.""" - - id: str - start_ts: float - end_ts: float - task_label: str | None = None # short categorical tag (e.g. "pick_red_cube") - task_description: str | None = None # free-form natural-language string for VLA conditioning - success: bool = True - metadata: dict[str, Any] = Field(default_factory=dict) - - @property - def duration(self) -> float: - return self.end_ts - self.start_ts - - -class Sample(BaseModel): - """One synchronized timestep: aligned obs + action at ts.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - ts: float - episode_id: str - observation: dict[str, np.ndarray] - action: dict[str, np.ndarray] - - -# DatasetSpec is defined before its referenced subclasses so it reads as the -# top-of-file entry point. Resolve those forward references now that every -# referenced class exists in the module namespace. -DatasetSpec.model_rebuild() diff --git a/dimos/learning/specs/datacollection.md b/dimos/learning/specs/datacollection.md index e348f1454b..b6e7afadc0 100644 --- a/dimos/learning/specs/datacollection.md +++ b/dimos/learning/specs/datacollection.md @@ -1,11 +1,7 @@ # Stage 1 — Data -Two phases: - -1. **Recording** — live; operator drives the robot. `RecordReplay` writes streams to `session.db`. -2. **DataPrep** — offline; convert `session.db` → `dataset/`. - -One `dataset.yaml` drives both, parsed once into a `DatasetConfig`. +1. **Recording** — live; `RecordReplay` writes streams to `session.db`. +2. **DataPrep** — offline; `session.db` → `dataset/` (LeRobot v2). --- @@ -15,12 +11,13 @@ One `dataset.yaml` drives both, parsed once into a `DatasetConfig`. ```python # dimos/learning/collection/blueprint.py -spec = DatasetConfig.from_file("datasets/pick_red.yaml") - learning_collect_quest_xarm7 = autoconnect( teleop_quest_xarm7, RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(spec=spec), + EpisodeMonitorModule.blueprint( + button_map={"start": "A", "save": "B", "discard": "X"}, + default_task_label="pick_red_cube", + ), ).transports({ ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), @@ -28,108 +25,12 @@ learning_collect_quest_xarm7 = autoconnect( }) ``` -`RecordReplay` (`--record-path`) captures every transport above, including `episode_status`. - ---- - -### Dataset spec — YAML - -```yaml -# datasets/pick_red.yaml -source: session.db - -episodes: - extractor: episode_status - status_stream: episode_status - default_task_label: pick_red_cube - button_map: {start: A, save: B, discard: X} - keyboard_map: {start: space, save: s, discard: d} - -observation: - cam: - stream: camera_color_image - field: image - joint_pos: - stream: coordinator_joint_state - field: position - -action: - joint_target: - stream: coordinator_joint_command - field: position - -sync: - anchor: cam - rate_hz: 30 - tolerance_ms: 50 - strategy: nearest - -output: - format: lerobot - path: datasets/pick_red/ - metadata: {fps: 30, robot: xarm7} -``` - ---- - -### Dataset spec — pydantic classes - -```python -# dimos/learning/config.py -from dimos.protocol.service.spec import BaseConfig # extra="forbid" - - -class DatasetConfig(BaseConfig): - source: str - episodes: EpisodeConfig - observation: dict[str, StreamField] - action: dict[str, StreamField] - sync: SyncConfig - output: OutputConfig - - @classmethod - def from_file(cls, path: str | Path) -> DatasetConfig: ... - - -class EpisodeConfig(BaseConfig): - extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" - status_stream: str = "episode_status" - ranges: list[tuple[float, float]] | None = None - default_task_label: str | None = None - button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} - keyboard_map: dict[Literal["start", "save", "discard"], str] = {} - - -class StreamField(BaseConfig): - stream: str - field: str | None = None - - -class SyncConfig(BaseConfig): - anchor: str - rate_hz: float - tolerance_ms: float - strategy: Literal["nearest", "interp"] = "nearest" - - -class OutputConfig(BaseConfig): - format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" - path: Path - metadata: dict[str, Any] = {} -``` - -| Module | Reads | -|---|---| -| `EpisodeMonitorModule` | `spec.episodes.button_map` / `keyboard_map` / `default_task_label` | -| `DataPrepModule` | full spec | -| `ChunkPolicyModule` | `spec.observation`, `spec.sync` | - ---- +`RecordReplay` (`--record-path`) captures every transport above into `session.db`. ### EpisodeMonitorModule -Translates teleop input (buttons, keyboard, future inputs) into a canonical -`EpisodeStatus` stream. `DataPrep` reads only that stream — never raw inputs. +Translates teleop input (buttons, keyboard) into the canonical +`EpisodeStatus` stream. DataPrep reads only this stream — never raw inputs. ```python # dimos/learning/collection/episode_monitor.py @@ -144,7 +45,9 @@ class EpisodeStatus(BaseModel): class EpisodeMonitorModuleConfig(ModuleConfig): - spec: DatasetConfig + button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} + keyboard_map: dict[Literal["start", "save", "discard"], str] = {} + default_task_label: str | None = None class EpisodeMonitorModule(Module): @@ -173,31 +76,35 @@ RECORDING --start--> RECORDING (auto-commit prev) session end mid-episode: always discard ``` ---- - ### Run ```bash -dimos run learning-collect-quest-xarm7 \ - --spec-path datasets/pick_red.yaml \ - --record-path data/pick_red.db +dimos run learning-collect-quest-xarm7 --record-path data/sessions/pick_red.db ``` --- ## Phase B — DataPrep -Reads `session.db`, slices on `episode_status`, syncs streams, writes -`dataset/`. Heavy deps run in a subprocess. - ### Blueprint ```python # dimos/learning/dataprep/blueprint.py -spec = DatasetConfig.from_file("datasets/pick_red.yaml") - learning_dataprep = autoconnect( - DataPrepModule.blueprint(spec=spec, auto_run=True), + DataPrepModule.blueprint( + source="data/sessions/pick_red.db", + episodes=EpisodeExtractor(), + observation={ + "cam": StreamField(stream="camera_color_image", field="image"), + "joint_pos": StreamField(stream="coordinator_joint_state", field="position"), + }, + action={ + "joint_target": StreamField(stream="coordinator_joint_command", field="position"), + }, + sync=SyncConfig(anchor="cam", rate_hz=30, tolerance_ms=50), + output=OutputConfig(format="lerobot", path=Path("data/datasets/pick_red/")), + auto_run=True, + ), ).transports({}) ``` @@ -205,30 +112,62 @@ learning_dataprep = autoconnect( ```python # dimos/learning/dataprep_module.py +from dimos.protocol.service.spec import BaseConfig + + +class EpisodeExtractor(BaseConfig): + extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" + status_stream: str = "episode_status" + ranges: list[tuple[float, float]] | None = None + + +class StreamField(BaseConfig): + stream: str + field: str | None = None + + +class SyncConfig(BaseConfig): + anchor: str + rate_hz: float + tolerance_ms: float + strategy: Literal["nearest", "interp"] = "nearest" + + +class OutputConfig(BaseConfig): + format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" + path: Path + metadata: dict[str, Any] = {} + class DataPrepModuleConfig(ModuleConfig): - spec: DatasetConfig - output_dir: str | None = None - auto_run: bool = False + source: str + episodes: EpisodeExtractor + observation: dict[str, StreamField] + action: dict[str, StreamField] + sync: SyncConfig + output: OutputConfig + auto_run: bool = False class DataPrepModule(Module): config: DataPrepModuleConfig @rpc - def build(self, output_dir: str | None = None) -> None: ... - @rpc - def cancel(self) -> bool: ... + def build(self) -> None: ... @rpc def get_status(self) -> dict[str, Any]: ... @rpc def inspect(self) -> dict[str, Any]: ... ``` +`build()` iterates samples, hands them to the format writer, and snapshots +`config.model_dump()` into `/dimos_meta.json`. Stats are +written into `meta/stats.json` by `DataPrep.compute_stats`. + ### Run ```bash -dimos run learning-dataprep --spec-path datasets/pick_red.yaml +dimos run learning-dataprep ``` --- @@ -236,12 +175,14 @@ dimos run learning-dataprep --spec-path datasets/pick_red.yaml ## End-to-end ```bash -SPEC=datasets/pick_red.yaml - -dimos run learning-collect-quest-xarm7 --spec-path $SPEC --record-path data/pick_red.db -dimos run learning-dataprep --spec-path $SPEC +dimos run learning-collect-quest-xarm7 --record-path data/sessions/pick_red.db +dimos run learning-dataprep ``` ``` -session.db ─► dataset/ + meta/stats.json +data/sessions/pick_red.db ─► data/datasets/pick_red/ + ├── data/ (parquet) + ├── videos/ (MP4) + └── meta/ (info.json, episodes.jsonl, + stats.json, dimos_meta.json) ``` diff --git a/dimos/learning/specs/inference.md b/dimos/learning/specs/inference.md index 10e48b7701..88e60569f1 100644 --- a/dimos/learning/specs/inference.md +++ b/dimos/learning/specs/inference.md @@ -1,10 +1,20 @@ # Stage 3 — Inference -- **`ChunkPolicyModule`** — Module @ 1–30 Hz. Builds obs via `ObsBuilder`, - calls `policy.predict_chunk(obs)`, emits `ActionChunk`. -- **`ActionReplayer`** — `BaseControlTask` in the 100 Hz `ControlCoordinator` - tick loop. Buffers chunks (latest-wins), interpolates to `state.now`, - emits `JointCommandOutput`. Holds last position on stall. +ACT only. Two pieces: + +- **`ChunkPolicyModule`** (`learning/inference/`) — Module @ ~30 Hz. + Builds obs, calls `policy.predict_chunk(obs)`, emits `ActionChunk`. +- **`ActionReplayer`** (`control/tasks/`) — `BaseControlTask` in the + 100 Hz `ControlCoordinator` tick loop. Buffers chunks, interpolates + to `state.now`, emits `JointCommandOutput`. + +``` +ChunkPolicyModule (~15-30 Hz) + │ ActionChunk (LCM) + ▼ +ControlCoordinator @ 100 Hz + └─ ActionReplayer.compute(state) → JointCommandOutput → hardware +``` --- @@ -12,31 +22,26 @@ ```python # dimos/learning/inference/blueprint.py -from dimos.learning.config import ActionChunk +from dimos.learning.policy.base import ActionChunk learning_infer_xarm7 = autoconnect( RealSenseCamera.blueprint(enable_pointcloud=False), ChunkPolicyModule.blueprint( - policy_path="runs/act_pick_red", + policy_path="data/runs/act_pick_red", inference_rate_hz=30.0, ), - coordinator_action_replayer_xarm7, + coordinator_action_replayer_xarm7, # registers ActionReplayer with the coordinator ).transports({ - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("language_text", str): LCMTransport("/learning/language_text", str), - ("action_chunk", ActionChunk): LCMTransport("/learning/action_chunk", ActionChunk), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("action_chunk", ActionChunk): LCMTransport("/learning/action_chunk", ActionChunk), }) ``` ## Message types -`ActionChunk` lives in `dimos/learning/config.py` next to `EpisodeStatus` -and `DatasetConfig` — single import for all cross-stage contracts. -`Policy` is the backend Protocol; lives in `policy/base.py`. - ```python -# dimos/learning/config.py +# dimos/learning/policy/base.py class ActionChunk(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -45,63 +50,55 @@ class ActionChunk(BaseModel): positions: np.ndarray # (T, N) dt: float chunk_id: int -``` -```python -# dimos/learning/policy/base.py @runtime_checkable class Policy(Protocol): @classmethod def load(cls, path: str | Path, device: str = "cuda") -> Policy: ... - - @property - def chunk_size(self) -> int: ... @property - def joint_names(self) -> list[str]: ... + def chunk_size(self) -> int: ... @property - def expects_language(self) -> bool: ... - + def joint_names(self) -> list[str]: ... def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: ... ``` +`LeRobotPolicy` (`policy/lerobot_policy.py`) is the v1 implementation. + ## ChunkPolicyModule ```python # dimos/learning/inference/chunk_policy_module.py +from dimos.learning.dataprep import StreamField, SyncConfig, resolve_field + class ChunkPolicyModuleConfig(ModuleConfig): - policy_path: str # spec read from /dimos_meta.json - inference_rate_hz: float = 5.0 + policy_path: str + inference_rate_hz: float = 30.0 device: str = "cuda" - default_language: str = "" class ChunkPolicyModule(Module): config: ChunkPolicyModuleConfig - color_image: In[Image] - joint_state: In[JointState] - language_text: In[str] - action_chunk: Out[ActionChunk] + color_image: In[Image] + joint_state: In[JointState] + action_chunk: Out[ActionChunk] - @rpc - def set_language(self, text: str) -> None: ... @rpc def reload_policy(self, policy_path: str, device: str | None = None) -> None: ... @rpc def get_status(self) -> dict[str, Any]: ... - # Lifecycle: start() loads policy + spawns the loop thread; stop() joins it. def _run_loop(self) -> None: period = 1.0 / self.config.inference_rate_hz while not self._stop.is_set(): t0 = time.monotonic() obs = self._build_live_obs() - if obs is None: # waiting for first frames + if obs is None: time.sleep(period); continue - positions = self.policy.predict_chunk(obs) # (T, action_dim) + positions = self.policy.predict_chunk(obs) self.action_chunk.publish(ActionChunk( ts=time.time(), joint_names=self.policy.joint_names, @@ -112,35 +109,25 @@ class ChunkPolicyModule(Module): time.sleep(max(0.0, period - (time.monotonic() - t0))) def _build_live_obs(self) -> dict[str, np.ndarray] | None: - # snapshot latched In[Image] / In[JointState] / In[str] under a lock, - # hand to ObsBuilder.build(...) → returns obs dict or None if not ready + # snapshot latched In[Image] / In[JointState], project via + # resolve_field using self._observation (StreamField map + # loaded from /dimos_meta.json at start()). ... ``` -## ObsBuilder - -`ChunkPolicyModule.start()` reads the embedded spec from -`/dimos_meta.json` and constructs the `ObsBuilder` from it -— no `--spec-path` needed at inference. - -```python -# dimos/learning/inference/obs_builder.py - -class ObsBuilder: - def __init__(self, spec: DatasetConfig) -> None: ... - def build(self, live_messages: dict[str, Any]) -> dict[str, np.ndarray]: ... - def required_streams(self) -> set[str]: ... -``` +`start()` reads `/dimos_meta.json`, reconstructs +`observation: dict[str, StreamField]` and `sync: SyncConfig`, and stores +them as instance state. `_build_live_obs` calls `resolve_field` +on each entry — same projection as training, no train/serve skew. ## ActionReplayer ```python -# dimos/learning/inference/action_replayer.py +# dimos/control/tasks/action_replayer_task.py @dataclass class ActionReplayerConfig: joint_names: list[str] - chunk_topic: str = "action_chunk" priority: int = 10 max_chunk_age_s: float = 0.5 hold_on_stall: bool = True @@ -158,27 +145,39 @@ class ActionReplayer(BaseControlTask): def on_action_chunk(self, msg: ActionChunk) -> None: ... ``` +## ControlCoordinator wiring + +`ControlCoordinator` gains a new port + dispatcher (mirrors how +`cartesian_command` / `twist_command` are routed): + +```python +class ControlCoordinator(Module): + # ... existing ports ... + action_chunk: In[ActionChunk] + + def _on_action_chunk(self, msg: ActionChunk) -> None: + for task in self._tasks: + if isinstance(task, ActionReplayer): + task.on_action_chunk(msg) +``` + --- ## Run ```bash -dimos run learning-infer-xarm7 --policy-path runs/act_pick_red +dimos run learning-infer-xarm7 ``` ---- - ## End-to-end ```bash -SPEC=datasets/pick_red.yaml - -dimos run learning-collect-quest-xarm7 --spec-path $SPEC --record-path data/pick_red.db -dimos run learning-dataprep --spec-path $SPEC -dimos run learning-train --dataset-path dataset/ --output-dir runs/act_pick_red -dimos run learning-infer-xarm7 --policy-path runs/act_pick_red +dimos run learning-collect-quest-xarm7 --record-path data/sessions/pick_red.db +dimos run learning-dataprep +dimos run learning-train +dimos run learning-infer-xarm7 ``` ``` -session.db ─► dataset/ ─► checkpoint/ ─► live policy +data/sessions/pick_red.db ─► data/datasets/pick_red/ ─► data/runs/act_pick_red/ ─► live policy ``` diff --git a/dimos/learning/specs/structure.md b/dimos/learning/specs/structure.md index 0f8ce55da8..b5577ed87f 100644 --- a/dimos/learning/specs/structure.md +++ b/dimos/learning/specs/structure.md @@ -1,141 +1,126 @@ # Folder Structure -The four spec docs in this directory are the source of truth. The code -tree below is the implementation layout — each file maps to a section in -one of the three stage docs. +Per-producer types: each Module owns its config + emitted message types +in its own file. No shared `config.py`, no umbrella class, no shared YAML. ``` dimos/learning/ │ -├── specs/ # ← spec docs (you are here) -│ ├── structure.md # this file — folder layout -│ ├── datacollection.md # Stage 1 — recording + dataprep + inspect -│ ├── training.md # Stage 2 — TrainerModule -│ └── inference.md # Stage 3 — ChunkPolicyModule + ActionReplayer +├── specs/ +│ ├── structure.md +│ ├── datacollection.md # Stage 1 +│ ├── training.md # Stage 2 +│ └── inference.md # Stage 3 │ -├── __init__.py -├── config.py # DatasetConfig + sub-configs (pydantic BaseConfig) -├── dataset.example.yaml # annotated example spec +├── dataprep.py # types + pure helpers (no Module) +│ # - Episode, Sample +│ # - StreamField, SyncConfig, OutputConfig, EpisodeExtractor +│ # - resolve_field, compute_stats, +│ # extract_episodes, iter_episode_samples +├── dataprep_module.py # DataPrepModule(Config) only │ -├── dataprep.py # DataPrep façade + resolve_field staticmethod -│ # `python -m dimos.learning.dataprep build|inspect` -├── dataprep_module.py # DataPrepModule (wraps the subprocess for blueprint UX) +├── collection/ +│ ├── episode_monitor.py # EpisodeStatus + EpisodeMonitorModule(Config) +│ └── blueprint.py # learning_collect_quest_ │ -├── collection/ # ── Stage 1 / Phase A: live recording ── -│ ├── __init__.py -│ ├── episode_monitor.py # EpisodeStatus, EpisodeMonitorModule(Config) -│ └── blueprint.py # learning_collect_quest_{xarm7,xarm6,piper,dual} -│ -├── formats/ # ── dataset writers (DataPrep._get_writer dispatches) ── -│ ├── __init__.py +├── formats/ # dataset writers; each calls DataPrep.compute_stats │ ├── lerobot.py # LeRobot v2 (parquet + MP4 + meta/stats.json) -│ ├── hdf5.py # flat HDF5 -│ └── rlds.py # RLDS / TFDS +│ ├── hdf5.py +│ └── rlds.py │ -├── training/ # ── Stage 2: offline training ── -│ ├── __init__.py -│ ├── trainer_module.py # TrainProgress, TrainDone, TrainerModule(Config) -│ ├── train.py # subprocess CLI -│ # `python -m dimos.learning.training.train {bc|vla}` -│ ├── configs.py # bc / vla training configs -│ ├── split.py # train/val episode-level split -│ ├── stats.py # meta/stats.json computation (norm/unnorm) +├── training/ +│ ├── trainer_module.py # TrainerModule(Config); runs train_bc on a thread +│ ├── train.py # train_bc + train_val_split (lazy lerobot/torch) +│ ├── configs.py # BCConfig │ └── blueprint.py # learning_train │ -├── policy/ # ── policy backends (live + checkpoint loading) ── -│ ├── __init__.py -│ ├── base.py # ActionChunk pydantic + Policy Protocol -│ └── lerobot_policy.py # LeRobotPolicy.load → reads dimos_meta.json + stats.json +├── policy/ +│ ├── base.py # ActionChunk + Policy Protocol +│ └── lerobot_policy.py # LeRobotPolicy.load │ -└── inference/ # ── Stage 3: live policy serving ── - ├── __init__.py - ├── chunk_policy_module.py # ChunkPolicyModule(Config); slow Module @ 1–30 Hz - ├── obs_builder.py # ObsBuilder; calls DataPrep.resolve_field - ├── action_replayer.py # ActionReplayer (BaseControlTask, NOT a Module) - └── blueprint.py # learning_infer_{xarm7,xarm6,piper} - # + learning_infer_vla_{xarm7,...} +└── inference/ + ├── chunk_policy_module.py # ChunkPolicyModule(Config); ~30 Hz + │ # (obs construction is a private method; + │ # uses DataPrep.resolve_field) + └── blueprint.py # learning_infer_ ``` ---- +`ActionReplayer` is a `ControlTask`, not a learning Module — it lives +with the other coordinator tasks: -## Where each artifact is produced / consumed +``` +dimos/control/ +├── coordinator.py # adds action_chunk: In[ActionChunk] +│ # _on_action_chunk → ActionReplayer +└── tasks/ + ├── teleop_task.py + ├── ... + └── action_replayer_task.py # NEW; imports ActionChunk from learning/policy/base.py +``` -| Artifact | Producer | Consumer | -|---|---|---| -| `dataset.yaml` | human (operator) | `DataPrep`, `ObsBuilder` | -| `session.db` | `RecordReplay` (transport hook, `--record-path`) | `DataPrep` | -| `dataset/` + stats | `dataprep build` → `formats/.py` | `lerobot.LeRobotDataset`, `train.py` | -| `checkpoint/` + meta | `train.py` | `LeRobotPolicy.load`, `ChunkPolicyModule` | -| `ActionChunk` (live) | `ChunkPolicyModule` (Module, LCM) | `ActionReplayer` (BaseControlTask) | -| `JointCommandOutput` | `ActionReplayer` (in 100 Hz tick loop) | `ControlCoordinator` → hardware | +Dependency: `control → learning.policy` (one-way). --- -## `DatasetConfig` as the single source of truth +## Per-producer typed contracts -`DatasetConfig` (loaded once from `dataset.yaml`) drives module configs -across stages — same instance, no drift between train and serve. +| Class | Lives in | Used by | +|---|---|---| +| `EpisodeStatus`, `EpisodeMonitorModuleConfig` | `learning/collection/episode_monitor.py` | `EpisodeMonitorModule`; `DataPrep` | +| `EpisodeExtractor`, `StreamField`, `SyncConfig`, `OutputConfig`, `Episode`, `Sample` | `learning/dataprep.py` | `DataPrepModule`, `ChunkPolicyModule`, format writers | +| `DataPrepModuleConfig` | `learning/dataprep_module.py` | `DataPrepModule` | +| `BCConfig` | `learning/training/configs.py` | `train_bc` | +| `TrainerModuleConfig` | `learning/training/trainer_module.py` | `TrainerModule` | +| `ActionChunk`, `Policy` Protocol | `learning/policy/base.py` | `ChunkPolicyModule`, `ActionReplayer`, `ControlCoordinator` | +| `ChunkPolicyModuleConfig` | `learning/inference/chunk_policy_module.py` | `ChunkPolicyModule` | +| `ActionReplayerConfig` | `control/tasks/action_replayer_task.py` | `ActionReplayer` | -```python -# Top-level, in each blueprint factory: -spec = DatasetConfig.from_file(spec_path) +--- -# Passed as a typed field on each module's config: -EpisodeMonitorModule.blueprint(spec=spec) # Stage 1: spec.episodes -DataPrepModule.blueprint(spec=spec) # Stage 1: full spec -ChunkPolicyModule.blueprint(spec=spec, ...) # Stage 3: spec.observation, spec.sync -``` +## Artifact flow -| Stage | Module | How it gets the spec | -|---|---|---| -| 1A | `EpisodeMonitorModule` | passed in via blueprint (`spec=spec`); reads `spec.episodes` for button maps | -| 1B | `DataPrepModule` | passed in via blueprint; reads full spec. **DataPrep snapshots the spec into `dataset/dataset.yaml`** so downstream stages don't need the YAML. | -| 2 | `TrainerModule` | reads `dataset/dataset.yaml` + LeRobot `info.json`; copies spec snapshot into `checkpoint/dimos_meta.json` | -| 3 | `ChunkPolicyModule` | reads `/dimos_meta.json` at `start()`; constructs `ObsBuilder` from the embedded spec. **No `--spec-path` flag needed at inference.** | +All generated artifacts live under `data/` (gitignored at repo root): -The operator only ever passes `--spec-path` for Recording and DataPrep -(stages where the spec is the input). After DataPrep, the spec rides -with the data. +``` +data/ +├── sessions/.db ← RecordReplay +├── datasets// ← DataPrepModule.build() +│ ├── data/ (parquet) +│ ├── videos/ (MP4) +│ └── meta/ +│ ├── info.json +│ ├── episodes.jsonl +│ ├── stats.json (DataPrep.compute_stats) +│ └── dimos_meta.json (DataPrepModuleConfig.model_dump()) +└── runs// ← train_bc + ├── *.safetensors + └── dimos_meta.json (dataset snapshot + policy fields) +``` -Same `resolve_field` is invoked from `DataPrep.iter_episode_samples` -(Stage 1B) and `ObsBuilder.build` (Stage 3). One source of truth → -no train/serve skew. +`dimos_meta.json` rides with the data: DataPrep writes it; training +copies it forward + adds policy fields; inference reads it at `start()`. +Operator never passes a spec path. --- -## What's deliberately not in this tree - -- **`RecordReplay`** — transport-layer hook (in `dimos/core/`), not a - `learning/` Module. Enabled by `--record-path` at the CLI; unaware of - what's recording. -- **`coordinator_action_replayer_`** — per-robot coordinator - blueprints that register the `ActionReplayer` task. These live next - to the rest of the per-robot wiring (likely - `dimos/robot//blueprints.py`), not under `learning/`. -- **A second `ControlCoordinator`** — the existing one is reused. We add - one task type (`ActionReplayer`), not a parallel control stack. -- **New transports** — v1 is LCM-only on the wire. -- **New LCM message types** — `ActionChunk` is local-only pydantic in v1. - Promote to a generated LCM type in v2 only if cross-language consumers - need it. +## Configuration ---- +All module config is set as kwargs in the blueprint. No CLI flags on +our modules. Framework CLI surface is `GlobalConfig` only (env vars, +`.env`, things like `--record-path`). -## Module / non-Module split (one rule) +--- -A class becomes a **Module** when it: -- has long-lived state worth `start()/stop()` lifecycle, **and** -- needs typed I/O ports across process boundaries. +## Module / non-Module split -Otherwise it stays a plain class or a `BaseControlTask`: +A class becomes a **Module** when it has long-lived state with +`start()/stop()` lifecycle **and** typed I/O ports. | Class | Type | Why | |---|---|---| -| `EpisodeMonitorModule` | Module | Long-lived; subscribes to buttons; publishes status | -| `DataPrepModule` | Module | Wraps subprocess; agent-callable via `@skill` | -| `TrainerModule` | Module | Wraps subprocess; long-running; agent-callable | -| `ChunkPolicyModule` | Module | Long-lived inference thread; latched In ports | -| `DataPrep` | plain class | Stateless façade over static helpers; no ports | -| `ObsBuilder` | plain class | Pure function over latched messages | -| `ActionReplayer` | `BaseControlTask` | Must run in coordinator's 100 Hz thread, not via transport | -| `RecordReplay` | transport hook | Captures every stream uniformly; not a Module | +| `EpisodeMonitorModule` | Module | Long-lived; subscribes to inputs; publishes status | +| `DataPrepModule` | Module | Long-running build job | +| `TrainerModule` | Module | Runs training on a daemon thread | +| `ChunkPolicyModule` | Module | Long-lived inference thread | +| `ActionReplayer` | `BaseControlTask` | Runs in coordinator's 100 Hz thread | +| `RecordReplay` | transport hook | Captures every stream uniformly | diff --git a/dimos/learning/specs/training.md b/dimos/learning/specs/training.md index e732507f41..fdd731cb5a 100644 --- a/dimos/learning/specs/training.md +++ b/dimos/learning/specs/training.md @@ -1,9 +1,10 @@ # Stage 2 — Training -Offline. Reads `dataset/`, writes `checkpoint/` + `dimos_meta.json`. +ACT only (BC). Reads `data/datasets//`, writes `data/runs//`. -`TrainerModule` is an RPC façade over a training subprocess. Metrics → -TensorBoard. Lifecycle → `get_status()`. +`TrainerModule` runs `train_bc(...)` on a daemon thread inside its own +worker. Lazy imports keep `lerobot` / `torch` / CUDA out of the worker +until `train()` is called. Metrics → TensorBoard. No `cancel()` in v1. --- @@ -12,18 +13,23 @@ TensorBoard. Lifecycle → `get_status()`. ```python # dimos/learning/training/blueprint.py learning_train = autoconnect( - TrainerModule.blueprint(auto_run=True), + TrainerModule.blueprint( + dataset_path="data/datasets/pick_red/", + output_dir="data/runs/act_pick_red", + auto_run=True, + ), ).transports({}) ``` ## Module ```python +# dimos/learning/training/trainer_module.py + class TrainerModuleConfig(ModuleConfig): dataset_path: str = "" output_dir: str = "" - config_kind: Literal["bc", "vla"] = "bc" - config_path: str | None = None + config_path: str | None = None # optional BCConfig YAML override auto_run: bool = False tensorboard_port: int = 6006 @@ -36,25 +42,37 @@ class TrainerModule(Module): self, dataset_path: str | None = None, output_dir: str | None = None, - config_kind: Literal["bc", "vla"] | None = None, config_overrides: dict[str, Any] | None = None, ) -> None: ... @rpc - def cancel(self) -> bool: ... - @rpc def get_status(self) -> dict[str, Any]: ... ``` +## Training entry point + +```python +# dimos/learning/training/train.py + +def train_bc( + dataset_path: str | Path, + cfg: BCConfig, + output_dir: str | Path, + config_overrides: dict[str, Any] | None = None, +) -> Path: + """Lazy-import lerobot, build Hydra-style argv from BCConfig, call + lerobot's training entry point, append `dimos_meta.json` to output_dir, + return the checkpoint path.""" +``` + +`BCConfig` (ACT hyperparams) lives in `training/configs.py`. +`train_val_split()` lives next to `train_bc` in `training/train.py`. + ## Run ```bash -dimos run learning-train \ - --dataset-path dataset/ \ - --output-dir runs/act_pick_red \ - --config-kind bc - -tensorboard --logdir runs/act_pick_red +dimos run learning-train +tensorboard --logdir data/runs/act_pick_red ``` -Artifact: `checkpoint/` = `*.safetensors` + `dimos_meta.json` (spec snapshot, -`joint_names`, `chunk_size`, `policy_type`, `expects_language`). +Artifact: `data/runs/act_pick_red/` = `*.safetensors` + `dimos_meta.json` +(dataset snapshot + `joint_names`, `chunk_size`, `policy_type`). diff --git a/dimos/learning/training/blueprint.py b/dimos/learning/training/blueprint.py index 7c7c97852b..cb170f45b4 100644 --- a/dimos/learning/training/blueprint.py +++ b/dimos/learning/training/blueprint.py @@ -12,92 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Training blueprints for the DimOS Learning Framework. - -Each blueprint composes TrainerModule + LearningMonitorModule. TrainerModule -handles both the dataset-build and the training subprocesses internally -(see its docstring) — there is no separate builder Module in v1. - -Variants: - learning_train_act - auto build (if needed) then train ACT (BC) - learning_train_vla - auto build (if needed) then finetune pi0/pi0.5 - learning_train_idle - module idle, agent drives via @rpc - -Defaults (spec_path, output_dir, ...) are placeholders; override at run -time via CLI flags or @rpc calls. Per-job overrides on the trigger payload -take precedence over module config. - -Usage: - dimos run learning-train-act \\ - --TrainerModule.config.spec_path dataset.yaml \\ - --TrainerModule.config.output_dir runs/act_pick_red -""" +"""ACT training blueprint. RPC-only surface (no streams).""" from __future__ import annotations from dimos.core.coordination.blueprints import autoconnect -from dimos.core.transport import LCMTransport -from dimos.learning.training.monitor_module import LearningMonitorModule -from dimos.learning.training.trainer_module import ( - TrainDone, - TrainerModule, - TrainProgress, -) - -# Topic names — shared across all variants so monitors / agents subscribe once. -_T_TRAIN_PROGRESS = "/learning/train/progress" -_T_TRAIN_DONE = "/learning/train/done" - - -# ── ACT (BC) — auto build (if needed) then train ───────────────────────────── - -learning_train_act = autoconnect( - TrainerModule.blueprint(config_kind="bc", auto_run=True), - LearningMonitorModule.blueprint(log_to_rerun=True), -).transports( - { - ("progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), - ("done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), - ("train_progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), - ("train_done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), - } -) - - -# ── pi0 / pi0.5 (VLA finetune) — auto build (if needed) then train ─────────── - -learning_train_vla = autoconnect( - TrainerModule.blueprint(config_kind="vla", auto_run=True), - LearningMonitorModule.blueprint(log_to_rerun=True), -).transports( - { - ("progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), - ("done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), - ("train_progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), - ("train_done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), - } -) - - -# ── Idle — TrainerModule waits for explicit @rpc / external trigger ────────── -# Agent-driven: agent skill calls TrainerModule.train(...) (or .build_only(...)) -# over RPC. Same module, no auto behavior. +from dimos.learning.training.trainer_module import TrainerModule -learning_train_idle = autoconnect( - TrainerModule.blueprint(auto_run=False), - LearningMonitorModule.blueprint(log_to_rerun=True), -).transports( - { - ("progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), - ("done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), - ("train_progress", TrainProgress): LCMTransport(_T_TRAIN_PROGRESS, TrainProgress), - ("train_done", TrainDone): LCMTransport(_T_TRAIN_DONE, TrainDone), - } -) +learning_train = autoconnect( + TrainerModule.blueprint( + dataset_path="data/datasets/pick_red/", + output_dir="data/runs/act_pick_red", + auto_run=True, + ), +).transports({}) -__all__ = [ - "learning_train_act", - "learning_train_idle", - "learning_train_vla", -] +__all__ = ["learning_train"] diff --git a/dimos/learning/training/configs.py b/dimos/learning/training/configs.py index 1630d925dd..a9428d7c29 100644 --- a/dimos/learning/training/configs.py +++ b/dimos/learning/training/configs.py @@ -12,16 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Trainer configs for v1. - -Two pydantic configs, one per training entry point: - - BCConfig -> consumed by train_bc (ACT, optionally Diffusion) - - VLAConfig -> consumed by finetune_vla (pi0, pi0.5) - -Both are translated into a `lerobot` training config inside the trainer; -fields here are the small, opinionated subset DimOS users actually need to -tune. Anything not exposed falls back to the lerobot default. -""" +"""ACT training config (v1). Fields are the opinionated subset DimOS exposes; +unset = lerobot default. Translated to Hydra-style argv inside `train_bc`.""" from __future__ import annotations @@ -31,15 +23,13 @@ class BCConfig(BaseModel): - """Behavior-cloning trainer config (v1: ACT, with Diffusion as a flag).""" - - policy_type: Literal["act", "diffusion"] = "act" + policy_type: Literal["act"] = "act" # Action chunking - chunk_size: int = 50 # number of future actions predicted per inference call - n_obs_steps: int = 1 # observation history length passed to the policy + chunk_size: int = 50 # future actions per inference call + n_obs_steps: int = 1 # obs history length - # ACT model arch (ignored for Diffusion) + # ACT model arch hidden_dim: int = 512 n_layers: int = 4 n_heads: int = 8 @@ -62,33 +52,3 @@ class BCConfig(BaseModel): eval_every: int = 5_000 seed: int = 0 device: str = "cuda" - - -class VLAConfig(BaseModel): - """VLA finetune config (v1: pi0, pi0.5).""" - - policy_type: Literal["pi0", "pi0_5"] = "pi0_5" - - # Pretrained checkpoint — HF hub id or local path - pretrained_path: str - - # Finetune mode - finetune_mode: Literal["full", "lora"] = "lora" - lora_rank: int = 16 - freeze_vision: bool = True - freeze_language: bool = True - - # Action chunking — pi0/pi0.5 default - chunk_size: int = 50 - - # Optim - steps: int = 30_000 - batch_size: int = 4 - lr: float = 5e-5 - weight_decay: float = 1e-4 - - # Eval / checkpointing - save_every: int = 5_000 - eval_every: int = 2_500 - seed: int = 0 - device: str = "cuda" diff --git a/dimos/learning/training/monitor_module.py b/dimos/learning/training/monitor_module.py deleted file mode 100644 index e771933df6..0000000000 --- a/dimos/learning/training/monitor_module.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Visualize / log training progress. - -Subscribes to the unified `TrainProgress` + `TrainDone` streams from -`TrainerModule` (which covers both build and train phases via `phase` field); -logs to: - - rerun (if the rerun bridge is available — already a DimOS dep) - - JSONL file (structured, post-hoc analysis) - - stdout (always, terse summary line per event) - -Optional in any blueprint. Sits passively on the bus so the same training -session can have multiple monitors (one in dev, one writing to a server). -""" - -from __future__ import annotations - -from typing import Any - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In -from dimos.learning.training.trainer_module import TrainDone, TrainProgress - - -class LearningMonitorModuleConfig(ModuleConfig): - """Where to send progress events. - - Attributes: - log_to_rerun: forward every event to the rerun bridge if importable. - log_to_stdout: print one terse summary line per event. - jsonl_path: if set, append JSON-per-line to this file. - train_loss_smoothing: EMA smoothing factor for the rerun loss curve. - """ - - log_to_rerun: bool = True - log_to_stdout: bool = True - jsonl_path: str | None = None - train_loss_smoothing: float = 0.9 - - -class LearningMonitorModule(Module): - """Pure subscriber. Owns no work, just visualization fan-out.""" - - config: LearningMonitorModuleConfig - - train_progress: In[TrainProgress] - train_done: In[TrainDone] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._jsonl_handle: Any = None # opened in start() - self._train_loss_ema: float | None = None - - @rpc - def start(self) -> None: - """Open the JSONL file (if configured), subscribe to both ports.""" - raise NotImplementedError - - @rpc - def stop(self) -> None: - """Flush + close the JSONL file; super().stop().""" - raise NotImplementedError - - # ── handlers (called from port subscriptions) ──────────────────────────── - - def _on_train_progress(self, msg: TrainProgress) -> None: - """Forward to enabled sinks. Routes by `msg.phase`: - - phase == "build": log dataset progress (episodes, samples) - - phase in {"train","eval"}: log loss curves (with EMA smoothing for rerun) - - other phases: log message line only. - """ - raise NotImplementedError - - def _on_train_done(self, msg: TrainDone) -> None: - """Final summary line; close JSONL section.""" - raise NotImplementedError diff --git a/dimos/learning/training/split.py b/dimos/learning/training/split.py deleted file mode 100644 index b4a6f27b0f..0000000000 --- a/dimos/learning/training/split.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Episode-level train/val split. - -LeRobot v2 supports filtering by episode index at training time, so we don't -materialize two datasets. We compute the partition once and pass the index -lists to the trainer. - -Resolution order (first non-None wins): - 1. `cfg.val_episode_ids` — explicit whitelist - 2. `cfg.val_ratio` — deterministic random split via cfg.val_split_seed - 3. neither set — empty val (everything is train) -""" - -from __future__ import annotations - -from dimos.learning.spec import Episode, FilterConfig - - -def train_val_split( - episodes: list[Episode], - cfg: FilterConfig | None, -) -> tuple[list[int], list[int]]: - """Partition `episodes` (already filtered) into (train_ids, val_ids). - - Returns lists of episode *indices* into `episodes`, not Episode objects. - LeRobot consumes index lists. Determinism is via `cfg.val_split_seed`. - """ - raise NotImplementedError diff --git a/dimos/learning/training/stats.py b/dimos/learning/training/stats.py deleted file mode 100644 index 23ca337c24..0000000000 --- a/dimos/learning/training/stats.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Per-feature dataset statistics — written once, read by trainers + inference. - -LeRobot expects `meta/stats.json` next to the dataset with mean/std/min/max/q01/q99 -for every observation and action key. The same dict is consumed by: - - the trainer (normalize inputs / unnormalize predicted actions) - - the inference `ObsBuilder` and `ActionReplayer` (same normalization, live) - -`Stats` is a streaming Welford accumulator: feed it `Sample` instances one at a -time and call `.result()` at the end. Used both inside `formats.lerobot.write` -(so the dataset is self-describing on first build) and as a standalone pass via -`compute_stats(dp)` if a dataset on disk needs stats recomputed. - -Image stats are computed on a subsample (every Nth frame) to bound cost. -""" - -from __future__ import annotations - -from collections.abc import Iterator -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from dimos.learning.spec import Sample - -if TYPE_CHECKING: - from dimos.learning.dataprep import DataPrep - - -class Stats: - """Streaming Welford accumulator for per-feature mean/std/min/max/q01/q99. - - Update with one Sample at a time; call `.result()` at the end to get the - serializable dict written to `meta/stats.json`. - - Quantiles (q01, q99) are computed from a reservoir sample of size - `quantile_reservoir` per feature — bounded memory for unbounded streams. - """ - - def __init__( - self, - image_subsample: int = 10, - quantile_reservoir: int = 10_000, - seed: int = 0, - ) -> None: - """Configure cost knobs. - - Args: - image_subsample: include every Nth image frame in stats; N=1 for full - accuracy, larger N for faster builds on long sessions. - quantile_reservoir: reservoir size per feature for q01/q99. - seed: for the reservoir sampler. - """ - raise NotImplementedError - - def update(self, sample: Sample) -> None: - """Fold one Sample into the running statistics for every obs/action key.""" - raise NotImplementedError - - def result(self) -> dict[str, Any]: - """Return the LeRobot-compatible stats dict. - - Schema: - { - "observation.": {"mean": [...], "std": [...], "min": [...], - "max": [...], "q01": [...], "q99": [...]}, - "action.": {... same keys ...}, - ... - } - """ - raise NotImplementedError - - def save(self, path: str | Path) -> None: - """Write `result()` to `path` as JSON.""" - raise NotImplementedError - - @classmethod - def load(cls, path: str | Path) -> dict[str, Any]: - """Read a stats JSON from disk. Returns the raw dict, not a Stats instance.""" - raise NotImplementedError - - -def compute_stats(samples: Iterator[Sample], **kw: Any) -> dict[str, Any]: - """One-shot helper: drain `samples`, return the stats dict. - - Equivalent to `s = Stats(**kw); for x in samples: s.update(x); return s.result()`. - """ - raise NotImplementedError - - -def compute_stats_from_prep(dp: DataPrep, **kw: Any) -> dict[str, Any]: - """One-shot helper that pulls samples from a DataPrep instance. - - Convenience for "I have a built dataset on disk and need to recompute stats." - """ - raise NotImplementedError diff --git a/dimos/learning/training/train.py b/dimos/learning/training/train.py index 8d0545288a..57a46d1f91 100644 --- a/dimos/learning/training/train.py +++ b/dimos/learning/training/train.py @@ -12,118 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Training entry points for v1. - -Two functions, both thin wrappers around `lerobot`: - - train_bc(spec, cfg, output_dir) -> ACT (or Diffusion) BC training - - finetune_vla(spec, cfg, output_dir) -> pi0 / pi0.5 finetune - -Both: - 1. Materialize the dataset via `DataPrep.build()` if `spec.output.path` - doesn't already exist (idempotent). - 2. Open the materialized dataset as a `lerobot.LeRobotDataset`. - 3. Translate the DimOS config to a lerobot config, build the policy. - 4. Compute / load `meta/stats.json`. - 5. Compute the train/val split. - 6. Call lerobot's training loop. - 7. Write the checkpoint + a sidecar `dimos_meta.json` so inference can - reconstruct everything from `output_dir` alone. - -We do NOT reimplement the training loop, optimizer schedule, normalization, -action chunking, language tokenization, or checkpoint format. Riding on -lerobot keeps this file small and means a `pi0.5` upgrade is a config bump. +"""ACT training entry point. Called directly by TrainerModule. + +`train_bc` lazy-imports lerobot/torch, builds Hydra-style argv from +BCConfig, calls lerobot's trainer in-process, appends `dimos_meta.json` +to output_dir, returns the checkpoint path. + +Stats live at `/meta/stats.json` (written by DataPrep). Training +reads them via lerobot's loader; never recomputes. """ from __future__ import annotations from pathlib import Path +from typing import Any -from dimos.learning.spec import DatasetSpec -from dimos.learning.training.configs import BCConfig, VLAConfig +from dimos.learning.dataprep import Episode +from dimos.learning.training.configs import BCConfig -# Sidecar written next to the lerobot checkpoint so inference can recover -# the spec + dataset path that produced this policy. DIMOS_META_FILENAME = "dimos_meta.json" -def train_bc(spec: DatasetSpec, cfg: BCConfig, output_dir: str | Path) -> Path: - """Train an ACT (or Diffusion) BC policy on `spec`. - - Returns the path to the final checkpoint directory. The returned dir - contains the lerobot checkpoint + `dimos_meta.json` linking back to the - spec and dataset used. - """ - raise NotImplementedError - - -def finetune_vla(spec: DatasetSpec, cfg: VLAConfig, output_dir: str | Path) -> Path: - """Finetune a pretrained pi0 / pi0.5 on `spec`. - - Loads `cfg.pretrained_path` (HF hub id or local), wraps it for the - requested `finetune_mode` (full or LoRA), runs lerobot's training loop, - and writes the resulting checkpoint to `output_dir`. - - Returns the checkpoint directory path. - """ - raise NotImplementedError - - -# ───────────────────────────────────────────────────────────────────────────── -# Internals — translate DimOS configs into the lerobot training entry point. -# ───────────────────────────────────────────────────────────────────────────── - - -def _ensure_dataset(spec: DatasetSpec) -> Path: - """If `spec.output.path` doesn't exist on disk yet, run `DataPrep.build()`. - - Returns the resolved dataset path. Raises if `spec.output` is None - (training requires a materialized dataset). - """ - raise NotImplementedError - - -def _build_lerobot_config_bc(spec: DatasetSpec, cfg: BCConfig, dataset_path: Path) -> object: - """Translate a DimOS BCConfig + spec into a lerobot training config. - - Returns the lerobot config object opaque to the rest of this file — - everything lerobot-specific stays inside the implementation. - """ +def train_bc( + dataset_path: str | Path, + cfg: BCConfig, + output_dir: str | Path, + config_overrides: dict[str, Any] | None = None, +) -> Path: + """Train ACT on a prepared dataset. Returns checkpoint dir.""" raise NotImplementedError -def _build_lerobot_config_vla(spec: DatasetSpec, cfg: VLAConfig, dataset_path: Path) -> object: - """Translate a DimOS VLAConfig + spec into a lerobot training config.""" +def _build_lerobot_argv(cfg: BCConfig, dataset_path: Path, output_dir: Path) -> list[str]: + """Translate BCConfig → Hydra-style CLI args for lerobot's trainer.""" raise NotImplementedError -def _write_dimos_meta(output_dir: Path, spec: DatasetSpec, dataset_path: Path) -> None: - """Write `dimos_meta.json` next to the checkpoint. - - Schema: - { - "dimos_version": "...", - "spec": , - "dataset_path": "", - "lerobot_version": "..." - } - Used by inference to rehydrate the spec without a separate yaml. - """ +def _write_dimos_meta(output_dir: Path, dataset_path: Path) -> None: + """Read /dimos_meta.json, add policy fields + (joint_names, chunk_size, policy_type), write to /dimos_meta.json.""" raise NotImplementedError -# ───────────────────────────────────────────────────────────────────────────── -# CLI -# ───────────────────────────────────────────────────────────────────────────── - +def train_val_split( + episodes: list[Episode], + val_episode_ids: list[int] | None = None, + val_ratio: float | None = None, + seed: int = 0, +) -> tuple[list[int], list[int]]: + """Partition `episodes` indices into (train_ids, val_ids). -def main() -> None: - """CLI entrypoint: - - python -m dimos.learning.training.train bc --output [...] - python -m dimos.learning.training.train vla --output --pretrained [...] + Resolution order (first non-None wins): + 1. `val_episode_ids` — explicit whitelist + 2. `val_ratio` — deterministic random split via `seed` + 3. both None — empty val (everything is train) """ raise NotImplementedError - - -if __name__ == "__main__": - main() diff --git a/dimos/learning/training/trainer_module.py b/dimos/learning/training/trainer_module.py index 0d49509097..d828dbeaf9 100644 --- a/dimos/learning/training/trainer_module.py +++ b/dimos/learning/training/trainer_module.py @@ -12,251 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""DimOS Module wrapper around the v1 training pipeline. +"""ACT training Module (v1, inline). -A single training job is two subprocesses run in sequence: - 1. `python -m dimos.learning.dataprep build` (skipped if output exists) - 2. `python -m dimos.learning.training.train ...` - -`TrainerModule` runs both, parses their progress lines, and republishes them -under one unified `TrainProgress` stream with a `phase` field. There is no -separate builder Module — building is always a precursor to training in v1, -so the wiring tax of two Modules + a chain port wasn't worth it. - -Why subprocess: keeps `lerobot`, `torch`, CUDA out of the runtime's import -graph. Process isolation also means a CUDA OOM doesn't poison the runtime. - -Wiring patterns: - - Default blueprint: `auto_run=True` -> module fires on start() - - Agent skill: agent calls `@rpc train(...)` directly - - Build-only (rare): agent calls `@rpc build_only(spec_path)` - - External trigger: publish `TrainTrigger` on the trigger port +Runs `train_bc` on a daemon thread inside its own worker. No subprocess. +Lazy imports keep `lerobot` / `torch` / CUDA out of the worker until +`train()` is called. Metrics → TensorBoard. No cancel() in v1. """ from __future__ import annotations -from pathlib import Path -import subprocess import threading -from typing import Any, Literal - -from pydantic import BaseModel +from typing import Any from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In, Out - -# ───────────────────────────────────────────────────────────────────────────── -# Message types -# ───────────────────────────────────────────────────────────────────────────── - - -class TrainTrigger(BaseModel): - """Start a training job. Empty trigger uses module config defaults.""" - - spec_path: str | None = None - output_dir: str | None = None - config_kind: Literal["bc", "vla"] | None = None - config_overrides: dict[str, Any] = {} # merged onto BCConfig/VLAConfig - skip_build: bool = False # set when caller knows the dataset is already built - job_id: str | None = None - - -class TrainProgress(BaseModel): - """Unified progress event covering both build and train phases. - - `phase` indicates which subprocess the event came from. For phase=="build" - the train-specific fields (loss, val_loss, step counts) are zero/None. - For phase=="train" the build-specific fields are zero. - """ - - job_id: str - phase: Literal["build", "load", "train", "eval", "save", "done", "failed"] - message: str = "" - - # Build-phase fields (meaningful only when phase == "build") - samples_written: int = 0 - current_episode: int = 0 - total_episodes: int = 0 - - # Train-phase fields (meaningful only when phase in {"load","train","eval","save"}) - step: int = 0 - total_steps: int = 0 - loss: float | None = None - val_loss: float | None = None - eta_s: float | None = None - - -class TrainDone(BaseModel): - """Terminal event with the final checkpoint dir or an error.""" - - job_id: str - success: bool - dataset_path: Path | None = None # the (possibly newly-built) dataset - checkpoint_dir: Path | None = None # None on failure - error: str | None = None - - -# ───────────────────────────────────────────────────────────────────────────── -# Module -# ───────────────────────────────────────────────────────────────────────────── class TrainerModuleConfig(ModuleConfig): - """Trainer module config. - - Attributes: - spec_path: default spec path (used for both build and train). - output_dir: default checkpoint output directory. - config_kind: "bc" (ACT/Diffusion) or "vla" (pi0/pi0.5). - config_path: optional BCConfig/VLAConfig YAML override. - python_executable: subprocess python; "" = current sys.executable. - skip_build_if_exists: if the dataset path already exists on disk, - skip the build phase. Default True (idempotent). - auto_run: if True, start a job on `start()`. - max_concurrent: cap on simultaneous jobs. v1 uses 1. - """ - - spec_path: str = "" + dataset_path: str = "" output_dir: str = "" - config_kind: Literal["bc", "vla"] = "bc" - config_path: str | None = None - python_executable: str = "" - skip_build_if_exists: bool = True + config_path: str | None = None # optional BCConfig YAML override auto_run: bool = False - max_concurrent: int = 1 + tensorboard_port: int = 6006 class TrainerModule(Module): - """Spawns dataprep build (if needed) then train; reports unified progress.""" - config: TrainerModuleConfig - trigger: In[TrainTrigger] - - progress: Out[TrainProgress] - done: Out[TrainDone] - def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - # job_id -> (build_proc, train_proc, watcher_thread). Each proc may be None. - self._jobs: dict[str, dict[str, Any]] = {} - self._jobs_lock = threading.Lock() - self._next_job_id = 0 - - # ── lifecycle ──────────────────────────────────────────────────────────── + self._thread: threading.Thread | None = None + self._lock = threading.Lock() + self._status: dict[str, Any] = { + "state": "idle", # idle | running | succeeded | failed + "checkpoint_dir": None, + "error": None, + } @rpc def start(self) -> None: - """Subscribe to `trigger`. If `auto_run`, kick off one training job.""" raise NotImplementedError @rpc def stop(self) -> None: - """Cancel all in-flight jobs, then super().stop().""" raise NotImplementedError - # ── agent / external surface ───────────────────────────────────────────── - @rpc def train( self, - spec_path: str | None = None, + dataset_path: str | None = None, output_dir: str | None = None, - config_kind: Literal["bc", "vla"] | None = None, config_overrides: dict[str, Any] | None = None, - skip_build: bool = False, - ) -> str: - """Start a build-then-train job. Returns job_id. - - All arguments override `config` for this job only. If `skip_build` or - `config.skip_build_if_exists` and the dataset is on disk, the build - phase is skipped. - """ - raise NotImplementedError - - @rpc - def build_only(self, spec_path: str | None = None) -> str: - """Run only the dataset-build subprocess; do not train. - - Convenience for the rare standalone case (CI dataset bake, debugging - a new spec). Returns job_id; emits TrainProgress with phase=="build" - events and a TrainDone with checkpoint_dir=None on completion. - """ - raise NotImplementedError - - @rpc - def cancel(self, job_id: str) -> bool: - """SIGTERM the active subprocess (build or train); True if cancelled.""" - raise NotImplementedError - - @rpc - def list_jobs(self) -> list[str]: - """Return active job ids.""" + ) -> None: + """Spawn a daemon thread running train_bc; returns immediately. + Raises if a run is already in progress.""" raise NotImplementedError @rpc - def list_checkpoints(self, output_dir: str | None = None) -> list[str]: - """Scan `output_dir` (defaults to config.output_dir) and return paths - to checkpoint subdirectories. Useful for agent flows like 'train then - deploy the latest checkpoint'.""" + def get_status(self) -> dict[str, Any]: raise NotImplementedError - # ── internals ──────────────────────────────────────────────────────────── - - def _on_trigger(self, msg: TrainTrigger) -> None: - """Port handler — calls `self.train(...)`.""" - raise NotImplementedError - - def _run_job( + def _run_training( self, - job_id: str, - spec_path: str, + dataset_path: str, output_dir: str, - config_kind: Literal["bc", "vla"], - config_overrides: dict[str, Any], - skip_build: bool, - train: bool, + config_overrides: dict[str, Any] | None, ) -> None: - """Background thread driving one job through its phases. - - Sequence: - 1. Resolve dataset path from spec. - 2. If `skip_build` is False and dataset doesn't exist (or - `skip_build_if_exists` is False), spawn `dataprep build` and - stream its progress as phase=="build". - 3. If `train` is True, spawn `train` and stream progress as - phase in {"load","train","eval","save"}. - 4. Emit terminal TrainDone. - """ - raise NotImplementedError - - def _spawn_build(self, spec_path: str, job_id: str) -> subprocess.Popen[str]: - """Build argv for `python -m dimos.learning.dataprep build --progress-json`.""" - raise NotImplementedError - - def _spawn_train( - self, - spec_path: str, - output_dir: str, - config_kind: Literal["bc", "vla"], - config_overrides: dict[str, Any], - job_id: str, - ) -> subprocess.Popen[str]: - """Build argv for `python -m dimos.learning.training.train --progress-json`.""" + """Thread target. Lazy-imports train_bc + BCConfig; updates _status.""" raise NotImplementedError - def _stream_build_progress(self, job_id: str, proc: subprocess.Popen[str]) -> int: - """Read stdout JSON-per-line from the build subprocess; publish each - as TrainProgress(phase="build"). Returns subprocess exit code.""" - raise NotImplementedError - - def _stream_train_progress(self, job_id: str, proc: subprocess.Popen[str]) -> int: - """Read stdout JSON-per-line from the train subprocess; publish each - as TrainProgress(phase in {"load","train","eval","save"}). Returns exit code.""" - raise NotImplementedError - def _allocate_job_id(self) -> str: - jid = f"train-{self._next_job_id}" - self._next_job_id += 1 - return jid +__all__ = ["TrainerModule", "TrainerModuleConfig"] From 0cdbe557c70ca9b93aa3afa63c8732a49aff3efc Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Tue, 5 May 2026 15:48:24 -0700 Subject: [PATCH 05/45] tested commit slop --- dimos/learning/collection/episode_monitor.py | 96 ++++- dimos/learning/dataprep.py | 228 ++++++++++- dimos/learning/dataprep_blueprint.py | 86 ++++ dimos/learning/dataprep_module.py | 210 +++++++++- dimos/learning/formats/_stats.py | 116 ++++++ dimos/learning/formats/hdf5.py | 122 +++++- dimos/learning/formats/lerobot.py | 277 ++++++++++++- dimos/learning/formats/rlds.py | 184 ++++++++- dimos/learning/inference/blueprint.py | 42 +- .../learning/inference/chunk_policy_module.py | 177 +++++++-- dimos/learning/policy/lerobot_policy.py | 199 +++++++++- dimos/learning/specs/spec_v2.md | 375 ++++++++++++++++++ dimos/learning/training/blueprint.py | 18 +- dimos/learning/training/train.py | 209 +++++++++- dimos/learning/training/trainer_module.py | 152 ++++++- dimos/memory2/codecs/jpeg.py | 5 + dimos/robot/all_blueprints.py | 12 + 17 files changed, 2352 insertions(+), 156 deletions(-) create mode 100644 dimos/learning/dataprep_blueprint.py create mode 100644 dimos/learning/formats/_stats.py create mode 100644 dimos/learning/specs/spec_v2.md diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index 437c844336..3f9edd1b83 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -15,13 +15,16 @@ """Single point of teleop-input → EpisodeStatus translation. Watches buttons / keyboard, runs the start/save/discard state machine, -publishes EpisodeStatus on every transition. RecordReplay captures that -stream into session.db; DataPrep reads only the recorded EpisodeStatus -events offline — never raw buttons or keypresses. +publishes EpisodeStatus on every transition. RecordReplay (or whatever +records the bus) captures that stream into session.db; DataPrep reads +only the recorded EpisodeStatus events offline — never raw buttons or +keypresses. """ from __future__ import annotations +import threading +import time from typing import Any, Literal from pydantic import BaseModel @@ -87,37 +90,96 @@ def __init__(self, **kwargs: Any) -> None: self._discarded: int = 0 self._current_start_ts: float | None = None self._prev_bits: dict[str, bool] = {} # rising-edge detection for buttons + self._lock = threading.Lock() @rpc def start(self) -> None: - raise NotImplementedError + super().start() + self.buttons.subscribe(self._on_buttons) + self.keyboard.subscribe(self._on_keyboard) + # Emit an initial idle status so subscribers (and recorders) have a + # known starting point in the timeline. + self._publish("init") @rpc def stop(self) -> None: - raise NotImplementedError + super().stop() @rpc def reset_counters(self) -> EpisodeStatus: - raise NotImplementedError + with self._lock: + self._state = "idle" + self._saved = 0 + self._discarded = 0 + self._current_start_ts = None + self._prev_bits = {} + return self._publish("init") @rpc def get_status(self) -> EpisodeStatus: - raise NotImplementedError + with self._lock: + return EpisodeStatus( + state=self._state, + episodes_saved=self._saved, + episodes_discarded=self._discarded, + current_episode_start_ts=self._current_start_ts, + last_event="init", + task_label=self.config.default_task_label, + ) + + # ── port handlers ──────────────────────────────────────────────────────── def _on_buttons(self, msg: Buttons) -> None: """Rising-edge detect against `config.button_map`; advance state machine.""" - raise NotImplementedError + ts = time.time() + for event_name, alias_or_attr in self.config.button_map.items(): + attr = BUTTON_ALIASES.get(alias_or_attr, alias_or_attr) + try: + pressed = bool(getattr(msg, attr)) + except AttributeError: + continue + prev = self._prev_bits.get(attr, False) + self._prev_bits[attr] = pressed + if pressed and not prev: # rising edge + self._transition(event_name, ts) def _on_keyboard(self, msg: KeyPress) -> None: """Match `msg.key` against `config.keyboard_map`; advance state machine.""" - raise NotImplementedError + for event_name, key in self.config.keyboard_map.items(): + if msg.key == key: + self._transition(event_name, msg.ts) + break def _transition(self, event: Literal["start", "save", "discard"], ts: float) -> None: - """Apply the state-machine transition and publish EpisodeStatus. - - IDLE --start--> RECORDING - RECORDING --save--> IDLE (commit, saved += 1) - RECORDING --discard--> IDLE (drop, discarded += 1) - RECORDING --start--> RECORDING (auto-commit prev, begin new) - """ - raise NotImplementedError + """State-machine transition. Publishes EpisodeStatus on every change.""" + with self._lock: + if event == "start": + # Auto-commit any in-progress episode (matches DataPrep extractor). + if self._state == "recording" and self._current_start_ts is not None: + self._saved += 1 + self._state = "recording" + self._current_start_ts = ts + elif event == "save": + if self._state == "recording": + self._saved += 1 + self._state = "idle" + self._current_start_ts = None + elif event == "discard": + if self._state == "recording": + self._discarded += 1 + self._state = "idle" + self._current_start_ts = None + self._publish(event) + + def _publish(self, last_event: Literal["start", "save", "discard", "init"]) -> EpisodeStatus: + with self._lock: + status = EpisodeStatus( + state=self._state, + episodes_saved=self._saved, + episodes_discarded=self._discarded, + current_episode_start_ts=self._current_start_ts, + last_event=last_event, + task_label=self.config.default_task_label, + ) + self.status.publish(status) + return status diff --git a/dimos/learning/dataprep.py b/dimos/learning/dataprep.py index 9ba4dbb0bf..246eafaaaf 100644 --- a/dimos/learning/dataprep.py +++ b/dimos/learning/dataprep.py @@ -25,6 +25,7 @@ from __future__ import annotations +import bisect from collections.abc import Callable, Iterator from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -102,14 +103,116 @@ class Sample(BaseModel): def resolve_field(msg: Any, ref: StreamField) -> np.ndarray: - """Project `msg` through `ref` (attribute access). Single source of - truth for obs/action construction across train and live.""" - raise NotImplementedError + """Project `msg` through `ref` (attribute access) and coerce to ndarray. + + Single source of truth for obs/action construction across train and + live inference. Behavior: + - `ref.field is None`: best-effort coerce the whole message + (Image → `.data`, ndarray pass-through, list/tuple → asarray). + - `ref.field` set: `getattr(msg, ref.field)` (or `msg[ref.field]` + for dict payloads) then coerce. + """ + if ref.field is None: + value: Any = msg + elif isinstance(msg, dict): + value = msg[ref.field] + else: + value = getattr(msg, ref.field) + + if isinstance(value, np.ndarray): + return value + if hasattr(value, "data") and isinstance(value.data, np.ndarray): + # e.g. Image → use its underlying ndarray + return value.data + return np.asarray(value) def extract_episodes(store: SqliteStore, cfg: EpisodeExtractor) -> list[Episode]: - """Walk recorded EpisodeStatus events (or ranges/whole_session) into Episodes.""" - raise NotImplementedError + """Walk recorded events into Episodes per the configured strategy. + + EPISODE_STATUS: scan `cfg.status_stream` for state transitions emitted + by `EpisodeMonitorModule`. State machine (mirrors the live monitor): + ev.last_event == "start": begin (auto-commit any prior pending) + ev.last_event == "save": commit (success=True) + ev.last_event == "discard": drop (success=False) + end of stream with pending: dropped (matches live spec) + + RANGES: emit one Episode per (start, end) tuple in `cfg.ranges`. + + WHOLE_SESSION: one Episode covering the full time range of every stream. + """ + if cfg.extractor == "ranges": + if not cfg.ranges: + return [] + return [ + Episode(id=f"ep_{i:06d}", start_ts=t0, end_ts=t1) + for i, (t0, t1) in enumerate(cfg.ranges) + ] + + if cfg.extractor == "whole_session": + # Span every stream's time range. + names = store.list_streams() + if not names: + return [] + starts: list[float] = [] + ends: list[float] = [] + for name in names: + try: + stream = store.stream(name) + t0, t1 = stream.get_time_range() + starts.append(t0) + ends.append(t1) + except Exception: + continue + if not starts: + return [] + return [Episode(id="ep_000000", start_ts=min(starts), end_ts=max(ends))] + + # episode_status (default) + status_stream = store.stream(cfg.status_stream) + events = list(status_stream) # observations in storage order + + episodes: list[Episode] = [] + pending_start_ts: float | None = None + pending_label: str | None = None + counter = 0 + + def _commit(end_ts: float, success: bool, label: str | None) -> None: + nonlocal counter, pending_start_ts, pending_label + if pending_start_ts is None: + return + episodes.append( + Episode( + id=f"ep_{counter:06d}", + start_ts=pending_start_ts, + end_ts=end_ts, + task_label=label, + success=success, + ) + ) + counter += 1 + pending_start_ts = None + pending_label = None + + for obs in events: + ev = obs.data + last_event = getattr(ev, "last_event", None) + ts = obs.ts + label = getattr(ev, "task_label", None) + + if last_event == "start": + # Auto-commit any prior pending episode (success=True per state-machine spec). + _commit(ts, success=True, label=pending_label) + pending_start_ts = getattr(ev, "current_episode_start_ts", None) or ts + pending_label = label + elif last_event == "save": + _commit(ts, success=True, label=pending_label or label) + elif last_event == "discard": + _commit(ts, success=False, label=pending_label or label) + # "init" and unknown events are no-ops. + + # Anything still pending at end-of-stream is dropped (state-machine spec). + return episodes def iter_episode_samples( @@ -117,9 +220,98 @@ def iter_episode_samples( episode: Episode, streams: dict[str, StreamField], # observation ∪ action sync: SyncConfig, + obs_keys: set[str] | None = None, + action_keys: set[str] | None = None, ) -> Iterator[Sample]: - """Yield synced (obs, action) Samples for one episode.""" - raise NotImplementedError + """Yield synced (obs, action) Samples for one episode. + + Walks the anchor stream at `sync.rate_hz` between `episode.start_ts` and + `episode.end_ts`. For each anchor timestamp, picks the nearest sample + from each configured stream within `sync.tolerance_ms`. Skips frames + where any required stream lacks a nearby sample. + + `obs_keys` / `action_keys` partition `streams` into observation vs + action. If omitted, every key is treated as observation (used by + callers that only need raw aligned data). + """ + if sync.anchor not in streams: + raise ValueError(f"sync.anchor {sync.anchor!r} not in streams: {sorted(streams)}") + + obs_keys = obs_keys if obs_keys is not None else set(streams) + action_keys = action_keys if action_keys is not None else set() + + tolerance_s = sync.tolerance_ms / 1000.0 + + # Materialize each stream's (timestamps, messages) once per episode. + cached: dict[str, tuple[list[float], list[Any]]] = {} + for key, ref in streams.items(): + sub = store.stream(ref.stream).time_range(episode.start_ts, episode.end_ts) + ts_list: list[float] = [] + msg_list: list[Any] = [] + for obs in sub: + ts_list.append(obs.ts) + msg_list.append(obs.data) + # Keep them sorted by time — query order is usually already sorted, but be safe. + if ts_list and any(ts_list[i] > ts_list[i + 1] for i in range(len(ts_list) - 1)): + order = sorted(range(len(ts_list)), key=ts_list.__getitem__) + ts_list = [ts_list[i] for i in order] + msg_list = [msg_list[i] for i in order] + cached[key] = (ts_list, msg_list) + + anchor_ts, _ = cached[sync.anchor] + if not anchor_ts: + return + + # Build the sequence of target timestamps for this episode. + if sync.rate_hz > 0: + period = 1.0 / sync.rate_hz + targets: list[float] = [] + t = anchor_ts[0] + end = anchor_ts[-1] + while t <= end: + targets.append(t) + t += period + else: + targets = list(anchor_ts) + + def _nearest(key: str, t: float) -> Any | None: + ts_list, msg_list = cached[key] + if not ts_list: + return None + i = bisect.bisect_left(ts_list, t) + candidates: list[int] = [] + if i < len(ts_list): + candidates.append(i) + if i > 0: + candidates.append(i - 1) + best: int | None = None + best_dt = float("inf") + for c in candidates: + dt = abs(ts_list[c] - t) + if dt < best_dt: + best = c + best_dt = dt + if best is None or best_dt > tolerance_s: + return None + return msg_list[best] + + for t in targets: + obs_dict: dict[str, np.ndarray] = {} + act_dict: dict[str, np.ndarray] = {} + skip = False + for key, ref in streams.items(): + msg = _nearest(key, t) + if msg is None: + skip = True + break + arr = resolve_field(msg, ref) + if key in action_keys: + act_dict[key] = arr + elif key in obs_keys: + obs_dict[key] = arr + if skip: + continue + yield Sample(ts=t, episode_id=episode.id, observation=obs_dict, action=act_dict) def compute_stats( @@ -128,12 +320,26 @@ def compute_stats( quantile_reservoir: int = 10_000, seed: int = 0, ) -> dict[str, Any]: - """Per-feature mean/std/min/max/q01/q99 in one pass. + """Single-pass per-feature stats over a Sample iterator. + + Output schema matches LeRobot v2 ``stats.json``:: + + { "observation.": {"mean", "std", "min", "max", "q01", "q99"}, + "action.": {...} } - Welford for mean/std; reservoir sample for quantiles. Image features - subsampled (every Nth frame) to bound cost. + Thin wrapper over :class:`StreamingStats` so format writers and + ad-hoc callers share the exact same accumulator. """ - raise NotImplementedError + from dimos.learning.formats._stats import StreamingStats + + s = StreamingStats(image_subsample=image_subsample, + quantile_reservoir=quantile_reservoir, seed=seed) + for sample in samples: + for k, v in sample.observation.items(): + s.update(f"observation.{k}", np.asarray(v)) + for k, v in sample.action.items(): + s.update(f"action.{k}", np.asarray(v)) + return s.finalize() def get_writer(format_name: str) -> Writer: diff --git a/dimos/learning/dataprep_blueprint.py b/dimos/learning/dataprep_blueprint.py new file mode 100644 index 0000000000..53066da29f --- /dev/null +++ b/dimos/learning/dataprep_blueprint.py @@ -0,0 +1,86 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset-build blueprints. + +Wraps `DataPrepModule` so users can run:: + + dimos run learning-dataprep + dimos run learning-dataprep -o dataprepmodule.source=data/recordings/foo.db \\ + -o dataprepmodule.output.path=data/datasets/foo + +The defaults below target the included pickplace_001 demo. For single-demo +recordings without an `episode_status` stream, `learning_dataprep_whole_session` +treats the entire recording as one episode. +""" + +from __future__ import annotations + +from dimos.core.coordination.blueprints import autoconnect +from dimos.learning.dataprep import ( + EpisodeExtractor, + OutputConfig, + StreamField, + SyncConfig, +) +from dimos.learning.dataprep_module import DataPrepModule + +learning_dataprep = autoconnect( + DataPrepModule.blueprint( + source="data/recordings/pickplace_001.db", + episodes=EpisodeExtractor( + extractor="ranges", + ranges=[(1777931622.11, 1777931646.61)], + ), + observation={ + "image": StreamField(stream="color_image", field="data"), + "joint_state": StreamField(stream="joint_state", field="position"), + }, + action={ + "joint_target": StreamField(stream="joint_state", field="position"), + }, + sync=SyncConfig(anchor="image", rate_hz=14.0, tolerance_ms=80.0), + output=OutputConfig( + format="lerobot", + path="data/datasets/pickplace_001", + metadata={"fps": 14, "robot": "xarm7", "default_task_label": "pick_and_place"}, + ), + auto_run=True, + ), +).transports({}) + + +learning_dataprep_whole_session = autoconnect( + DataPrepModule.blueprint( + source="data/session.db", + episodes=EpisodeExtractor(extractor="whole_session"), + observation={ + "image": StreamField(stream="camera_color_image", field="data"), + "joint_state": StreamField(stream="coordinator_joint_state", field="position"), + }, + action={ + "joint_target": StreamField(stream="coordinator_joint_command", field="position"), + }, + sync=SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0), + output=OutputConfig( + format="lerobot", + path="data/datasets/default", + metadata={"fps": 30, "robot": "xarm7"}, + ), + auto_run=True, + ), +).transports({}) + + +__all__ = ["learning_dataprep", "learning_dataprep_whole_session"] diff --git a/dimos/learning/dataprep_module.py b/dimos/learning/dataprep_module.py index e08b660907..d6e80ba4e3 100644 --- a/dimos/learning/dataprep_module.py +++ b/dimos/learning/dataprep_module.py @@ -20,7 +20,11 @@ from __future__ import annotations +import json import threading +import traceback +from collections.abc import Iterator +from pathlib import Path from typing import Any from dimos.core.core import rpc @@ -28,18 +32,27 @@ from dimos.learning.dataprep import ( EpisodeExtractor, OutputConfig, + Sample, StreamField, SyncConfig, + extract_episodes, + get_writer, + iter_episode_samples, ) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() class DataPrepModuleConfig(ModuleConfig): - source: str - episodes: EpisodeExtractor - observation: dict[str, StreamField] - action: dict[str, StreamField] - sync: SyncConfig - output: OutputConfig + # Fields are defaulted so partial CLI overrides (e.g. just `source=...`) + # pass blueprint validation; blueprint atoms supply real values. + source: str = "" + episodes: EpisodeExtractor = EpisodeExtractor() + observation: dict[str, StreamField] = {} + action: dict[str, StreamField] = {} + sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) + output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") auto_run: bool = False @@ -58,35 +71,202 @@ def __init__(self, **kwargs: Any) -> None: "progress_pct": 0.0, "dataset_path": None, "error": None, + "episodes_seen": 0, + "samples_seen": 0, } + # ── lifecycle ──────────────────────────────────────────────────────────── + @rpc def start(self) -> None: - raise NotImplementedError + super().start() + if self.config.auto_run: + self.build() @rpc def stop(self) -> None: - raise NotImplementedError + # Build thread is daemon: dies with the process. No mid-iteration interrupt. + super().stop() @rpc def build(self) -> None: """Spawn a daemon thread running the build pipeline. Returns immediately.""" - raise NotImplementedError + with self._lock: + if self._status["state"] == "running": + return + self._status.update( + state="running", + current_phase=None, + progress_pct=0.0, + dataset_path=None, + error=None, + episodes_seen=0, + samples_seen=0, + ) + self._thread = threading.Thread(target=self._run_build, daemon=True) + self._thread.start() @rpc def get_status(self) -> dict[str, Any]: - raise NotImplementedError + with self._lock: + return dict(self._status) @rpc def inspect(self) -> dict[str, Any]: """Read-only summary: episode count, drop rates, joint names, stats presence.""" - raise NotImplementedError + from dimos.memory2.store.sqlite import SqliteStore + + store = SqliteStore(path=self.config.source, must_exist=True) + try: + episodes = extract_episodes(store, self.config.episodes) + saved = sum(1 for e in episodes if e.success) + dropped = sum(1 for e in episodes if not e.success) + durations = [e.duration for e in episodes if e.success] + return { + "source": self.config.source, + "streams": store.list_streams(), + "episodes_saved": saved, + "episodes_dropped": dropped, + "duration_s": { + "min": min(durations) if durations else 0.0, + "max": max(durations) if durations else 0.0, + "mean": (sum(durations) / len(durations)) if durations else 0.0, + }, + } + finally: + store.stop() + + # ── internals ──────────────────────────────────────────────────────────── def _run_build(self) -> None: - """Thread target. Opens session.db, calls extract_episodes / - iter_episode_samples / format writer, snapshots config to - /dimos_meta.json. Updates _status under _lock.""" - raise NotImplementedError + """Thread target. Opens session.db, walks samples episode-by-episode, + drives the format writer, snapshots config to /dimos_meta.json. + Updates _status under _lock. + """ + try: + logger.info( + "[dataprep] starting build source=%s extractor=%s output=%s", + self.config.source, + self.config.episodes.extractor, + self.config.output.path, + ) + self._update_status(current_phase="scan_episodes") + + from dimos.memory2.store.sqlite import SqliteStore + + store = SqliteStore(path=self.config.source, must_exist=True) + try: + logger.info("[dataprep] streams in source: %s", store.list_streams()) + all_eps = extract_episodes(store, self.config.episodes) + episodes = [e for e in all_eps if e.success] + logger.info( + "[dataprep] episodes extracted: %d total / %d successful", + len(all_eps), len(episodes), + ) + self._update_status(episodes_seen=len(episodes)) + + if not episodes: + raise RuntimeError( + f"No successful episodes extracted from {self.config.source!r} " + f"using extractor={self.config.episodes.extractor!r}. " + f"Available streams: {store.list_streams()}. " + f"For a single-demo .db with no episode_status stream, use " + f"extractor='whole_session' or 'ranges'." + ) + + streams = {**self.config.observation, **self.config.action} + obs_keys = set(self.config.observation) + action_keys = set(self.config.action) + logger.info( + "[dataprep] obs streams=%s action streams=%s sync=%s", + sorted(obs_keys), sorted(action_keys), + self.config.sync.model_dump(), + ) + + writer = get_writer(self.config.output.format) + + self._update_status(current_phase="write") + logger.info( + "[dataprep] writing %s dataset to %s", + self.config.output.format, self.config.output.path, + ) + + samples_seen = 0 + episodes_done = 0 + total = len(episodes) + + def _all_samples() -> Iterator[Sample]: + nonlocal samples_seen, episodes_done + for ep in episodes: + for sample in iter_episode_samples( + store=store, + episode=ep, + streams=streams, + sync=self.config.sync, + obs_keys=obs_keys, + action_keys=action_keys, + ): + samples_seen += 1 + if samples_seen % 50 == 0: + self._update_status( + samples_seen=samples_seen, + progress_pct=100.0 * episodes_done / total, + ) + logger.info( + "[dataprep] %.1f%% samples=%d ep %d/%d", + 100.0 * episodes_done / total, + samples_seen, episodes_done, total, + ) + yield sample + episodes_done += 1 + self._update_status( + samples_seen=samples_seen, + progress_pct=100.0 * episodes_done / total, + ) + + dataset_path = writer(_all_samples(), self.config.output) + + self._write_dimos_meta(Path(dataset_path), episodes) + + self._update_status( + state="succeeded", + current_phase="done", + progress_pct=100.0, + dataset_path=str(dataset_path), + ) + logger.info( + "[dataprep] succeeded — wrote %d samples across %d episodes to %s", + samples_seen, total, dataset_path, + ) + finally: + store.stop() + except Exception as e: + err = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" + self._update_status(state="failed", error=err) + logger.error("[dataprep] FAILED: %s", err) + + def _write_dimos_meta(self, dataset_path: Path, episodes: list[Any]) -> None: + """Sidecar describing how this dataset was built. ChunkPolicyModule + reads it at inference time to recover the obs/action schema.""" + meta = { + "source": self.config.source, + "observation": {k: v.model_dump() for k, v in self.config.observation.items()}, + "action": {k: v.model_dump() for k, v in self.config.action.items()}, + "sync": self.config.sync.model_dump(), + "episodes": [ + {"id": e.id, "start_ts": e.start_ts, "end_ts": e.end_ts, + "task_label": e.task_label, "success": e.success} + for e in episodes + ], + "format": self.config.output.format, + "metadata": self.config.output.metadata, + } + with open(dataset_path / "dimos_meta.json", "w") as f: + json.dump(meta, f, indent=2, default=str) + + def _update_status(self, **kwargs: Any) -> None: + with self._lock: + self._status.update(kwargs) __all__ = ["DataPrepModule", "DataPrepModuleConfig"] diff --git a/dimos/learning/formats/_stats.py b/dimos/learning/formats/_stats.py new file mode 100644 index 0000000000..92d5be648c --- /dev/null +++ b/dimos/learning/formats/_stats.py @@ -0,0 +1,116 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Streaming feature stats — shared by every format writer. + +Welford for mean/std and a reservoir sample for q01/q99 over scalar / +low-dim features. Image-like (≥3D) features are subsampled and reduced +to per-channel summaries so per-pixel stats don't blow up memory. +""" + +from __future__ import annotations + +import random +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + + +@dataclass +class FeatureAggregator: + is_image: bool + n: int = 0 + mean: np.ndarray | None = None + m2: np.ndarray | None = None + minv: np.ndarray | None = None + maxv: np.ndarray | None = None + reservoir: list[np.ndarray] = field(default_factory=list) + image_seen: int = 0 + shape: tuple[int, ...] | None = None + dtype: str | None = None + + +class StreamingStats: + """Single-pass mean/std/min/max/quantile aggregator across many features.""" + + def __init__(self, image_subsample: int = 10, quantile_reservoir: int = 10_000, + seed: int = 0) -> None: + self.image_subsample = image_subsample + self.quantile_reservoir = quantile_reservoir + self._rng = random.Random(seed) + self.aggs: dict[str, FeatureAggregator] = {} + + def update(self, name: str, value: np.ndarray) -> None: + a = np.asarray(value) + is_image = a.ndim >= 3 + agg = self.aggs.setdefault( + name, FeatureAggregator(is_image=is_image, shape=tuple(a.shape), dtype=str(a.dtype)), + ) + + if is_image: + agg.image_seen += 1 + if (agg.image_seen - 1) % self.image_subsample != 0: + return + v = a.astype(np.float32).mean(axis=(0, 1)) if a.ndim == 3 else a.astype(np.float32).reshape(-1) + else: + v = a.astype(np.float64) + + if agg.mean is None: + agg.mean = np.zeros(v.shape, dtype=np.float64) + agg.m2 = np.zeros(v.shape, dtype=np.float64) + agg.minv = np.full(v.shape, np.inf, dtype=np.float64) + agg.maxv = np.full(v.shape, -np.inf, dtype=np.float64) + + agg.n += 1 + delta = v - agg.mean + agg.mean += delta / agg.n + assert agg.m2 is not None + agg.m2 += delta * (v - agg.mean) + assert agg.minv is not None and agg.maxv is not None + np.minimum(agg.minv, v, out=agg.minv) + np.maximum(agg.maxv, v, out=agg.maxv) + + if not is_image: + if len(agg.reservoir) < self.quantile_reservoir: + agg.reservoir.append(v.copy()) + else: + j = self._rng.randint(0, agg.n - 1) + if j < self.quantile_reservoir: + agg.reservoir[j] = v.copy() + + def finalize(self) -> dict[str, dict[str, Any]]: + out: dict[str, dict[str, Any]] = {} + for name, agg in self.aggs.items(): + if agg.mean is None: + continue + n = max(1, agg.n) + var = agg.m2 / n if agg.n > 1 else np.zeros_like(agg.mean) + std = np.sqrt(var) + entry: dict[str, Any] = { + "mean": agg.mean.tolist(), + "std": std.tolist(), + "min": agg.minv.tolist() if agg.minv is not None else None, + "max": agg.maxv.tolist() if agg.maxv is not None else None, + "count": int(agg.n), + } + if agg.reservoir: + stacked = np.stack(agg.reservoir, axis=0) + entry["q01"] = np.quantile(stacked, 0.01, axis=0).tolist() + entry["q99"] = np.quantile(stacked, 0.99, axis=0).tolist() + out[name] = entry + return out + + +__all__ = ["FeatureAggregator", "StreamingStats"] diff --git a/dimos/learning/formats/hdf5.py b/dimos/learning/formats/hdf5.py index 57ddb096a0..acbe647c9b 100644 --- a/dimos/learning/formats/hdf5.py +++ b/dimos/learning/formats/hdf5.py @@ -12,17 +12,131 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""HDF5 dataset writer. Single .hdf5 with one group per episode + stats group.""" +"""HDF5 dataset writer. + +Single ``.hdf5`` file with one group per episode plus a stats group. +Layout:: + + .hdf5 (or output.path. if a dir was given) + / attrs: codebase_version, robot, fps, + num_episodes, num_frames, num_tasks + /tasks attrs: task_ = "" + /stats/ attrs: count + datasets mean/std/min/max[/q01/q99] + /episodes/episode_NNNNNN + timestamp (T,) float32 + (T, ...) as recorded + (T, ...) as recorded + attrs: length, start_ts, task_index + +This is the ACT-original style adapted to one file with multiple episodes. +""" from __future__ import annotations from collections.abc import Iterator from pathlib import Path +import numpy as np + from dimos.learning.dataprep import OutputConfig, Sample +from dimos.learning.formats._stats import StreamingStats def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Write samples to a single HDF5 file (stats as group attrs). - Returns the file path.""" - raise NotImplementedError + """Drain `samples` into a single .hdf5 file. Returns the file path.""" + try: + import h5py + except ImportError as e: + raise RuntimeError("HDF5 writer requires h5py — install with `pip install h5py`") from e + + out = Path(output.path) + if out.suffix not in (".h5", ".hdf5"): + out = out.with_suffix(".hdf5") + out.parent.mkdir(parents=True, exist_ok=True) + + stats = StreamingStats( + image_subsample=int(output.metadata.get("image_subsample", 10)), + quantile_reservoir=int(output.metadata.get("quantile_reservoir", 10_000)), + seed=int(output.metadata.get("stats_seed", 0)), + ) + + default_task_label: str = output.metadata.get("default_task_label", "task") + fps = float(output.metadata.get("fps", 30.0)) + + tasks_index: dict[str, int] = {} + + # Per-episode buffers — flushed at episode boundary. + cur_id: str | None = None + cur_idx = -1 + cur_start_ts: float | None = None + buf_ts: list[float] = [] + buf_obs: dict[str, list[np.ndarray]] = {} + buf_act: dict[str, list[np.ndarray]] = {} + + total_frames = 0 + + with h5py.File(out, "w") as h5: + episodes_g = h5.create_group("episodes") + + def _flush() -> None: + if cur_idx < 0 or not buf_ts: + return + ep = episodes_g.create_group(f"episode_{cur_idx:06d}") + ep.attrs["length"] = len(buf_ts) + ep.attrs["start_ts"] = float(cur_start_ts or 0.0) + ep.attrs["task_index"] = tasks_index[default_task_label] + ep.create_dataset("timestamp", data=np.asarray(buf_ts, dtype=np.float32)) + for k, frames in buf_obs.items(): + arr = np.stack(frames, axis=0) + ep.create_dataset(f"observation/{k}", data=arr, + compression="gzip" if arr.ndim >= 3 else None, + compression_opts=4 if arr.ndim >= 3 else None) + for k, frames in buf_act.items(): + ep.create_dataset(f"action/{k}", data=np.stack(frames, axis=0)) + buf_ts.clear() + buf_obs.clear() + buf_act.clear() + + for sample in samples: + if sample.episode_id != cur_id: + _flush() + cur_id = sample.episode_id + cur_idx += 1 + cur_start_ts = float(sample.ts) + if default_task_label not in tasks_index: + tasks_index[default_task_label] = len(tasks_index) + + buf_ts.append(float(sample.ts) - (cur_start_ts or 0.0)) + for k, v in sample.observation.items(): + a = np.asarray(v) + buf_obs.setdefault(k, []).append(a) + stats.update(f"observation.{k}", a) + for k, v in sample.action.items(): + a = np.asarray(v) + buf_act.setdefault(k, []).append(a) + stats.update(f"action.{k}", a) + total_frames += 1 + + _flush() + + # ── meta ──────────────────────────────────────────────────────────── + h5.attrs["codebase_version"] = "dimos-v1" + h5.attrs["robot"] = output.metadata.get("robot", "unknown") + h5.attrs["fps"] = fps + h5.attrs["num_episodes"] = len(episodes_g) + h5.attrs["num_frames"] = total_frames + h5.attrs["num_tasks"] = len(tasks_index) + + tasks_g = h5.create_group("tasks") + for task, idx in tasks_index.items(): + tasks_g.attrs[f"task_{idx}"] = task + + stats_g = h5.create_group("stats") + for name, entry in stats.finalize().items(): + g = stats_g.create_group(name) + g.attrs["count"] = entry["count"] + for k in ("mean", "std", "min", "max", "q01", "q99"): + if k in entry and entry[k] is not None: + g.create_dataset(k, data=np.asarray(entry[k], dtype=np.float64)) + + return out diff --git a/dimos/learning/formats/lerobot.py b/dimos/learning/formats/lerobot.py index e5e348af44..acb1f27bbb 100644 --- a/dimos/learning/formats/lerobot.py +++ b/dimos/learning/formats/lerobot.py @@ -14,24 +14,285 @@ """LeRobot v2 dataset writer. -Layout: +Layout:: + / - meta/info.json schema, fps, total episodes/frames + meta/info.json schema, fps, total episodes/frames, features meta/episodes.jsonl per-episode metadata - meta/stats.json per-feature stats (from DataPrep.compute_stats) - data/chunk-000/episode_*.parquet - videos/chunk-000//episode_*.mp4 + meta/tasks.jsonl task descriptions for language conditioning + meta/stats.json per-feature mean/std/min/max/q01/q99 + data/chunk-000/episode_NNNNNN.parquet + videos/chunk-000/observation.images./episode_NNNNNN.mp4 + +Single pass: streams samples to disk per-episode and accumulates stats in +parallel. Image frames go to MP4 (one per camera, per episode); their +columns are excluded from the parquet — lerobot loads them from MP4 at +``__getitem__`` time using the ``video_path`` template + episode_index + +timestamp. """ from __future__ import annotations +import json from collections.abc import Iterator from pathlib import Path +from typing import Any + +import numpy as np from dimos.learning.dataprep import OutputConfig, Sample +from dimos.learning.formats._stats import StreamingStats + +CHUNK = "chunk-000" +DATA_DIR = "data" +VIDEO_DIR = "videos" +META_DIR = "meta" + + +def _feature_name(prefix: str, key: str, is_image: bool, + single_action: bool, single_state: bool = False) -> str: + """Translate (prefix, key) into the LeRobot v2 feature name. + + Canonical names lerobot policies (ACT, Diffusion, π₀) expect: + observation.state single proprio vector + action single action vector + observation.images. per-camera RGB + Multi-key fallbacks: ``observation.`` / ``action.``. + """ + if prefix == "action" and single_action: + return "action" + if is_image: + return f"observation.images.{key}" + if prefix == "observation" and single_state: + return "observation.state" + if prefix == "observation": + return f"observation.{key}" + return f"action.{key}" def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Drain samples, write parquet+MP4, call DataPrep.compute_stats, - serialize stats to meta/stats.json. Return the dataset root path.""" - raise NotImplementedError + """Drain `samples`, write parquet+MP4+meta in LeRobot v2 layout. + Returns the dataset root path. + """ + try: + import cv2 + except ImportError as e: + raise RuntimeError("LeRobot writer requires opencv-python (cv2) for MP4 encoding") from e + try: + import pyarrow as pa + import pyarrow.parquet as pq + except ImportError as e: + raise RuntimeError("LeRobot writer requires pyarrow for parquet writes") from e + + root = Path(output.path) + (root / META_DIR).mkdir(parents=True, exist_ok=True) + (root / DATA_DIR / CHUNK).mkdir(parents=True, exist_ok=True) + (root / VIDEO_DIR / CHUNK).mkdir(parents=True, exist_ok=True) + + fps = float(output.metadata.get("fps", 30.0)) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + + stats = StreamingStats( + image_subsample=int(output.metadata.get("image_subsample", 10)), + quantile_reservoir=int(output.metadata.get("quantile_reservoir", 10_000)), + seed=int(output.metadata.get("stats_seed", 0)), + ) + + image_keys: set[str] = set() + state_keys: list[str] = [] + action_keys: list[str] = [] + feature_shapes: dict[str, tuple[int, ...]] = {} + feature_dtypes: dict[str, str] = {} + + episodes_meta: list[dict[str, Any]] = [] + tasks_index: dict[str, int] = {} + default_task_label = output.metadata.get("default_task_label", "task") + + current_episode_id: str | None = None + current_episode_index = -1 + current_episode_start_ts: float | None = None + current_frames: list[dict[str, Any]] = [] + current_video_writers: dict[str, Any] = {} + global_index = 0 + + def _episode_path_parquet(ep_idx: int) -> Path: + return root / DATA_DIR / CHUNK / f"episode_{ep_idx:06d}.parquet" + + def _episode_path_video(image_key: str, ep_idx: int) -> Path: + feat_name = _feature_name("observation", image_key, is_image=True, single_action=False) + d = root / VIDEO_DIR / CHUNK / feat_name + d.mkdir(parents=True, exist_ok=True) + return d / f"episode_{ep_idx:06d}.mp4" + + def _open_video(image_key: str, ep_idx: int, frame: np.ndarray) -> Any: + # Frames written as-is; cv2.VideoWriter is BGR-native. RGB frames will + # encode OK but decode color-swapped. Standardize upstream if needed. + h, w = frame.shape[:2] + path = _episode_path_video(image_key, ep_idx) + writer = cv2.VideoWriter(str(path), fourcc, fps, (w, h)) + if not writer.isOpened(): + raise RuntimeError(f"Failed to open VideoWriter for {path}") + return writer + + def _flush_episode() -> None: + nonlocal current_frames, current_video_writers, current_episode_index + if not current_frames: + return + for vw in current_video_writers.values(): + vw.release() + current_video_writers = {} + + cols: dict[str, list[Any]] = { + "timestamp": [f["timestamp"] for f in current_frames], + "frame_index": [f["frame_index"] for f in current_frames], + "episode_index": [f["episode_index"] for f in current_frames], + "index": [f["index"] for f in current_frames], + "task_index": [f["task_index"] for f in current_frames], + } + single_state = len(state_keys) == 1 + for k in state_keys: + name = _feature_name("observation", k, is_image=False, + single_action=False, single_state=single_state) + cols[name] = [f["obs"][k].tolist() for f in current_frames] + single_action = len(action_keys) == 1 + for k in action_keys: + name = _feature_name("action", k, is_image=False, single_action=single_action) + cols[name] = [f["act"][k].tolist() for f in current_frames] + # Video columns intentionally omitted: lerobot's hf_features schema + # skips dtype="video" and reads frames from MP4 at __getitem__ time. + + table = pa.Table.from_pydict(cols) + pq.write_table(table, _episode_path_parquet(current_episode_index)) + + episodes_meta.append({ + "episode_index": current_episode_index, + "tasks": [list(tasks_index.keys())[current_frames[0]["task_index"]]], + "length": len(current_frames), + }) + current_frames = [] + + for sample in samples: + if sample.episode_id != current_episode_id: + _flush_episode() + current_episode_id = sample.episode_id + current_episode_index += 1 + # Episode-relative timestamps so they fit float32 with sub-ms + # precision; lerobot's check_timestamps_sync compares against 1/fps. + current_episode_start_ts = float(sample.ts) + label = default_task_label + if label not in tasks_index: + tasks_index[label] = len(tasks_index) + + # Schema discovery + stats accumulation. + n_low_dim_obs = sum(1 for _, v in sample.observation.items() if np.asarray(v).ndim < 3) + single_state = n_low_dim_obs == 1 + for k, arr in sample.observation.items(): + a = np.asarray(arr) + is_image = a.ndim >= 3 + name = _feature_name("observation", k, is_image=is_image, + single_action=False, single_state=single_state) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = "video" if is_image else str(a.dtype) + if is_image: + image_keys.add(k) + elif k not in state_keys: + state_keys.append(k) + stats.update(name, a) + + for k, arr in sample.action.items(): + a = np.asarray(arr) + single_action = len(sample.action) == 1 + name = _feature_name("action", k, is_image=False, single_action=single_action) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = str(a.dtype) + if k not in action_keys: + action_keys.append(k) + stats.update(name, a) + + # Video frame write + parquet row buffer. + frame_index = len(current_frames) + for k, arr in sample.observation.items(): + a = np.asarray(arr) + if a.ndim >= 3: + if k not in current_video_writers: + current_video_writers[k] = _open_video(k, current_episode_index, a) + current_video_writers[k].write(a) + + rel_ts = float(sample.ts) - (current_episode_start_ts or 0.0) + current_frames.append({ + "timestamp": rel_ts, + "frame_index": frame_index, + "episode_index": current_episode_index, + "index": global_index, + "task_index": tasks_index[default_task_label], + "obs": {k: np.asarray(v) for k, v in sample.observation.items() if np.asarray(v).ndim < 3}, + "act": {k: np.asarray(v) for k, v in sample.action.items()}, + }) + global_index += 1 + + _flush_episode() + + # ── meta files ─────────────────────────────────────────────────────────── + total_episodes = len(episodes_meta) + total_frames = global_index + + features: dict[str, Any] = {} + for name, shape in feature_shapes.items(): + if feature_dtypes[name] == "video": + features[name] = { + "dtype": "video", + "shape": list(shape), + "names": ["height", "width", "channel"], + "info": { + "video.fps": fps, + "video.height": int(shape[0]), + "video.width": int(shape[1]), + "video.channels": int(shape[2]) if len(shape) > 2 else 3, + "video.codec": "mp4v", + "video.pix_fmt": "yuv420p", + "video.is_depth_map": False, + "has_audio": False, + }, + } + else: + # Per-dim names; downstream loaders only require len(names) == shape[0]. + n = int(shape[0]) if shape else 0 + base = name.split(".")[-1] + features[name] = { + "dtype": feature_dtypes[name], + "shape": list(shape), + "names": [f"{base}_{i}" for i in range(n)], + } + for col, dt in [("timestamp", "float32"), ("frame_index", "int64"), + ("episode_index", "int64"), ("index", "int64"), ("task_index", "int64")]: + features[col] = {"dtype": dt, "shape": [1], "names": None} + + info = { + "codebase_version": "v2.0", + "robot_type": output.metadata.get("robot", "unknown"), + "total_episodes": total_episodes, + "total_frames": total_frames, + "total_tasks": len(tasks_index), + "total_videos": total_episodes * len(image_keys), + "total_chunks": 1, + "chunks_size": max(1, total_episodes), + "fps": fps, + "splits": {"train": f"0:{total_episodes}"}, + "data_path": "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet", + "video_path": "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4", + "features": features, + } + with open(root / META_DIR / "info.json", "w") as f: + json.dump(info, f, indent=2) + with open(root / META_DIR / "episodes.jsonl", "w") as f: + for ep in episodes_meta: + f.write(json.dumps(ep) + "\n") + with open(root / META_DIR / "tasks.jsonl", "w") as f: + for task, idx in tasks_index.items(): + f.write(json.dumps({"task_index": idx, "task": task}) + "\n") + with open(root / META_DIR / "stats.json", "w") as f: + json.dump(stats.finalize(), f, indent=2) + + return root diff --git a/dimos/learning/formats/rlds.py b/dimos/learning/formats/rlds.py index 23126279dc..8003b7c938 100644 --- a/dimos/learning/formats/rlds.py +++ b/dimos/learning/formats/rlds.py @@ -12,18 +12,192 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""RLDS / TFDS dataset writer. TFRecord shards + dataset_info.json -following the RLDS Episode/Step protocol.""" +"""RLDS / TFDS dataset writer. + +Emits a single TFRecord shard following the RLDS Episode/Step protocol. +Each example is a ``tf.train.SequenceExample`` whose feature_lists hold +per-step observation / action / reward / discount / is_first / is_last / +is_terminal arrays, plus a context for episode metadata. + +Layout:: + + / + features.json schema (feature names + shapes + dtypes) + dataset_info.json episode count, step count, fps, robot, stats + rlds-00000-of-00001.tfrecord + +This is RLDS-shaped on the wire; loading it as a `tfds.builder` requires +matching the schema in a TFDS DatasetBuilder. See the RLDS docs for that +glue. For OpenX-Embodiment contributions, point your TFDS builder at +this directory. +""" from __future__ import annotations +import json from collections.abc import Iterator from pathlib import Path +from typing import Any + +import numpy as np from dimos.learning.dataprep import OutputConfig, Sample +from dimos.learning.formats._stats import StreamingStats def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Write samples as TFDS/RLDS shards (stats in features metadata). - Returns the dataset directory path.""" - raise NotImplementedError + """Drain `samples` into RLDS-style TFRecord shards. Returns the dataset dir.""" + try: + import tensorflow as tf # type: ignore[import-not-found] + except ImportError as e: + raise RuntimeError( + "RLDS writer requires tensorflow — install with " + "`pip install tensorflow_datasets` (pulls in tf)" + ) from e + + root = Path(output.path) + root.mkdir(parents=True, exist_ok=True) + + stats = StreamingStats( + image_subsample=int(output.metadata.get("image_subsample", 10)), + quantile_reservoir=int(output.metadata.get("quantile_reservoir", 10_000)), + seed=int(output.metadata.get("stats_seed", 0)), + ) + + default_task_label = output.metadata.get("default_task_label", "task") + fps = float(output.metadata.get("fps", 30.0)) + + feature_shapes: dict[str, tuple[int, ...]] = {} + feature_dtypes: dict[str, str] = {} + tasks_index: dict[str, int] = {} + + # Per-episode buffers. + cur_id: str | None = None + cur_idx = -1 + cur_start_ts: float | None = None + buf_ts: list[float] = [] + buf_obs: dict[str, list[np.ndarray]] = {} + buf_act: dict[str, list[np.ndarray]] = {} + + total_frames = 0 + episodes_meta: list[dict[str, Any]] = [] + shard_path = root / "rlds-00000-of-00001.tfrecord" + + def _bytes(arr: np.ndarray) -> "tf.train.Feature": + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[arr.tobytes()])) + + def _flush(writer: "tf.io.TFRecordWriter") -> None: + nonlocal cur_idx, cur_start_ts + if cur_idx < 0 or not buf_ts: + return + + T = len(buf_ts) + feature_lists: dict[str, "tf.train.FeatureList"] = {} + + def _make_list(arrs: list[np.ndarray]) -> "tf.train.FeatureList": + return tf.train.FeatureList(feature=[_bytes(a) for a in arrs]) + + for k, frames in buf_obs.items(): + feature_lists[f"observation/{k}"] = _make_list(frames) + for k, frames in buf_act.items(): + feature_lists[f"action/{k}"] = _make_list(frames) + + ts_arr = np.asarray(buf_ts, dtype=np.float32) + feature_lists["timestamp"] = tf.train.FeatureList( + feature=[tf.train.Feature(float_list=tf.train.FloatList(value=[t])) for t in ts_arr] + ) + # RLDS step booleans. + is_first = [i == 0 for i in range(T)] + is_last = [i == T - 1 for i in range(T)] + for name, vals in (("is_first", is_first), ("is_last", is_last), ("is_terminal", is_last)): + feature_lists[name] = tf.train.FeatureList(feature=[ + tf.train.Feature(int64_list=tf.train.Int64List(value=[int(v)])) for v in vals + ]) + # Default reward / discount per RLDS convention. + feature_lists["reward"] = tf.train.FeatureList(feature=[ + tf.train.Feature(float_list=tf.train.FloatList(value=[0.0])) for _ in range(T) + ]) + feature_lists["discount"] = tf.train.FeatureList(feature=[ + tf.train.Feature(float_list=tf.train.FloatList(value=[1.0])) for _ in range(T) + ]) + + ctx = tf.train.Features(feature={ + "episode_index": tf.train.Feature(int64_list=tf.train.Int64List(value=[cur_idx])), + "length": tf.train.Feature(int64_list=tf.train.Int64List(value=[T])), + "start_ts": tf.train.Feature(float_list=tf.train.FloatList(value=[float(cur_start_ts or 0.0)])), + "task": tf.train.Feature(bytes_list=tf.train.BytesList(value=[default_task_label.encode()])), + }) + ex = tf.train.SequenceExample( + context=ctx, + feature_lists=tf.train.FeatureLists(feature_list=feature_lists), + ) + writer.write(ex.SerializeToString()) + + episodes_meta.append({ + "episode_index": cur_idx, + "length": T, + "start_ts": float(cur_start_ts or 0.0), + "task": default_task_label, + }) + buf_ts.clear() + buf_obs.clear() + buf_act.clear() + + with tf.io.TFRecordWriter(str(shard_path)) as writer: + for sample in samples: + if sample.episode_id != cur_id: + _flush(writer) + cur_id = sample.episode_id + cur_idx += 1 + cur_start_ts = float(sample.ts) + if default_task_label not in tasks_index: + tasks_index[default_task_label] = len(tasks_index) + + buf_ts.append(float(sample.ts) - (cur_start_ts or 0.0)) + for k, v in sample.observation.items(): + a = np.asarray(v) + buf_obs.setdefault(k, []).append(a) + stats.update(f"observation.{k}", a) + if k not in feature_shapes: + feature_shapes[f"observation/{k}"] = tuple(a.shape) + feature_dtypes[f"observation/{k}"] = str(a.dtype) + for k, v in sample.action.items(): + a = np.asarray(v) + buf_act.setdefault(k, []).append(a) + stats.update(f"action.{k}", a) + if k not in feature_shapes: + feature_shapes[f"action/{k}"] = tuple(a.shape) + feature_dtypes[f"action/{k}"] = str(a.dtype) + total_frames += 1 + + _flush(writer) + + # ── sidecar metadata ───────────────────────────────────────────────────── + features_meta = { + name: {"shape": list(shape), "dtype": feature_dtypes[name]} + for name, shape in feature_shapes.items() + } + features_meta["timestamp"] = {"shape": [], "dtype": "float32"} + features_meta["is_first"] = {"shape": [], "dtype": "int64"} + features_meta["is_last"] = {"shape": [], "dtype": "int64"} + features_meta["is_terminal"] = {"shape": [], "dtype": "int64"} + features_meta["reward"] = {"shape": [], "dtype": "float32"} + features_meta["discount"] = {"shape": [], "dtype": "float32"} + + info = { + "format_version": "rlds-1.0", + "robot": output.metadata.get("robot", "unknown"), + "fps": fps, + "num_episodes": len(episodes_meta), + "num_steps": total_frames, + "num_tasks": len(tasks_index), + "tasks": {idx: task for task, idx in tasks_index.items()}, + "episodes": episodes_meta, + "stats": stats.finalize(), + } + with open(root / "features.json", "w") as f: + json.dump(features_meta, f, indent=2) + with open(root / "dataset_info.json", "w") as f: + json.dump(info, f, indent=2) + + return root diff --git a/dimos/learning/inference/blueprint.py b/dimos/learning/inference/blueprint.py index 485ec7a2d5..1b921e4fc9 100644 --- a/dimos/learning/inference/blueprint.py +++ b/dimos/learning/inference/blueprint.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ACT inference blueprint. ActionReplayer is registered by the per-robot -coordinator blueprint (passed in below). v1 placeholder uses the existing -teleop coordinator; replace with a coordinator that registers ActionReplayer. +"""ACT inference blueprints. + +`ChunkPolicyModule` publishes `joint_command` directly so a coordinator's +servo / position task can consume it without an `ActionReplayer` task in +the tick loop. Compose with the user's coordinator blueprint at the call +site, e.g.:: + + autoconnect(learning_infer_chunkpolicy_only, my_servo_coordinator) """ from __future__ import annotations -from dimos.control.blueprints.teleop import coordinator_teleop_xarm7 from dimos.core.coordination.blueprints import autoconnect from dimos.core.transport import LCMTransport from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera @@ -28,25 +32,27 @@ from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.JointState import JointState -_T_COLOR_IMAGE = "/camera/color_image" -_T_JOINT_STATE = "/coordinator/joint_state" -_T_ACTION_CHUNK = "/learning/action_chunk" +# Stable topics so external tools (lcmspy, dimos topic echo) work without rebuild. +_T_COLOR_IMAGE = "/camera/color_image" +_T_JOINT_STATE = "/coordinator/joint_state" +_T_ACTION_CHUNK = "/learning/action_chunk" +_T_JOINT_COMMAND = "/teleop/joint_command" # matches coordinator_servo_* default + +_INFER_TRANSPORTS = { + ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), + ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), + ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), + ("joint_command", JointState): LCMTransport(_T_JOINT_COMMAND, JointState), +} -learning_infer_xarm7 = autoconnect( +learning_infer_chunkpolicy_only = autoconnect( RealSenseCamera.blueprint(enable_pointcloud=False), ChunkPolicyModule.blueprint( - policy_path="data/runs/act_pick_red", + policy_path="data/runs/act_pickplace_001", inference_rate_hz=30.0, ), - coordinator_teleop_xarm7, # TODO: replace with coordinator_action_replayer_xarm7 -).transports( - { - ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), - ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), - ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), - } -) +).transports(_INFER_TRANSPORTS) -__all__ = ["learning_infer_xarm7"] +__all__ = ["learning_infer_chunkpolicy_only"] diff --git a/dimos/learning/inference/chunk_policy_module.py b/dimos/learning/inference/chunk_policy_module.py index 811dd136f7..f27e736345 100644 --- a/dimos/learning/inference/chunk_policy_module.py +++ b/dimos/learning/inference/chunk_policy_module.py @@ -12,20 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ACT inference Module @ ~30 Hz. +"""ACT inference Module — runs the policy in a background thread. -Reads `/dimos_meta.json` at start() to recover obs schema -(StreamField map + sync). Latches In ports; calls predict_chunk on the -freshest snapshot every tick; publishes ActionChunk over LCM. +Reads ``/dimos_meta.json`` at start() to recover the obs +schema (StreamField map + sync). Latches In ports; calls +``policy.predict_chunk`` on the freshest snapshot every tick; publishes +each result as an ActionChunk over LCM. -Heavy ML deps (`lerobot`, `torch`) imported lazily inside `start()` — -this file is import-light. +When ``publish_joint_command=True`` (default), also emits the first +action of each chunk as a JointState on ``joint_command``. This lets a +coordinator's servo task consume the policy directly without an +``ActionReplayer`` in the 100 Hz tick loop. Once ActionReplayer is wired +in, callers can ignore that port. + +Heavy ML deps (``lerobot``, ``torch``) are imported lazily inside +``start()`` so this file is import-light. """ from __future__ import annotations +import json import threading import time +import traceback +from pathlib import Path from typing import Any import numpy as np @@ -33,16 +43,21 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.learning.dataprep import StreamField, SyncConfig +from dimos.learning.dataprep import StreamField, SyncConfig, resolve_field from dimos.learning.policy.base import ActionChunk, Policy from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.JointState import JointState +DIMOS_META_FILENAME = "dimos_meta.json" + class ChunkPolicyModuleConfig(ModuleConfig): policy_path: str inference_rate_hz: float = 30.0 device: str = "cuda" + # Emit chunk[0] as a JointState on `joint_command` each tick. Useful when + # the downstream coordinator has a servo task but no ActionReplayer task. + publish_joint_command: bool = True class ChunkPolicyModule(Module): @@ -51,10 +66,10 @@ class ChunkPolicyModule(Module): color_image: In[Image] joint_state: In[JointState] action_chunk: Out[ActionChunk] + joint_command: Out[JointState] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - # Latched live messages self._latest_image: Image | None = None self._latest_joint_state: JointState | None = None self._latch_lock = threading.Lock() @@ -64,56 +79,156 @@ def __init__(self, **kwargs: Any) -> None: self._observation: dict[str, StreamField] = {} self._sync: SyncConfig | None = None self._chunk_id: int = 0 + self._last_chunk_ts: float | None = None self._thread: threading.Thread | None = None self._stop = threading.Event() @rpc def start(self) -> None: - """Lazy-import LeRobotPolicy; load checkpoint; read dimos_meta.json - for observation/sync; subscribe to ports; spawn the loop thread.""" - raise NotImplementedError + """Load checkpoint, read dimos_meta, subscribe ports, spawn loop thread.""" + super().start() + + meta_path = Path(self.config.policy_path) / DIMOS_META_FILENAME + if not meta_path.exists(): + raise FileNotFoundError(f"Missing {DIMOS_META_FILENAME} in {self.config.policy_path}") + with open(meta_path) as f: + meta = json.load(f) + + self._observation = {k: StreamField(**v) for k, v in (meta.get("observation") or {}).items()} + sync_cfg = meta.get("sync") or {} + if sync_cfg: + self._sync = SyncConfig(**sync_cfg) + + from dimos.learning.policy.lerobot_policy import LeRobotPolicy + self._policy = LeRobotPolicy.load(self.config.policy_path, device=self.config.device) + + self.color_image.subscribe(self._on_color_image) + self.joint_state.subscribe(self._on_joint_state) + + self._stop.clear() + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() @rpc def stop(self) -> None: - raise NotImplementedError + self._stop.set() + t = self._thread + if t is not None and t.is_alive(): + t.join(timeout=2.0) + self._thread = None + super().stop() @rpc def reload_policy(self, policy_path: str, device: str | None = None) -> None: """Hot-swap the checkpoint without restarting the blueprint.""" - raise NotImplementedError + self.stop() + self.config.policy_path = policy_path + if device is not None: + self.config.device = device + self.start() @rpc def get_status(self) -> dict[str, Any]: - """{'running', 'chunk_count', 'policy_path', 'last_chunk_ts'}.""" - raise NotImplementedError + return { + "running": self._thread is not None and self._thread.is_alive(), + "chunk_count": self._chunk_id, + "policy_path": self.config.policy_path, + "last_chunk_ts": self._last_chunk_ts, + } + + # ── port handlers ──────────────────────────────────────────────────────── + + def _on_color_image(self, msg: Image) -> None: + with self._latch_lock: + self._latest_image = msg + + def _on_joint_state(self, msg: JointState) -> None: + with self._latch_lock: + self._latest_joint_state = msg + + # ── loop ───────────────────────────────────────────────────────────────── def _run_loop(self) -> None: + assert self._policy is not None period = 1.0 / self.config.inference_rate_hz while not self._stop.is_set(): t0 = time.monotonic() - obs = self._build_live_obs() - if obs is None: - time.sleep(period) - continue - - positions = self._policy.predict_chunk(obs) # (T, action_dim) - self.action_chunk.publish( - ActionChunk( - ts=time.time(), + try: + obs = self._build_live_obs() + if obs is None: + self._stop.wait(timeout=period) + continue + + positions = self._policy.predict_chunk(obs) # (T, action_dim) + chunk_ts = time.time() + self.action_chunk.publish(ActionChunk( + ts=chunk_ts, joint_names=self._policy.joint_names, positions=positions, dt=period, chunk_id=self._next_chunk_id(), - ) - ) - time.sleep(max(0.0, period - (time.monotonic() - t0))) + )) + self._last_chunk_ts = chunk_ts + + if self.config.publish_joint_command: + js = JointState( + name=self._policy.joint_names, + position=[float(x) for x in positions[0]], + velocity=[], + ) + js.ts = chunk_ts + self.joint_command.publish(js) + except Exception: + # Single bad tick must not kill the loop. + traceback.print_exc() + + elapsed = time.monotonic() - t0 + if elapsed < period: + self._stop.wait(timeout=period - elapsed) def _build_live_obs(self) -> dict[str, np.ndarray] | None: - """Snapshot latched messages under a lock, project each obs key - through `resolve_field` using `self._observation`. Returns - None if any required stream hasn't received a message yet.""" - raise NotImplementedError + """Snapshot latched messages and project each obs key through `resolve_field`. + Returns None if any required stream hasn't received a message yet. + """ + with self._latch_lock: + latest_image = self._latest_image + latest_joints = self._latest_joint_state + + if not self._observation: + # No spec — fall back to canonical port names. + if latest_image is None or latest_joints is None: + return None + return { + "image": np.asarray(latest_image.data), + "joint_state": np.asarray(latest_joints.position), + } + + out: dict[str, np.ndarray] = {} + for obs_key, sf in self._observation.items(): + port = self._guess_port(sf.stream) + if port == "color_image": + if latest_image is None: + return None + out[obs_key] = resolve_field(latest_image, sf) + elif port == "joint_state": + if latest_joints is None: + return None + out[obs_key] = resolve_field(latest_joints, sf) + else: + # Extend here when adding In ports for new sensor types. + return None + return out + + @staticmethod + def _guess_port(stream_name: str) -> str: + """Route a recorded stream name to one of this module's In ports.""" + n = stream_name.lower() + if "image" in n or "camera" in n or "rgb" in n: + return "color_image" + if "joint_state" in n: + return "joint_state" + return n def _next_chunk_id(self) -> int: cid = self._chunk_id diff --git a/dimos/learning/policy/lerobot_policy.py b/dimos/learning/policy/lerobot_policy.py index c6d764ed23..a14675d7c6 100644 --- a/dimos/learning/policy/lerobot_policy.py +++ b/dimos/learning/policy/lerobot_policy.py @@ -16,6 +16,7 @@ from __future__ import annotations +import json from pathlib import Path from typing import Any @@ -23,10 +24,13 @@ from dimos.learning.policy.base import Policy +DIMOS_META_FILENAME = "dimos_meta.json" + class LeRobotPolicy: - _model: Any # lerobot.policies.pretrained.PreTrainedPolicy + _model: Any # lerobot ACTPolicy / PreTrainedPolicy _stats: dict[str, Any] + _dimos_meta: dict[str, Any] _chunk_size: int _joint_names: list[str] _device: str @@ -35,16 +39,66 @@ def __init__( self, model: Any, stats: dict[str, Any], + dimos_meta: dict[str, Any], chunk_size: int, joint_names: list[str], device: str, ) -> None: - raise NotImplementedError + self._model = model + self._stats = stats + self._dimos_meta = dimos_meta + self._chunk_size = chunk_size + self._joint_names = joint_names + self._device = device @classmethod def load(cls, path: str | Path, device: str = "cuda") -> LeRobotPolicy: - """Load checkpoint dir: model.safetensors + meta/stats.json + dimos_meta.json.""" - raise NotImplementedError + """Load checkpoint + dataset stats + dimos_meta sidecar. + + ``path`` may be a TrainerModule output dir (we walk into + ``checkpoints/last/pretrained_model/``) or an exact + ``pretrained_model`` dir. + """ + path = Path(path) + pretrained_dir, run_dir = _resolve_checkpoint_dirs(path) + + meta_path = run_dir / DIMOS_META_FILENAME + if not meta_path.exists(): + raise FileNotFoundError(f"Missing {DIMOS_META_FILENAME} in {run_dir}") + with open(meta_path) as f: + dimos_meta = json.load(f) + + stats_path = _find_stats(run_dir, dimos_meta) + with open(stats_path) as f: + stats = json.load(f) + + # Lazy import — keeps torch/CUDA out of the dimos runtime at module load. + try: + from lerobot.policies.act.modeling_act import ACTPolicy + except ImportError: + try: + from lerobot.common.policies.act.modeling_act import ACTPolicy + except ImportError as e: + raise RuntimeError( + "lerobot is required to load a checkpoint; install with " + "`pip install lerobot` (>=0.3)" + ) from e + + model = ACTPolicy.from_pretrained(str(pretrained_dir)) + model.eval() + model.to(device) + + chunk_size = int(dimos_meta.get("chunk_size", 50)) + joint_names = dimos_meta.get("joint_names") or _infer_joint_names(model) + + return cls( + model=model, + stats=stats, + dimos_meta=dimos_meta, + chunk_size=chunk_size, + joint_names=joint_names, + device=device, + ) @property def chunk_size(self) -> int: @@ -55,9 +109,140 @@ def joint_names(self) -> list[str]: return self._joint_names def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: - """Normalize obs → forward pass → unnormalize → (chunk_size, action_dim).""" - raise NotImplementedError + """Forward pass on ``obs``. Returns shape ``(chunk_size, action_dim)``. + + ``obs`` keys are the dataset spec's observation keys (e.g. "image", + "joint_state"). They are translated to lerobot's canonical names + (``observation.images.*`` / ``observation.state``) using the + dimos_meta sidecar so train and infer agree. + """ + import torch # lazy + + batch = self._build_batch(obs) + with torch.inference_mode(): + chunk = self._forward_chunk(batch) + if chunk.ndim == 3: + chunk = chunk[0] # (B, T, A) → (T, A) + return chunk.detach().cpu().numpy() + + # ── internals ──────────────────────────────────────────────────────────── + + def _build_batch(self, obs: dict[str, np.ndarray]) -> dict[str, Any]: + import torch + + batch: dict[str, Any] = {} + observation_map = self._dimos_meta.get("observation", {}) + for user_key, value in obs.items(): + arr = np.asarray(value) + if arr.ndim >= 3: + # HWC uint8 → 1xCxHxW float32 / 255 (lerobot's expected layout). + chw = np.transpose(arr, (2, 0, 1)) if arr.shape[-1] in (1, 3, 4) else arr + t = torch.from_numpy(chw.astype(np.float32) / 255.0).unsqueeze(0) + feat = f"observation.images.{user_key}" + else: + t = torch.from_numpy(arr.astype(np.float32)).unsqueeze(0) + # Single low-dim observation is canonical "observation.state". + low_dim_other = any( + k != user_key and k in obs and np.asarray(obs[k]).ndim < 3 + for k in observation_map + ) + feat = f"observation.{user_key}" if low_dim_other else "observation.state" + batch[feat] = t.to(self._device) + return batch + + def _forward_chunk(self, batch: dict[str, Any]) -> Any: + """Prefer ``predict_action_chunk`` (newer API); fall back to repeated + ``select_action`` after ``reset()`` to assemble a chunk.""" + if hasattr(self._model, "predict_action_chunk"): + return self._model.predict_action_chunk(batch) + if hasattr(self._model, "select_action"): + import torch + + self._model.reset() + actions = [self._model.select_action(batch) for _ in range(self._chunk_size)] + return torch.stack(actions, dim=1) # (B, T, A) + raise RuntimeError("lerobot policy has neither predict_action_chunk nor select_action") + + +def _resolve_checkpoint_dirs(path: Path) -> tuple[Path, Path]: + """Return ``(pretrained_model_dir, run_dir)`` for any supported input path. + + Run dir layout (lerobot 0.3+):: + + / + dimos_meta.json + checkpoints//pretrained_model/ # lerobot safetensors + checkpoints/last -> symlink to latest + """ + if (path / "model.safetensors").exists(): + # `…/checkpoints//pretrained_model` → run_dir is 3 parents up. + return path, path.parent.parent.parent + + last = path / "checkpoints" / "last" / "pretrained_model" + if last.exists(): + return last, path + + ckpts = path / "checkpoints" + if ckpts.is_dir(): + numeric = sorted( + (p for p in ckpts.iterdir() if p.is_dir() and p.name.isdigit()), + key=lambda p: int(p.name), + ) + if numeric and (numeric[-1] / "pretrained_model").exists(): + return numeric[-1] / "pretrained_model", path + + raise FileNotFoundError( + f"No lerobot checkpoint found under {path}. " + f"Expected {path}/checkpoints/last/pretrained_model/ " + f"or a numeric checkpoint subdir." + ) + + +def _find_stats(run_dir: Path, dimos_meta: dict[str, Any]) -> Path: + """Locate ``stats.json`` near a checkpoint. + + Lookup order: + 1. ``/meta/stats.json`` + 2. dimos_meta's recorded ``dataset_path`` / ``source`` + 3. ``/../datasets//meta/stats.json`` (sibling convention) + """ + candidates: list[Path] = [ + run_dir / "meta" / "stats.json", + run_dir / "stats.json", + ] + metadata = dimos_meta.get("metadata") or {} + for key in ("dataset_path", "source"): + v = metadata.get(key) or dimos_meta.get(key) + if v and Path(v).suffix not in (".db", ".sqlite"): + candidates.append(Path(v) / "meta" / "stats.json") + + for parent in (run_dir.parent, run_dir.parent.parent): + if parent and (parent / "datasets").is_dir(): + for d in (parent / "datasets").iterdir(): + if (d / "meta" / "stats.json").is_file(): + candidates.append(d / "meta" / "stats.json") + + for c in candidates: + if c.exists(): + return c + raise FileNotFoundError(f"stats.json not found near {run_dir}; tried: {candidates}") + + +def _infer_joint_names(model: Any) -> list[str]: + """Synthetic joint-name fallback when dimos_meta didn't record any.""" + cfg = getattr(model, "config", None) + action_dim: int | None = None + if cfg is not None: + out_shapes = getattr(cfg, "output_shapes", None) or {} + if "action" in out_shapes: + action_dim = out_shapes["action"][-1] + if action_dim is None: + af = getattr(cfg, "action_feature", None) + action_dim = getattr(af, "shape", [None])[-1] if af is not None else None + if action_dim is None: + action_dim = 7 + return [f"joint{i}" for i in range(action_dim)] -# Protocol conformance check at import time. +# Protocol conformance assertion at import time. _: type[Policy] = LeRobotPolicy diff --git a/dimos/learning/specs/spec_v2.md b/dimos/learning/specs/spec_v2.md new file mode 100644 index 0000000000..cd4518aefe --- /dev/null +++ b/dimos/learning/specs/spec_v2.md @@ -0,0 +1,375 @@ +# Stage 1 — Data (v2) + +Two phases: + +1. **Recording** — live; teleop + camera + `EpisodeMonitorModule` produce + LCM streams. `RecordReplay` (CLI flag) captures every active topic + into `session.db` (memory2 `SqliteStore` format). +2. **DataPrep** — offline; `DataPrepModule` reads `session.db` and writes + a training-ready dataset on disk in one of three formats (LeRobot v2, + HDF5, RLDS). + +Same code paths drive both phases as DimOS blueprints. Format dispatch is +config-only; the same module backs every output format. + +--- + +## Phase A — Recording + +### Blueprint + +```python +# dimos/learning/collection/blueprint.py +_DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} +_TRANSPORTS = { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), +} + +learning_collect_quest_xarm7 = autoconnect( + teleop_quest_xarm7, + RealSenseCamera.blueprint(enable_pointcloud=False), + EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), +).transports(_TRANSPORTS) + +# Variants (one per arm config) follow the same pattern: +# learning_collect_quest_xarm6 +# learning_collect_quest_piper +# learning_collect_quest_dual +``` + +`RecordReplay` (`--record-path`) captures every transport above into +`session.db`. Recording is a transport-layer hook, not a Module — every +LCM stream is recorded uniformly. + +### EpisodeMonitorModule + +Translates teleop input (Quest buttons, optional keyboard) into the +canonical `EpisodeStatus` stream. DataPrep reads only this stream, never +raw button presses. + +```python +# dimos/learning/collection/episode_monitor.py + +class EpisodeStatus(BaseModel): + state: Literal["idle", "recording"] + episodes_saved: int + episodes_discarded: int + current_episode_start_ts: float | None + last_event: Literal["start", "save", "discard", "init"] = "init" + task_label: str | None = None + + +class KeyPress(BaseModel): + key: str + ts: float + + +class EpisodeMonitorModuleConfig(ModuleConfig): + button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} + keyboard_map: dict[Literal["start", "save", "discard"], str] = {} + default_task_label: str | None = None + + +class EpisodeMonitorModule(Module): + config: EpisodeMonitorModuleConfig + + buttons: In[Buttons] + keyboard: In[KeyPress] + status: Out[EpisodeStatus] + + @rpc + def reset_counters(self) -> EpisodeStatus: ... + @rpc + def get_status(self) -> EpisodeStatus: ... +``` + +State machine (mirrored offline by `DataPrep.extract_episodes`): + +``` +IDLE --start--> RECORDING +RECORDING --save--> IDLE (commit, saved += 1) +RECORDING --discard--> IDLE (drop, discarded += 1) +RECORDING --start--> RECORDING (auto-commit prev, begin new) +session end mid-episode: always discarded +``` + +Friendly button names (`A`/`B`/`X`/...) resolve to `Buttons` attributes +via `BUTTON_ALIASES` (e.g. `"A"` → `right_primary`). Override with the +attribute name directly in `button_map`. + +### Run + +```bash +dimos run learning-collect-quest-xarm7 --record-path data/recordings/pick_red.db +``` + +--- + +## Phase B — DataPrep + +### Blueprints + +```python +# dimos/learning/dataprep_blueprint.py + +learning_dataprep = autoconnect( + DataPrepModule.blueprint( + source="data/recordings/pickplace_001.db", + episodes=EpisodeExtractor(extractor="ranges", ranges=[(t0, t1)]), + observation={ + "image": StreamField(stream="color_image", field="data"), + "joint_state": StreamField(stream="joint_state", field="position"), + }, + action={ + "joint_target": StreamField(stream="joint_state", field="position"), + }, + sync=SyncConfig(anchor="image", rate_hz=14.0, tolerance_ms=80.0), + output=OutputConfig( + format="lerobot", + path="data/datasets/pickplace_001", + metadata={"fps": 14, "robot": "xarm7", "default_task_label": "pick_and_place"}, + ), + auto_run=True, + ), +).transports({}) + + +# Variant for one-demo-per-file recordings (no episode_status stream). +learning_dataprep_whole_session = autoconnect( + DataPrepModule.blueprint( + source="data/session.db", + episodes=EpisodeExtractor(extractor="whole_session"), + observation={...}, + action={...}, + sync=SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0), + output=OutputConfig(format="lerobot", path="data/datasets/default", + metadata={"fps": 30, "robot": "xarm7"}), + auto_run=True, + ), +).transports({}) +``` + +All `DataPrepModuleConfig` fields are defaulted — the DimOS CLI's +per-module override path validates user kwargs in isolation, so +required-without-default fields would reject partial `-o ...` overrides. +Real values come from the blueprint atom; CLI flags overlay on top. + +### DataPrepModule + +```python +# dimos/learning/dataprep.py +from dimos.protocol.service.spec import BaseConfig + + +class EpisodeExtractor(BaseConfig): + extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" + status_stream: str = "episode_status" + ranges: list[tuple[float, float]] | None = None + + +class StreamField(BaseConfig): + stream: str + field: str | None = None + + +class SyncConfig(BaseConfig): + anchor: str + rate_hz: float + tolerance_ms: float + strategy: Literal["nearest", "interp"] = "nearest" + + +class OutputConfig(BaseConfig): + format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" + path: Path + metadata: dict[str, Any] = {} +``` + +```python +# dimos/learning/dataprep_module.py + +class DataPrepModuleConfig(ModuleConfig): + source: str = "" + episodes: EpisodeExtractor = EpisodeExtractor() + observation: dict[str, StreamField] = {} + action: dict[str, StreamField] = {} + sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) + output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") + auto_run: bool = False + + +class DataPrepModule(Module): + config: DataPrepModuleConfig + + @rpc + def build(self) -> None: ... # spawns build thread; returns immediately + @rpc + def get_status(self) -> dict[str, Any]: ... # state, current_phase, progress_pct, samples_seen, error + @rpc + def inspect(self) -> dict[str, Any]: ... # streams, episode counts, duration distribution +``` + +`build()` opens the `SqliteStore`, walks samples episode-by-episode, +hands them to the configured format writer, and snapshots the spec +(`config.model_dump()`) into `/dimos_meta.json`. The build +thread is a daemon — there is no mid-iteration cancel. + +### Pure helpers (in `dataprep.py`) + +Stateless, importable without booting a Module. Reused by every format +writer **and** by `ChunkPolicyModule._build_live_obs` at inference time +(single source of truth for obs construction). + +```python +def resolve_field(msg: Any, ref: StreamField) -> np.ndarray: ... + +def extract_episodes(store: SqliteStore, cfg: EpisodeExtractor) -> list[Episode]: + """ + episode_status: replay EpisodeMonitorModule's state machine over + the recorded EpisodeStatus events. + ranges: emit one Episode per (start, end) tuple. + whole_session: one Episode covering every stream's combined time range. + """ + +def iter_episode_samples( + store: SqliteStore, + episode: Episode, + streams: dict[str, StreamField], # observation ∪ action + sync: SyncConfig, + obs_keys: set[str] | None = None, + action_keys: set[str] | None = None, +) -> Iterator[Sample]: + """ + Anchor-rate timestep walker. Caches each stream once per episode, + bisect-nearest within tolerance_ms; skips frames where any required + stream lacks a nearby sample. + """ + +def compute_stats( + samples: Iterator[Sample], + image_subsample: int = 10, + quantile_reservoir: int = 10_000, + seed: int = 0, +) -> dict[str, Any]: + """Welford mean/std + reservoir quantiles. Image features (≥3D) + subsampled and reduced to per-channel summaries.""" +``` + +### Format writers (`dimos/learning/formats/`) + +All three writers consume `Iterator[Sample] + OutputConfig` and accumulate +stats via the shared `StreamingStats` (`formats/_stats.py`) so format- +agnostic stats logic exists in exactly one place. + +```python +# dimos/learning/formats/_stats.py +class StreamingStats: + def __init__(self, image_subsample=10, quantile_reservoir=10_000, seed=0): ... + def update(self, name: str, value: np.ndarray) -> None: ... + def finalize(self) -> dict[str, dict[str, Any]]: ... +``` + +| Format | Layout | Heavy dep | +|---|---|---| +| `lerobot` | `meta/{info,episodes,tasks,stats}.json` + `data/chunk-000/episode_NNNNNN.parquet` + `videos/chunk-000/observation.images./episode_NNNNNN.mp4` | `pyarrow`, `opencv-python` | +| `hdf5` | single `.hdf5` with `/episodes/episode_NNNNNN/{timestamp, observation/, action/}` + `/stats/` + `/tasks` + root attrs | `h5py` | +| `rlds` | `rlds-NNNNN-of-MMMMM.tfrecord` (one `SequenceExample` per episode, RLDS step protocol) + `features.json` + `dataset_info.json` | `tensorflow` | + +#### LeRobot v2 specifics + +- **Image columns are NOT in the parquet** — lerobot's + `get_hf_features_from_features` skips dtype="video" and reads frames + from MP4 at `__getitem__` time. +- **Timestamps are episode-relative** (subtract `episode.start_ts`) + because lerobot stores `timestamp` as float32 and validates frame-to- + frame deltas against `1/fps`. Absolute Unix epoch values would collide + in float32. +- **Feature naming follows lerobot convention** — single low-dim obs ⇒ + `observation.state`, single action key ⇒ `action`, image keys ⇒ + `observation.images.`. +- **`info.json` features include per-dim `names` lists** (required by + lerobot 0.3+). + +### dimos_meta.json sidecar + +Written into every dataset directory; describes how it was built. Used +downstream by training (copies + adds policy fields) and by inference +(reads it at `start()` to recover the obs schema — no operator-supplied +spec path). + +```json +{ + "source": "data/recordings/pickplace_001.db", + "observation": {"image": {...}, "joint_state": {...}}, + "action": {"joint_target": {...}}, + "sync": {"anchor": "image", "rate_hz": 14.0, ...}, + "episodes": [{"id": "ep_000000", "start_ts": ..., "end_ts": ..., "task_label": ...}], + "format": "lerobot", + "metadata": {"fps": 14, "robot": "xarm7", ...} +} +``` + +### Run + +```bash +dimos run learning-dataprep +``` + +Override per run: + +```bash +dimos run learning-dataprep \ + -o dataprepmodule.source=data/recordings/foo.db \ + -o dataprepmodule.output.path=data/datasets/foo \ + -o dataprepmodule.output.format=hdf5 +``` + +For complex nested overrides (observation/action stream maps), use a JSON +config: + +```bash +dimos run learning-dataprep -c data/foo_dataset.json +``` + +--- + +## End-to-end + +```bash +dimos run learning-collect-quest-xarm7 --record-path data/recordings/pick_red.db +dimos run learning-dataprep \ + -o dataprepmodule.source=data/recordings/pick_red.db \ + -o dataprepmodule.output.path=data/datasets/pick_red +``` + +``` +data/recordings/pick_red.db ─► data/datasets/pick_red/ + ├── data/ (parquet) ─┐ + ├── videos/ (MP4) ├─ format=lerobot + └── meta/ (info, episodes, ─┘ + tasks, stats) + └── dimos_meta.json (always) +``` + +--- + +## Compatibility note — JpegCodec + +Recordings made before commit `` ship Image blobs with a +1-byte format tag (`b'J'`) preceding the LCM envelope. `JpegCodec.decode` +strips it transparently so old + new sessions both read cleanly. Affects +any consumer of recorded JPEG-encoded `Image` streams, not just learning. + +--- + +## Module / non-Module split for Stage 1 + +| Component | Type | Why | +|---|---|---| +| `EpisodeMonitorModule` | `Module` | Long-lived; subscribes to teleop input; publishes status | +| `DataPrepModule` | `Module` | Long-running build job; thread + `get_status` RPC | +| `RecordReplay` | transport hook | Captures every stream uniformly; not a per-Module concern | +| `StreamingStats` | helper class | No lifecycle, no I/O — pure accumulator | +| `extract_episodes` / `iter_episode_samples` / `resolve_field` / `compute_stats` | functions | Pure helpers; reused by inference | diff --git a/dimos/learning/training/blueprint.py b/dimos/learning/training/blueprint.py index cb170f45b4..2bfd96d0dd 100644 --- a/dimos/learning/training/blueprint.py +++ b/dimos/learning/training/blueprint.py @@ -12,18 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ACT training blueprint. RPC-only surface (no streams).""" +"""ACT training blueprint. RPC-only surface (no streams). + +Defaults are tuned for the local pickplace_001 demo: 2k steps, batch=4, CPU. +For real training, override via: + dimos run learning-train -o trainermodule.bc.steps=100000 -o trainermodule.bc.device=cuda +""" from __future__ import annotations from dimos.core.coordination.blueprints import autoconnect +from dimos.learning.training.configs import BCConfig from dimos.learning.training.trainer_module import TrainerModule learning_train = autoconnect( TrainerModule.blueprint( - dataset_path="data/datasets/pick_red/", - output_dir="data/runs/act_pick_red", + dataset_path="data/datasets/pickplace_001", + output_dir="data/runs/act_pickplace_001", + bc=BCConfig( + steps=2000, + batch_size=4, + device="cpu", + ), auto_run=True, + overwrite=True, ), ).transports({}) diff --git a/dimos/learning/training/train.py b/dimos/learning/training/train.py index 57a46d1f91..4252aa3a7d 100644 --- a/dimos/learning/training/train.py +++ b/dimos/learning/training/train.py @@ -12,18 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ACT training entry point. Called directly by TrainerModule. +"""ACT training entry point. -`train_bc` lazy-imports lerobot/torch, builds Hydra-style argv from -BCConfig, calls lerobot's trainer in-process, appends `dimos_meta.json` -to output_dir, returns the checkpoint path. - -Stats live at `/meta/stats.json` (written by DataPrep). Training -reads them via lerobot's loader; never recomputes. +`train_bc` subprocesses ``python -m lerobot.scripts.train`` with argv +translated from `BCConfig`. Lerobot is never imported in-process so the +dimos runtime stays free of torch/CUDA. After a successful run we write +``dimos_meta.json`` next to the checkpoint so `LeRobotPolicy.load` can +recover the obs/action schema. """ from __future__ import annotations +import argparse +import json +import random +import shutil +import subprocess +import sys from pathlib import Path from typing import Any @@ -38,20 +43,122 @@ def train_bc( cfg: BCConfig, output_dir: str | Path, config_overrides: dict[str, Any] | None = None, + overwrite: bool = True, + resume: bool = False, ) -> Path: - """Train ACT on a prepared dataset. Returns checkpoint dir.""" - raise NotImplementedError + """Train ACT on a prepared LeRobot v2 dataset. Returns the checkpoint dir. + Args: + overwrite: if True (default) wipes ``output_dir`` before launching. + Lerobot's ``cfg.validate()`` refuses to run if the dir exists. + resume: pass ``--resume=true`` to lerobot. Takes precedence over + ``overwrite``. + """ + dataset_path = Path(dataset_path) + output_dir = Path(output_dir) + + if resume: + overwrite = False + elif overwrite and output_dir.exists(): + print(f"[train_bc] removing existing {output_dir}", flush=True) + shutil.rmtree(output_dir) + + argv = _build_lerobot_argv(cfg, dataset_path, output_dir) + if resume: + argv.append("--resume=true") + if config_overrides: + for k, v in config_overrides.items(): + argv.append(f"--{k}={v}") + + print(f"[train_bc] launching lerobot ({len(argv)} args, output → {output_dir})", flush=True) + # Log alongside output_dir, not inside — lerobot creates output_dir itself + # and its validate() refuses if it already exists. + output_dir.parent.mkdir(parents=True, exist_ok=True) + log_path = output_dir.parent / f"{output_dir.name}.lerobot.log" + + # Stream lerobot stdout to a full log file + filtered terminal output. + interesting = ( + "step:", "loss:", "Logs will be", "Creating dataset", "Output dir", + "Saved checkpoint", "epoch", "loss", "lr=", "ETA", "ERROR", "Error", + "WARNING", "Traceback", + ) + proc = subprocess.Popen(argv, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, text=True, bufsize=1) + with open(log_path, "w") as logf: + assert proc.stdout is not None + for line in proc.stdout: + logf.write(line) + logf.flush() + if any(kw in line for kw in interesting): + sys.stdout.write(line) + sys.stdout.flush() + proc.wait() + if proc.returncode != 0: + print(f"[train_bc] lerobot exited {proc.returncode} — full log: {log_path}", flush=True) + raise subprocess.CalledProcessError(proc.returncode, argv) + + _write_dimos_meta(output_dir, dataset_path, cfg) + return output_dir -def _build_lerobot_argv(cfg: BCConfig, dataset_path: Path, output_dir: Path) -> list[str]: - """Translate BCConfig → Hydra-style CLI args for lerobot's trainer.""" - raise NotImplementedError +def _build_lerobot_argv(cfg: BCConfig, dataset_path: Path, output_dir: Path) -> list[str]: + """Translate BCConfig → argv for ``lerobot.scripts.train`` (lerobot 0.3.x). -def _write_dimos_meta(output_dir: Path, dataset_path: Path) -> None: - """Read /dimos_meta.json, add policy fields - (joint_names, chunk_size, policy_type), write to /dimos_meta.json.""" - raise NotImplementedError + LeRobot 0.4.x renamed the entry point to ``lerobot.scripts.lerobot_train`` + and adjusted some draccus flag names — adjust this function if you pin + a different version. + """ + return [ + sys.executable, "-m", "lerobot.scripts.train", + f"--policy.type={cfg.policy_type}", + f"--policy.chunk_size={cfg.chunk_size}", + f"--policy.n_action_steps={cfg.chunk_size}", + f"--policy.n_obs_steps={cfg.n_obs_steps}", + f"--policy.dim_model={cfg.hidden_dim}", + f"--policy.n_encoder_layers={cfg.n_layers}", + f"--policy.n_decoder_layers={cfg.n_layers}", + f"--policy.n_heads={cfg.n_heads}", + f"--policy.use_vae={str(cfg.use_vae).lower()}", + f"--policy.kl_weight={cfg.kl_weight}", + f"--policy.vision_backbone={cfg.vision_backbone}", + f"--policy.pretrained_backbone_weights={'ResNet18_Weights.IMAGENET1K_V1' if cfg.pretrained else 'null'}", + f"--policy.device={cfg.device}", + # push_to_hub defaults True in lerobot and triggers a repo_id requirement. + "--policy.push_to_hub=false", + "--dataset.repo_id=local", + f"--dataset.root={dataset_path}", + f"--steps={cfg.steps}", + f"--batch_size={cfg.batch_size}", + f"--optimizer.lr={cfg.lr}", + f"--optimizer.weight_decay={cfg.weight_decay}", + f"--save_freq={cfg.save_every}", + f"--eval_freq={cfg.eval_every}", + "--wandb.enable=false", + f"--seed={cfg.seed}", + f"--output_dir={output_dir}", + # Note: do NOT pass --env — its choice-class decoder rejects "none"; + # leaving it unset disables eval cleanly. + ] + + +def _write_dimos_meta(output_dir: Path, dataset_path: Path, cfg: BCConfig) -> None: + """Write the inference sidecar at ``/dimos_meta.json``. + + Combines the dataset's dimos_meta (obs/action streams, sync) with policy + fields (type, chunk_size, n_obs_steps) and the dataset_path so + `LeRobotPolicy.load` can resolve `meta/stats.json`. + """ + src = dataset_path / DIMOS_META_FILENAME + base: dict[str, Any] = json.load(open(src)) if src.exists() else {} + base.update({ + "dataset_path": str(dataset_path), + "policy_type": cfg.policy_type, + "chunk_size": cfg.chunk_size, + "n_obs_steps": cfg.n_obs_steps, + "joint_names": base.get("joint_names"), # often None; inference falls back + }) + with open(output_dir / DIMOS_META_FILENAME, "w") as f: + json.dump(base, f, indent=2, default=str) def train_val_split( @@ -60,11 +167,69 @@ def train_val_split( val_ratio: float | None = None, seed: int = 0, ) -> tuple[list[int], list[int]]: - """Partition `episodes` indices into (train_ids, val_ids). + """Partition episode indices into (train_ids, val_ids). - Resolution order (first non-None wins): - 1. `val_episode_ids` — explicit whitelist - 2. `val_ratio` — deterministic random split via `seed` - 3. both None — empty val (everything is train) + Resolution order: ``val_episode_ids`` (whitelist) > ``val_ratio`` + (deterministic via ``seed``) > both None (everything in train). """ - raise NotImplementedError + n = len(episodes) + all_ids = list(range(n)) + + if val_episode_ids is not None: + val_set = set(val_episode_ids) + return ([i for i in all_ids if i not in val_set], + [i for i in all_ids if i in val_set]) + + if val_ratio is not None: + rng = random.Random(seed) + shuffled = all_ids[:] + rng.shuffle(shuffled) + n_val = int(round(n * val_ratio)) + return sorted(shuffled[n_val:]), sorted(shuffled[:n_val]) + + return all_ids, [] + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI: `python -m dimos.learning.training.train bc --output ...` +# ───────────────────────────────────────────────────────────────────────────── + + +def main() -> None: + parser = argparse.ArgumentParser(prog="dimos.learning.training.train") + sub = parser.add_subparsers(dest="kind", required=True) + + p_bc = sub.add_parser("bc", help="Train an ACT (BC) policy") + p_bc.add_argument("dataset", help="path to LeRobot v2 dataset directory") + p_bc.add_argument("--output", required=True, help="checkpoint output directory") + p_bc.add_argument("--config", help="path to BCConfig JSON override") + p_bc.add_argument("--steps", type=int) + p_bc.add_argument("--batch-size", type=int) + p_bc.add_argument("--chunk-size", type=int) + p_bc.add_argument("--device", type=str) + p_bc.add_argument("-o", "--override", action="append", default=[], + help="extra lerobot CLI override, e.g. -o optimizer.lr=5e-5") + + args = parser.parse_args() + + if args.kind == "bc": + cfg_kwargs: dict[str, Any] = json.load(open(args.config)) if args.config else {} + for k, v in (("steps", args.steps), ("batch_size", args.batch_size), + ("chunk_size", args.chunk_size), ("device", args.device)): + if v is not None: + cfg_kwargs[k] = v + cfg = BCConfig(**cfg_kwargs) + + overrides: dict[str, Any] = {} + for o in args.override: + if "=" not in o: + parser.error(f"--override must be key=value, got {o!r}") + k, v = o.split("=", 1) + overrides[k] = v + + out = train_bc(args.dataset, cfg, args.output, config_overrides=overrides) + print(f"[train_bc] checkpoint at: {out}") + + +if __name__ == "__main__": + main() diff --git a/dimos/learning/training/trainer_module.py b/dimos/learning/training/trainer_module.py index d828dbeaf9..eb32e530e9 100644 --- a/dimos/learning/training/trainer_module.py +++ b/dimos/learning/training/trainer_module.py @@ -12,28 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ACT training Module (v1, inline). +"""ACT training Module — thin wrapper around `train_bc`. -Runs `train_bc` on a daemon thread inside its own worker. No subprocess. -Lazy imports keep `lerobot` / `torch` / CUDA out of the worker until -`train()` is called. Metrics → TensorBoard. No cancel() in v1. +Spawns `train_bc` on a daemon thread (which subprocesses +``python -m lerobot.scripts.train``). Exposes: + + @rpc start() lifecycle (auto-fires train if auto_run) + @rpc train(...) kick off a training job + @rpc get_status() current state + checkpoint dir + @rpc stop() best-effort shutdown + +There is no cancel(): the lerobot subprocess is sent SIGTERM only on +process exit. Heavy deps (torch, lerobot) stay in the subprocess. """ from __future__ import annotations +import json +import shutil +import subprocess import threading +import traceback +from pathlib import Path from typing import Any from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig +from dimos.learning.training.configs import BCConfig +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() class TrainerModuleConfig(ModuleConfig): dataset_path: str = "" output_dir: str = "" - config_path: str | None = None # optional BCConfig YAML override + # ACT hyperparams. CLI override pattern: + # -o trainermodule.bc.steps=2000 -o trainermodule.bc.batch_size=4 + bc: BCConfig = BCConfig() + # Optional JSON file with BCConfig overrides; merged on top of `bc`. + config_path: str | None = None auto_run: bool = False + overwrite: bool = True # wipe output_dir before training (lerobot refuses to overwrite) + resume: bool = False # pass --resume=true to lerobot + # Lerobot 0.3.x does not write tensorboard event files; the launch is a + # no-op there and shows an empty UI. Disabled until we wire a stdout + # parser → SummaryWriter on our side. + tensorboard: bool = False tensorboard_port: int = 6006 + tensorboard_host: str = "0.0.0.0" class TrainerModule(Module): @@ -44,18 +71,30 @@ def __init__(self, **kwargs: Any) -> None: self._thread: threading.Thread | None = None self._lock = threading.Lock() self._status: dict[str, Any] = { - "state": "idle", # idle | running | succeeded | failed - "checkpoint_dir": None, - "error": None, + "state": "idle", # idle | running | succeeded | failed + "checkpoint_dir": None, + "tensorboard_url": None, + "error": None, } + self._tb_proc: subprocess.Popen[bytes] | None = None @rpc def start(self) -> None: - raise NotImplementedError + super().start() + if self.config.auto_run: + self.train() @rpc def stop(self) -> None: - raise NotImplementedError + # Train thread is daemon: dies with the process. No mid-run interrupt. + if self._tb_proc is not None and self._tb_proc.poll() is None: + logger.info("[trainer] stopping tensorboard pid=%s", self._tb_proc.pid) + self._tb_proc.terminate() + try: + self._tb_proc.wait(timeout=2.0) + except subprocess.TimeoutExpired: + self._tb_proc.kill() + super().stop() @rpc def train( @@ -64,13 +103,34 @@ def train( output_dir: str | None = None, config_overrides: dict[str, Any] | None = None, ) -> None: - """Spawn a daemon thread running train_bc; returns immediately. + """Spawn a daemon thread running ``train_bc``; returns immediately. Raises if a run is already in progress.""" - raise NotImplementedError + with self._lock: + if self._status["state"] == "running": + raise RuntimeError("training already in progress") + self._status.update(state="running", checkpoint_dir=None, error=None) + + ds = dataset_path or self.config.dataset_path + od = output_dir or self.config.output_dir + if not ds or not od: + with self._lock: + self._status.update(state="failed", + error="dataset_path and output_dir are required") + raise ValueError("dataset_path and output_dir are required") + + self._maybe_start_tensorboard(Path(od)) + + self._thread = threading.Thread( + target=self._run_training, + args=(ds, od, config_overrides), + daemon=True, + ) + self._thread.start() @rpc def get_status(self) -> dict[str, Any]: - raise NotImplementedError + with self._lock: + return dict(self._status) def _run_training( self, @@ -78,8 +138,70 @@ def _run_training( output_dir: str, config_overrides: dict[str, Any] | None, ) -> None: - """Thread target. Lazy-imports train_bc + BCConfig; updates _status.""" - raise NotImplementedError + try: + from dimos.learning.training.train import train_bc + + cfg_kwargs = self.config.bc.model_dump() + if self.config.config_path: + with open(self.config.config_path) as f: + cfg_kwargs.update(json.load(f)) + cfg = BCConfig(**cfg_kwargs) + + ckpt = train_bc( + dataset_path, cfg, output_dir, + config_overrides=config_overrides, + overwrite=self.config.overwrite, + resume=self.config.resume, + ) + + with self._lock: + self._status.update(state="succeeded", checkpoint_dir=str(ckpt)) + except Exception as e: + with self._lock: + self._status.update( + state="failed", + error=f"{type(e).__name__}: {e}\n{traceback.format_exc()}", + ) + + # ── tensorboard ────────────────────────────────────────────────────────── + + def _maybe_start_tensorboard(self, logdir: Path) -> None: + """Spawn ``tensorboard --logdir `` if enabled and available.""" + if not self.config.tensorboard or self.config.tensorboard_port == 0: + return + if self._tb_proc is not None and self._tb_proc.poll() is None: + return + + tb_bin = shutil.which("tensorboard") + if tb_bin is None: + logger.warning( + "[trainer] tensorboard binary not found on PATH — skipping. " + "Install with: pip install tensorboard" + ) + return + + # Do NOT pre-create logdir — lerobot's cfg.validate() refuses if the + # output dir exists. Tensorboard polls happily on a missing dir. + port = self.config.tensorboard_port + host = self.config.tensorboard_host + try: + self._tb_proc = subprocess.Popen( + [tb_bin, "--logdir", str(logdir), + "--port", str(port), "--host", host, + "--reload_interval", "5"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except Exception as e: + logger.warning("[trainer] failed to launch tensorboard: %s", e) + return + + view_host = "localhost" if host in ("0.0.0.0", "") else host + url = f"http://{view_host}:{port}/" + with self._lock: + self._status["tensorboard_url"] = url + logger.info("[trainer] tensorboard launched pid=%s — view at %s", + self._tb_proc.pid, url) __all__ = ["TrainerModule", "TrainerModuleConfig"] diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py index 3d854400b1..80f22696a2 100644 --- a/dimos/memory2/codecs/jpeg.py +++ b/dimos/memory2/codecs/jpeg.py @@ -36,4 +36,9 @@ def encode(self, value: Image) -> bytes: def decode(self, data: bytes) -> Image: from dimos.msgs.sensor_msgs.Image import Image + # Some recordings include a 1-byte format tag before the LCM envelope + # (b'J' for JPEG-encoded Image). Strip it on read so old + new sessions + # both decode cleanly. + if data and data[0:1] == b"J": + data = data[1:] return Image.lcm_jpeg_decode(data) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 8d544dff70..0123127394 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -55,6 +55,14 @@ "keyboard-teleop-piper": "dimos.robot.manipulators.piper.blueprints:keyboard_teleop_piper", "keyboard-teleop-xarm6": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm6", "keyboard-teleop-xarm7": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm7", + "learning-collect-quest-dual": "dimos.learning.collection.blueprint:learning_collect_quest_dual", + "learning-collect-quest-piper": "dimos.learning.collection.blueprint:learning_collect_quest_piper", + "learning-collect-quest-xarm6": "dimos.learning.collection.blueprint:learning_collect_quest_xarm6", + "learning-collect-quest-xarm7": "dimos.learning.collection.blueprint:learning_collect_quest_xarm7", + "learning-dataprep": "dimos.learning.dataprep_blueprint:learning_dataprep", + "learning-dataprep-whole-session": "dimos.learning.dataprep_blueprint:learning_dataprep_whole_session", + "learning-infer-chunkpolicy-only": "dimos.learning.inference.blueprint:learning_infer_chunkpolicy_only", + "learning-train": "dimos.learning.training.blueprint:learning_train", "mid360": "dimos.hardware.sensors.lidar.livox.livox_blueprints:mid360", "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", "mid360-fastlio-voxels": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels", @@ -116,8 +124,10 @@ "b1-connection-module": "dimos.robot.unitree.b1.connection.B1ConnectionModule", "camera-module": "dimos.hardware.sensors.camera.module.CameraModule", "cartesian-motion-controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller.CartesianMotionController", + "chunk-policy-module": "dimos.learning.inference.chunk_policy_module.ChunkPolicyModule", "control-coordinator": "dimos.control.coordinator.ControlCoordinator", "cost-mapper": "dimos.mapping.costmapper.CostMapper", + "data-prep-module": "dimos.learning.dataprep_module.DataPrepModule", "demo-calculator-skill": "dimos.agents.skills.demo_calculator_skill.DemoCalculatorSkill", "demo-robot": "dimos.agents.skills.demo_robot.DemoRobot", "detection2-d-module": "dimos.perception.detection.module2D.Detection2DModule", @@ -127,6 +137,7 @@ "drone-tracking-module": "dimos.robot.drone.drone_tracking_module.DroneTrackingModule", "embedding-memory": "dimos.memory.embedding.EmbeddingMemory", "emitter-module": "dimos.utils.demo_image_encoding.EmitterModule", + "episode-monitor-module": "dimos.learning.collection.episode_monitor.EpisodeMonitorModule", "fast-lio2": "dimos.hardware.sensors.lidar.fastlio2.module.FastLio2", "foxglove-bridge": "dimos.robot.foxglove_bridge.FoxgloveBridge", "g1-connection": "dimos.robot.unitree.g1.connection.G1Connection", @@ -182,6 +193,7 @@ "spatial-memory": "dimos.perception.spatial_perception.SpatialMemory", "speak-skill": "dimos.agents.skills.speak_skill.SpeakSkill", "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory.TemporalMemory", + "trainer-module": "dimos.learning.training.trainer_module.TrainerModule", "twist-teleop-module": "dimos.teleop.quest.quest_extensions.TwistTeleopModule", "unitree-g1-skill-container": "dimos.robot.unitree.g1.skill_container.UnitreeG1SkillContainer", "unitree-skill-container": "dimos.robot.unitree.unitree_skill_container.UnitreeSkillContainer", From 353d0b427b8615b56dd3352385a7d5c2059c8164 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Fri, 29 May 2026 15:29:26 -0700 Subject: [PATCH 06/45] feat: xarm7 inference --- dimos/learning/inference/blueprint.py | 48 ++++++++++++++++++++++++--- dimos/robot/all_blueprints.py | 1 + 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/dimos/learning/inference/blueprint.py b/dimos/learning/inference/blueprint.py index 1b921e4fc9..8b2f237d93 100644 --- a/dimos/learning/inference/blueprint.py +++ b/dimos/learning/inference/blueprint.py @@ -16,21 +16,32 @@ `ChunkPolicyModule` publishes `joint_command` directly so a coordinator's servo / position task can consume it without an `ActionReplayer` task in -the tick loop. Compose with the user's coordinator blueprint at the call -site, e.g.:: +the tick loop. Two variants are provided: - autoconnect(learning_infer_chunkpolicy_only, my_servo_coordinator) +* ``learning_infer_chunkpolicy_only`` — policy + camera only. Compose at + the call site with your own coordinator:: + + autoconnect(learning_infer_chunkpolicy_only, my_servo_coordinator) + +* ``learning_infer_xarm7`` — sample fully wired blueprint: policy + + camera + XArm7 ControlCoordinator running a servo task. The + coordinator publishes ``joint_state`` and consumes ``joint_command`` + on the same LCM topics the policy uses, so ``autoconnect`` closes the + loop with no extra glue. Use as a template for other arms. """ from __future__ import annotations +from dimos.control.coordinator import ControlCoordinator from dimos.core.coordination.blueprints import autoconnect +from dimos.core.global_config import global_config from dimos.core.transport import LCMTransport from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.learning.inference.chunk_policy_module import ChunkPolicyModule from dimos.learning.policy.base import ActionChunk from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.robot.catalog.ufactory import xarm7 as _catalog_xarm7 # Stable topics so external tools (lcmspy, dimos topic echo) work without rebuild. _T_COLOR_IMAGE = "/camera/color_image" @@ -55,4 +66,33 @@ ).transports(_INFER_TRANSPORTS) -__all__ = ["learning_infer_chunkpolicy_only"] +# Sample end-to-end inference blueprint: policy → coordinator → hardware. +# Mirror the arm config used in collection (xarm7 + gripper) so trained +# joint_names line up with what the servo task claims. +_xarm7_infer_cfg = _catalog_xarm7( + name="arm", + adapter_type="xarm", + address=global_config.xarm7_ip, + add_gripper=True, +) + +learning_infer_xarm7 = autoconnect( + RealSenseCamera.blueprint(enable_pointcloud=False), + ChunkPolicyModule.blueprint( + policy_path="data/runs/act_pickplace_001", + inference_rate_hz=30.0, + publish_joint_command=True, + ), + ControlCoordinator.blueprint( + hardware=[_xarm7_infer_cfg.to_hardware_component()], + tasks=[ + _xarm7_infer_cfg.to_task_config(task_type="servo", task_name="servo_arm"), + ], + ), +).transports(_INFER_TRANSPORTS) + + +__all__ = [ + "learning_infer_chunkpolicy_only", + "learning_infer_xarm7", +] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 0b6d6bd533..0b449c45f7 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -71,6 +71,7 @@ "learning-dataprep": "dimos.learning.dataprep_blueprint:learning_dataprep", "learning-dataprep-whole-session": "dimos.learning.dataprep_blueprint:learning_dataprep_whole_session", "learning-infer-chunkpolicy-only": "dimos.learning.inference.blueprint:learning_infer_chunkpolicy_only", + "learning-infer-xarm7": "dimos.learning.inference.blueprint:learning_infer_xarm7", "learning-train": "dimos.learning.training.blueprint:learning_train", "mid360": "dimos.hardware.sensors.lidar.livox.livox_blueprints:mid360", "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", From fd0c05a0929b619c88c6c6c6f5c24c6485a2b247 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 4 Jun 2026 11:33:16 -0700 Subject: [PATCH 07/45] remove training and inference codes --- dimos/control/tasks/action_replayer_task.py | 73 ------ dimos/learning/dataprep.py | 6 +- dimos/learning/dataprep_module.py | 4 +- dimos/learning/formats/rlds.py | 203 -------------- dimos/learning/inference/blueprint.py | 98 ------- .../learning/inference/chunk_policy_module.py | 236 ----------------- dimos/learning/policy/base.py | 60 ----- dimos/learning/policy/lerobot_policy.py | 248 ------------------ dimos/learning/specs/inference.md | 183 ------------- dimos/learning/specs/spec_v2.md | 20 +- dimos/learning/specs/structure.md | 74 +----- dimos/learning/specs/training.md | 78 ------ dimos/learning/training/blueprint.py | 43 --- dimos/learning/training/configs.py | 54 ---- dimos/learning/training/train.py | 235 ----------------- dimos/learning/training/trainer_module.py | 207 --------------- dimos/robot/all_blueprints.py | 5 - 17 files changed, 26 insertions(+), 1801 deletions(-) delete mode 100644 dimos/control/tasks/action_replayer_task.py delete mode 100644 dimos/learning/formats/rlds.py delete mode 100644 dimos/learning/inference/blueprint.py delete mode 100644 dimos/learning/inference/chunk_policy_module.py delete mode 100644 dimos/learning/policy/base.py delete mode 100644 dimos/learning/policy/lerobot_policy.py delete mode 100644 dimos/learning/specs/inference.md delete mode 100644 dimos/learning/specs/training.md delete mode 100644 dimos/learning/training/blueprint.py delete mode 100644 dimos/learning/training/configs.py delete mode 100644 dimos/learning/training/train.py delete mode 100644 dimos/learning/training/trainer_module.py diff --git a/dimos/control/tasks/action_replayer_task.py b/dimos/control/tasks/action_replayer_task.py deleted file mode 100644 index a56c5fb267..0000000000 --- a/dimos/control/tasks/action_replayer_task.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Replay policy ActionChunks at coordinator tick rate (100 Hz). - -Slow producer (ChunkPolicyModule @ ~30 Hz, jittery) → buffer → -deterministic 100 Hz JointCommandOutput. Stale-chunk and policy-stall -handling keep hardware safe when the policy falls behind or dies. -""" - -from __future__ import annotations - -from dataclasses import dataclass - -from dimos.control.task import ( - BaseControlTask, - CoordinatorState, - JointCommandOutput, - ResourceClaim, -) -from dimos.learning.policy.base import ActionChunk - - -@dataclass -class ActionReplayerConfig: - joint_names: list[str] - priority: int = 10 - max_chunk_age_s: float = 0.5 # drop chunks older than this at receive time - hold_on_stall: bool = True # hold last position if buffer drains - temporal_ensemble: bool = False # ACT trick; off in v1 - - -class ActionReplayer(BaseControlTask): - """Buffer latest chunk; interpolate per tick; emit JointCommandOutput.""" - - def __init__(self, name: str, config: ActionReplayerConfig) -> None: - raise NotImplementedError - - @property - def name(self) -> str: - raise NotImplementedError - - def claim(self) -> ResourceClaim: - """Claim `config.joint_names` at `config.priority`.""" - raise NotImplementedError - - def is_active(self) -> bool: - """True iff buffer has a non-stale target for now (or hold_on_stall).""" - raise NotImplementedError - - def compute(self, state: CoordinatorState) -> JointCommandOutput | None: - """Pure lookup / interpolate over the buffer at `state.now`. - Must complete in << 10 ms.""" - raise NotImplementedError - - def on_action_chunk(self, msg: ActionChunk) -> None: - """Latest-wins push. Drop if msg too old; drop buffered entries - at/after msg's first target_ts; append new (target_ts, positions).""" - raise NotImplementedError - - def _interpolate(self, t: float) -> JointCommandOutput | None: - raise NotImplementedError diff --git a/dimos/learning/dataprep.py b/dimos/learning/dataprep.py index 246eafaaaf..cc15665334 100644 --- a/dimos/learning/dataprep.py +++ b/dimos/learning/dataprep.py @@ -65,7 +65,7 @@ class SyncConfig(BaseConfig): class OutputConfig(BaseConfig): - format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" + format: Literal["lerobot", "hdf5"] = "lerobot" path: Path metadata: dict[str, Any] = {} @@ -98,7 +98,7 @@ class Sample(BaseModel): # ───────────────────────────────────────────────────────────────────────────── -# Pure helpers — used by ChunkPolicyModule, format writers, DataPrepModule +# Pure helpers — used by format writers, DataPrepModule # ───────────────────────────────────────────────────────────────────────────── @@ -348,8 +348,6 @@ def get_writer(format_name: str) -> Writer: from dimos.learning.formats.lerobot import write elif format_name == "hdf5": from dimos.learning.formats.hdf5 import write - elif format_name == "rlds": - from dimos.learning.formats.rlds import write else: raise ValueError(f"Unknown format: {format_name!r}") return write diff --git a/dimos/learning/dataprep_module.py b/dimos/learning/dataprep_module.py index d6e80ba4e3..4d7caeb5fa 100644 --- a/dimos/learning/dataprep_module.py +++ b/dimos/learning/dataprep_module.py @@ -246,8 +246,8 @@ def _all_samples() -> Iterator[Sample]: logger.error("[dataprep] FAILED: %s", err) def _write_dimos_meta(self, dataset_path: Path, episodes: list[Any]) -> None: - """Sidecar describing how this dataset was built. ChunkPolicyModule - reads it at inference time to recover the obs/action schema.""" + """Sidecar describing how this dataset was built, recording the + obs/action schema alongside the dataset.""" meta = { "source": self.config.source, "observation": {k: v.model_dump() for k, v in self.config.observation.items()}, diff --git a/dimos/learning/formats/rlds.py b/dimos/learning/formats/rlds.py deleted file mode 100644 index 8003b7c938..0000000000 --- a/dimos/learning/formats/rlds.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""RLDS / TFDS dataset writer. - -Emits a single TFRecord shard following the RLDS Episode/Step protocol. -Each example is a ``tf.train.SequenceExample`` whose feature_lists hold -per-step observation / action / reward / discount / is_first / is_last / -is_terminal arrays, plus a context for episode metadata. - -Layout:: - - / - features.json schema (feature names + shapes + dtypes) - dataset_info.json episode count, step count, fps, robot, stats - rlds-00000-of-00001.tfrecord - -This is RLDS-shaped on the wire; loading it as a `tfds.builder` requires -matching the schema in a TFDS DatasetBuilder. See the RLDS docs for that -glue. For OpenX-Embodiment contributions, point your TFDS builder at -this directory. -""" - -from __future__ import annotations - -import json -from collections.abc import Iterator -from pathlib import Path -from typing import Any - -import numpy as np - -from dimos.learning.dataprep import OutputConfig, Sample -from dimos.learning.formats._stats import StreamingStats - - -def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Drain `samples` into RLDS-style TFRecord shards. Returns the dataset dir.""" - try: - import tensorflow as tf # type: ignore[import-not-found] - except ImportError as e: - raise RuntimeError( - "RLDS writer requires tensorflow — install with " - "`pip install tensorflow_datasets` (pulls in tf)" - ) from e - - root = Path(output.path) - root.mkdir(parents=True, exist_ok=True) - - stats = StreamingStats( - image_subsample=int(output.metadata.get("image_subsample", 10)), - quantile_reservoir=int(output.metadata.get("quantile_reservoir", 10_000)), - seed=int(output.metadata.get("stats_seed", 0)), - ) - - default_task_label = output.metadata.get("default_task_label", "task") - fps = float(output.metadata.get("fps", 30.0)) - - feature_shapes: dict[str, tuple[int, ...]] = {} - feature_dtypes: dict[str, str] = {} - tasks_index: dict[str, int] = {} - - # Per-episode buffers. - cur_id: str | None = None - cur_idx = -1 - cur_start_ts: float | None = None - buf_ts: list[float] = [] - buf_obs: dict[str, list[np.ndarray]] = {} - buf_act: dict[str, list[np.ndarray]] = {} - - total_frames = 0 - episodes_meta: list[dict[str, Any]] = [] - shard_path = root / "rlds-00000-of-00001.tfrecord" - - def _bytes(arr: np.ndarray) -> "tf.train.Feature": - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[arr.tobytes()])) - - def _flush(writer: "tf.io.TFRecordWriter") -> None: - nonlocal cur_idx, cur_start_ts - if cur_idx < 0 or not buf_ts: - return - - T = len(buf_ts) - feature_lists: dict[str, "tf.train.FeatureList"] = {} - - def _make_list(arrs: list[np.ndarray]) -> "tf.train.FeatureList": - return tf.train.FeatureList(feature=[_bytes(a) for a in arrs]) - - for k, frames in buf_obs.items(): - feature_lists[f"observation/{k}"] = _make_list(frames) - for k, frames in buf_act.items(): - feature_lists[f"action/{k}"] = _make_list(frames) - - ts_arr = np.asarray(buf_ts, dtype=np.float32) - feature_lists["timestamp"] = tf.train.FeatureList( - feature=[tf.train.Feature(float_list=tf.train.FloatList(value=[t])) for t in ts_arr] - ) - # RLDS step booleans. - is_first = [i == 0 for i in range(T)] - is_last = [i == T - 1 for i in range(T)] - for name, vals in (("is_first", is_first), ("is_last", is_last), ("is_terminal", is_last)): - feature_lists[name] = tf.train.FeatureList(feature=[ - tf.train.Feature(int64_list=tf.train.Int64List(value=[int(v)])) for v in vals - ]) - # Default reward / discount per RLDS convention. - feature_lists["reward"] = tf.train.FeatureList(feature=[ - tf.train.Feature(float_list=tf.train.FloatList(value=[0.0])) for _ in range(T) - ]) - feature_lists["discount"] = tf.train.FeatureList(feature=[ - tf.train.Feature(float_list=tf.train.FloatList(value=[1.0])) for _ in range(T) - ]) - - ctx = tf.train.Features(feature={ - "episode_index": tf.train.Feature(int64_list=tf.train.Int64List(value=[cur_idx])), - "length": tf.train.Feature(int64_list=tf.train.Int64List(value=[T])), - "start_ts": tf.train.Feature(float_list=tf.train.FloatList(value=[float(cur_start_ts or 0.0)])), - "task": tf.train.Feature(bytes_list=tf.train.BytesList(value=[default_task_label.encode()])), - }) - ex = tf.train.SequenceExample( - context=ctx, - feature_lists=tf.train.FeatureLists(feature_list=feature_lists), - ) - writer.write(ex.SerializeToString()) - - episodes_meta.append({ - "episode_index": cur_idx, - "length": T, - "start_ts": float(cur_start_ts or 0.0), - "task": default_task_label, - }) - buf_ts.clear() - buf_obs.clear() - buf_act.clear() - - with tf.io.TFRecordWriter(str(shard_path)) as writer: - for sample in samples: - if sample.episode_id != cur_id: - _flush(writer) - cur_id = sample.episode_id - cur_idx += 1 - cur_start_ts = float(sample.ts) - if default_task_label not in tasks_index: - tasks_index[default_task_label] = len(tasks_index) - - buf_ts.append(float(sample.ts) - (cur_start_ts or 0.0)) - for k, v in sample.observation.items(): - a = np.asarray(v) - buf_obs.setdefault(k, []).append(a) - stats.update(f"observation.{k}", a) - if k not in feature_shapes: - feature_shapes[f"observation/{k}"] = tuple(a.shape) - feature_dtypes[f"observation/{k}"] = str(a.dtype) - for k, v in sample.action.items(): - a = np.asarray(v) - buf_act.setdefault(k, []).append(a) - stats.update(f"action.{k}", a) - if k not in feature_shapes: - feature_shapes[f"action/{k}"] = tuple(a.shape) - feature_dtypes[f"action/{k}"] = str(a.dtype) - total_frames += 1 - - _flush(writer) - - # ── sidecar metadata ───────────────────────────────────────────────────── - features_meta = { - name: {"shape": list(shape), "dtype": feature_dtypes[name]} - for name, shape in feature_shapes.items() - } - features_meta["timestamp"] = {"shape": [], "dtype": "float32"} - features_meta["is_first"] = {"shape": [], "dtype": "int64"} - features_meta["is_last"] = {"shape": [], "dtype": "int64"} - features_meta["is_terminal"] = {"shape": [], "dtype": "int64"} - features_meta["reward"] = {"shape": [], "dtype": "float32"} - features_meta["discount"] = {"shape": [], "dtype": "float32"} - - info = { - "format_version": "rlds-1.0", - "robot": output.metadata.get("robot", "unknown"), - "fps": fps, - "num_episodes": len(episodes_meta), - "num_steps": total_frames, - "num_tasks": len(tasks_index), - "tasks": {idx: task for task, idx in tasks_index.items()}, - "episodes": episodes_meta, - "stats": stats.finalize(), - } - with open(root / "features.json", "w") as f: - json.dump(features_meta, f, indent=2) - with open(root / "dataset_info.json", "w") as f: - json.dump(info, f, indent=2) - - return root diff --git a/dimos/learning/inference/blueprint.py b/dimos/learning/inference/blueprint.py deleted file mode 100644 index 8b2f237d93..0000000000 --- a/dimos/learning/inference/blueprint.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ACT inference blueprints. - -`ChunkPolicyModule` publishes `joint_command` directly so a coordinator's -servo / position task can consume it without an `ActionReplayer` task in -the tick loop. Two variants are provided: - -* ``learning_infer_chunkpolicy_only`` — policy + camera only. Compose at - the call site with your own coordinator:: - - autoconnect(learning_infer_chunkpolicy_only, my_servo_coordinator) - -* ``learning_infer_xarm7`` — sample fully wired blueprint: policy + - camera + XArm7 ControlCoordinator running a servo task. The - coordinator publishes ``joint_state`` and consumes ``joint_command`` - on the same LCM topics the policy uses, so ``autoconnect`` closes the - loop with no extra glue. Use as a template for other arms. -""" - -from __future__ import annotations - -from dimos.control.coordinator import ControlCoordinator -from dimos.core.coordination.blueprints import autoconnect -from dimos.core.global_config import global_config -from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera -from dimos.learning.inference.chunk_policy_module import ChunkPolicyModule -from dimos.learning.policy.base import ActionChunk -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.robot.catalog.ufactory import xarm7 as _catalog_xarm7 - -# Stable topics so external tools (lcmspy, dimos topic echo) work without rebuild. -_T_COLOR_IMAGE = "/camera/color_image" -_T_JOINT_STATE = "/coordinator/joint_state" -_T_ACTION_CHUNK = "/learning/action_chunk" -_T_JOINT_COMMAND = "/teleop/joint_command" # matches coordinator_servo_* default - -_INFER_TRANSPORTS = { - ("color_image", Image): LCMTransport(_T_COLOR_IMAGE, Image), - ("joint_state", JointState): LCMTransport(_T_JOINT_STATE, JointState), - ("action_chunk", ActionChunk): LCMTransport(_T_ACTION_CHUNK, ActionChunk), - ("joint_command", JointState): LCMTransport(_T_JOINT_COMMAND, JointState), -} - - -learning_infer_chunkpolicy_only = autoconnect( - RealSenseCamera.blueprint(enable_pointcloud=False), - ChunkPolicyModule.blueprint( - policy_path="data/runs/act_pickplace_001", - inference_rate_hz=30.0, - ), -).transports(_INFER_TRANSPORTS) - - -# Sample end-to-end inference blueprint: policy → coordinator → hardware. -# Mirror the arm config used in collection (xarm7 + gripper) so trained -# joint_names line up with what the servo task claims. -_xarm7_infer_cfg = _catalog_xarm7( - name="arm", - adapter_type="xarm", - address=global_config.xarm7_ip, - add_gripper=True, -) - -learning_infer_xarm7 = autoconnect( - RealSenseCamera.blueprint(enable_pointcloud=False), - ChunkPolicyModule.blueprint( - policy_path="data/runs/act_pickplace_001", - inference_rate_hz=30.0, - publish_joint_command=True, - ), - ControlCoordinator.blueprint( - hardware=[_xarm7_infer_cfg.to_hardware_component()], - tasks=[ - _xarm7_infer_cfg.to_task_config(task_type="servo", task_name="servo_arm"), - ], - ), -).transports(_INFER_TRANSPORTS) - - -__all__ = [ - "learning_infer_chunkpolicy_only", - "learning_infer_xarm7", -] diff --git a/dimos/learning/inference/chunk_policy_module.py b/dimos/learning/inference/chunk_policy_module.py deleted file mode 100644 index f27e736345..0000000000 --- a/dimos/learning/inference/chunk_policy_module.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ACT inference Module — runs the policy in a background thread. - -Reads ``/dimos_meta.json`` at start() to recover the obs -schema (StreamField map + sync). Latches In ports; calls -``policy.predict_chunk`` on the freshest snapshot every tick; publishes -each result as an ActionChunk over LCM. - -When ``publish_joint_command=True`` (default), also emits the first -action of each chunk as a JointState on ``joint_command``. This lets a -coordinator's servo task consume the policy directly without an -``ActionReplayer`` in the 100 Hz tick loop. Once ActionReplayer is wired -in, callers can ignore that port. - -Heavy ML deps (``lerobot``, ``torch``) are imported lazily inside -``start()`` so this file is import-light. -""" - -from __future__ import annotations - -import json -import threading -import time -import traceback -from pathlib import Path -from typing import Any - -import numpy as np - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In, Out -from dimos.learning.dataprep import StreamField, SyncConfig, resolve_field -from dimos.learning.policy.base import ActionChunk, Policy -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.JointState import JointState - -DIMOS_META_FILENAME = "dimos_meta.json" - - -class ChunkPolicyModuleConfig(ModuleConfig): - policy_path: str - inference_rate_hz: float = 30.0 - device: str = "cuda" - # Emit chunk[0] as a JointState on `joint_command` each tick. Useful when - # the downstream coordinator has a servo task but no ActionReplayer task. - publish_joint_command: bool = True - - -class ChunkPolicyModule(Module): - config: ChunkPolicyModuleConfig - - color_image: In[Image] - joint_state: In[JointState] - action_chunk: Out[ActionChunk] - joint_command: Out[JointState] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._latest_image: Image | None = None - self._latest_joint_state: JointState | None = None - self._latch_lock = threading.Lock() - - # Filled in start(): - self._policy: Policy | None = None - self._observation: dict[str, StreamField] = {} - self._sync: SyncConfig | None = None - self._chunk_id: int = 0 - self._last_chunk_ts: float | None = None - - self._thread: threading.Thread | None = None - self._stop = threading.Event() - - @rpc - def start(self) -> None: - """Load checkpoint, read dimos_meta, subscribe ports, spawn loop thread.""" - super().start() - - meta_path = Path(self.config.policy_path) / DIMOS_META_FILENAME - if not meta_path.exists(): - raise FileNotFoundError(f"Missing {DIMOS_META_FILENAME} in {self.config.policy_path}") - with open(meta_path) as f: - meta = json.load(f) - - self._observation = {k: StreamField(**v) for k, v in (meta.get("observation") or {}).items()} - sync_cfg = meta.get("sync") or {} - if sync_cfg: - self._sync = SyncConfig(**sync_cfg) - - from dimos.learning.policy.lerobot_policy import LeRobotPolicy - self._policy = LeRobotPolicy.load(self.config.policy_path, device=self.config.device) - - self.color_image.subscribe(self._on_color_image) - self.joint_state.subscribe(self._on_joint_state) - - self._stop.clear() - self._thread = threading.Thread(target=self._run_loop, daemon=True) - self._thread.start() - - @rpc - def stop(self) -> None: - self._stop.set() - t = self._thread - if t is not None and t.is_alive(): - t.join(timeout=2.0) - self._thread = None - super().stop() - - @rpc - def reload_policy(self, policy_path: str, device: str | None = None) -> None: - """Hot-swap the checkpoint without restarting the blueprint.""" - self.stop() - self.config.policy_path = policy_path - if device is not None: - self.config.device = device - self.start() - - @rpc - def get_status(self) -> dict[str, Any]: - return { - "running": self._thread is not None and self._thread.is_alive(), - "chunk_count": self._chunk_id, - "policy_path": self.config.policy_path, - "last_chunk_ts": self._last_chunk_ts, - } - - # ── port handlers ──────────────────────────────────────────────────────── - - def _on_color_image(self, msg: Image) -> None: - with self._latch_lock: - self._latest_image = msg - - def _on_joint_state(self, msg: JointState) -> None: - with self._latch_lock: - self._latest_joint_state = msg - - # ── loop ───────────────────────────────────────────────────────────────── - - def _run_loop(self) -> None: - assert self._policy is not None - period = 1.0 / self.config.inference_rate_hz - while not self._stop.is_set(): - t0 = time.monotonic() - try: - obs = self._build_live_obs() - if obs is None: - self._stop.wait(timeout=period) - continue - - positions = self._policy.predict_chunk(obs) # (T, action_dim) - chunk_ts = time.time() - self.action_chunk.publish(ActionChunk( - ts=chunk_ts, - joint_names=self._policy.joint_names, - positions=positions, - dt=period, - chunk_id=self._next_chunk_id(), - )) - self._last_chunk_ts = chunk_ts - - if self.config.publish_joint_command: - js = JointState( - name=self._policy.joint_names, - position=[float(x) for x in positions[0]], - velocity=[], - ) - js.ts = chunk_ts - self.joint_command.publish(js) - except Exception: - # Single bad tick must not kill the loop. - traceback.print_exc() - - elapsed = time.monotonic() - t0 - if elapsed < period: - self._stop.wait(timeout=period - elapsed) - - def _build_live_obs(self) -> dict[str, np.ndarray] | None: - """Snapshot latched messages and project each obs key through `resolve_field`. - Returns None if any required stream hasn't received a message yet. - """ - with self._latch_lock: - latest_image = self._latest_image - latest_joints = self._latest_joint_state - - if not self._observation: - # No spec — fall back to canonical port names. - if latest_image is None or latest_joints is None: - return None - return { - "image": np.asarray(latest_image.data), - "joint_state": np.asarray(latest_joints.position), - } - - out: dict[str, np.ndarray] = {} - for obs_key, sf in self._observation.items(): - port = self._guess_port(sf.stream) - if port == "color_image": - if latest_image is None: - return None - out[obs_key] = resolve_field(latest_image, sf) - elif port == "joint_state": - if latest_joints is None: - return None - out[obs_key] = resolve_field(latest_joints, sf) - else: - # Extend here when adding In ports for new sensor types. - return None - return out - - @staticmethod - def _guess_port(stream_name: str) -> str: - """Route a recorded stream name to one of this module's In ports.""" - n = stream_name.lower() - if "image" in n or "camera" in n or "rgb" in n: - return "color_image" - if "joint_state" in n: - return "joint_state" - return n - - def _next_chunk_id(self) -> int: - cid = self._chunk_id - self._chunk_id += 1 - return cid diff --git a/dimos/learning/policy/base.py b/dimos/learning/policy/base.py deleted file mode 100644 index dc31a4fa10..0000000000 --- a/dimos/learning/policy/base.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ActionChunk message + Policy backend Protocol. - -`ChunkPolicyModule` produces ActionChunks; `ActionReplayer` consumes them. -Any policy backend (lerobot in v1) just needs to satisfy `Policy`. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Protocol, runtime_checkable - -import numpy as np -from pydantic import BaseModel, ConfigDict - - -class ActionChunk(BaseModel): - """T future joint targets + the metadata to replay them. - - positions: shape (T, N), N = len(joint_names). - Replayer uses ts + i*dt as the target time for positions[i]. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True) - - ts: float - joint_names: list[str] - positions: np.ndarray # (T, N) - dt: float - chunk_id: int - - -@runtime_checkable -class Policy(Protocol): - """What ChunkPolicyModule needs from any policy implementation.""" - - @classmethod - def load(cls, path: str | Path, device: str = "cuda") -> Policy: ... - - @property - def chunk_size(self) -> int: ... - @property - def joint_names(self) -> list[str]: ... - - def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: - """Return shape (chunk_size, action_dim), unnormalized to joint space.""" - ... diff --git a/dimos/learning/policy/lerobot_policy.py b/dimos/learning/policy/lerobot_policy.py deleted file mode 100644 index a14675d7c6..0000000000 --- a/dimos/learning/policy/lerobot_policy.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""LeRobot ACT policy wrapper. Lazy-imports lerobot/torch in load().""" - -from __future__ import annotations - -import json -from pathlib import Path -from typing import Any - -import numpy as np - -from dimos.learning.policy.base import Policy - -DIMOS_META_FILENAME = "dimos_meta.json" - - -class LeRobotPolicy: - _model: Any # lerobot ACTPolicy / PreTrainedPolicy - _stats: dict[str, Any] - _dimos_meta: dict[str, Any] - _chunk_size: int - _joint_names: list[str] - _device: str - - def __init__( - self, - model: Any, - stats: dict[str, Any], - dimos_meta: dict[str, Any], - chunk_size: int, - joint_names: list[str], - device: str, - ) -> None: - self._model = model - self._stats = stats - self._dimos_meta = dimos_meta - self._chunk_size = chunk_size - self._joint_names = joint_names - self._device = device - - @classmethod - def load(cls, path: str | Path, device: str = "cuda") -> LeRobotPolicy: - """Load checkpoint + dataset stats + dimos_meta sidecar. - - ``path`` may be a TrainerModule output dir (we walk into - ``checkpoints/last/pretrained_model/``) or an exact - ``pretrained_model`` dir. - """ - path = Path(path) - pretrained_dir, run_dir = _resolve_checkpoint_dirs(path) - - meta_path = run_dir / DIMOS_META_FILENAME - if not meta_path.exists(): - raise FileNotFoundError(f"Missing {DIMOS_META_FILENAME} in {run_dir}") - with open(meta_path) as f: - dimos_meta = json.load(f) - - stats_path = _find_stats(run_dir, dimos_meta) - with open(stats_path) as f: - stats = json.load(f) - - # Lazy import — keeps torch/CUDA out of the dimos runtime at module load. - try: - from lerobot.policies.act.modeling_act import ACTPolicy - except ImportError: - try: - from lerobot.common.policies.act.modeling_act import ACTPolicy - except ImportError as e: - raise RuntimeError( - "lerobot is required to load a checkpoint; install with " - "`pip install lerobot` (>=0.3)" - ) from e - - model = ACTPolicy.from_pretrained(str(pretrained_dir)) - model.eval() - model.to(device) - - chunk_size = int(dimos_meta.get("chunk_size", 50)) - joint_names = dimos_meta.get("joint_names") or _infer_joint_names(model) - - return cls( - model=model, - stats=stats, - dimos_meta=dimos_meta, - chunk_size=chunk_size, - joint_names=joint_names, - device=device, - ) - - @property - def chunk_size(self) -> int: - return self._chunk_size - - @property - def joint_names(self) -> list[str]: - return self._joint_names - - def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: - """Forward pass on ``obs``. Returns shape ``(chunk_size, action_dim)``. - - ``obs`` keys are the dataset spec's observation keys (e.g. "image", - "joint_state"). They are translated to lerobot's canonical names - (``observation.images.*`` / ``observation.state``) using the - dimos_meta sidecar so train and infer agree. - """ - import torch # lazy - - batch = self._build_batch(obs) - with torch.inference_mode(): - chunk = self._forward_chunk(batch) - if chunk.ndim == 3: - chunk = chunk[0] # (B, T, A) → (T, A) - return chunk.detach().cpu().numpy() - - # ── internals ──────────────────────────────────────────────────────────── - - def _build_batch(self, obs: dict[str, np.ndarray]) -> dict[str, Any]: - import torch - - batch: dict[str, Any] = {} - observation_map = self._dimos_meta.get("observation", {}) - for user_key, value in obs.items(): - arr = np.asarray(value) - if arr.ndim >= 3: - # HWC uint8 → 1xCxHxW float32 / 255 (lerobot's expected layout). - chw = np.transpose(arr, (2, 0, 1)) if arr.shape[-1] in (1, 3, 4) else arr - t = torch.from_numpy(chw.astype(np.float32) / 255.0).unsqueeze(0) - feat = f"observation.images.{user_key}" - else: - t = torch.from_numpy(arr.astype(np.float32)).unsqueeze(0) - # Single low-dim observation is canonical "observation.state". - low_dim_other = any( - k != user_key and k in obs and np.asarray(obs[k]).ndim < 3 - for k in observation_map - ) - feat = f"observation.{user_key}" if low_dim_other else "observation.state" - batch[feat] = t.to(self._device) - return batch - - def _forward_chunk(self, batch: dict[str, Any]) -> Any: - """Prefer ``predict_action_chunk`` (newer API); fall back to repeated - ``select_action`` after ``reset()`` to assemble a chunk.""" - if hasattr(self._model, "predict_action_chunk"): - return self._model.predict_action_chunk(batch) - if hasattr(self._model, "select_action"): - import torch - - self._model.reset() - actions = [self._model.select_action(batch) for _ in range(self._chunk_size)] - return torch.stack(actions, dim=1) # (B, T, A) - raise RuntimeError("lerobot policy has neither predict_action_chunk nor select_action") - - -def _resolve_checkpoint_dirs(path: Path) -> tuple[Path, Path]: - """Return ``(pretrained_model_dir, run_dir)`` for any supported input path. - - Run dir layout (lerobot 0.3+):: - - / - dimos_meta.json - checkpoints//pretrained_model/ # lerobot safetensors - checkpoints/last -> symlink to latest - """ - if (path / "model.safetensors").exists(): - # `…/checkpoints//pretrained_model` → run_dir is 3 parents up. - return path, path.parent.parent.parent - - last = path / "checkpoints" / "last" / "pretrained_model" - if last.exists(): - return last, path - - ckpts = path / "checkpoints" - if ckpts.is_dir(): - numeric = sorted( - (p for p in ckpts.iterdir() if p.is_dir() and p.name.isdigit()), - key=lambda p: int(p.name), - ) - if numeric and (numeric[-1] / "pretrained_model").exists(): - return numeric[-1] / "pretrained_model", path - - raise FileNotFoundError( - f"No lerobot checkpoint found under {path}. " - f"Expected {path}/checkpoints/last/pretrained_model/ " - f"or a numeric checkpoint subdir." - ) - - -def _find_stats(run_dir: Path, dimos_meta: dict[str, Any]) -> Path: - """Locate ``stats.json`` near a checkpoint. - - Lookup order: - 1. ``/meta/stats.json`` - 2. dimos_meta's recorded ``dataset_path`` / ``source`` - 3. ``/../datasets//meta/stats.json`` (sibling convention) - """ - candidates: list[Path] = [ - run_dir / "meta" / "stats.json", - run_dir / "stats.json", - ] - metadata = dimos_meta.get("metadata") or {} - for key in ("dataset_path", "source"): - v = metadata.get(key) or dimos_meta.get(key) - if v and Path(v).suffix not in (".db", ".sqlite"): - candidates.append(Path(v) / "meta" / "stats.json") - - for parent in (run_dir.parent, run_dir.parent.parent): - if parent and (parent / "datasets").is_dir(): - for d in (parent / "datasets").iterdir(): - if (d / "meta" / "stats.json").is_file(): - candidates.append(d / "meta" / "stats.json") - - for c in candidates: - if c.exists(): - return c - raise FileNotFoundError(f"stats.json not found near {run_dir}; tried: {candidates}") - - -def _infer_joint_names(model: Any) -> list[str]: - """Synthetic joint-name fallback when dimos_meta didn't record any.""" - cfg = getattr(model, "config", None) - action_dim: int | None = None - if cfg is not None: - out_shapes = getattr(cfg, "output_shapes", None) or {} - if "action" in out_shapes: - action_dim = out_shapes["action"][-1] - if action_dim is None: - af = getattr(cfg, "action_feature", None) - action_dim = getattr(af, "shape", [None])[-1] if af is not None else None - if action_dim is None: - action_dim = 7 - return [f"joint{i}" for i in range(action_dim)] - - -# Protocol conformance assertion at import time. -_: type[Policy] = LeRobotPolicy diff --git a/dimos/learning/specs/inference.md b/dimos/learning/specs/inference.md deleted file mode 100644 index 88e60569f1..0000000000 --- a/dimos/learning/specs/inference.md +++ /dev/null @@ -1,183 +0,0 @@ -# Stage 3 — Inference - -ACT only. Two pieces: - -- **`ChunkPolicyModule`** (`learning/inference/`) — Module @ ~30 Hz. - Builds obs, calls `policy.predict_chunk(obs)`, emits `ActionChunk`. -- **`ActionReplayer`** (`control/tasks/`) — `BaseControlTask` in the - 100 Hz `ControlCoordinator` tick loop. Buffers chunks, interpolates - to `state.now`, emits `JointCommandOutput`. - -``` -ChunkPolicyModule (~15-30 Hz) - │ ActionChunk (LCM) - ▼ -ControlCoordinator @ 100 Hz - └─ ActionReplayer.compute(state) → JointCommandOutput → hardware -``` - ---- - -## Blueprint - -```python -# dimos/learning/inference/blueprint.py -from dimos.learning.policy.base import ActionChunk - -learning_infer_xarm7 = autoconnect( - RealSenseCamera.blueprint(enable_pointcloud=False), - ChunkPolicyModule.blueprint( - policy_path="data/runs/act_pick_red", - inference_rate_hz=30.0, - ), - coordinator_action_replayer_xarm7, # registers ActionReplayer with the coordinator -).transports({ - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("action_chunk", ActionChunk): LCMTransport("/learning/action_chunk", ActionChunk), -}) -``` - -## Message types - -```python -# dimos/learning/policy/base.py - -class ActionChunk(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - ts: float - joint_names: list[str] - positions: np.ndarray # (T, N) - dt: float - chunk_id: int - - -@runtime_checkable -class Policy(Protocol): - @classmethod - def load(cls, path: str | Path, device: str = "cuda") -> Policy: ... - @property - def chunk_size(self) -> int: ... - @property - def joint_names(self) -> list[str]: ... - def predict_chunk(self, obs: dict[str, np.ndarray]) -> np.ndarray: ... -``` - -`LeRobotPolicy` (`policy/lerobot_policy.py`) is the v1 implementation. - -## ChunkPolicyModule - -```python -# dimos/learning/inference/chunk_policy_module.py -from dimos.learning.dataprep import StreamField, SyncConfig, resolve_field - - -class ChunkPolicyModuleConfig(ModuleConfig): - policy_path: str - inference_rate_hz: float = 30.0 - device: str = "cuda" - - -class ChunkPolicyModule(Module): - config: ChunkPolicyModuleConfig - - color_image: In[Image] - joint_state: In[JointState] - action_chunk: Out[ActionChunk] - - @rpc - def reload_policy(self, policy_path: str, device: str | None = None) -> None: ... - @rpc - def get_status(self) -> dict[str, Any]: ... - - def _run_loop(self) -> None: - period = 1.0 / self.config.inference_rate_hz - while not self._stop.is_set(): - t0 = time.monotonic() - obs = self._build_live_obs() - if obs is None: - time.sleep(period); continue - - positions = self.policy.predict_chunk(obs) - self.action_chunk.publish(ActionChunk( - ts=time.time(), - joint_names=self.policy.joint_names, - positions=positions, - dt=period, - chunk_id=self._next_chunk_id(), - )) - time.sleep(max(0.0, period - (time.monotonic() - t0))) - - def _build_live_obs(self) -> dict[str, np.ndarray] | None: - # snapshot latched In[Image] / In[JointState], project via - # resolve_field using self._observation (StreamField map - # loaded from /dimos_meta.json at start()). - ... -``` - -`start()` reads `/dimos_meta.json`, reconstructs -`observation: dict[str, StreamField]` and `sync: SyncConfig`, and stores -them as instance state. `_build_live_obs` calls `resolve_field` -on each entry — same projection as training, no train/serve skew. - -## ActionReplayer - -```python -# dimos/control/tasks/action_replayer_task.py - -@dataclass -class ActionReplayerConfig: - joint_names: list[str] - priority: int = 10 - max_chunk_age_s: float = 0.5 - hold_on_stall: bool = True - temporal_ensemble: bool = False - - -class ActionReplayer(BaseControlTask): - def __init__(self, name: str, config: ActionReplayerConfig) -> None: ... - - @property - def name(self) -> str: ... - def claim(self) -> ResourceClaim: ... - def is_active(self) -> bool: ... - def compute(self, state: CoordinatorState) -> JointCommandOutput | None: ... - def on_action_chunk(self, msg: ActionChunk) -> None: ... -``` - -## ControlCoordinator wiring - -`ControlCoordinator` gains a new port + dispatcher (mirrors how -`cartesian_command` / `twist_command` are routed): - -```python -class ControlCoordinator(Module): - # ... existing ports ... - action_chunk: In[ActionChunk] - - def _on_action_chunk(self, msg: ActionChunk) -> None: - for task in self._tasks: - if isinstance(task, ActionReplayer): - task.on_action_chunk(msg) -``` - ---- - -## Run - -```bash -dimos run learning-infer-xarm7 -``` - -## End-to-end - -```bash -dimos run learning-collect-quest-xarm7 --record-path data/sessions/pick_red.db -dimos run learning-dataprep -dimos run learning-train -dimos run learning-infer-xarm7 -``` - -``` -data/sessions/pick_red.db ─► data/datasets/pick_red/ ─► data/runs/act_pick_red/ ─► live policy -``` diff --git a/dimos/learning/specs/spec_v2.md b/dimos/learning/specs/spec_v2.md index cd4518aefe..94e5c85d71 100644 --- a/dimos/learning/specs/spec_v2.md +++ b/dimos/learning/specs/spec_v2.md @@ -6,8 +6,8 @@ Two phases: LCM streams. `RecordReplay` (CLI flag) captures every active topic into `session.db` (memory2 `SqliteStore` format). 2. **DataPrep** — offline; `DataPrepModule` reads `session.db` and writes - a training-ready dataset on disk in one of three formats (LeRobot v2, - HDF5, RLDS). + a training-ready dataset on disk in one of two formats (LeRobot v2, + HDF5). Same code paths drive both phases as DimOS blueprints. Format dispatch is config-only; the same module backs every output format. @@ -182,7 +182,7 @@ class SyncConfig(BaseConfig): class OutputConfig(BaseConfig): - format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" + format: Literal["lerobot", "hdf5"] = "lerobot" path: Path metadata: dict[str, Any] = {} ``` @@ -219,8 +219,7 @@ thread is a daemon — there is no mid-iteration cancel. ### Pure helpers (in `dataprep.py`) Stateless, importable without booting a Module. Reused by every format -writer **and** by `ChunkPolicyModule._build_live_obs` at inference time -(single source of truth for obs construction). +writer (single source of truth for obs construction). ```python def resolve_field(msg: Any, ref: StreamField) -> np.ndarray: ... @@ -259,7 +258,7 @@ def compute_stats( ### Format writers (`dimos/learning/formats/`) -All three writers consume `Iterator[Sample] + OutputConfig` and accumulate +Both writers consume `Iterator[Sample] + OutputConfig` and accumulate stats via the shared `StreamingStats` (`formats/_stats.py`) so format- agnostic stats logic exists in exactly one place. @@ -275,7 +274,6 @@ class StreamingStats: |---|---|---| | `lerobot` | `meta/{info,episodes,tasks,stats}.json` + `data/chunk-000/episode_NNNNNN.parquet` + `videos/chunk-000/observation.images./episode_NNNNNN.mp4` | `pyarrow`, `opencv-python` | | `hdf5` | single `.hdf5` with `/episodes/episode_NNNNNN/{timestamp, observation/, action/}` + `/stats/` + `/tasks` + root attrs | `h5py` | -| `rlds` | `rlds-NNNNN-of-MMMMM.tfrecord` (one `SequenceExample` per episode, RLDS step protocol) + `features.json` + `dataset_info.json` | `tensorflow` | #### LeRobot v2 specifics @@ -294,10 +292,8 @@ class StreamingStats: ### dimos_meta.json sidecar -Written into every dataset directory; describes how it was built. Used -downstream by training (copies + adds policy fields) and by inference -(reads it at `start()` to recover the obs schema — no operator-supplied -spec path). +Written into every dataset directory; describes how it was built and +records the obs/action schema alongside the data. ```json { @@ -372,4 +368,4 @@ any consumer of recorded JPEG-encoded `Image` streams, not just learning. | `DataPrepModule` | `Module` | Long-running build job; thread + `get_status` RPC | | `RecordReplay` | transport hook | Captures every stream uniformly; not a per-Module concern | | `StreamingStats` | helper class | No lifecycle, no I/O — pure accumulator | -| `extract_episodes` / `iter_episode_samples` / `resolve_field` / `compute_stats` | functions | Pure helpers; reused by inference | +| `extract_episodes` / `iter_episode_samples` / `resolve_field` / `compute_stats` | functions | Pure helpers; reused by every format writer | diff --git a/dimos/learning/specs/structure.md b/dimos/learning/specs/structure.md index b5577ed87f..8dce1b1fa3 100644 --- a/dimos/learning/specs/structure.md +++ b/dimos/learning/specs/structure.md @@ -8,9 +8,7 @@ dimos/learning/ │ ├── specs/ │ ├── structure.md -│ ├── datacollection.md # Stage 1 -│ ├── training.md # Stage 2 -│ └── inference.md # Stage 3 +│ └── datacollection.md # Stage 1 │ ├── dataprep.py # types + pure helpers (no Module) │ # - Episode, Sample @@ -23,42 +21,10 @@ dimos/learning/ │ ├── episode_monitor.py # EpisodeStatus + EpisodeMonitorModule(Config) │ └── blueprint.py # learning_collect_quest_ │ -├── formats/ # dataset writers; each calls DataPrep.compute_stats -│ ├── lerobot.py # LeRobot v2 (parquet + MP4 + meta/stats.json) -│ ├── hdf5.py -│ └── rlds.py -│ -├── training/ -│ ├── trainer_module.py # TrainerModule(Config); runs train_bc on a thread -│ ├── train.py # train_bc + train_val_split (lazy lerobot/torch) -│ ├── configs.py # BCConfig -│ └── blueprint.py # learning_train -│ -├── policy/ -│ ├── base.py # ActionChunk + Policy Protocol -│ └── lerobot_policy.py # LeRobotPolicy.load -│ -└── inference/ - ├── chunk_policy_module.py # ChunkPolicyModule(Config); ~30 Hz - │ # (obs construction is a private method; - │ # uses DataPrep.resolve_field) - └── blueprint.py # learning_infer_ -``` - -`ActionReplayer` is a `ControlTask`, not a learning Module — it lives -with the other coordinator tasks: - +└── formats/ # dataset writers; each calls DataPrep.compute_stats + ├── lerobot.py # LeRobot v2 (parquet + MP4 + meta/stats.json) + └── hdf5.py ``` -dimos/control/ -├── coordinator.py # adds action_chunk: In[ActionChunk] -│ # _on_action_chunk → ActionReplayer -└── tasks/ - ├── teleop_task.py - ├── ... - └── action_replayer_task.py # NEW; imports ActionChunk from learning/policy/base.py -``` - -Dependency: `control → learning.policy` (one-way). --- @@ -69,11 +35,6 @@ Dependency: `control → learning.policy` (one-way). | `EpisodeStatus`, `EpisodeMonitorModuleConfig` | `learning/collection/episode_monitor.py` | `EpisodeMonitorModule`; `DataPrep` | | `EpisodeExtractor`, `StreamField`, `SyncConfig`, `OutputConfig`, `Episode`, `Sample` | `learning/dataprep.py` | `DataPrepModule`, `ChunkPolicyModule`, format writers | | `DataPrepModuleConfig` | `learning/dataprep_module.py` | `DataPrepModule` | -| `BCConfig` | `learning/training/configs.py` | `train_bc` | -| `TrainerModuleConfig` | `learning/training/trainer_module.py` | `TrainerModule` | -| `ActionChunk`, `Policy` Protocol | `learning/policy/base.py` | `ChunkPolicyModule`, `ActionReplayer`, `ControlCoordinator` | -| `ChunkPolicyModuleConfig` | `learning/inference/chunk_policy_module.py` | `ChunkPolicyModule` | -| `ActionReplayerConfig` | `control/tasks/action_replayer_task.py` | `ActionReplayer` | --- @@ -84,22 +45,18 @@ All generated artifacts live under `data/` (gitignored at repo root): ``` data/ ├── sessions/.db ← RecordReplay -├── datasets// ← DataPrepModule.build() -│ ├── data/ (parquet) -│ ├── videos/ (MP4) -│ └── meta/ -│ ├── info.json -│ ├── episodes.jsonl -│ ├── stats.json (DataPrep.compute_stats) -│ └── dimos_meta.json (DataPrepModuleConfig.model_dump()) -└── runs// ← train_bc - ├── *.safetensors - └── dimos_meta.json (dataset snapshot + policy fields) +└── datasets// ← DataPrepModule.build() + ├── data/ (parquet) + ├── videos/ (MP4) + └── meta/ + ├── info.json + ├── episodes.jsonl + ├── stats.json (DataPrep.compute_stats) + └── dimos_meta.json (DataPrepModuleConfig.model_dump()) ``` -`dimos_meta.json` rides with the data: DataPrep writes it; training -copies it forward + adds policy fields; inference reads it at `start()`. -Operator never passes a spec path. +`dimos_meta.json` rides with the data: DataPrep writes it alongside the +dataset to record the obs/action schema. --- @@ -120,7 +77,4 @@ A class becomes a **Module** when it has long-lived state with |---|---|---| | `EpisodeMonitorModule` | Module | Long-lived; subscribes to inputs; publishes status | | `DataPrepModule` | Module | Long-running build job | -| `TrainerModule` | Module | Runs training on a daemon thread | -| `ChunkPolicyModule` | Module | Long-lived inference thread | -| `ActionReplayer` | `BaseControlTask` | Runs in coordinator's 100 Hz thread | | `RecordReplay` | transport hook | Captures every stream uniformly | diff --git a/dimos/learning/specs/training.md b/dimos/learning/specs/training.md deleted file mode 100644 index fdd731cb5a..0000000000 --- a/dimos/learning/specs/training.md +++ /dev/null @@ -1,78 +0,0 @@ -# Stage 2 — Training - -ACT only (BC). Reads `data/datasets//`, writes `data/runs//`. - -`TrainerModule` runs `train_bc(...)` on a daemon thread inside its own -worker. Lazy imports keep `lerobot` / `torch` / CUDA out of the worker -until `train()` is called. Metrics → TensorBoard. No `cancel()` in v1. - ---- - -## Blueprint - -```python -# dimos/learning/training/blueprint.py -learning_train = autoconnect( - TrainerModule.blueprint( - dataset_path="data/datasets/pick_red/", - output_dir="data/runs/act_pick_red", - auto_run=True, - ), -).transports({}) -``` - -## Module - -```python -# dimos/learning/training/trainer_module.py - -class TrainerModuleConfig(ModuleConfig): - dataset_path: str = "" - output_dir: str = "" - config_path: str | None = None # optional BCConfig YAML override - auto_run: bool = False - tensorboard_port: int = 6006 - - -class TrainerModule(Module): - config: TrainerModuleConfig - - @rpc - def train( - self, - dataset_path: str | None = None, - output_dir: str | None = None, - config_overrides: dict[str, Any] | None = None, - ) -> None: ... - @rpc - def get_status(self) -> dict[str, Any]: ... -``` - -## Training entry point - -```python -# dimos/learning/training/train.py - -def train_bc( - dataset_path: str | Path, - cfg: BCConfig, - output_dir: str | Path, - config_overrides: dict[str, Any] | None = None, -) -> Path: - """Lazy-import lerobot, build Hydra-style argv from BCConfig, call - lerobot's training entry point, append `dimos_meta.json` to output_dir, - return the checkpoint path.""" -``` - -`BCConfig` (ACT hyperparams) lives in `training/configs.py`. -`train_val_split()` lives next to `train_bc` in `training/train.py`. - -## Run - -```bash -dimos run learning-train -tensorboard --logdir data/runs/act_pick_red -``` - -Artifact: `data/runs/act_pick_red/` = `*.safetensors` + `dimos_meta.json` -(dataset snapshot + `joint_names`, `chunk_size`, `policy_type`). diff --git a/dimos/learning/training/blueprint.py b/dimos/learning/training/blueprint.py deleted file mode 100644 index 2bfd96d0dd..0000000000 --- a/dimos/learning/training/blueprint.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ACT training blueprint. RPC-only surface (no streams). - -Defaults are tuned for the local pickplace_001 demo: 2k steps, batch=4, CPU. -For real training, override via: - dimos run learning-train -o trainermodule.bc.steps=100000 -o trainermodule.bc.device=cuda -""" - -from __future__ import annotations - -from dimos.core.coordination.blueprints import autoconnect -from dimos.learning.training.configs import BCConfig -from dimos.learning.training.trainer_module import TrainerModule - -learning_train = autoconnect( - TrainerModule.blueprint( - dataset_path="data/datasets/pickplace_001", - output_dir="data/runs/act_pickplace_001", - bc=BCConfig( - steps=2000, - batch_size=4, - device="cpu", - ), - auto_run=True, - overwrite=True, - ), -).transports({}) - - -__all__ = ["learning_train"] diff --git a/dimos/learning/training/configs.py b/dimos/learning/training/configs.py deleted file mode 100644 index a9428d7c29..0000000000 --- a/dimos/learning/training/configs.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ACT training config (v1). Fields are the opinionated subset DimOS exposes; -unset = lerobot default. Translated to Hydra-style argv inside `train_bc`.""" - -from __future__ import annotations - -from typing import Literal - -from pydantic import BaseModel - - -class BCConfig(BaseModel): - policy_type: Literal["act"] = "act" - - # Action chunking - chunk_size: int = 50 # future actions per inference call - n_obs_steps: int = 1 # obs history length - - # ACT model arch - hidden_dim: int = 512 - n_layers: int = 4 - n_heads: int = 8 - use_vae: bool = True - kl_weight: float = 10.0 - - # Vision backbone - vision_backbone: str = "resnet18" - pretrained: bool = True - - # Optim - steps: int = 100_000 - batch_size: int = 8 - lr: float = 1e-5 - lr_backbone: float = 1e-5 - weight_decay: float = 1e-4 - - # Eval / checkpointing - save_every: int = 10_000 - eval_every: int = 5_000 - seed: int = 0 - device: str = "cuda" diff --git a/dimos/learning/training/train.py b/dimos/learning/training/train.py deleted file mode 100644 index 4252aa3a7d..0000000000 --- a/dimos/learning/training/train.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ACT training entry point. - -`train_bc` subprocesses ``python -m lerobot.scripts.train`` with argv -translated from `BCConfig`. Lerobot is never imported in-process so the -dimos runtime stays free of torch/CUDA. After a successful run we write -``dimos_meta.json`` next to the checkpoint so `LeRobotPolicy.load` can -recover the obs/action schema. -""" - -from __future__ import annotations - -import argparse -import json -import random -import shutil -import subprocess -import sys -from pathlib import Path -from typing import Any - -from dimos.learning.dataprep import Episode -from dimos.learning.training.configs import BCConfig - -DIMOS_META_FILENAME = "dimos_meta.json" - - -def train_bc( - dataset_path: str | Path, - cfg: BCConfig, - output_dir: str | Path, - config_overrides: dict[str, Any] | None = None, - overwrite: bool = True, - resume: bool = False, -) -> Path: - """Train ACT on a prepared LeRobot v2 dataset. Returns the checkpoint dir. - - Args: - overwrite: if True (default) wipes ``output_dir`` before launching. - Lerobot's ``cfg.validate()`` refuses to run if the dir exists. - resume: pass ``--resume=true`` to lerobot. Takes precedence over - ``overwrite``. - """ - dataset_path = Path(dataset_path) - output_dir = Path(output_dir) - - if resume: - overwrite = False - elif overwrite and output_dir.exists(): - print(f"[train_bc] removing existing {output_dir}", flush=True) - shutil.rmtree(output_dir) - - argv = _build_lerobot_argv(cfg, dataset_path, output_dir) - if resume: - argv.append("--resume=true") - if config_overrides: - for k, v in config_overrides.items(): - argv.append(f"--{k}={v}") - - print(f"[train_bc] launching lerobot ({len(argv)} args, output → {output_dir})", flush=True) - # Log alongside output_dir, not inside — lerobot creates output_dir itself - # and its validate() refuses if it already exists. - output_dir.parent.mkdir(parents=True, exist_ok=True) - log_path = output_dir.parent / f"{output_dir.name}.lerobot.log" - - # Stream lerobot stdout to a full log file + filtered terminal output. - interesting = ( - "step:", "loss:", "Logs will be", "Creating dataset", "Output dir", - "Saved checkpoint", "epoch", "loss", "lr=", "ETA", "ERROR", "Error", - "WARNING", "Traceback", - ) - proc = subprocess.Popen(argv, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, text=True, bufsize=1) - with open(log_path, "w") as logf: - assert proc.stdout is not None - for line in proc.stdout: - logf.write(line) - logf.flush() - if any(kw in line for kw in interesting): - sys.stdout.write(line) - sys.stdout.flush() - proc.wait() - if proc.returncode != 0: - print(f"[train_bc] lerobot exited {proc.returncode} — full log: {log_path}", flush=True) - raise subprocess.CalledProcessError(proc.returncode, argv) - - _write_dimos_meta(output_dir, dataset_path, cfg) - return output_dir - - -def _build_lerobot_argv(cfg: BCConfig, dataset_path: Path, output_dir: Path) -> list[str]: - """Translate BCConfig → argv for ``lerobot.scripts.train`` (lerobot 0.3.x). - - LeRobot 0.4.x renamed the entry point to ``lerobot.scripts.lerobot_train`` - and adjusted some draccus flag names — adjust this function if you pin - a different version. - """ - return [ - sys.executable, "-m", "lerobot.scripts.train", - f"--policy.type={cfg.policy_type}", - f"--policy.chunk_size={cfg.chunk_size}", - f"--policy.n_action_steps={cfg.chunk_size}", - f"--policy.n_obs_steps={cfg.n_obs_steps}", - f"--policy.dim_model={cfg.hidden_dim}", - f"--policy.n_encoder_layers={cfg.n_layers}", - f"--policy.n_decoder_layers={cfg.n_layers}", - f"--policy.n_heads={cfg.n_heads}", - f"--policy.use_vae={str(cfg.use_vae).lower()}", - f"--policy.kl_weight={cfg.kl_weight}", - f"--policy.vision_backbone={cfg.vision_backbone}", - f"--policy.pretrained_backbone_weights={'ResNet18_Weights.IMAGENET1K_V1' if cfg.pretrained else 'null'}", - f"--policy.device={cfg.device}", - # push_to_hub defaults True in lerobot and triggers a repo_id requirement. - "--policy.push_to_hub=false", - "--dataset.repo_id=local", - f"--dataset.root={dataset_path}", - f"--steps={cfg.steps}", - f"--batch_size={cfg.batch_size}", - f"--optimizer.lr={cfg.lr}", - f"--optimizer.weight_decay={cfg.weight_decay}", - f"--save_freq={cfg.save_every}", - f"--eval_freq={cfg.eval_every}", - "--wandb.enable=false", - f"--seed={cfg.seed}", - f"--output_dir={output_dir}", - # Note: do NOT pass --env — its choice-class decoder rejects "none"; - # leaving it unset disables eval cleanly. - ] - - -def _write_dimos_meta(output_dir: Path, dataset_path: Path, cfg: BCConfig) -> None: - """Write the inference sidecar at ``/dimos_meta.json``. - - Combines the dataset's dimos_meta (obs/action streams, sync) with policy - fields (type, chunk_size, n_obs_steps) and the dataset_path so - `LeRobotPolicy.load` can resolve `meta/stats.json`. - """ - src = dataset_path / DIMOS_META_FILENAME - base: dict[str, Any] = json.load(open(src)) if src.exists() else {} - base.update({ - "dataset_path": str(dataset_path), - "policy_type": cfg.policy_type, - "chunk_size": cfg.chunk_size, - "n_obs_steps": cfg.n_obs_steps, - "joint_names": base.get("joint_names"), # often None; inference falls back - }) - with open(output_dir / DIMOS_META_FILENAME, "w") as f: - json.dump(base, f, indent=2, default=str) - - -def train_val_split( - episodes: list[Episode], - val_episode_ids: list[int] | None = None, - val_ratio: float | None = None, - seed: int = 0, -) -> tuple[list[int], list[int]]: - """Partition episode indices into (train_ids, val_ids). - - Resolution order: ``val_episode_ids`` (whitelist) > ``val_ratio`` - (deterministic via ``seed``) > both None (everything in train). - """ - n = len(episodes) - all_ids = list(range(n)) - - if val_episode_ids is not None: - val_set = set(val_episode_ids) - return ([i for i in all_ids if i not in val_set], - [i for i in all_ids if i in val_set]) - - if val_ratio is not None: - rng = random.Random(seed) - shuffled = all_ids[:] - rng.shuffle(shuffled) - n_val = int(round(n * val_ratio)) - return sorted(shuffled[n_val:]), sorted(shuffled[:n_val]) - - return all_ids, [] - - -# ───────────────────────────────────────────────────────────────────────────── -# CLI: `python -m dimos.learning.training.train bc --output ...` -# ───────────────────────────────────────────────────────────────────────────── - - -def main() -> None: - parser = argparse.ArgumentParser(prog="dimos.learning.training.train") - sub = parser.add_subparsers(dest="kind", required=True) - - p_bc = sub.add_parser("bc", help="Train an ACT (BC) policy") - p_bc.add_argument("dataset", help="path to LeRobot v2 dataset directory") - p_bc.add_argument("--output", required=True, help="checkpoint output directory") - p_bc.add_argument("--config", help="path to BCConfig JSON override") - p_bc.add_argument("--steps", type=int) - p_bc.add_argument("--batch-size", type=int) - p_bc.add_argument("--chunk-size", type=int) - p_bc.add_argument("--device", type=str) - p_bc.add_argument("-o", "--override", action="append", default=[], - help="extra lerobot CLI override, e.g. -o optimizer.lr=5e-5") - - args = parser.parse_args() - - if args.kind == "bc": - cfg_kwargs: dict[str, Any] = json.load(open(args.config)) if args.config else {} - for k, v in (("steps", args.steps), ("batch_size", args.batch_size), - ("chunk_size", args.chunk_size), ("device", args.device)): - if v is not None: - cfg_kwargs[k] = v - cfg = BCConfig(**cfg_kwargs) - - overrides: dict[str, Any] = {} - for o in args.override: - if "=" not in o: - parser.error(f"--override must be key=value, got {o!r}") - k, v = o.split("=", 1) - overrides[k] = v - - out = train_bc(args.dataset, cfg, args.output, config_overrides=overrides) - print(f"[train_bc] checkpoint at: {out}") - - -if __name__ == "__main__": - main() diff --git a/dimos/learning/training/trainer_module.py b/dimos/learning/training/trainer_module.py deleted file mode 100644 index eb32e530e9..0000000000 --- a/dimos/learning/training/trainer_module.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ACT training Module — thin wrapper around `train_bc`. - -Spawns `train_bc` on a daemon thread (which subprocesses -``python -m lerobot.scripts.train``). Exposes: - - @rpc start() lifecycle (auto-fires train if auto_run) - @rpc train(...) kick off a training job - @rpc get_status() current state + checkpoint dir - @rpc stop() best-effort shutdown - -There is no cancel(): the lerobot subprocess is sent SIGTERM only on -process exit. Heavy deps (torch, lerobot) stay in the subprocess. -""" - -from __future__ import annotations - -import json -import shutil -import subprocess -import threading -import traceback -from pathlib import Path -from typing import Any - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.learning.training.configs import BCConfig -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - - -class TrainerModuleConfig(ModuleConfig): - dataset_path: str = "" - output_dir: str = "" - # ACT hyperparams. CLI override pattern: - # -o trainermodule.bc.steps=2000 -o trainermodule.bc.batch_size=4 - bc: BCConfig = BCConfig() - # Optional JSON file with BCConfig overrides; merged on top of `bc`. - config_path: str | None = None - auto_run: bool = False - overwrite: bool = True # wipe output_dir before training (lerobot refuses to overwrite) - resume: bool = False # pass --resume=true to lerobot - # Lerobot 0.3.x does not write tensorboard event files; the launch is a - # no-op there and shows an empty UI. Disabled until we wire a stdout - # parser → SummaryWriter on our side. - tensorboard: bool = False - tensorboard_port: int = 6006 - tensorboard_host: str = "0.0.0.0" - - -class TrainerModule(Module): - config: TrainerModuleConfig - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._thread: threading.Thread | None = None - self._lock = threading.Lock() - self._status: dict[str, Any] = { - "state": "idle", # idle | running | succeeded | failed - "checkpoint_dir": None, - "tensorboard_url": None, - "error": None, - } - self._tb_proc: subprocess.Popen[bytes] | None = None - - @rpc - def start(self) -> None: - super().start() - if self.config.auto_run: - self.train() - - @rpc - def stop(self) -> None: - # Train thread is daemon: dies with the process. No mid-run interrupt. - if self._tb_proc is not None and self._tb_proc.poll() is None: - logger.info("[trainer] stopping tensorboard pid=%s", self._tb_proc.pid) - self._tb_proc.terminate() - try: - self._tb_proc.wait(timeout=2.0) - except subprocess.TimeoutExpired: - self._tb_proc.kill() - super().stop() - - @rpc - def train( - self, - dataset_path: str | None = None, - output_dir: str | None = None, - config_overrides: dict[str, Any] | None = None, - ) -> None: - """Spawn a daemon thread running ``train_bc``; returns immediately. - Raises if a run is already in progress.""" - with self._lock: - if self._status["state"] == "running": - raise RuntimeError("training already in progress") - self._status.update(state="running", checkpoint_dir=None, error=None) - - ds = dataset_path or self.config.dataset_path - od = output_dir or self.config.output_dir - if not ds or not od: - with self._lock: - self._status.update(state="failed", - error="dataset_path and output_dir are required") - raise ValueError("dataset_path and output_dir are required") - - self._maybe_start_tensorboard(Path(od)) - - self._thread = threading.Thread( - target=self._run_training, - args=(ds, od, config_overrides), - daemon=True, - ) - self._thread.start() - - @rpc - def get_status(self) -> dict[str, Any]: - with self._lock: - return dict(self._status) - - def _run_training( - self, - dataset_path: str, - output_dir: str, - config_overrides: dict[str, Any] | None, - ) -> None: - try: - from dimos.learning.training.train import train_bc - - cfg_kwargs = self.config.bc.model_dump() - if self.config.config_path: - with open(self.config.config_path) as f: - cfg_kwargs.update(json.load(f)) - cfg = BCConfig(**cfg_kwargs) - - ckpt = train_bc( - dataset_path, cfg, output_dir, - config_overrides=config_overrides, - overwrite=self.config.overwrite, - resume=self.config.resume, - ) - - with self._lock: - self._status.update(state="succeeded", checkpoint_dir=str(ckpt)) - except Exception as e: - with self._lock: - self._status.update( - state="failed", - error=f"{type(e).__name__}: {e}\n{traceback.format_exc()}", - ) - - # ── tensorboard ────────────────────────────────────────────────────────── - - def _maybe_start_tensorboard(self, logdir: Path) -> None: - """Spawn ``tensorboard --logdir `` if enabled and available.""" - if not self.config.tensorboard or self.config.tensorboard_port == 0: - return - if self._tb_proc is not None and self._tb_proc.poll() is None: - return - - tb_bin = shutil.which("tensorboard") - if tb_bin is None: - logger.warning( - "[trainer] tensorboard binary not found on PATH — skipping. " - "Install with: pip install tensorboard" - ) - return - - # Do NOT pre-create logdir — lerobot's cfg.validate() refuses if the - # output dir exists. Tensorboard polls happily on a missing dir. - port = self.config.tensorboard_port - host = self.config.tensorboard_host - try: - self._tb_proc = subprocess.Popen( - [tb_bin, "--logdir", str(logdir), - "--port", str(port), "--host", host, - "--reload_interval", "5"], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - except Exception as e: - logger.warning("[trainer] failed to launch tensorboard: %s", e) - return - - view_host = "localhost" if host in ("0.0.0.0", "") else host - url = f"http://{view_host}:{port}/" - with self._lock: - self._status["tensorboard_url"] = url - logger.info("[trainer] tensorboard launched pid=%s — view at %s", - self._tb_proc.pid, url) - - -__all__ = ["TrainerModule", "TrainerModuleConfig"] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 2a42490a6a..83b5ed8922 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -69,9 +69,6 @@ "learning-collect-quest-xarm7": "dimos.learning.collection.blueprint:learning_collect_quest_xarm7", "learning-dataprep": "dimos.learning.dataprep_blueprint:learning_dataprep", "learning-dataprep-whole-session": "dimos.learning.dataprep_blueprint:learning_dataprep_whole_session", - "learning-infer-chunkpolicy-only": "dimos.learning.inference.blueprint:learning_infer_chunkpolicy_only", - "learning-infer-xarm7": "dimos.learning.inference.blueprint:learning_infer_xarm7", - "learning-train": "dimos.learning.training.blueprint:learning_train", "mid360": "dimos.hardware.sensors.lidar.livox.livox_blueprints:mid360", "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", "mid360-fastlio-ray-trace": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_ray_trace", @@ -141,7 +138,6 @@ "b1-connection-module": "dimos.robot.unitree.b1.connection.B1ConnectionModule", "camera-module": "dimos.hardware.sensors.camera.module.CameraModule", "cartesian-motion-controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller.CartesianMotionController", - "chunk-policy-module": "dimos.learning.inference.chunk_policy_module.ChunkPolicyModule", "control-coordinator": "dimos.control.coordinator.ControlCoordinator", "cost-mapper": "dimos.mapping.costmapper.CostMapper", "data-prep-module": "dimos.learning.dataprep_module.DataPrepModule", @@ -226,7 +222,6 @@ "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory.TemporalMemory", "terrain-analysis": "dimos.navigation.nav_stack.modules.terrain_analysis.terrain_analysis.TerrainAnalysis", "terrain-map-ext": "dimos.navigation.nav_stack.modules.terrain_map_ext.terrain_map_ext.TerrainMapExt", - "trainer-module": "dimos.learning.training.trainer_module.TrainerModule", "twist-teleop-module": "dimos.teleop.quest.quest_extensions.TwistTeleopModule", "unitree-g1-skill-container": "dimos.robot.unitree.g1.skill_container.UnitreeG1SkillContainer", "unitree-skill-container": "dimos.robot.unitree.unitree_skill_container.UnitreeSkillContainer", From 95fa2397dec687c39d160cdd32df2fe300c4807c Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 4 Jun 2026 12:05:44 -0700 Subject: [PATCH 08/45] docs: remove readme --- dimos/learning/specs/datacollection.md | 188 ------------- dimos/learning/specs/spec_v2.md | 371 ------------------------- dimos/learning/specs/structure.md | 80 ------ 3 files changed, 639 deletions(-) delete mode 100644 dimos/learning/specs/datacollection.md delete mode 100644 dimos/learning/specs/spec_v2.md delete mode 100644 dimos/learning/specs/structure.md diff --git a/dimos/learning/specs/datacollection.md b/dimos/learning/specs/datacollection.md deleted file mode 100644 index b6e7afadc0..0000000000 --- a/dimos/learning/specs/datacollection.md +++ /dev/null @@ -1,188 +0,0 @@ -# Stage 1 — Data - -1. **Recording** — live; `RecordReplay` writes streams to `session.db`. -2. **DataPrep** — offline; `session.db` → `dataset/` (LeRobot v2). - ---- - -## Phase A — Recording - -### Blueprint - -```python -# dimos/learning/collection/blueprint.py -learning_collect_quest_xarm7 = autoconnect( - teleop_quest_xarm7, - RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint( - button_map={"start": "A", "save": "B", "discard": "X"}, - default_task_label="pick_red_cube", - ), -).transports({ - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), -}) -``` - -`RecordReplay` (`--record-path`) captures every transport above into `session.db`. - -### EpisodeMonitorModule - -Translates teleop input (buttons, keyboard) into the canonical -`EpisodeStatus` stream. DataPrep reads only this stream — never raw inputs. - -```python -# dimos/learning/collection/episode_monitor.py - -class EpisodeStatus(BaseModel): - state: Literal["idle", "recording"] - episodes_saved: int - episodes_discarded: int - current_episode_start_ts: float | None - last_event: Literal["start", "save", "discard", "init"] = "init" - task_label: str | None = None - - -class EpisodeMonitorModuleConfig(ModuleConfig): - button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} - keyboard_map: dict[Literal["start", "save", "discard"], str] = {} - default_task_label: str | None = None - - -class EpisodeMonitorModule(Module): - config: EpisodeMonitorModuleConfig - - buttons: In[Buttons] - keyboard: In[KeyPress] - status: Out[EpisodeStatus] - - @rpc - def reset_counters(self) -> EpisodeStatus: ... - @rpc - def get_status(self) -> EpisodeStatus: ... - - def _on_buttons(self, msg: Buttons) -> None: ... - def _on_keyboard(self, msg: KeyPress) -> None: ... -``` - -State machine: - -``` -IDLE --start--> RECORDING -RECORDING --save--> IDLE (commit) -RECORDING --discard--> IDLE (drop) -RECORDING --start--> RECORDING (auto-commit prev) -session end mid-episode: always discard -``` - -### Run - -```bash -dimos run learning-collect-quest-xarm7 --record-path data/sessions/pick_red.db -``` - ---- - -## Phase B — DataPrep - -### Blueprint - -```python -# dimos/learning/dataprep/blueprint.py -learning_dataprep = autoconnect( - DataPrepModule.blueprint( - source="data/sessions/pick_red.db", - episodes=EpisodeExtractor(), - observation={ - "cam": StreamField(stream="camera_color_image", field="image"), - "joint_pos": StreamField(stream="coordinator_joint_state", field="position"), - }, - action={ - "joint_target": StreamField(stream="coordinator_joint_command", field="position"), - }, - sync=SyncConfig(anchor="cam", rate_hz=30, tolerance_ms=50), - output=OutputConfig(format="lerobot", path=Path("data/datasets/pick_red/")), - auto_run=True, - ), -).transports({}) -``` - -### DataPrepModule - -```python -# dimos/learning/dataprep_module.py -from dimos.protocol.service.spec import BaseConfig - - -class EpisodeExtractor(BaseConfig): - extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" - status_stream: str = "episode_status" - ranges: list[tuple[float, float]] | None = None - - -class StreamField(BaseConfig): - stream: str - field: str | None = None - - -class SyncConfig(BaseConfig): - anchor: str - rate_hz: float - tolerance_ms: float - strategy: Literal["nearest", "interp"] = "nearest" - - -class OutputConfig(BaseConfig): - format: Literal["lerobot", "hdf5", "rlds"] = "lerobot" - path: Path - metadata: dict[str, Any] = {} - - -class DataPrepModuleConfig(ModuleConfig): - source: str - episodes: EpisodeExtractor - observation: dict[str, StreamField] - action: dict[str, StreamField] - sync: SyncConfig - output: OutputConfig - auto_run: bool = False - - -class DataPrepModule(Module): - config: DataPrepModuleConfig - - @rpc - def build(self) -> None: ... - @rpc - def get_status(self) -> dict[str, Any]: ... - @rpc - def inspect(self) -> dict[str, Any]: ... -``` - -`build()` iterates samples, hands them to the format writer, and snapshots -`config.model_dump()` into `/dimos_meta.json`. Stats are -written into `meta/stats.json` by `DataPrep.compute_stats`. - -### Run - -```bash -dimos run learning-dataprep -``` - ---- - -## End-to-end - -```bash -dimos run learning-collect-quest-xarm7 --record-path data/sessions/pick_red.db -dimos run learning-dataprep -``` - -``` -data/sessions/pick_red.db ─► data/datasets/pick_red/ - ├── data/ (parquet) - ├── videos/ (MP4) - └── meta/ (info.json, episodes.jsonl, - stats.json, dimos_meta.json) -``` diff --git a/dimos/learning/specs/spec_v2.md b/dimos/learning/specs/spec_v2.md deleted file mode 100644 index 94e5c85d71..0000000000 --- a/dimos/learning/specs/spec_v2.md +++ /dev/null @@ -1,371 +0,0 @@ -# Stage 1 — Data (v2) - -Two phases: - -1. **Recording** — live; teleop + camera + `EpisodeMonitorModule` produce - LCM streams. `RecordReplay` (CLI flag) captures every active topic - into `session.db` (memory2 `SqliteStore` format). -2. **DataPrep** — offline; `DataPrepModule` reads `session.db` and writes - a training-ready dataset on disk in one of two formats (LeRobot v2, - HDF5). - -Same code paths drive both phases as DimOS blueprints. Format dispatch is -config-only; the same module backs every output format. - ---- - -## Phase A — Recording - -### Blueprint - -```python -# dimos/learning/collection/blueprint.py -_DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} -_TRANSPORTS = { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), -} - -learning_collect_quest_xarm7 = autoconnect( - teleop_quest_xarm7, - RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), -).transports(_TRANSPORTS) - -# Variants (one per arm config) follow the same pattern: -# learning_collect_quest_xarm6 -# learning_collect_quest_piper -# learning_collect_quest_dual -``` - -`RecordReplay` (`--record-path`) captures every transport above into -`session.db`. Recording is a transport-layer hook, not a Module — every -LCM stream is recorded uniformly. - -### EpisodeMonitorModule - -Translates teleop input (Quest buttons, optional keyboard) into the -canonical `EpisodeStatus` stream. DataPrep reads only this stream, never -raw button presses. - -```python -# dimos/learning/collection/episode_monitor.py - -class EpisodeStatus(BaseModel): - state: Literal["idle", "recording"] - episodes_saved: int - episodes_discarded: int - current_episode_start_ts: float | None - last_event: Literal["start", "save", "discard", "init"] = "init" - task_label: str | None = None - - -class KeyPress(BaseModel): - key: str - ts: float - - -class EpisodeMonitorModuleConfig(ModuleConfig): - button_map: dict[Literal["start", "save", "discard"], str] = {"start": "A", "save": "B", "discard": "X"} - keyboard_map: dict[Literal["start", "save", "discard"], str] = {} - default_task_label: str | None = None - - -class EpisodeMonitorModule(Module): - config: EpisodeMonitorModuleConfig - - buttons: In[Buttons] - keyboard: In[KeyPress] - status: Out[EpisodeStatus] - - @rpc - def reset_counters(self) -> EpisodeStatus: ... - @rpc - def get_status(self) -> EpisodeStatus: ... -``` - -State machine (mirrored offline by `DataPrep.extract_episodes`): - -``` -IDLE --start--> RECORDING -RECORDING --save--> IDLE (commit, saved += 1) -RECORDING --discard--> IDLE (drop, discarded += 1) -RECORDING --start--> RECORDING (auto-commit prev, begin new) -session end mid-episode: always discarded -``` - -Friendly button names (`A`/`B`/`X`/...) resolve to `Buttons` attributes -via `BUTTON_ALIASES` (e.g. `"A"` → `right_primary`). Override with the -attribute name directly in `button_map`. - -### Run - -```bash -dimos run learning-collect-quest-xarm7 --record-path data/recordings/pick_red.db -``` - ---- - -## Phase B — DataPrep - -### Blueprints - -```python -# dimos/learning/dataprep_blueprint.py - -learning_dataprep = autoconnect( - DataPrepModule.blueprint( - source="data/recordings/pickplace_001.db", - episodes=EpisodeExtractor(extractor="ranges", ranges=[(t0, t1)]), - observation={ - "image": StreamField(stream="color_image", field="data"), - "joint_state": StreamField(stream="joint_state", field="position"), - }, - action={ - "joint_target": StreamField(stream="joint_state", field="position"), - }, - sync=SyncConfig(anchor="image", rate_hz=14.0, tolerance_ms=80.0), - output=OutputConfig( - format="lerobot", - path="data/datasets/pickplace_001", - metadata={"fps": 14, "robot": "xarm7", "default_task_label": "pick_and_place"}, - ), - auto_run=True, - ), -).transports({}) - - -# Variant for one-demo-per-file recordings (no episode_status stream). -learning_dataprep_whole_session = autoconnect( - DataPrepModule.blueprint( - source="data/session.db", - episodes=EpisodeExtractor(extractor="whole_session"), - observation={...}, - action={...}, - sync=SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0), - output=OutputConfig(format="lerobot", path="data/datasets/default", - metadata={"fps": 30, "robot": "xarm7"}), - auto_run=True, - ), -).transports({}) -``` - -All `DataPrepModuleConfig` fields are defaulted — the DimOS CLI's -per-module override path validates user kwargs in isolation, so -required-without-default fields would reject partial `-o ...` overrides. -Real values come from the blueprint atom; CLI flags overlay on top. - -### DataPrepModule - -```python -# dimos/learning/dataprep.py -from dimos.protocol.service.spec import BaseConfig - - -class EpisodeExtractor(BaseConfig): - extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" - status_stream: str = "episode_status" - ranges: list[tuple[float, float]] | None = None - - -class StreamField(BaseConfig): - stream: str - field: str | None = None - - -class SyncConfig(BaseConfig): - anchor: str - rate_hz: float - tolerance_ms: float - strategy: Literal["nearest", "interp"] = "nearest" - - -class OutputConfig(BaseConfig): - format: Literal["lerobot", "hdf5"] = "lerobot" - path: Path - metadata: dict[str, Any] = {} -``` - -```python -# dimos/learning/dataprep_module.py - -class DataPrepModuleConfig(ModuleConfig): - source: str = "" - episodes: EpisodeExtractor = EpisodeExtractor() - observation: dict[str, StreamField] = {} - action: dict[str, StreamField] = {} - sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) - output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") - auto_run: bool = False - - -class DataPrepModule(Module): - config: DataPrepModuleConfig - - @rpc - def build(self) -> None: ... # spawns build thread; returns immediately - @rpc - def get_status(self) -> dict[str, Any]: ... # state, current_phase, progress_pct, samples_seen, error - @rpc - def inspect(self) -> dict[str, Any]: ... # streams, episode counts, duration distribution -``` - -`build()` opens the `SqliteStore`, walks samples episode-by-episode, -hands them to the configured format writer, and snapshots the spec -(`config.model_dump()`) into `/dimos_meta.json`. The build -thread is a daemon — there is no mid-iteration cancel. - -### Pure helpers (in `dataprep.py`) - -Stateless, importable without booting a Module. Reused by every format -writer (single source of truth for obs construction). - -```python -def resolve_field(msg: Any, ref: StreamField) -> np.ndarray: ... - -def extract_episodes(store: SqliteStore, cfg: EpisodeExtractor) -> list[Episode]: - """ - episode_status: replay EpisodeMonitorModule's state machine over - the recorded EpisodeStatus events. - ranges: emit one Episode per (start, end) tuple. - whole_session: one Episode covering every stream's combined time range. - """ - -def iter_episode_samples( - store: SqliteStore, - episode: Episode, - streams: dict[str, StreamField], # observation ∪ action - sync: SyncConfig, - obs_keys: set[str] | None = None, - action_keys: set[str] | None = None, -) -> Iterator[Sample]: - """ - Anchor-rate timestep walker. Caches each stream once per episode, - bisect-nearest within tolerance_ms; skips frames where any required - stream lacks a nearby sample. - """ - -def compute_stats( - samples: Iterator[Sample], - image_subsample: int = 10, - quantile_reservoir: int = 10_000, - seed: int = 0, -) -> dict[str, Any]: - """Welford mean/std + reservoir quantiles. Image features (≥3D) - subsampled and reduced to per-channel summaries.""" -``` - -### Format writers (`dimos/learning/formats/`) - -Both writers consume `Iterator[Sample] + OutputConfig` and accumulate -stats via the shared `StreamingStats` (`formats/_stats.py`) so format- -agnostic stats logic exists in exactly one place. - -```python -# dimos/learning/formats/_stats.py -class StreamingStats: - def __init__(self, image_subsample=10, quantile_reservoir=10_000, seed=0): ... - def update(self, name: str, value: np.ndarray) -> None: ... - def finalize(self) -> dict[str, dict[str, Any]]: ... -``` - -| Format | Layout | Heavy dep | -|---|---|---| -| `lerobot` | `meta/{info,episodes,tasks,stats}.json` + `data/chunk-000/episode_NNNNNN.parquet` + `videos/chunk-000/observation.images./episode_NNNNNN.mp4` | `pyarrow`, `opencv-python` | -| `hdf5` | single `.hdf5` with `/episodes/episode_NNNNNN/{timestamp, observation/, action/}` + `/stats/` + `/tasks` + root attrs | `h5py` | - -#### LeRobot v2 specifics - -- **Image columns are NOT in the parquet** — lerobot's - `get_hf_features_from_features` skips dtype="video" and reads frames - from MP4 at `__getitem__` time. -- **Timestamps are episode-relative** (subtract `episode.start_ts`) - because lerobot stores `timestamp` as float32 and validates frame-to- - frame deltas against `1/fps`. Absolute Unix epoch values would collide - in float32. -- **Feature naming follows lerobot convention** — single low-dim obs ⇒ - `observation.state`, single action key ⇒ `action`, image keys ⇒ - `observation.images.`. -- **`info.json` features include per-dim `names` lists** (required by - lerobot 0.3+). - -### dimos_meta.json sidecar - -Written into every dataset directory; describes how it was built and -records the obs/action schema alongside the data. - -```json -{ - "source": "data/recordings/pickplace_001.db", - "observation": {"image": {...}, "joint_state": {...}}, - "action": {"joint_target": {...}}, - "sync": {"anchor": "image", "rate_hz": 14.0, ...}, - "episodes": [{"id": "ep_000000", "start_ts": ..., "end_ts": ..., "task_label": ...}], - "format": "lerobot", - "metadata": {"fps": 14, "robot": "xarm7", ...} -} -``` - -### Run - -```bash -dimos run learning-dataprep -``` - -Override per run: - -```bash -dimos run learning-dataprep \ - -o dataprepmodule.source=data/recordings/foo.db \ - -o dataprepmodule.output.path=data/datasets/foo \ - -o dataprepmodule.output.format=hdf5 -``` - -For complex nested overrides (observation/action stream maps), use a JSON -config: - -```bash -dimos run learning-dataprep -c data/foo_dataset.json -``` - ---- - -## End-to-end - -```bash -dimos run learning-collect-quest-xarm7 --record-path data/recordings/pick_red.db -dimos run learning-dataprep \ - -o dataprepmodule.source=data/recordings/pick_red.db \ - -o dataprepmodule.output.path=data/datasets/pick_red -``` - -``` -data/recordings/pick_red.db ─► data/datasets/pick_red/ - ├── data/ (parquet) ─┐ - ├── videos/ (MP4) ├─ format=lerobot - └── meta/ (info, episodes, ─┘ - tasks, stats) - └── dimos_meta.json (always) -``` - ---- - -## Compatibility note — JpegCodec - -Recordings made before commit `` ship Image blobs with a -1-byte format tag (`b'J'`) preceding the LCM envelope. `JpegCodec.decode` -strips it transparently so old + new sessions both read cleanly. Affects -any consumer of recorded JPEG-encoded `Image` streams, not just learning. - ---- - -## Module / non-Module split for Stage 1 - -| Component | Type | Why | -|---|---|---| -| `EpisodeMonitorModule` | `Module` | Long-lived; subscribes to teleop input; publishes status | -| `DataPrepModule` | `Module` | Long-running build job; thread + `get_status` RPC | -| `RecordReplay` | transport hook | Captures every stream uniformly; not a per-Module concern | -| `StreamingStats` | helper class | No lifecycle, no I/O — pure accumulator | -| `extract_episodes` / `iter_episode_samples` / `resolve_field` / `compute_stats` | functions | Pure helpers; reused by every format writer | diff --git a/dimos/learning/specs/structure.md b/dimos/learning/specs/structure.md deleted file mode 100644 index 8dce1b1fa3..0000000000 --- a/dimos/learning/specs/structure.md +++ /dev/null @@ -1,80 +0,0 @@ -# Folder Structure - -Per-producer types: each Module owns its config + emitted message types -in its own file. No shared `config.py`, no umbrella class, no shared YAML. - -``` -dimos/learning/ -│ -├── specs/ -│ ├── structure.md -│ └── datacollection.md # Stage 1 -│ -├── dataprep.py # types + pure helpers (no Module) -│ # - Episode, Sample -│ # - StreamField, SyncConfig, OutputConfig, EpisodeExtractor -│ # - resolve_field, compute_stats, -│ # extract_episodes, iter_episode_samples -├── dataprep_module.py # DataPrepModule(Config) only -│ -├── collection/ -│ ├── episode_monitor.py # EpisodeStatus + EpisodeMonitorModule(Config) -│ └── blueprint.py # learning_collect_quest_ -│ -└── formats/ # dataset writers; each calls DataPrep.compute_stats - ├── lerobot.py # LeRobot v2 (parquet + MP4 + meta/stats.json) - └── hdf5.py -``` - ---- - -## Per-producer typed contracts - -| Class | Lives in | Used by | -|---|---|---| -| `EpisodeStatus`, `EpisodeMonitorModuleConfig` | `learning/collection/episode_monitor.py` | `EpisodeMonitorModule`; `DataPrep` | -| `EpisodeExtractor`, `StreamField`, `SyncConfig`, `OutputConfig`, `Episode`, `Sample` | `learning/dataprep.py` | `DataPrepModule`, `ChunkPolicyModule`, format writers | -| `DataPrepModuleConfig` | `learning/dataprep_module.py` | `DataPrepModule` | - ---- - -## Artifact flow - -All generated artifacts live under `data/` (gitignored at repo root): - -``` -data/ -├── sessions/.db ← RecordReplay -└── datasets// ← DataPrepModule.build() - ├── data/ (parquet) - ├── videos/ (MP4) - └── meta/ - ├── info.json - ├── episodes.jsonl - ├── stats.json (DataPrep.compute_stats) - └── dimos_meta.json (DataPrepModuleConfig.model_dump()) -``` - -`dimos_meta.json` rides with the data: DataPrep writes it alongside the -dataset to record the obs/action schema. - ---- - -## Configuration - -All module config is set as kwargs in the blueprint. No CLI flags on -our modules. Framework CLI surface is `GlobalConfig` only (env vars, -`.env`, things like `--record-path`). - ---- - -## Module / non-Module split - -A class becomes a **Module** when it has long-lived state with -`start()/stop()` lifecycle **and** typed I/O ports. - -| Class | Type | Why | -|---|---|---| -| `EpisodeMonitorModule` | Module | Long-lived; subscribes to inputs; publishes status | -| `DataPrepModule` | Module | Long-running build job | -| `RecordReplay` | transport hook | Captures every stream uniformly | From 475565b76d58b4bb8e178338367b49aa3dadf475 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 4 Jun 2026 12:43:46 -0700 Subject: [PATCH 09/45] feat: dataprep folder --- dimos/learning/collection/blueprint.py | 39 +++++++------------ .../blueprint.py} | 34 +++------------- .../{dataprep.py => dataprep/core.py} | 6 +-- .../learning/{ => dataprep}/formats/_stats.py | 0 dimos/learning/{ => dataprep}/formats/hdf5.py | 4 +- .../{ => dataprep}/formats/lerobot.py | 4 +- .../module.py} | 2 +- dimos/robot/all_blueprints.py | 7 +--- 8 files changed, 30 insertions(+), 66 deletions(-) rename dimos/learning/{dataprep_blueprint.py => dataprep/blueprint.py} (62%) rename dimos/learning/{dataprep.py => dataprep/core.py} (98%) rename dimos/learning/{ => dataprep}/formats/_stats.py (100%) rename dimos/learning/{ => dataprep}/formats/hdf5.py (97%) rename dimos/learning/{ => dataprep}/formats/lerobot.py (98%) rename dimos/learning/{dataprep_module.py => dataprep/module.py} (99%) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index 0ee92a298a..5b6e52b81e 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -25,52 +25,41 @@ ) from dimos.msgs.sensor_msgs.Image import Image from dimos.teleop.quest.blueprints import ( - teleop_quest_dual, teleop_quest_piper, - teleop_quest_xarm6, teleop_quest_xarm7, ) from dimos.teleop.quest.quest_types import Buttons _DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} -_TRANSPORTS = { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), -} +# Transports are written inline per blueprint (not factored into a shared +# variable) so each recording config is self-contained and readable on its +# own: buttons drive the episode state machine, color_image is the camera +# stream, and status carries the canonical EpisodeStatus that DataPrep reads. learning_collect_quest_xarm7 = autoconnect( teleop_quest_xarm7, RealSenseCamera.blueprint(enable_pointcloud=False), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), -).transports(_TRANSPORTS) +).transports({ + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), +}) learning_collect_quest_piper = autoconnect( teleop_quest_piper, RealSenseCamera.blueprint(enable_pointcloud=False), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), -).transports(_TRANSPORTS) - - -learning_collect_quest_xarm6 = autoconnect( - teleop_quest_xarm6, - RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), -).transports(_TRANSPORTS) - - -learning_collect_quest_dual = autoconnect( - teleop_quest_dual, - RealSenseCamera.blueprint(enable_pointcloud=False), - EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), -).transports(_TRANSPORTS) +).transports({ + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), +}) __all__ = [ - "learning_collect_quest_dual", "learning_collect_quest_piper", - "learning_collect_quest_xarm6", "learning_collect_quest_xarm7", ] diff --git a/dimos/learning/dataprep_blueprint.py b/dimos/learning/dataprep/blueprint.py similarity index 62% rename from dimos/learning/dataprep_blueprint.py rename to dimos/learning/dataprep/blueprint.py index 53066da29f..6558cd0d3c 100644 --- a/dimos/learning/dataprep_blueprint.py +++ b/dimos/learning/dataprep/blueprint.py @@ -20,21 +20,21 @@ dimos run learning-dataprep -o dataprepmodule.source=data/recordings/foo.db \\ -o dataprepmodule.output.path=data/datasets/foo -The defaults below target the included pickplace_001 demo. For single-demo -recordings without an `episode_status` stream, `learning_dataprep_whole_session` -treats the entire recording as one episode. +The defaults below target the included pickplace_001 demo. Episodes are +always segmented from the recording (the `episode_status` stream or +explicit `ranges`) — we never collapse a session into a single episode. """ from __future__ import annotations from dimos.core.coordination.blueprints import autoconnect -from dimos.learning.dataprep import ( +from dimos.learning.dataprep.core import ( EpisodeExtractor, OutputConfig, StreamField, SyncConfig, ) -from dimos.learning.dataprep_module import DataPrepModule +from dimos.learning.dataprep.module import DataPrepModule learning_dataprep = autoconnect( DataPrepModule.blueprint( @@ -61,26 +61,4 @@ ).transports({}) -learning_dataprep_whole_session = autoconnect( - DataPrepModule.blueprint( - source="data/session.db", - episodes=EpisodeExtractor(extractor="whole_session"), - observation={ - "image": StreamField(stream="camera_color_image", field="data"), - "joint_state": StreamField(stream="coordinator_joint_state", field="position"), - }, - action={ - "joint_target": StreamField(stream="coordinator_joint_command", field="position"), - }, - sync=SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0), - output=OutputConfig( - format="lerobot", - path="data/datasets/default", - metadata={"fps": 30, "robot": "xarm7"}, - ), - auto_run=True, - ), -).transports({}) - - -__all__ = ["learning_dataprep", "learning_dataprep_whole_session"] +__all__ = ["learning_dataprep"] diff --git a/dimos/learning/dataprep.py b/dimos/learning/dataprep/core.py similarity index 98% rename from dimos/learning/dataprep.py rename to dimos/learning/dataprep/core.py index cc15665334..9a04a839c5 100644 --- a/dimos/learning/dataprep.py +++ b/dimos/learning/dataprep/core.py @@ -330,7 +330,7 @@ def compute_stats( Thin wrapper over :class:`StreamingStats` so format writers and ad-hoc callers share the exact same accumulator. """ - from dimos.learning.formats._stats import StreamingStats + from dimos.learning.dataprep.formats._stats import StreamingStats s = StreamingStats(image_subsample=image_subsample, quantile_reservoir=quantile_reservoir, seed=seed) @@ -345,9 +345,9 @@ def compute_stats( def get_writer(format_name: str) -> Writer: """Lazy-import the format writer's `write` function.""" if format_name == "lerobot": - from dimos.learning.formats.lerobot import write + from dimos.learning.dataprep.formats.lerobot import write elif format_name == "hdf5": - from dimos.learning.formats.hdf5 import write + from dimos.learning.dataprep.formats.hdf5 import write else: raise ValueError(f"Unknown format: {format_name!r}") return write diff --git a/dimos/learning/formats/_stats.py b/dimos/learning/dataprep/formats/_stats.py similarity index 100% rename from dimos/learning/formats/_stats.py rename to dimos/learning/dataprep/formats/_stats.py diff --git a/dimos/learning/formats/hdf5.py b/dimos/learning/dataprep/formats/hdf5.py similarity index 97% rename from dimos/learning/formats/hdf5.py rename to dimos/learning/dataprep/formats/hdf5.py index acbe647c9b..1dcb4f378b 100644 --- a/dimos/learning/formats/hdf5.py +++ b/dimos/learning/dataprep/formats/hdf5.py @@ -38,8 +38,8 @@ import numpy as np -from dimos.learning.dataprep import OutputConfig, Sample -from dimos.learning.formats._stats import StreamingStats +from dimos.learning.dataprep.core import OutputConfig, Sample +from dimos.learning.dataprep.formats._stats import StreamingStats def write(samples: Iterator[Sample], output: OutputConfig) -> Path: diff --git a/dimos/learning/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py similarity index 98% rename from dimos/learning/formats/lerobot.py rename to dimos/learning/dataprep/formats/lerobot.py index acb1f27bbb..51ffd511ad 100644 --- a/dimos/learning/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -40,8 +40,8 @@ import numpy as np -from dimos.learning.dataprep import OutputConfig, Sample -from dimos.learning.formats._stats import StreamingStats +from dimos.learning.dataprep.core import OutputConfig, Sample +from dimos.learning.dataprep.formats._stats import StreamingStats CHUNK = "chunk-000" DATA_DIR = "data" diff --git a/dimos/learning/dataprep_module.py b/dimos/learning/dataprep/module.py similarity index 99% rename from dimos/learning/dataprep_module.py rename to dimos/learning/dataprep/module.py index 4d7caeb5fa..3303eb6bed 100644 --- a/dimos/learning/dataprep_module.py +++ b/dimos/learning/dataprep/module.py @@ -29,7 +29,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig -from dimos.learning.dataprep import ( +from dimos.learning.dataprep.core import ( EpisodeExtractor, OutputConfig, Sample, diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index ac4a4bc39f..fd987568b2 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -64,12 +64,9 @@ "keyboard-teleop-piper": "dimos.robot.manipulators.piper.blueprints:keyboard_teleop_piper", "keyboard-teleop-xarm6": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm6", "keyboard-teleop-xarm7": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm7", - "learning-collect-quest-dual": "dimos.learning.collection.blueprint:learning_collect_quest_dual", "learning-collect-quest-piper": "dimos.learning.collection.blueprint:learning_collect_quest_piper", - "learning-collect-quest-xarm6": "dimos.learning.collection.blueprint:learning_collect_quest_xarm6", "learning-collect-quest-xarm7": "dimos.learning.collection.blueprint:learning_collect_quest_xarm7", - "learning-dataprep": "dimos.learning.dataprep_blueprint:learning_dataprep", - "learning-dataprep-whole-session": "dimos.learning.dataprep_blueprint:learning_dataprep_whole_session", + "learning-dataprep": "dimos.learning.dataprep.blueprint:learning_dataprep", "mid360": "dimos.hardware.sensors.lidar.livox.livox_blueprints:mid360", "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", "mid360-fastlio-ray-trace": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_ray_trace", @@ -142,7 +139,7 @@ "click-start-goal-router": "dimos.navigation.nav_stack.modules.click_start_goal_router.click_start_goal_router.ClickStartGoalRouter", "control-coordinator": "dimos.control.coordinator.ControlCoordinator", "cost-mapper": "dimos.mapping.costmapper.CostMapper", - "data-prep-module": "dimos.learning.dataprep_module.DataPrepModule", + "data-prep-module": "dimos.learning.dataprep.module.DataPrepModule", "demo-calculator-skill": "dimos.agents.skills.demo_calculator_skill.DemoCalculatorSkill", "demo-monitoring": "dimos.agents.demos.demo_capabilities.DemoMonitoring", "demo-robot": "dimos.agents.skills.demo_robot.DemoRobot", From 38d434cef30042e8cd81c8e98d62d08a60cb8bf6 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 4 Jun 2026 13:09:43 -0700 Subject: [PATCH 10/45] feat: add recorder --- dimos/learning/collection/blueprint.py | 10 ++++- dimos/learning/collection/recorder.py | 53 ++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 dimos/learning/collection/recorder.py diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index 5b6e52b81e..fc4279affb 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Recording blueprints. RecordReplay is enabled via `--record-path`.""" +"""Recording blueprints. + +`CollectionRecorder` (a memory2 Recorder) captures the obs/action/status +streams to a SQLite session DB during the run and flushes it durably on +shutdown. DataPrep reads that DB afterwards. +""" from __future__ import annotations @@ -23,6 +28,7 @@ EpisodeMonitorModule, EpisodeStatus, ) +from dimos.learning.collection.recorder import CollectionRecorder from dimos.msgs.sensor_msgs.Image import Image from dimos.teleop.quest.blueprints import ( teleop_quest_piper, @@ -41,6 +47,7 @@ teleop_quest_xarm7, RealSenseCamera.blueprint(enable_pointcloud=False), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), + CollectionRecorder.blueprint(), ).transports({ ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), @@ -52,6 +59,7 @@ teleop_quest_piper, RealSenseCamera.blueprint(enable_pointcloud=False), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), + CollectionRecorder.blueprint(), ).transports({ ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), diff --git a/dimos/learning/collection/recorder.py b/dimos/learning/collection/recorder.py new file mode 100644 index 0000000000..63f0a662cf --- /dev/null +++ b/dimos/learning/collection/recorder.py @@ -0,0 +1,53 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CollectionRecorder — captures teleop collection streams to a memory2 DB. + +A `Recorder` (memory2) subscribes each declared `In` port and appends every +message to a SQLite store, flushing durably on stop(). Only *connected* +streams are recorded, so the same recorder works for any arm whose +coordinator publishes `joint_state`. + +The recorded stream names match what `DataPrepModule` reads: `color_image` +and `joint_state` (observation), `status` (episode segmentation). +""" + +from __future__ import annotations + +from pathlib import Path + +from dimos.core.stream import In +from dimos.learning.collection.episode_monitor import EpisodeStatus +from dimos.memory2.module import Recorder, RecorderConfig +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.teleop.quest.quest_types import Buttons + + +class CollectionRecorderConfig(RecorderConfig): + db_path: str | Path = "data/recordings/session.db" + + +class CollectionRecorder(Recorder): + """Records the streams DataPrep consumes from a teleop session.""" + + config: CollectionRecorderConfig + + color_image: In[Image] # observation (camera) + joint_state: In[JointState] # observation + action (measured/next state) + status: In[EpisodeStatus] # episode start/save/discard segmentation + buttons: In[Buttons] # raw teleop input (kept for debugging) + + +__all__ = ["CollectionRecorder", "CollectionRecorderConfig"] From b39ceaf78af2f49f8a25aca0a554c018b7d339bc Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 4 Jun 2026 13:10:02 -0700 Subject: [PATCH 11/45] fix: ore-commit --- dimos/learning/collection/blueprint.py | 24 ++++--- dimos/learning/collection/recorder.py | 8 +-- dimos/learning/dataprep/blueprint.py | 2 +- dimos/learning/dataprep/core.py | 5 +- dimos/learning/dataprep/formats/_stats.py | 24 ++++--- dimos/learning/dataprep/formats/hdf5.py | 9 ++- dimos/learning/dataprep/formats/lerobot.py | 70 ++++++++++++-------- dimos/learning/dataprep/module.py | 76 +++++++++++++--------- 8 files changed, 130 insertions(+), 88 deletions(-) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index fc4279affb..b3ae3cec5a 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -48,11 +48,13 @@ RealSenseCamera.blueprint(enable_pointcloud=False), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), CollectionRecorder.blueprint(), -).transports({ - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), -}) +).transports( + { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), + } +) learning_collect_quest_piper = autoconnect( @@ -60,11 +62,13 @@ RealSenseCamera.blueprint(enable_pointcloud=False), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), CollectionRecorder.blueprint(), -).transports({ - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), -}) +).transports( + { + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), + } +) __all__ = [ diff --git a/dimos/learning/collection/recorder.py b/dimos/learning/collection/recorder.py index 63f0a662cf..ad6183c27d 100644 --- a/dimos/learning/collection/recorder.py +++ b/dimos/learning/collection/recorder.py @@ -44,10 +44,10 @@ class CollectionRecorder(Recorder): config: CollectionRecorderConfig - color_image: In[Image] # observation (camera) - joint_state: In[JointState] # observation + action (measured/next state) - status: In[EpisodeStatus] # episode start/save/discard segmentation - buttons: In[Buttons] # raw teleop input (kept for debugging) + color_image: In[Image] # observation (camera) + joint_state: In[JointState] # observation + action (measured/next state) + status: In[EpisodeStatus] # episode start/save/discard segmentation + buttons: In[Buttons] # raw teleop input (kept for debugging) __all__ = ["CollectionRecorder", "CollectionRecorderConfig"] diff --git a/dimos/learning/dataprep/blueprint.py b/dimos/learning/dataprep/blueprint.py index 6558cd0d3c..10da3f13e9 100644 --- a/dimos/learning/dataprep/blueprint.py +++ b/dimos/learning/dataprep/blueprint.py @@ -44,7 +44,7 @@ ranges=[(1777931622.11, 1777931646.61)], ), observation={ - "image": StreamField(stream="color_image", field="data"), + "image": StreamField(stream="color_image", field="data"), "joint_state": StreamField(stream="joint_state", field="position"), }, action={ diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 9a04a839c5..dced77ef9a 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -332,8 +332,9 @@ def compute_stats( """ from dimos.learning.dataprep.formats._stats import StreamingStats - s = StreamingStats(image_subsample=image_subsample, - quantile_reservoir=quantile_reservoir, seed=seed) + s = StreamingStats( + image_subsample=image_subsample, quantile_reservoir=quantile_reservoir, seed=seed + ) for sample in samples: for k, v in sample.observation.items(): s.update(f"observation.{k}", np.asarray(v)) diff --git a/dimos/learning/dataprep/formats/_stats.py b/dimos/learning/dataprep/formats/_stats.py index 92d5be648c..8d10ebdf35 100644 --- a/dimos/learning/dataprep/formats/_stats.py +++ b/dimos/learning/dataprep/formats/_stats.py @@ -21,8 +21,8 @@ from __future__ import annotations -import random from dataclasses import dataclass, field +import random from typing import Any import numpy as np @@ -45,8 +45,9 @@ class FeatureAggregator: class StreamingStats: """Single-pass mean/std/min/max/quantile aggregator across many features.""" - def __init__(self, image_subsample: int = 10, quantile_reservoir: int = 10_000, - seed: int = 0) -> None: + def __init__( + self, image_subsample: int = 10, quantile_reservoir: int = 10_000, seed: int = 0 + ) -> None: self.image_subsample = image_subsample self.quantile_reservoir = quantile_reservoir self._rng = random.Random(seed) @@ -56,14 +57,19 @@ def update(self, name: str, value: np.ndarray) -> None: a = np.asarray(value) is_image = a.ndim >= 3 agg = self.aggs.setdefault( - name, FeatureAggregator(is_image=is_image, shape=tuple(a.shape), dtype=str(a.dtype)), + name, + FeatureAggregator(is_image=is_image, shape=tuple(a.shape), dtype=str(a.dtype)), ) if is_image: agg.image_seen += 1 if (agg.image_seen - 1) % self.image_subsample != 0: return - v = a.astype(np.float32).mean(axis=(0, 1)) if a.ndim == 3 else a.astype(np.float32).reshape(-1) + v = ( + a.astype(np.float32).mean(axis=(0, 1)) + if a.ndim == 3 + else a.astype(np.float32).reshape(-1) + ) else: v = a.astype(np.float64) @@ -99,10 +105,10 @@ def finalize(self) -> dict[str, dict[str, Any]]: var = agg.m2 / n if agg.n > 1 else np.zeros_like(agg.mean) std = np.sqrt(var) entry: dict[str, Any] = { - "mean": agg.mean.tolist(), - "std": std.tolist(), - "min": agg.minv.tolist() if agg.minv is not None else None, - "max": agg.maxv.tolist() if agg.maxv is not None else None, + "mean": agg.mean.tolist(), + "std": std.tolist(), + "min": agg.minv.tolist() if agg.minv is not None else None, + "max": agg.maxv.tolist() if agg.maxv is not None else None, "count": int(agg.n), } if agg.reservoir: diff --git a/dimos/learning/dataprep/formats/hdf5.py b/dimos/learning/dataprep/formats/hdf5.py index 1dcb4f378b..3cb5d3cd1a 100644 --- a/dimos/learning/dataprep/formats/hdf5.py +++ b/dimos/learning/dataprep/formats/hdf5.py @@ -88,9 +88,12 @@ def _flush() -> None: ep.create_dataset("timestamp", data=np.asarray(buf_ts, dtype=np.float32)) for k, frames in buf_obs.items(): arr = np.stack(frames, axis=0) - ep.create_dataset(f"observation/{k}", data=arr, - compression="gzip" if arr.ndim >= 3 else None, - compression_opts=4 if arr.ndim >= 3 else None) + ep.create_dataset( + f"observation/{k}", + data=arr, + compression="gzip" if arr.ndim >= 3 else None, + compression_opts=4 if arr.ndim >= 3 else None, + ) for k, frames in buf_act.items(): ep.create_dataset(f"action/{k}", data=np.stack(frames, axis=0)) buf_ts.clear() diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 51ffd511ad..89d0e0de06 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -33,8 +33,8 @@ from __future__ import annotations -import json from collections.abc import Iterator +import json from pathlib import Path from typing import Any @@ -49,8 +49,9 @@ META_DIR = "meta" -def _feature_name(prefix: str, key: str, is_image: bool, - single_action: bool, single_state: bool = False) -> str: +def _feature_name( + prefix: str, key: str, is_image: bool, single_action: bool, single_state: bool = False +) -> str: """Translate (prefix, key) into the LeRobot v2 feature name. Canonical names lerobot policies (ACT, Diffusion, π₀) expect: @@ -143,16 +144,17 @@ def _flush_episode() -> None: current_video_writers = {} cols: dict[str, list[Any]] = { - "timestamp": [f["timestamp"] for f in current_frames], - "frame_index": [f["frame_index"] for f in current_frames], + "timestamp": [f["timestamp"] for f in current_frames], + "frame_index": [f["frame_index"] for f in current_frames], "episode_index": [f["episode_index"] for f in current_frames], - "index": [f["index"] for f in current_frames], - "task_index": [f["task_index"] for f in current_frames], + "index": [f["index"] for f in current_frames], + "task_index": [f["task_index"] for f in current_frames], } single_state = len(state_keys) == 1 for k in state_keys: - name = _feature_name("observation", k, is_image=False, - single_action=False, single_state=single_state) + name = _feature_name( + "observation", k, is_image=False, single_action=False, single_state=single_state + ) cols[name] = [f["obs"][k].tolist() for f in current_frames] single_action = len(action_keys) == 1 for k in action_keys: @@ -164,11 +166,13 @@ def _flush_episode() -> None: table = pa.Table.from_pydict(cols) pq.write_table(table, _episode_path_parquet(current_episode_index)) - episodes_meta.append({ - "episode_index": current_episode_index, - "tasks": [list(tasks_index.keys())[current_frames[0]["task_index"]]], - "length": len(current_frames), - }) + episodes_meta.append( + { + "episode_index": current_episode_index, + "tasks": [list(tasks_index.keys())[current_frames[0]["task_index"]]], + "length": len(current_frames), + } + ) current_frames = [] for sample in samples: @@ -189,8 +193,9 @@ def _flush_episode() -> None: for k, arr in sample.observation.items(): a = np.asarray(arr) is_image = a.ndim >= 3 - name = _feature_name("observation", k, is_image=is_image, - single_action=False, single_state=single_state) + name = _feature_name( + "observation", k, is_image=is_image, single_action=False, single_state=single_state + ) if name not in feature_shapes: feature_shapes[name] = tuple(a.shape) feature_dtypes[name] = "video" if is_image else str(a.dtype) @@ -221,15 +226,21 @@ def _flush_episode() -> None: current_video_writers[k].write(a) rel_ts = float(sample.ts) - (current_episode_start_ts or 0.0) - current_frames.append({ - "timestamp": rel_ts, - "frame_index": frame_index, - "episode_index": current_episode_index, - "index": global_index, - "task_index": tasks_index[default_task_label], - "obs": {k: np.asarray(v) for k, v in sample.observation.items() if np.asarray(v).ndim < 3}, - "act": {k: np.asarray(v) for k, v in sample.action.items()}, - }) + current_frames.append( + { + "timestamp": rel_ts, + "frame_index": frame_index, + "episode_index": current_episode_index, + "index": global_index, + "task_index": tasks_index[default_task_label], + "obs": { + k: np.asarray(v) + for k, v in sample.observation.items() + if np.asarray(v).ndim < 3 + }, + "act": {k: np.asarray(v) for k, v in sample.action.items()}, + } + ) global_index += 1 _flush_episode() @@ -265,8 +276,13 @@ def _flush_episode() -> None: "shape": list(shape), "names": [f"{base}_{i}" for i in range(n)], } - for col, dt in [("timestamp", "float32"), ("frame_index", "int64"), - ("episode_index", "int64"), ("index", "int64"), ("task_index", "int64")]: + for col, dt in [ + ("timestamp", "float32"), + ("frame_index", "int64"), + ("episode_index", "int64"), + ("index", "int64"), + ("task_index", "int64"), + ]: features[col] = {"dtype": dt, "shape": [1], "names": None} info = { diff --git a/dimos/learning/dataprep/module.py b/dimos/learning/dataprep/module.py index 3303eb6bed..75c1c91ce8 100644 --- a/dimos/learning/dataprep/module.py +++ b/dimos/learning/dataprep/module.py @@ -20,11 +20,11 @@ from __future__ import annotations +from collections.abc import Iterator import json +from pathlib import Path import threading import traceback -from collections.abc import Iterator -from pathlib import Path from typing import Any from dimos.core.core import rpc @@ -47,13 +47,13 @@ class DataPrepModuleConfig(ModuleConfig): # Fields are defaulted so partial CLI overrides (e.g. just `source=...`) # pass blueprint validation; blueprint atoms supply real values. - source: str = "" - episodes: EpisodeExtractor = EpisodeExtractor() + source: str = "" + episodes: EpisodeExtractor = EpisodeExtractor() observation: dict[str, StreamField] = {} - action: dict[str, StreamField] = {} - sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) - output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") - auto_run: bool = False + action: dict[str, StreamField] = {} + sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) + output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") + auto_run: bool = False class DataPrepModule(Module): @@ -66,13 +66,13 @@ def __init__(self, **kwargs: Any) -> None: self._thread: threading.Thread | None = None self._lock = threading.Lock() self._status: dict[str, Any] = { - "state": "idle", # idle | running | succeeded | failed - "current_phase": None, # scan_episodes | write | done - "progress_pct": 0.0, - "dataset_path": None, - "error": None, + "state": "idle", # idle | running | succeeded | failed + "current_phase": None, # scan_episodes | write | done + "progress_pct": 0.0, + "dataset_path": None, + "error": None, "episodes_seen": 0, - "samples_seen": 0, + "samples_seen": 0, } # ── lifecycle ──────────────────────────────────────────────────────────── @@ -123,13 +123,13 @@ def inspect(self) -> dict[str, Any]: dropped = sum(1 for e in episodes if not e.success) durations = [e.duration for e in episodes if e.success] return { - "source": self.config.source, - "streams": store.list_streams(), + "source": self.config.source, + "streams": store.list_streams(), "episodes_saved": saved, "episodes_dropped": dropped, - "duration_s": { - "min": min(durations) if durations else 0.0, - "max": max(durations) if durations else 0.0, + "duration_s": { + "min": min(durations) if durations else 0.0, + "max": max(durations) if durations else 0.0, "mean": (sum(durations) / len(durations)) if durations else 0.0, }, } @@ -161,7 +161,8 @@ def _run_build(self) -> None: episodes = [e for e in all_eps if e.success] logger.info( "[dataprep] episodes extracted: %d total / %d successful", - len(all_eps), len(episodes), + len(all_eps), + len(episodes), ) self._update_status(episodes_seen=len(episodes)) @@ -179,7 +180,8 @@ def _run_build(self) -> None: action_keys = set(self.config.action) logger.info( "[dataprep] obs streams=%s action streams=%s sync=%s", - sorted(obs_keys), sorted(action_keys), + sorted(obs_keys), + sorted(action_keys), self.config.sync.model_dump(), ) @@ -188,7 +190,8 @@ def _run_build(self) -> None: self._update_status(current_phase="write") logger.info( "[dataprep] writing %s dataset to %s", - self.config.output.format, self.config.output.path, + self.config.output.format, + self.config.output.path, ) samples_seen = 0 @@ -215,7 +218,9 @@ def _all_samples() -> Iterator[Sample]: logger.info( "[dataprep] %.1f%% samples=%d ep %d/%d", 100.0 * episodes_done / total, - samples_seen, episodes_done, total, + samples_seen, + episodes_done, + total, ) yield sample episodes_done += 1 @@ -236,7 +241,9 @@ def _all_samples() -> Iterator[Sample]: ) logger.info( "[dataprep] succeeded — wrote %d samples across %d episodes to %s", - samples_seen, total, dataset_path, + samples_seen, + total, + dataset_path, ) finally: store.stop() @@ -249,17 +256,22 @@ def _write_dimos_meta(self, dataset_path: Path, episodes: list[Any]) -> None: """Sidecar describing how this dataset was built, recording the obs/action schema alongside the dataset.""" meta = { - "source": self.config.source, + "source": self.config.source, "observation": {k: v.model_dump() for k, v in self.config.observation.items()}, - "action": {k: v.model_dump() for k, v in self.config.action.items()}, - "sync": self.config.sync.model_dump(), - "episodes": [ - {"id": e.id, "start_ts": e.start_ts, "end_ts": e.end_ts, - "task_label": e.task_label, "success": e.success} + "action": {k: v.model_dump() for k, v in self.config.action.items()}, + "sync": self.config.sync.model_dump(), + "episodes": [ + { + "id": e.id, + "start_ts": e.start_ts, + "end_ts": e.end_ts, + "task_label": e.task_label, + "success": e.success, + } for e in episodes ], - "format": self.config.output.format, - "metadata": self.config.output.metadata, + "format": self.config.output.format, + "metadata": self.config.output.metadata, } with open(dataset_path / "dimos_meta.json", "w") as f: json.dump(meta, f, indent=2, default=str) From 6a4847595c984d0845d1ddbc8d3b0e60b61f3234 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 4 Jun 2026 13:30:29 -0700 Subject: [PATCH 12/45] fix: episodeextractor default --- dimos/learning/dataprep/blueprint.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dimos/learning/dataprep/blueprint.py b/dimos/learning/dataprep/blueprint.py index 10da3f13e9..ee01de52e2 100644 --- a/dimos/learning/dataprep/blueprint.py +++ b/dimos/learning/dataprep/blueprint.py @@ -39,10 +39,7 @@ learning_dataprep = autoconnect( DataPrepModule.blueprint( source="data/recordings/pickplace_001.db", - episodes=EpisodeExtractor( - extractor="ranges", - ranges=[(1777931622.11, 1777931646.61)], - ), + episodes=EpisodeExtractor(), observation={ "image": StreamField(stream="color_image", field="data"), "joint_state": StreamField(stream="joint_state", field="position"), From 6fc0cff6e550d1395119edd2fe2b14ff26005045 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Tue, 9 Jun 2026 18:48:58 -0700 Subject: [PATCH 13/45] refactor: dimos dataprep subcommand with build and inspect --- dimos/learning/collection/recorder.py | 2 +- dimos/learning/dataprep/blueprint.py | 61 ----- dimos/learning/dataprep/build.py | 181 +++++++++++++ dimos/learning/dataprep/cli.py | 104 +++++++ dimos/learning/dataprep/core.py | 106 ++++---- dimos/learning/dataprep/example_config.json | 20 ++ dimos/learning/dataprep/formats/hdf5.py | 52 ++++ dimos/learning/dataprep/formats/lerobot.py | 44 +++ dimos/learning/dataprep/module.py | 284 -------------------- dimos/robot/all_blueprints.py | 2 - dimos/robot/cli/dimos.py | 34 +++ 11 files changed, 482 insertions(+), 408 deletions(-) delete mode 100644 dimos/learning/dataprep/blueprint.py create mode 100644 dimos/learning/dataprep/build.py create mode 100644 dimos/learning/dataprep/cli.py create mode 100644 dimos/learning/dataprep/example_config.json delete mode 100644 dimos/learning/dataprep/module.py diff --git a/dimos/learning/collection/recorder.py b/dimos/learning/collection/recorder.py index ad6183c27d..f8c1b0f45d 100644 --- a/dimos/learning/collection/recorder.py +++ b/dimos/learning/collection/recorder.py @@ -19,7 +19,7 @@ streams are recorded, so the same recorder works for any arm whose coordinator publishes `joint_state`. -The recorded stream names match what `DataPrepModule` reads: `color_image` +The recorded stream names match what DataPrep reads: `color_image` and `joint_state` (observation), `status` (episode segmentation). """ diff --git a/dimos/learning/dataprep/blueprint.py b/dimos/learning/dataprep/blueprint.py deleted file mode 100644 index ee01de52e2..0000000000 --- a/dimos/learning/dataprep/blueprint.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dataset-build blueprints. - -Wraps `DataPrepModule` so users can run:: - - dimos run learning-dataprep - dimos run learning-dataprep -o dataprepmodule.source=data/recordings/foo.db \\ - -o dataprepmodule.output.path=data/datasets/foo - -The defaults below target the included pickplace_001 demo. Episodes are -always segmented from the recording (the `episode_status` stream or -explicit `ranges`) — we never collapse a session into a single episode. -""" - -from __future__ import annotations - -from dimos.core.coordination.blueprints import autoconnect -from dimos.learning.dataprep.core import ( - EpisodeExtractor, - OutputConfig, - StreamField, - SyncConfig, -) -from dimos.learning.dataprep.module import DataPrepModule - -learning_dataprep = autoconnect( - DataPrepModule.blueprint( - source="data/recordings/pickplace_001.db", - episodes=EpisodeExtractor(), - observation={ - "image": StreamField(stream="color_image", field="data"), - "joint_state": StreamField(stream="joint_state", field="position"), - }, - action={ - "joint_target": StreamField(stream="joint_state", field="position"), - }, - sync=SyncConfig(anchor="image", rate_hz=14.0, tolerance_ms=80.0), - output=OutputConfig( - format="lerobot", - path="data/datasets/pickplace_001", - metadata={"fps": 14, "robot": "xarm7", "default_task_label": "pick_and_place"}, - ), - auto_run=True, - ), -).transports({}) - - -__all__ = ["learning_dataprep"] diff --git a/dimos/learning/dataprep/build.py b/dimos/learning/dataprep/build.py new file mode 100644 index 0000000000..2cf9eb53c9 --- /dev/null +++ b/dimos/learning/dataprep/build.py @@ -0,0 +1,181 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataPrep build orchestration — the impure layer over `core.py`. + +`run_dataprep` (build) and `inspect_dataset` (read-back) own the I/O and side +effects — open/close the store, drive the writer/reader, emit logs, write +files; they compose the pure helpers in `core.py` and the per-format +readers/writers. Exposed by the `dimos dataprep` subcommand. +""" + +from __future__ import annotations + +from collections.abc import Iterator +import json +from pathlib import Path +from typing import Any + +from dimos.learning.dataprep.core import ( + DataPrepConfig, + Episode, + Sample, + extract_episodes, + get_writer, + iter_episode_samples, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def _write_dimos_meta(dataset_path: Path, config: DataPrepConfig, episodes: list[Episode]) -> None: + """Sidecar describing how this dataset was built, recording the obs/action + schema alongside the dataset.""" + meta = { + "source": config.source, + "observation": {k: v.model_dump() for k, v in config.observation.items()}, + "action": {k: v.model_dump() for k, v in config.action.items()}, + "sync": config.sync.model_dump(), + "episodes": [ + { + "id": e.id, + "start_ts": e.start_ts, + "end_ts": e.end_ts, + "task_label": e.task_label, + "success": e.success, + } + for e in episodes + ], + "format": config.output.format, + "metadata": config.output.metadata, + } + with open(dataset_path / "dimos_meta.json", "w") as f: + json.dump(meta, f, indent=2, default=str) + + +def run_dataprep(config: DataPrepConfig) -> Path: + """Build a dataset from a recording and return the dataset path. + + Opens the source store, extracts episodes, streams samples through the + configured format writer, and writes `dimos_meta.json`. Synchronous — + raises on failure so the caller owns the exit code. + """ + from dimos.memory2.store.sqlite import SqliteStore + + logger.info( + "[dataprep] starting build source=%s extractor=%s output=%s", + config.source, + config.episodes.extractor, + config.output.path, + ) + store = SqliteStore(path=config.source, must_exist=True) + try: + logger.info("[dataprep] streams in source: %s", store.list_streams()) + all_eps = extract_episodes(store, config.episodes) + episodes = [e for e in all_eps if e.success] + logger.info( + "[dataprep] episodes extracted: %d total / %d successful", + len(all_eps), + len(episodes), + ) + + if not episodes: + raise RuntimeError( + f"No successful episodes extracted from {config.source!r} " + f"using extractor={config.episodes.extractor!r}. " + f"Available streams: {store.list_streams()}. " + f"For a recording with no episode_status stream, set " + f"extractor='ranges' with explicit (start, end) tuples." + ) + + streams = {**config.observation, **config.action} + obs_keys = set(config.observation) + action_keys = set(config.action) + logger.info( + "[dataprep] obs streams=%s action streams=%s sync=%s", + sorted(obs_keys), + sorted(action_keys), + config.sync.model_dump(), + ) + + writer = get_writer(config.output.format) + logger.info( + "[dataprep] writing %s dataset to %s", config.output.format, config.output.path + ) + + samples_seen = 0 + episodes_done = 0 + total = len(episodes) + + def _all_samples() -> Iterator[Sample]: + nonlocal samples_seen, episodes_done + for ep in episodes: + for sample in iter_episode_samples( + store=store, + episode=ep, + streams=streams, + sync=config.sync, + obs_keys=obs_keys, + action_keys=action_keys, + ): + samples_seen += 1 + if samples_seen % 50 == 0: + logger.info( + "[dataprep] %.1f%% samples=%d ep %d/%d", + 100.0 * episodes_done / total, + samples_seen, + episodes_done, + total, + ) + yield sample + episodes_done += 1 + + dataset_path = Path(writer(_all_samples(), config.output)) + _write_dimos_meta(dataset_path, config, episodes) + logger.info( + "[dataprep] succeeded — wrote %d samples across %d episodes to %s", + samples_seen, + total, + dataset_path, + ) + return dataset_path + finally: + store.stop() + + +def inspect_dataset(path: Path | str, fmt: str | None = None) -> dict[str, Any]: + """Summarize a built dataset: observation/action features (shape + dtype), + episode/frame counts, and whether shapes/lengths are uniform. + + `fmt` is auto-detected when omitted: a `.hdf5`/`.h5` file → hdf5; a + directory containing `meta/info.json` → lerobot. + """ + from dimos.learning.dataprep.core import get_inspector + + p = Path(path) + if fmt is None: + if p.suffix in (".h5", ".hdf5"): + fmt = "hdf5" + elif (p / "meta" / "info.json").exists(): + fmt = "lerobot" + else: + raise ValueError( + f"Cannot detect dataset format at {p}: expected a .hdf5 file or a " + f"lerobot directory with meta/info.json. Pass --format explicitly." + ) + return get_inspector(fmt)(p) + + +__all__ = ["inspect_dataset", "run_dataprep"] diff --git a/dimos/learning/dataprep/cli.py b/dimos/learning/dataprep/cli.py new file mode 100644 index 0000000000..fe6c5c81f2 --- /dev/null +++ b/dimos/learning/dataprep/cli.py @@ -0,0 +1,104 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of the `dimos dataprep` subcommand (build + inspect). + +DataPrep is a one-shot batch transform, not a long-lived module, so it runs +as a plain command over the pure helpers in `dimos.learning.dataprep.core` +and exits with a 0/1 status — no coordinator, no blocking loop. + +The obs/action stream maps are nested, so they come from a JSON +`DataPrepConfig` via `--config`; simple flags override `source`/`output`/ +`format` on top. See `dimos/learning/dataprep/example_config.json`. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import TYPE_CHECKING + +import typer + +if TYPE_CHECKING: + from dimos.learning.dataprep.core import DataPrepConfig + + +def _load_config( + config_path: Path | None, + source: Path | None, + output: Path | None, + output_format: str | None, +) -> DataPrepConfig: + """Build a DataPrepConfig from an optional JSON file + flag overrides.""" + from dimos.learning.dataprep.core import DataPrepConfig, OutputConfig + + if config_path is not None: + cfg = DataPrepConfig.model_validate_json(Path(config_path).read_text()) + else: + cfg = DataPrepConfig() + + updates: dict[str, object] = {} + if source is not None: + updates["source"] = str(source) + if output is not None or output_format is not None: + updates["output"] = OutputConfig( + format=output_format or cfg.output.format, + path=output or cfg.output.path, + metadata=cfg.output.metadata, + ) + return cfg.model_copy(update=updates) if updates else cfg + + +def build( + config_path: Path | None, + source: Path | None, + output: Path | None, + output_format: str | None, +) -> None: + from dimos.learning.dataprep.build import run_dataprep + + cfg = _load_config(config_path, source, output, output_format) + if not cfg.source: + typer.echo("error: no source given (use --source or set it in --config)", err=True) + raise typer.Exit(2) + if not cfg.observation and not cfg.action: + typer.echo( + "error: no observation/action streams configured; pass --config with the " + "stream maps (see dimos/learning/dataprep/example_config.json)", + err=True, + ) + raise typer.Exit(2) + + try: + path = run_dataprep(cfg) + except Exception as e: + typer.echo(f"dataprep build failed: {e}", err=True) + raise typer.Exit(1) + typer.echo(f"✓ wrote dataset to {path}") + + +def inspect(dataset: Path | None, output_format: str | None) -> None: + from dimos.learning.dataprep.build import inspect_dataset + + if dataset is None: + typer.echo("error: no dataset given (pass a .hdf5 file or a lerobot directory)", err=True) + raise typer.Exit(2) + + try: + info = inspect_dataset(dataset, output_format) + except Exception as e: + typer.echo(f"dataprep inspect failed: {e}", err=True) + raise typer.Exit(1) + typer.echo(json.dumps(info, indent=2, default=str)) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index dced77ef9a..263a886511 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -16,11 +16,12 @@ Sub-configs (StreamField, SyncConfig, OutputConfig, EpisodeExtractor) and data records (Episode, Sample) live here. So do the stateless functions -that walk samples — `resolve_field`, `compute_stats`, `extract_episodes`, -`iter_episode_samples`. Importable without booting a Module. +that walk samples — `resolve_field`, `extract_episodes`, +`iter_episode_samples`. Pure and side-effect-free; importable without +booting a Module. -`DataPrepModule` (in `dataprep_module.py`) is a thin wrapper that runs -these helpers on a thread. +The impure orchestration that composes these (opening the store, driving +the writer, writing files) lives in `build.py`. """ from __future__ import annotations @@ -47,7 +48,7 @@ class EpisodeExtractor(BaseConfig): - extractor: Literal["episode_status", "ranges", "whole_session"] = "episode_status" + extractor: Literal["episode_status", "ranges"] = "episode_status" status_stream: str = "episode_status" ranges: list[tuple[float, float]] | None = None @@ -70,6 +71,22 @@ class OutputConfig(BaseConfig): metadata: dict[str, Any] = {} +class DataPrepConfig(BaseConfig): + """Everything needed to turn a recording into a dataset. + + `source` is a recording `.db`; `observation`/`action` map dataset feature + names to recorded streams; `sync` resamples them onto a common timeline; + `output` selects format + path. Consumed by `build.run_dataprep`. + """ + + source: str = "" + episodes: EpisodeExtractor = EpisodeExtractor() + observation: dict[str, StreamField] = {} + action: dict[str, StreamField] = {} + sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) + output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") + + # ───────────────────────────────────────────────────────────────────────────── # Data records # ───────────────────────────────────────────────────────────────────────────── @@ -83,10 +100,6 @@ class Episode(BaseModel): success: bool = True metadata: dict[str, Any] = Field(default_factory=dict) - @property - def duration(self) -> float: - return self.end_ts - self.start_ts - class Sample(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -98,7 +111,7 @@ class Sample(BaseModel): # ───────────────────────────────────────────────────────────────────────────── -# Pure helpers — used by format writers, DataPrepModule +# Pure helpers — used by format writers and run_dataprep # ───────────────────────────────────────────────────────────────────────────── @@ -138,8 +151,6 @@ def extract_episodes(store: SqliteStore, cfg: EpisodeExtractor) -> list[Episode] end of stream with pending: dropped (matches live spec) RANGES: emit one Episode per (start, end) tuple in `cfg.ranges`. - - WHOLE_SESSION: one Episode covering the full time range of every stream. """ if cfg.extractor == "ranges": if not cfg.ranges: @@ -149,25 +160,6 @@ def extract_episodes(store: SqliteStore, cfg: EpisodeExtractor) -> list[Episode] for i, (t0, t1) in enumerate(cfg.ranges) ] - if cfg.extractor == "whole_session": - # Span every stream's time range. - names = store.list_streams() - if not names: - return [] - starts: list[float] = [] - ends: list[float] = [] - for name in names: - try: - stream = store.stream(name) - t0, t1 = stream.get_time_range() - starts.append(t0) - ends.append(t1) - except Exception: - continue - if not starts: - return [] - return [Episode(id="ep_000000", start_ts=min(starts), end_ts=max(ends))] - # episode_status (default) status_stream = store.stream(cfg.status_stream) events = list(status_stream) # observations in storage order @@ -314,35 +306,6 @@ def _nearest(key: str, t: float) -> Any | None: yield Sample(ts=t, episode_id=episode.id, observation=obs_dict, action=act_dict) -def compute_stats( - samples: Iterator[Sample], - image_subsample: int = 10, - quantile_reservoir: int = 10_000, - seed: int = 0, -) -> dict[str, Any]: - """Single-pass per-feature stats over a Sample iterator. - - Output schema matches LeRobot v2 ``stats.json``:: - - { "observation.": {"mean", "std", "min", "max", "q01", "q99"}, - "action.": {...} } - - Thin wrapper over :class:`StreamingStats` so format writers and - ad-hoc callers share the exact same accumulator. - """ - from dimos.learning.dataprep.formats._stats import StreamingStats - - s = StreamingStats( - image_subsample=image_subsample, quantile_reservoir=quantile_reservoir, seed=seed - ) - for sample in samples: - for k, v in sample.observation.items(): - s.update(f"observation.{k}", np.asarray(v)) - for k, v in sample.action.items(): - s.update(f"action.{k}", np.asarray(v)) - return s.finalize() - - def get_writer(format_name: str) -> Writer: """Lazy-import the format writer's `write` function.""" if format_name == "lerobot": @@ -352,3 +315,26 @@ def get_writer(format_name: str) -> Writer: else: raise ValueError(f"Unknown format: {format_name!r}") return write + + +def get_inspector(format_name: str) -> Callable[[Path], dict[str, Any]]: + """Lazy-import the format reader's `inspect` function.""" + if format_name == "lerobot": + from dimos.learning.dataprep.formats.lerobot import inspect + elif format_name == "hdf5": + from dimos.learning.dataprep.formats.hdf5 import inspect + else: + raise ValueError(f"Unknown format: {format_name!r}") + return inspect + + +def summarize_lengths(lengths: list[int]) -> dict[str, Any]: + """Min/max/mean of per-episode frame counts + whether they're all equal.""" + if not lengths: + return {"min": 0, "max": 0, "mean": 0.0, "uniform": True} + return { + "min": min(lengths), + "max": max(lengths), + "mean": sum(lengths) / len(lengths), + "uniform": min(lengths) == max(lengths), + } diff --git a/dimos/learning/dataprep/example_config.json b/dimos/learning/dataprep/example_config.json new file mode 100644 index 0000000000..17c3bbcebd --- /dev/null +++ b/dimos/learning/dataprep/example_config.json @@ -0,0 +1,20 @@ +{ + "source": "data/recordings/session.db", + "episodes": { + "extractor": "episode_status", + "status_stream": "status" + }, + "observation": { + "image": {"stream": "color_image", "field": "data"}, + "joint_state": {"stream": "joint_state", "field": "position"} + }, + "action": { + "joint_target": {"stream": "joint_state", "field": "position"} + }, + "sync": {"anchor": "image", "rate_hz": 14.0, "tolerance_ms": 80.0}, + "output": { + "format": "lerobot", + "path": "data/datasets/session", + "metadata": {"fps": 14, "robot": "xarm7", "default_task_label": "pick_and_place"} + } +} diff --git a/dimos/learning/dataprep/formats/hdf5.py b/dimos/learning/dataprep/formats/hdf5.py index 3cb5d3cd1a..e81de0fbd3 100644 --- a/dimos/learning/dataprep/formats/hdf5.py +++ b/dimos/learning/dataprep/formats/hdf5.py @@ -35,6 +35,7 @@ from collections.abc import Iterator from pathlib import Path +from typing import Any import numpy as np @@ -143,3 +144,54 @@ def _flush() -> None: g.create_dataset(k, data=np.asarray(entry[k], dtype=np.float64)) return out + + +def inspect(path: Path) -> dict[str, Any]: + """Summarize an .hdf5 dataset: features (per-frame shape/dtype), episode + counts, and whether feature shapes are uniform across episodes.""" + try: + import h5py + except ImportError as e: + raise RuntimeError("HDF5 inspect requires h5py — install with `pip install h5py`") from e + + from dimos.learning.dataprep.core import summarize_lengths + + out = Path(path) + with h5py.File(out, "r") as h5: + eps_g = h5["episodes"] + ep_names = sorted(eps_g.keys()) + lengths = [int(eps_g[e].attrs.get("length", 0)) for e in ep_names] + + observation: dict[str, Any] = {} + action: dict[str, Any] = {} + # Feature schema from the first episode (per-frame shape = dataset.shape[1:]). + if ep_names: + first = eps_g[ep_names[0]] + for grp, ref in (("observation", observation), ("action", action)): + if grp in first: + for k, d in first[grp].items(): + ref[k] = {"shape": list(d.shape[1:]), "dtype": str(d.dtype)} + + # Are per-frame shapes consistent across every episode? + shapes_uniform = True + for e in ep_names[1:]: + g = eps_g[e] + for grp, ref in (("observation", observation), ("action", action)): + if grp in g: + for k, d in g[grp].items(): + if k in ref and list(d.shape[1:]) != ref[k]["shape"]: + shapes_uniform = False + + return { + "format": "hdf5", + "path": str(out), + "episodes": int(h5.attrs.get("num_episodes", len(ep_names))), + "frames": int(h5.attrs.get("num_frames", sum(lengths))), + "fps": float(h5.attrs.get("fps", 0.0)), + "robot": str(h5.attrs.get("robot", "unknown")), + "observation": observation, + "action": action, + "episode_lengths": summarize_lengths(lengths), + "shapes_uniform": shapes_uniform, + "has_stats": "stats" in h5, + } diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 89d0e0de06..cc79d007ed 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -312,3 +312,47 @@ def _flush_episode() -> None: json.dump(stats.finalize(), f, indent=2) return root + + +_META_COLS = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + + +def inspect(path: Path) -> dict[str, Any]: + """Summarize a LeRobot v2 dataset from its meta/ files (no parquet load).""" + from dimos.learning.dataprep.core import summarize_lengths + + root = Path(path) + info = json.loads((root / META_DIR / "info.json").read_text()) + features = info.get("features", {}) + + observation: dict[str, Any] = {} + action: dict[str, Any] = {} + for name, feat in features.items(): + if name in _META_COLS: + continue + entry = {"shape": feat.get("shape"), "dtype": feat.get("dtype")} + if name.startswith("observation"): + observation[name] = entry + elif name.startswith("action"): + action[name] = entry + + lengths: list[int] = [] + ep_path = root / META_DIR / "episodes.jsonl" + if ep_path.exists(): + for line in ep_path.read_text().splitlines(): + if line.strip(): + lengths.append(int(json.loads(line).get("length", 0))) + + return { + "format": "lerobot", + "path": str(root), + "episodes": info.get("total_episodes"), + "frames": info.get("total_frames"), + "fps": info.get("fps"), + "robot": info.get("robot_type"), + "observation": observation, + "action": action, + "episode_lengths": summarize_lengths(lengths), + "shapes_uniform": True, # LeRobot declares one global feature schema + "has_stats": (root / META_DIR / "stats.json").exists(), + } diff --git a/dimos/learning/dataprep/module.py b/dimos/learning/dataprep/module.py deleted file mode 100644 index 75c1c91ce8..0000000000 --- a/dimos/learning/dataprep/module.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DataPrepModule — wraps the dataprep pipeline as a Module with RPC surface. - -All dataset-shape types and pure helpers live in `dataprep.py`. This file -just adds the Module lifecycle + thread + status tracking. -""" - -from __future__ import annotations - -from collections.abc import Iterator -import json -from pathlib import Path -import threading -import traceback -from typing import Any - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.learning.dataprep.core import ( - EpisodeExtractor, - OutputConfig, - Sample, - StreamField, - SyncConfig, - extract_episodes, - get_writer, - iter_episode_samples, -) -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - - -class DataPrepModuleConfig(ModuleConfig): - # Fields are defaulted so partial CLI overrides (e.g. just `source=...`) - # pass blueprint validation; blueprint atoms supply real values. - source: str = "" - episodes: EpisodeExtractor = EpisodeExtractor() - observation: dict[str, StreamField] = {} - action: dict[str, StreamField] = {} - sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) - output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") - auto_run: bool = False - - -class DataPrepModule(Module): - """Wraps a long-running dataset build job.""" - - config: DataPrepModuleConfig - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._thread: threading.Thread | None = None - self._lock = threading.Lock() - self._status: dict[str, Any] = { - "state": "idle", # idle | running | succeeded | failed - "current_phase": None, # scan_episodes | write | done - "progress_pct": 0.0, - "dataset_path": None, - "error": None, - "episodes_seen": 0, - "samples_seen": 0, - } - - # ── lifecycle ──────────────────────────────────────────────────────────── - - @rpc - def start(self) -> None: - super().start() - if self.config.auto_run: - self.build() - - @rpc - def stop(self) -> None: - # Build thread is daemon: dies with the process. No mid-iteration interrupt. - super().stop() - - @rpc - def build(self) -> None: - """Spawn a daemon thread running the build pipeline. Returns immediately.""" - with self._lock: - if self._status["state"] == "running": - return - self._status.update( - state="running", - current_phase=None, - progress_pct=0.0, - dataset_path=None, - error=None, - episodes_seen=0, - samples_seen=0, - ) - self._thread = threading.Thread(target=self._run_build, daemon=True) - self._thread.start() - - @rpc - def get_status(self) -> dict[str, Any]: - with self._lock: - return dict(self._status) - - @rpc - def inspect(self) -> dict[str, Any]: - """Read-only summary: episode count, drop rates, joint names, stats presence.""" - from dimos.memory2.store.sqlite import SqliteStore - - store = SqliteStore(path=self.config.source, must_exist=True) - try: - episodes = extract_episodes(store, self.config.episodes) - saved = sum(1 for e in episodes if e.success) - dropped = sum(1 for e in episodes if not e.success) - durations = [e.duration for e in episodes if e.success] - return { - "source": self.config.source, - "streams": store.list_streams(), - "episodes_saved": saved, - "episodes_dropped": dropped, - "duration_s": { - "min": min(durations) if durations else 0.0, - "max": max(durations) if durations else 0.0, - "mean": (sum(durations) / len(durations)) if durations else 0.0, - }, - } - finally: - store.stop() - - # ── internals ──────────────────────────────────────────────────────────── - - def _run_build(self) -> None: - """Thread target. Opens session.db, walks samples episode-by-episode, - drives the format writer, snapshots config to /dimos_meta.json. - Updates _status under _lock. - """ - try: - logger.info( - "[dataprep] starting build source=%s extractor=%s output=%s", - self.config.source, - self.config.episodes.extractor, - self.config.output.path, - ) - self._update_status(current_phase="scan_episodes") - - from dimos.memory2.store.sqlite import SqliteStore - - store = SqliteStore(path=self.config.source, must_exist=True) - try: - logger.info("[dataprep] streams in source: %s", store.list_streams()) - all_eps = extract_episodes(store, self.config.episodes) - episodes = [e for e in all_eps if e.success] - logger.info( - "[dataprep] episodes extracted: %d total / %d successful", - len(all_eps), - len(episodes), - ) - self._update_status(episodes_seen=len(episodes)) - - if not episodes: - raise RuntimeError( - f"No successful episodes extracted from {self.config.source!r} " - f"using extractor={self.config.episodes.extractor!r}. " - f"Available streams: {store.list_streams()}. " - f"For a single-demo .db with no episode_status stream, use " - f"extractor='whole_session' or 'ranges'." - ) - - streams = {**self.config.observation, **self.config.action} - obs_keys = set(self.config.observation) - action_keys = set(self.config.action) - logger.info( - "[dataprep] obs streams=%s action streams=%s sync=%s", - sorted(obs_keys), - sorted(action_keys), - self.config.sync.model_dump(), - ) - - writer = get_writer(self.config.output.format) - - self._update_status(current_phase="write") - logger.info( - "[dataprep] writing %s dataset to %s", - self.config.output.format, - self.config.output.path, - ) - - samples_seen = 0 - episodes_done = 0 - total = len(episodes) - - def _all_samples() -> Iterator[Sample]: - nonlocal samples_seen, episodes_done - for ep in episodes: - for sample in iter_episode_samples( - store=store, - episode=ep, - streams=streams, - sync=self.config.sync, - obs_keys=obs_keys, - action_keys=action_keys, - ): - samples_seen += 1 - if samples_seen % 50 == 0: - self._update_status( - samples_seen=samples_seen, - progress_pct=100.0 * episodes_done / total, - ) - logger.info( - "[dataprep] %.1f%% samples=%d ep %d/%d", - 100.0 * episodes_done / total, - samples_seen, - episodes_done, - total, - ) - yield sample - episodes_done += 1 - self._update_status( - samples_seen=samples_seen, - progress_pct=100.0 * episodes_done / total, - ) - - dataset_path = writer(_all_samples(), self.config.output) - - self._write_dimos_meta(Path(dataset_path), episodes) - - self._update_status( - state="succeeded", - current_phase="done", - progress_pct=100.0, - dataset_path=str(dataset_path), - ) - logger.info( - "[dataprep] succeeded — wrote %d samples across %d episodes to %s", - samples_seen, - total, - dataset_path, - ) - finally: - store.stop() - except Exception as e: - err = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" - self._update_status(state="failed", error=err) - logger.error("[dataprep] FAILED: %s", err) - - def _write_dimos_meta(self, dataset_path: Path, episodes: list[Any]) -> None: - """Sidecar describing how this dataset was built, recording the - obs/action schema alongside the dataset.""" - meta = { - "source": self.config.source, - "observation": {k: v.model_dump() for k, v in self.config.observation.items()}, - "action": {k: v.model_dump() for k, v in self.config.action.items()}, - "sync": self.config.sync.model_dump(), - "episodes": [ - { - "id": e.id, - "start_ts": e.start_ts, - "end_ts": e.end_ts, - "task_label": e.task_label, - "success": e.success, - } - for e in episodes - ], - "format": self.config.output.format, - "metadata": self.config.output.metadata, - } - with open(dataset_path / "dimos_meta.json", "w") as f: - json.dump(meta, f, indent=2, default=str) - - def _update_status(self, **kwargs: Any) -> None: - with self._lock: - self._status.update(kwargs) - - -__all__ = ["DataPrepModule", "DataPrepModuleConfig"] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index fd987568b2..0a0b7383a9 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -66,7 +66,6 @@ "keyboard-teleop-xarm7": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm7", "learning-collect-quest-piper": "dimos.learning.collection.blueprint:learning_collect_quest_piper", "learning-collect-quest-xarm7": "dimos.learning.collection.blueprint:learning_collect_quest_xarm7", - "learning-dataprep": "dimos.learning.dataprep.blueprint:learning_dataprep", "mid360": "dimos.hardware.sensors.lidar.livox.livox_blueprints:mid360", "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", "mid360-fastlio-ray-trace": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_ray_trace", @@ -139,7 +138,6 @@ "click-start-goal-router": "dimos.navigation.nav_stack.modules.click_start_goal_router.click_start_goal_router.ClickStartGoalRouter", "control-coordinator": "dimos.control.coordinator.ControlCoordinator", "cost-mapper": "dimos.mapping.costmapper.CostMapper", - "data-prep-module": "dimos.learning.dataprep.module.DataPrepModule", "demo-calculator-skill": "dimos.agents.skills.demo_calculator_skill.DemoCalculatorSkill", "demo-monitoring": "dimos.agents.demos.demo_capabilities.DemoMonitoring", "demo-robot": "dimos.agents.skills.demo_robot.DemoRobot", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 3f94a6be4e..6c80ea88d5 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -681,6 +681,40 @@ def send( map_app = typer.Typer(help="Voxel-map tools over recorded sqlite datasets") main.add_typer(map_app, name="map") map_app.command("global")(_map_main) + + +dataprep_app = typer.Typer(help="Build and inspect learning datasets from recordings") +main.add_typer(dataprep_app, name="dataprep") + + +@dataprep_app.command("build") +def dataprep_build( + source: Path = typer.Option(None, "--source", "-s", help="Recording .db to read"), + output: Path = typer.Option(None, "--output", help="Dataset output directory"), + output_format: str = typer.Option(None, "--format", "-f", help="Output format: lerobot | hdf5"), + config_path: Path = typer.Option( + None, "--config", "-c", help="JSON DataPrepConfig (needed for obs/action stream maps)" + ), +) -> None: + """Build a dataset from a recording (lerobot/hdf5 + dimos_meta.json).""" + from dimos.learning.dataprep.cli import build + + build(config_path, source, output, output_format) + + +@dataprep_app.command("inspect") +def dataprep_inspect( + dataset: Path = typer.Argument(None, help="Built dataset: a .hdf5 file or a lerobot directory"), + output_format: str = typer.Option( + None, "--format", "-f", help="lerobot | hdf5 (auto-detected from the path if omitted)" + ), +) -> None: + """Summarize a built dataset: features, shapes, episode/frame counts, uniformity.""" + from dimos.learning.dataprep.cli import inspect + + inspect(dataset, output_format) + + map_app.command("summary")(_map_summary_main) map_app.command("rename")(_map_rename_main) map_app.command("pose-fill")(_map_pose_fill_main) From 45bebf633a4fd3108e28b15693ca30a7630f35a0 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Tue, 9 Jun 2026 18:50:00 -0700 Subject: [PATCH 14/45] fix: pre-commit fixes --- dimos/learning/dataprep/build.py | 4 +-- dimos/learning/dataprep/example_config.json | 27 +++++++++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/dimos/learning/dataprep/build.py b/dimos/learning/dataprep/build.py index 2cf9eb53c9..dc02ce47c9 100644 --- a/dimos/learning/dataprep/build.py +++ b/dimos/learning/dataprep/build.py @@ -111,9 +111,7 @@ def run_dataprep(config: DataPrepConfig) -> Path: ) writer = get_writer(config.output.format) - logger.info( - "[dataprep] writing %s dataset to %s", config.output.format, config.output.path - ) + logger.info("[dataprep] writing %s dataset to %s", config.output.format, config.output.path) samples_seen = 0 episodes_done = 0 diff --git a/dimos/learning/dataprep/example_config.json b/dimos/learning/dataprep/example_config.json index 17c3bbcebd..ed582caf8e 100644 --- a/dimos/learning/dataprep/example_config.json +++ b/dimos/learning/dataprep/example_config.json @@ -5,16 +5,33 @@ "status_stream": "status" }, "observation": { - "image": {"stream": "color_image", "field": "data"}, - "joint_state": {"stream": "joint_state", "field": "position"} + "image": { + "stream": "color_image", + "field": "data" + }, + "joint_state": { + "stream": "joint_state", + "field": "position" + } }, "action": { - "joint_target": {"stream": "joint_state", "field": "position"} + "joint_target": { + "stream": "joint_state", + "field": "position" + } + }, + "sync": { + "anchor": "image", + "rate_hz": 14.0, + "tolerance_ms": 80.0 }, - "sync": {"anchor": "image", "rate_hz": 14.0, "tolerance_ms": 80.0}, "output": { "format": "lerobot", "path": "data/datasets/session", - "metadata": {"fps": 14, "robot": "xarm7", "default_task_label": "pick_and_place"} + "metadata": { + "fps": 14, + "robot": "xarm7", + "default_task_label": "pick_and_place" + } } } From 65058374c8021291094e108aa8057450feec8d58 Mon Sep 17 00:00:00 2001 From: Ruthwik Date: Mon, 15 Jun 2026 16:24:07 -0700 Subject: [PATCH 15/45] feat: live logs of episode status --- dimos/learning/collection/episode_monitor.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index 3f9edd1b83..f28e888d10 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -33,6 +33,9 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.teleop.quest.quest_types import Buttons +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() # Friendly names → Quest Buttons attribute names. Override by supplying an # attribute name directly in `button_map`. @@ -182,4 +185,24 @@ def _publish(self, last_event: Literal["start", "save", "discard", "init"]) -> E task_label=self.config.default_task_label, ) self.status.publish(status) + self._log_status(status) return status + + def _log_status(self, status: EpisodeStatus) -> None: + """Print a one-line operator-facing status to the terminal on every + transition — the only live feedback during a collection session.""" + verb = { + "start": "▶ RECORDING episode", + "save": "✓ SAVED episode", + "discard": "✗ DISCARDED episode", + "init": "· ready", + }.get(status.last_event, status.last_event) + label = f" [{status.task_label}]" if status.task_label else "" + logger.info( + "[collect] %s%s (state=%s saved=%d discarded=%d)", + verb, + label, + status.state, + status.episodes_saved, + status.episodes_discarded, + ) From 76f67192ef9d4943d2941728720ca98da3599ca7 Mon Sep 17 00:00:00 2001 From: Ruthwik Date: Mon, 15 Jun 2026 17:24:55 -0700 Subject: [PATCH 16/45] =?UTF-8?q?fix:=20dataprep=20status=5Fstream=20defau?= =?UTF-8?q?lt,=20rgb=E2=86=92bgr,=20drop=20button=20recording?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dimos/learning/collection/blueprint.py | 10 +- dimos/learning/collection/episode_monitor.py | 3 +- dimos/learning/collection/recorder.py | 2 - dimos/learning/dataprep/core.py | 31 +- dimos/learning/dataprep/formats/lerobot.py | 8 +- dimos/learning/dataprep/test_core.py | 308 +++++++++++++++++++ 6 files changed, 349 insertions(+), 13 deletions(-) create mode 100644 dimos/learning/dataprep/test_core.py diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index b3ae3cec5a..d50bbaa7bc 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -30,6 +30,7 @@ ) from dimos.learning.collection.recorder import CollectionRecorder from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.blueprints import ( teleop_quest_piper, teleop_quest_xarm7, @@ -39,10 +40,9 @@ _DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} -# Transports are written inline per blueprint (not factored into a shared -# variable) so each recording config is self-contained and readable on its -# own: buttons drive the episode state machine, color_image is the camera -# stream, and status carries the canonical EpisodeStatus that DataPrep reads. +# Transports inline per blueprint so each recording config is self-contained. +# joint_state is declared explicitly (not left to autoconnect) so it keeps +# recording if the recorder moves to its own process. learning_collect_quest_xarm7 = autoconnect( teleop_quest_xarm7, RealSenseCamera.blueprint(enable_pointcloud=False), @@ -52,6 +52,7 @@ { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), } ) @@ -66,6 +67,7 @@ { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), } ) diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index f28e888d10..b0722798da 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -189,8 +189,7 @@ def _publish(self, last_event: Literal["start", "save", "discard", "init"]) -> E return status def _log_status(self, status: EpisodeStatus) -> None: - """Print a one-line operator-facing status to the terminal on every - transition — the only live feedback during a collection session.""" + """One-line operator feedback to the terminal on every transition.""" verb = { "start": "▶ RECORDING episode", "save": "✓ SAVED episode", diff --git a/dimos/learning/collection/recorder.py b/dimos/learning/collection/recorder.py index f8c1b0f45d..3e205074db 100644 --- a/dimos/learning/collection/recorder.py +++ b/dimos/learning/collection/recorder.py @@ -32,7 +32,6 @@ from dimos.memory2.module import Recorder, RecorderConfig from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.teleop.quest.quest_types import Buttons class CollectionRecorderConfig(RecorderConfig): @@ -47,7 +46,6 @@ class CollectionRecorder(Recorder): color_image: In[Image] # observation (camera) joint_state: In[JointState] # observation + action (measured/next state) status: In[EpisodeStatus] # episode start/save/discard segmentation - buttons: In[Buttons] # raw teleop input (kept for debugging) __all__ = ["CollectionRecorder", "CollectionRecorderConfig"] diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 263a886511..ab439ac384 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -49,7 +49,9 @@ class EpisodeExtractor(BaseConfig): extractor: Literal["episode_status", "ranges"] = "episode_status" - status_stream: str = "episode_status" + # Recorded stream name for EpisodeStatus events. Must match the recorder's + # `status` In port (CollectionRecorder records it as "status"). + status_stream: str = "status" ranges: list[tuple[float, float]] | None = None @@ -63,6 +65,9 @@ class SyncConfig(BaseConfig): rate_hz: float tolerance_ms: float strategy: Literal["nearest", "interp"] = "nearest" + # Action = state this many frames ahead (default 1 = next state, for BC). + # Trailing frames with no successor are dropped. Set 0 for action==state. + action_shift: int = 1 class OutputConfig(BaseConfig): @@ -225,6 +230,9 @@ def iter_episode_samples( `obs_keys` / `action_keys` partition `streams` into observation vs action. If omitted, every key is treated as observation (used by callers that only need raw aligned data). + + With `sync.action_shift > 0` (default 1), each frame's action is taken + `action_shift` frames later (next-state target); the tail is dropped. """ if sync.anchor not in streams: raise ValueError(f"sync.anchor {sync.anchor!r} not in streams: {sorted(streams)}") @@ -287,6 +295,9 @@ def _nearest(key: str, t: float) -> Any | None: return None return msg_list[best] + # Buffer the episode in order; the action shift below pairs obs[i] with + # action[i + shift]. + frames: list[Sample] = [] for t in targets: obs_dict: dict[str, np.ndarray] = {} act_dict: dict[str, np.ndarray] = {} @@ -303,7 +314,23 @@ def _nearest(key: str, t: float) -> Any | None: obs_dict[key] = arr if skip: continue - yield Sample(ts=t, episode_id=episode.id, observation=obs_dict, action=act_dict) + frames.append(Sample(ts=t, episode_id=episode.id, observation=obs_dict, action=act_dict)) + + shift = max(0, sync.action_shift) + if shift == 0 or not action_keys: + yield from frames + return + + # frame i keeps its obs but takes frame i+shift's action; tail dropped. + for i in range(len(frames) - shift): + cur = frames[i] + nxt = frames[i + shift] + yield Sample( + ts=cur.ts, + episode_id=cur.episode_id, + observation=cur.observation, + action=nxt.action, + ) def get_writer(format_name: str) -> Writer: diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index cc79d007ed..06b485e75c 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -126,8 +126,7 @@ def _episode_path_video(image_key: str, ep_idx: int) -> Path: return d / f"episode_{ep_idx:06d}.mp4" def _open_video(image_key: str, ep_idx: int, frame: np.ndarray) -> Any: - # Frames written as-is; cv2.VideoWriter is BGR-native. RGB frames will - # encode OK but decode color-swapped. Standardize upstream if needed. + # Frames are RGB→BGR converted at write time (see the write loop below). h, w = frame.shape[:2] path = _episode_path_video(image_key, ep_idx) writer = cv2.VideoWriter(str(path), fourcc, fps, (w, h)) @@ -223,7 +222,10 @@ def _flush_episode() -> None: if a.ndim >= 3: if k not in current_video_writers: current_video_writers[k] = _open_video(k, current_episode_index, a) - current_video_writers[k].write(a) + # Frames are RGB; cv2.VideoWriter is BGR-native — convert or + # the MP4 decodes color-swapped. + bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a + current_video_writers[k].write(bgr) rel_ts = float(sample.ts) - (current_episode_start_ts or 0.0) current_frames.append( diff --git a/dimos/learning/dataprep/test_core.py b/dimos/learning/dataprep/test_core.py new file mode 100644 index 0000000000..8a8cdf3b26 --- /dev/null +++ b/dimos/learning/dataprep/test_core.py @@ -0,0 +1,308 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the pure DataPrep helpers in `core.py`. + +No I/O: a tiny in-memory fake stands in for `SqliteStore`, exposing only the +surface the helpers touch (`stream(name)` → iterable of `.ts`/`.data` records, +with `.time_range(t0, t1)`). Keeps these fast and dependency-free. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np +import pytest + +from dimos.learning.dataprep.core import ( + Episode, + EpisodeExtractor, + StreamField, + SyncConfig, + extract_episodes, + iter_episode_samples, + resolve_field, + summarize_lengths, +) + +# ── fakes ──────────────────────────────────────────────────────────────────── + + +@dataclass +class _Obs: + ts: float + data: Any + + +class _FakeStream: + def __init__(self, obs: list[_Obs]) -> None: + self._obs = sorted(obs, key=lambda o: o.ts) + + def __iter__(self): + return iter(self._obs) + + def time_range(self, t0: float, t1: float) -> _FakeStream: + return _FakeStream([o for o in self._obs if t0 <= o.ts <= t1]) + + +class _FakeStore: + def __init__(self, streams: dict[str, list[_Obs]]) -> None: + self._streams = {k: _FakeStream(v) for k, v in streams.items()} + + def stream(self, name: str) -> _FakeStream: + return self._streams.get(name, _FakeStream([])) + + def list_streams(self) -> list[str]: + return list(self._streams) + + +@dataclass +class _Status: + """Mimics EpisodeStatus fields the extractor reads via getattr.""" + + last_event: str + current_episode_start_ts: float | None = None + task_label: str | None = None + + +def _status(events: list[tuple[float, str, float | None, str | None]]) -> list[_Obs]: + """events = [(ts, last_event, start_ts, label), ...]""" + return [ + _Obs(ts=ts, data=_Status(last_event=ev, current_episode_start_ts=start, task_label=lbl)) + for ts, ev, start, lbl in events + ] + + +# ── resolve_field ──────────────────────────────────────────────────────────── + + +def test_resolve_field_attribute(): + @dataclass + class Msg: + position: list[float] + + arr = resolve_field(Msg(position=[1.0, 2.0, 3.0]), StreamField(stream="x", field="position")) + assert isinstance(arr, np.ndarray) + np.testing.assert_array_equal(arr, np.array([1.0, 2.0, 3.0])) + + +def test_resolve_field_dict_payload(): + arr = resolve_field({"q": [4, 5]}, StreamField(stream="x", field="q")) + np.testing.assert_array_equal(arr, np.array([4, 5])) + + +def test_resolve_field_none_passthrough_ndarray(): + src = np.arange(6).reshape(2, 3) + out = resolve_field(src, StreamField(stream="x", field=None)) + assert out is src # ndarray passes straight through + + +def test_resolve_field_none_unwraps_data_attr(): + @dataclass + class Image: + data: np.ndarray + + img = Image(data=np.ones((2, 2))) + out = resolve_field(img, StreamField(stream="x", field=None)) + np.testing.assert_array_equal(out, np.ones((2, 2))) + + +# ── extract_episodes: episode_status ───────────────────────────────────────── + + +def test_extract_start_save(): + store = _FakeStore( + {"status": _status([(1.0, "start", 1.0, "pick"), (5.0, "save", None, None)])} + ) + eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) + assert len(eps) == 1 + assert eps[0].start_ts == 1.0 and eps[0].end_ts == 5.0 + assert eps[0].success is True + assert eps[0].task_label == "pick" + + +def test_extract_discard_marks_failure(): + store = _FakeStore( + {"status": _status([(1.0, "start", 1.0, None), (3.0, "discard", None, None)])} + ) + eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) + assert len(eps) == 1 + assert eps[0].success is False + + +def test_extract_auto_commit_on_restart(): + # start, then another start without save → first auto-commits (success=True) + store = _FakeStore( + { + "status": _status( + [ + (1.0, "start", 1.0, None), + (4.0, "start", 4.0, None), + (8.0, "save", None, None), + ] + ) + } + ) + eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) + assert len(eps) == 2 + assert eps[0].start_ts == 1.0 and eps[0].end_ts == 4.0 and eps[0].success is True + assert eps[1].start_ts == 4.0 and eps[1].end_ts == 8.0 + + +def test_extract_pending_at_eof_dropped(): + store = _FakeStore({"status": _status([(1.0, "start", 1.0, None)])}) + eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) + assert eps == [] + + +def test_extract_init_and_unknown_are_noops(): + store = _FakeStore( + { + "status": _status( + [(0.5, "init", None, None), (1.0, "start", 1.0, None), (5.0, "save", None, None)] + ) + } + ) + eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) + assert len(eps) == 1 + + +def test_extract_save_without_start_emits_nothing(): + store = _FakeStore({"status": _status([(2.0, "save", None, None)])}) + assert extract_episodes(store, EpisodeExtractor(status_stream="status")) == [] + + +# ── extract_episodes: ranges ───────────────────────────────────────────────── + + +def test_extract_ranges(): + cfg = EpisodeExtractor(extractor="ranges", ranges=[(0.0, 1.0), (2.0, 3.0)]) + eps = extract_episodes(_FakeStore({}), cfg) + assert [(e.start_ts, e.end_ts) for e in eps] == [(0.0, 1.0), (2.0, 3.0)] + + +def test_extract_ranges_empty(): + cfg = EpisodeExtractor(extractor="ranges", ranges=None) + assert extract_episodes(_FakeStore({}), cfg) == [] + + +# ── iter_episode_samples ───────────────────────────────────────────────────── + + +def _scalar_stream(values: list[tuple[float, float]]) -> list[_Obs]: + """values = [(ts, scalar), ...] → messages with a `.position` 1-vector.""" + + @dataclass + class S: + position: list[float] + + return [_Obs(ts=ts, data=S(position=[v])) for ts, v in values] + + +def test_sync_basic_no_shift(): + # obs == action, shift disabled → one sample per anchor target + store = _FakeStore( + { + "js": _scalar_stream([(0.0, 10.0), (1.0, 11.0), (2.0, 12.0)]), + } + ) + ep = Episode(id="ep_0", start_ts=0.0, end_ts=2.0) + streams = { + "state": StreamField(stream="js", field="position"), + "act": StreamField(stream="js", field="position"), + } + sync = SyncConfig(anchor="state", rate_hz=1.0, tolerance_ms=100.0, action_shift=0) + samples = list( + iter_episode_samples( + store, ep, streams, sync, obs_keys={"state"}, action_keys={"act"} + ) + ) + assert len(samples) == 3 + # action equals state at the same frame + np.testing.assert_array_equal(samples[0].observation["state"], samples[0].action["act"]) + + +def test_sync_action_shift_next_state(): + store = _FakeStore({"js": _scalar_stream([(0.0, 10.0), (1.0, 11.0), (2.0, 12.0)])}) + ep = Episode(id="ep_0", start_ts=0.0, end_ts=2.0) + streams = { + "state": StreamField(stream="js", field="position"), + "act": StreamField(stream="js", field="position"), + } + sync = SyncConfig(anchor="state", rate_hz=1.0, tolerance_ms=100.0, action_shift=1) + samples = list( + iter_episode_samples( + store, ep, streams, sync, obs_keys={"state"}, action_keys={"act"} + ) + ) + # 3 frames, shift 1 → 2 emitted; trailing frame dropped + assert len(samples) == 2 + # frame 0: obs is state@0 (10), action is state@1 (11) + np.testing.assert_array_equal(samples[0].observation["state"], [10.0]) + np.testing.assert_array_equal(samples[0].action["act"], [11.0]) + np.testing.assert_array_equal(samples[1].observation["state"], [11.0]) + np.testing.assert_array_equal(samples[1].action["act"], [12.0]) + + +def test_sync_tolerance_skips_unmatched_frame(): + # anchor ticks every 1s, but the second stream has a big gap around t=1 + store = _FakeStore( + { + "anchor": _scalar_stream([(0.0, 0.0), (1.0, 0.0), (2.0, 0.0)]), + "other": _scalar_stream([(0.0, 5.0), (2.0, 7.0)]), # nothing near t=1 + } + ) + ep = Episode(id="ep_0", start_ts=0.0, end_ts=2.0) + streams = { + "anchor": StreamField(stream="anchor", field="position"), + "other": StreamField(stream="other", field="position"), + } + sync = SyncConfig(anchor="anchor", rate_hz=1.0, tolerance_ms=100.0, action_shift=0) + samples = list(iter_episode_samples(store, ep, streams, sync, obs_keys={"anchor", "other"})) + # t=1 dropped (no `other` within 100ms) → only t=0 and t=2 survive + assert [round(s.ts) for s in samples] == [0, 2] + + +def test_sync_missing_anchor_raises(): + ep = Episode(id="ep_0", start_ts=0.0, end_ts=1.0) + streams = {"x": StreamField(stream="x", field="position")} + sync = SyncConfig(anchor="not_there", rate_hz=1.0, tolerance_ms=10.0) + with pytest.raises(ValueError, match="anchor"): + list(iter_episode_samples(_FakeStore({}), ep, streams, sync)) + + +def test_sync_empty_anchor_yields_nothing(): + store = _FakeStore({"a": []}) + ep = Episode(id="ep_0", start_ts=0.0, end_ts=1.0) + streams = {"a": StreamField(stream="a", field="position")} + sync = SyncConfig(anchor="a", rate_hz=1.0, tolerance_ms=10.0) + assert list(iter_episode_samples(store, ep, streams, sync)) == [] + + +# ── summarize_lengths ──────────────────────────────────────────────────────── + + +def test_summarize_lengths_uniform(): + assert summarize_lengths([5, 5, 5]) == {"min": 5, "max": 5, "mean": 5.0, "uniform": True} + + +def test_summarize_lengths_varied(): + s = summarize_lengths([2, 4, 6]) + assert s == {"min": 2, "max": 6, "mean": 4.0, "uniform": False} + + +def test_summarize_lengths_empty(): + assert summarize_lengths([]) == {"min": 0, "max": 0, "mean": 0.0, "uniform": True} From 06b1c8a71bb5daf9c25396a3a81cee9a397db7c0 Mon Sep 17 00:00:00 2001 From: Ruthwik Date: Mon, 15 Jun 2026 17:25:14 -0700 Subject: [PATCH 17/45] fix: pre-commit checks --- dimos/learning/dataprep/test_core.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/dimos/learning/dataprep/test_core.py b/dimos/learning/dataprep/test_core.py index 8a8cdf3b26..ca1f67fd40 100644 --- a/dimos/learning/dataprep/test_core.py +++ b/dimos/learning/dataprep/test_core.py @@ -226,9 +226,7 @@ def test_sync_basic_no_shift(): } sync = SyncConfig(anchor="state", rate_hz=1.0, tolerance_ms=100.0, action_shift=0) samples = list( - iter_episode_samples( - store, ep, streams, sync, obs_keys={"state"}, action_keys={"act"} - ) + iter_episode_samples(store, ep, streams, sync, obs_keys={"state"}, action_keys={"act"}) ) assert len(samples) == 3 # action equals state at the same frame @@ -244,9 +242,7 @@ def test_sync_action_shift_next_state(): } sync = SyncConfig(anchor="state", rate_hz=1.0, tolerance_ms=100.0, action_shift=1) samples = list( - iter_episode_samples( - store, ep, streams, sync, obs_keys={"state"}, action_keys={"act"} - ) + iter_episode_samples(store, ep, streams, sync, obs_keys={"state"}, action_keys={"act"}) ) # 3 frames, shift 1 → 2 emitted; trailing frame dropped assert len(samples) == 2 From feb93c68cb31f6417fa19bfac95684493ebd9384 Mon Sep 17 00:00:00 2001 From: Ruthwik Date: Wed, 17 Jun 2026 12:39:48 -0700 Subject: [PATCH 18/45] feat: dataprep action-shift + collection status log, fixes --- dimos/learning/collection/episode_monitor.py | 4 +- dimos/learning/dataprep/build.py | 9 ++++- dimos/learning/dataprep/core.py | 9 +++-- dimos/learning/dataprep/example_config.json | 3 +- dimos/learning/dataprep/test_core.py | 42 ++++++++++++++++++++ 5 files changed, 60 insertions(+), 7 deletions(-) diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index b0722798da..0325660072 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -92,6 +92,7 @@ def __init__(self, **kwargs: Any) -> None: self._saved: int = 0 self._discarded: int = 0 self._current_start_ts: float | None = None + self._last_event: Literal["start", "save", "discard", "init"] = "init" self._prev_bits: dict[str, bool] = {} # rising-edge detection for buttons self._lock = threading.Lock() @@ -126,7 +127,7 @@ def get_status(self) -> EpisodeStatus: episodes_saved=self._saved, episodes_discarded=self._discarded, current_episode_start_ts=self._current_start_ts, - last_event="init", + last_event=self._last_event, task_label=self.config.default_task_label, ) @@ -176,6 +177,7 @@ def _transition(self, event: Literal["start", "save", "discard"], ts: float) -> def _publish(self, last_event: Literal["start", "save", "discard", "init"]) -> EpisodeStatus: with self._lock: + self._last_event = last_event status = EpisodeStatus( state=self._state, episodes_saved=self._saved, diff --git a/dimos/learning/dataprep/build.py b/dimos/learning/dataprep/build.py index dc02ce47c9..13dfc3e0b8 100644 --- a/dimos/learning/dataprep/build.py +++ b/dimos/learning/dataprep/build.py @@ -61,7 +61,13 @@ def _write_dimos_meta(dataset_path: Path, config: DataPrepConfig, episodes: list "format": config.output.format, "metadata": config.output.metadata, } - with open(dataset_path / "dimos_meta.json", "w") as f: + # Writers return a directory (lerobot) or a file (hdf5). Put the sidecar + # *inside* a directory, or *beside* a file (`.dimos_meta.json`). + if dataset_path.is_dir(): + meta_path = dataset_path / "dimos_meta.json" + else: + meta_path = dataset_path.with_name(f"{dataset_path.stem}.dimos_meta.json") + with open(meta_path, "w") as f: json.dump(meta, f, indent=2, default=str) @@ -109,7 +115,6 @@ def run_dataprep(config: DataPrepConfig) -> Path: sorted(action_keys), config.sync.model_dump(), ) - writer = get_writer(config.output.format) logger.info("[dataprep] writing %s dataset to %s", config.output.format, config.output.path) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index ab439ac384..2c1e7bf01c 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -65,8 +65,9 @@ class SyncConfig(BaseConfig): rate_hz: float tolerance_ms: float strategy: Literal["nearest", "interp"] = "nearest" - # Action = state this many frames ahead (default 1 = next state, for BC). - # Trailing frames with no successor are dropped. Set 0 for action==state. + # Action = state this many frames ahead (default 1 = next-state BC). Set 0 + # for action==state. Use 0 for ACT — actions stay flat (one per frame) and + # ACT builds its own chunk at train time via delta_timestamps. action_shift: int = 1 @@ -200,7 +201,9 @@ def _commit(end_ts: float, success: bool, label: str | None) -> None: if last_event == "start": # Auto-commit any prior pending episode (success=True per state-machine spec). _commit(ts, success=True, label=pending_label) - pending_start_ts = getattr(ev, "current_episode_start_ts", None) or ts + # None check, not `or ts`: a start at absolute ts 0.0 is valid. + ep_start = getattr(ev, "current_episode_start_ts", None) + pending_start_ts = ts if ep_start is None else ep_start pending_label = label elif last_event == "save": _commit(ts, success=True, label=pending_label or label) diff --git a/dimos/learning/dataprep/example_config.json b/dimos/learning/dataprep/example_config.json index ed582caf8e..6e173afad8 100644 --- a/dimos/learning/dataprep/example_config.json +++ b/dimos/learning/dataprep/example_config.json @@ -23,7 +23,8 @@ "sync": { "anchor": "image", "rate_hz": 14.0, - "tolerance_ms": 80.0 + "tolerance_ms": 80.0, + "action_shift": 1 }, "output": { "format": "lerobot", diff --git a/dimos/learning/dataprep/test_core.py b/dimos/learning/dataprep/test_core.py index ca1f67fd40..987e6c9d33 100644 --- a/dimos/learning/dataprep/test_core.py +++ b/dimos/learning/dataprep/test_core.py @@ -302,3 +302,45 @@ def test_summarize_lengths_varied(): def test_summarize_lengths_empty(): assert summarize_lengths([]) == {"min": 0, "max": 0, "mean": 0.0, "uniform": True} + + +# ── dimos_meta sidecar ─────────────────────────────────────────────────────── + + +def test_dimos_meta_records_sync_and_action_shift(tmp_path): + import json + + from dimos.learning.dataprep.build import _write_dimos_meta + from dimos.learning.dataprep.core import DataPrepConfig, OutputConfig, StreamField + + cfg = DataPrepConfig( + source="s.db", + observation={"state": StreamField(stream="js", field="position")}, + action={"action": StreamField(stream="js", field="position")}, + sync=SyncConfig(anchor="state", rate_hz=14.0, tolerance_ms=80.0, action_shift=0), + output=OutputConfig(format="lerobot", path=tmp_path, metadata={"fps": 14}), + ) + _write_dimos_meta(tmp_path, cfg, episodes=[]) + + meta = json.loads((tmp_path / "dimos_meta.json").read_text()) + assert meta["sync"]["action_shift"] == 0 + assert meta["source"] == "s.db" + + +def test_dimos_meta_beside_file_for_hdf5(tmp_path): + """hdf5 writer returns a FILE path; the sidecar must land beside it, not + inside it (which would treat the .hdf5 file as a directory and crash).""" + import json + + from dimos.learning.dataprep.build import _write_dimos_meta + from dimos.learning.dataprep.core import DataPrepConfig, OutputConfig + + ds_file = tmp_path / "session.hdf5" + ds_file.write_bytes(b"\x89HDF\r\n") # stand-in for a real .hdf5 + cfg = DataPrepConfig(source="s.db", output=OutputConfig(format="hdf5", path=ds_file)) + + _write_dimos_meta(ds_file, cfg, episodes=[]) + + sidecar = tmp_path / "session.dimos_meta.json" + assert sidecar.exists() # beside the file, not session.hdf5/dimos_meta.json + assert json.loads(sidecar.read_text())["format"] == "hdf5" From 8b0da13e5773c728f1b25ef58fa6d5ee409bbc22 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 13:32:04 -0700 Subject: [PATCH 19/45] fix: db path + cam sim support --- dimos/learning/collection/blueprint.py | 22 ++++++++++++++++------ dimos/memory2/store/sqlite.py | 4 ++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index d50bbaa7bc..4d7723ea3e 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -21,8 +21,9 @@ from __future__ import annotations -from dimos.core.coordination.blueprints import autoconnect -from dimos.core.transport import LCMTransport +from dimos.core.coordination.blueprints import Blueprint, autoconnect +from dimos.core.global_config import global_config +from dimos.core.transport import LCMTransport, pLCMTransport from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.learning.collection.episode_monitor import ( EpisodeMonitorModule, @@ -40,12 +41,21 @@ _DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} +def _camera_if_real() -> tuple[Blueprint, ...]: + """Real RealSense only off-sim. In `--simulation` the teleop coordinator's + MujocoSimModule already publishes color_image on /camera/color_image, so a + real camera would be redundant (and fail with no device connected).""" + if global_config.simulation: + return () + return (RealSenseCamera.blueprint(enable_pointcloud=False),) + + # Transports inline per blueprint so each recording config is self-contained. # joint_state is declared explicitly (not left to autoconnect) so it keeps # recording if the recorder moves to its own process. learning_collect_quest_xarm7 = autoconnect( teleop_quest_xarm7, - RealSenseCamera.blueprint(enable_pointcloud=False), + *_camera_if_real(), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), CollectionRecorder.blueprint(), ).transports( @@ -53,14 +63,14 @@ ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), + ("status", EpisodeStatus): pLCMTransport("/learning/episode_status"), } ) learning_collect_quest_piper = autoconnect( teleop_quest_piper, - RealSenseCamera.blueprint(enable_pointcloud=False), + *_camera_if_real(), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), CollectionRecorder.blueprint(), ).transports( @@ -68,7 +78,7 @@ ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), ("color_image", Image): LCMTransport("/camera/color_image", Image), ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("status", EpisodeStatus): LCMTransport("/learning/episode_status", EpisodeStatus), + ("status", EpisodeStatus): pLCMTransport("/learning/episode_status"), } ) diff --git a/dimos/memory2/store/sqlite.py b/dimos/memory2/store/sqlite.py index bb2b735c1c..229961a126 100644 --- a/dimos/memory2/store/sqlite.py +++ b/dimos/memory2/store/sqlite.py @@ -54,6 +54,10 @@ def __init__(self, **kwargs: Any) -> None: raise FileNotFoundError( f"SQLite database not found: {os.path.abspath(self.config.path)}" ) + if not self.config.must_exist: + parent = os.path.dirname(self.config.path) + if parent: + os.makedirs(parent, exist_ok=True) self._registry_conn = self._open_connection() self._registry = RegistryStore(conn=self._registry_conn) From d42270886d17c562eb6b63b612656980fe3519c2 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 13:51:40 -0700 Subject: [PATCH 20/45] session_db file name with datetime --- dimos/learning/collection/blueprint.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index 4d7723ea3e..f3235d6852 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -21,6 +21,8 @@ from __future__ import annotations +from datetime import datetime + from dimos.core.coordination.blueprints import Blueprint, autoconnect from dimos.core.global_config import global_config from dimos.core.transport import LCMTransport, pLCMTransport @@ -40,6 +42,8 @@ _DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} +_SESSION_DB = f"data/recordings/session_{datetime.now():%Y%m%d_%H%M%S}.db" + def _camera_if_real() -> tuple[Blueprint, ...]: """Real RealSense only off-sim. In `--simulation` the teleop coordinator's @@ -57,7 +61,7 @@ def _camera_if_real() -> tuple[Blueprint, ...]: teleop_quest_xarm7, *_camera_if_real(), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), - CollectionRecorder.blueprint(), + CollectionRecorder.blueprint(db_path=_SESSION_DB), ).transports( { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), @@ -72,7 +76,7 @@ def _camera_if_real() -> tuple[Blueprint, ...]: teleop_quest_piper, *_camera_if_real(), EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), - CollectionRecorder.blueprint(), + CollectionRecorder.blueprint(db_path=_SESSION_DB), ).transports( { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), From d10b955fc60769895372b01e25aa191cd2f33493 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 17:00:32 -0700 Subject: [PATCH 21/45] fix: episode toggle button --- dimos/learning/collection/blueprint.py | 2 +- dimos/learning/collection/episode_monitor.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index f3235d6852..a449c17a63 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -40,7 +40,7 @@ ) from dimos.teleop.quest.quest_types import Buttons -_DEFAULT_BUTTON_MAP = {"start": "A", "save": "B", "discard": "X"} +_DEFAULT_BUTTON_MAP = {"toggle": "B", "discard": "Y"} _SESSION_DB = f"data/recordings/session_{datetime.now():%Y%m%d_%H%M%S}.db" diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index 0325660072..1d5b1d8cfd 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -70,12 +70,11 @@ class KeyPress(BaseModel): class EpisodeMonitorModuleConfig(ModuleConfig): - button_map: dict[Literal["start", "save", "discard"], str] = { - "start": "A", - "save": "B", - "discard": "X", + button_map: dict[Literal["start", "save", "discard", "toggle"], str] = { + "toggle": "B", + "discard": "Y", } - keyboard_map: dict[Literal["start", "save", "discard"], str] = {} + keyboard_map: dict[Literal["start", "save", "discard", "toggle"], str] = {} default_task_label: str | None = None @@ -154,9 +153,16 @@ def _on_keyboard(self, msg: KeyPress) -> None: self._transition(event_name, msg.ts) break - def _transition(self, event: Literal["start", "save", "discard"], ts: float) -> None: - """State-machine transition. Publishes EpisodeStatus on every change.""" + def _transition(self, event: Literal["start", "save", "discard", "toggle"], ts: float) -> None: + """State-machine transition. Publishes EpisodeStatus on every change. + + ``toggle`` resolves to ``start`` when idle and ``save`` when recording, + so one button can begin and end a take. The resolved event is what gets + published (DataPrep only ever sees start/save/discard). + """ with self._lock: + if event == "toggle": + event = "save" if self._state == "recording" else "start" if event == "start": # Auto-commit any in-progress episode (matches DataPrep extractor). if self._state == "recording" and self._current_start_ts is not None: From 66a31d6c511ec18f8e65cf7e5886fffce437199c Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 17:00:44 -0700 Subject: [PATCH 22/45] fix: dataprep float32 + lerobot timestamp/stats fixes --- dimos/learning/dataprep/core.py | 2 ++ dimos/learning/dataprep/formats/lerobot.py | 19 +++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 2c1e7bf01c..0a373c5770 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -311,6 +311,8 @@ def _nearest(key: str, t: float) -> Any | None: skip = True break arr = resolve_field(msg, ref) + if arr.ndim < 3: + arr = arr.astype(np.float32, copy=False) if key in action_keys: act_dict[key] = arr elif key in obs_keys: diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 06b485e75c..075ebfa630 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -111,7 +111,6 @@ def write(samples: Iterator[Sample], output: OutputConfig) -> Path: current_episode_id: str | None = None current_episode_index = -1 - current_episode_start_ts: float | None = None current_frames: list[dict[str, Any]] = [] current_video_writers: dict[str, Any] = {} global_index = 0 @@ -149,16 +148,17 @@ def _flush_episode() -> None: "index": [f["index"] for f in current_frames], "task_index": [f["task_index"] for f in current_frames], } + f32_list = pa.list_(pa.float32()) single_state = len(state_keys) == 1 for k in state_keys: name = _feature_name( "observation", k, is_image=False, single_action=False, single_state=single_state ) - cols[name] = [f["obs"][k].tolist() for f in current_frames] + cols[name] = pa.array([f["obs"][k].tolist() for f in current_frames], type=f32_list) single_action = len(action_keys) == 1 for k in action_keys: name = _feature_name("action", k, is_image=False, single_action=single_action) - cols[name] = [f["act"][k].tolist() for f in current_frames] + cols[name] = pa.array([f["act"][k].tolist() for f in current_frames], type=f32_list) # Video columns intentionally omitted: lerobot's hf_features schema # skips dtype="video" and reads frames from MP4 at __getitem__ time. @@ -179,9 +179,6 @@ def _flush_episode() -> None: _flush_episode() current_episode_id = sample.episode_id current_episode_index += 1 - # Episode-relative timestamps so they fit float32 with sub-ms - # precision; lerobot's check_timestamps_sync compares against 1/fps. - current_episode_start_ts = float(sample.ts) label = default_task_label if label not in tasks_index: tasks_index[label] = len(tasks_index) @@ -227,7 +224,7 @@ def _flush_episode() -> None: bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a current_video_writers[k].write(bgr) - rel_ts = float(sample.ts) - (current_episode_start_ts or 0.0) + rel_ts = frame_index / fps current_frames.append( { "timestamp": rel_ts, @@ -310,8 +307,14 @@ def _flush_episode() -> None: with open(root / META_DIR / "tasks.jsonl", "w") as f: for task, idx in tasks_index.items(): f.write(json.dumps({"task_index": idx, "task": task}) + "\n") + final_stats = stats.finalize() + for name, entry in final_stats.items(): + if feature_dtypes.get(name) == "video": + for k in ("mean", "std", "min", "max"): + if entry.get(k) is not None: + entry[k] = [[[c]] for c in entry[k]] with open(root / META_DIR / "stats.json", "w") as f: - json.dump(stats.finalize(), f, indent=2) + json.dump(final_stats, f, indent=2) return root From 750b085846ae7f98b75172947c90dbc33c8cd99d Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 19:56:19 -0700 Subject: [PATCH 23/45] hey jeff, in the rui interview meeting. just here to see if anyone joins --- dimos/learning/collection/blueprint.py | 6 +- dimos/learning/collection/episode_monitor.py | 26 ++-- dimos/learning/dataprep/cli.py | 3 + dimos/learning/dataprep/formats/lerobot.py | 144 ++++++++++--------- pyproject.toml | 6 + 5 files changed, 104 insertions(+), 81 deletions(-) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index a449c17a63..cd1999abd4 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -40,8 +40,6 @@ ) from dimos.teleop.quest.quest_types import Buttons -_DEFAULT_BUTTON_MAP = {"toggle": "B", "discard": "Y"} - _SESSION_DB = f"data/recordings/session_{datetime.now():%Y%m%d_%H%M%S}.db" @@ -60,7 +58,7 @@ def _camera_if_real() -> tuple[Blueprint, ...]: learning_collect_quest_xarm7 = autoconnect( teleop_quest_xarm7, *_camera_if_real(), - EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), + EpisodeMonitorModule.blueprint(), # default button_map: toggle=B, discard=Y CollectionRecorder.blueprint(db_path=_SESSION_DB), ).transports( { @@ -75,7 +73,7 @@ def _camera_if_real() -> tuple[Blueprint, ...]: learning_collect_quest_piper = autoconnect( teleop_quest_piper, *_camera_if_real(), - EpisodeMonitorModule.blueprint(button_map=_DEFAULT_BUTTON_MAP), + EpisodeMonitorModule.blueprint(), # default button_map: toggle=B, discard=Y CollectionRecorder.blueprint(db_path=_SESSION_DB), ).transports( { diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index 1d5b1d8cfd..fc8eebec3a 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -135,16 +135,22 @@ def get_status(self) -> EpisodeStatus: def _on_buttons(self, msg: Buttons) -> None: """Rising-edge detect against `config.button_map`; advance state machine.""" ts = time.time() - for event_name, alias_or_attr in self.config.button_map.items(): - attr = BUTTON_ALIASES.get(alias_or_attr, alias_or_attr) - try: - pressed = bool(getattr(msg, attr)) - except AttributeError: - continue - prev = self._prev_bits.get(attr, False) - self._prev_bits[attr] = pressed - if pressed and not prev: # rising edge - self._transition(event_name, ts) + # Edge-detect under the lock (it shares `_prev_bits` with reset_counters), + # then fire transitions outside it — `_transition` takes the same lock. + fired: list[Literal["start", "save", "discard", "toggle"]] = [] + with self._lock: + for event_name, alias_or_attr in self.config.button_map.items(): + attr = BUTTON_ALIASES.get(alias_or_attr, alias_or_attr) + try: + pressed = bool(getattr(msg, attr)) + except AttributeError: + continue + prev = self._prev_bits.get(attr, False) + self._prev_bits[attr] = pressed + if pressed and not prev: # rising edge + fired.append(event_name) + for event_name in fired: + self._transition(event_name, ts) def _on_keyboard(self, msg: KeyPress) -> None: """Match `msg.key` against `config.keyboard_map`; advance state machine.""" diff --git a/dimos/learning/dataprep/cli.py b/dimos/learning/dataprep/cli.py index fe6c5c81f2..434055a922 100644 --- a/dimos/learning/dataprep/cli.py +++ b/dimos/learning/dataprep/cli.py @@ -84,6 +84,8 @@ def build( try: path = run_dataprep(cfg) except Exception as e: + # CLI boundary: any failure becomes a clean message + non-zero exit + # instead of a traceback. run_dataprep raises specific errors internally. typer.echo(f"dataprep build failed: {e}", err=True) raise typer.Exit(1) typer.echo(f"✓ wrote dataset to {path}") @@ -99,6 +101,7 @@ def inspect(dataset: Path | None, output_format: str | None) -> None: try: info = inspect_dataset(dataset, output_format) except Exception as e: + # CLI boundary: surface failures as a message + non-zero exit, not a traceback. typer.echo(f"dataprep inspect failed: {e}", err=True) raise typer.Exit(1) typer.echo(json.dumps(info, indent=2, default=str)) diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 075ebfa630..691a158502 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -174,75 +174,85 @@ def _flush_episode() -> None: ) current_frames = [] - for sample in samples: - if sample.episode_id != current_episode_id: - _flush_episode() - current_episode_id = sample.episode_id - current_episode_index += 1 - label = default_task_label - if label not in tasks_index: - tasks_index[label] = len(tasks_index) - - # Schema discovery + stats accumulation. - n_low_dim_obs = sum(1 for _, v in sample.observation.items() if np.asarray(v).ndim < 3) - single_state = n_low_dim_obs == 1 - for k, arr in sample.observation.items(): - a = np.asarray(arr) - is_image = a.ndim >= 3 - name = _feature_name( - "observation", k, is_image=is_image, single_action=False, single_state=single_state + try: + for sample in samples: + if sample.episode_id != current_episode_id: + _flush_episode() + current_episode_id = sample.episode_id + current_episode_index += 1 + label = default_task_label + if label not in tasks_index: + tasks_index[label] = len(tasks_index) + + # Schema discovery + stats accumulation. + n_low_dim_obs = sum(1 for _, v in sample.observation.items() if np.asarray(v).ndim < 3) + single_state = n_low_dim_obs == 1 + for k, arr in sample.observation.items(): + a = np.asarray(arr) + is_image = a.ndim >= 3 + name = _feature_name( + "observation", + k, + is_image=is_image, + single_action=False, + single_state=single_state, + ) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = "video" if is_image else str(a.dtype) + if is_image: + image_keys.add(k) + elif k not in state_keys: + state_keys.append(k) + stats.update(name, a) + + for k, arr in sample.action.items(): + a = np.asarray(arr) + single_action = len(sample.action) == 1 + name = _feature_name("action", k, is_image=False, single_action=single_action) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = str(a.dtype) + if k not in action_keys: + action_keys.append(k) + stats.update(name, a) + + # Video frame write + parquet row buffer. + frame_index = len(current_frames) + for k, arr in sample.observation.items(): + a = np.asarray(arr) + if a.ndim >= 3: + if k not in current_video_writers: + current_video_writers[k] = _open_video(k, current_episode_index, a) + # Frames are RGB; cv2.VideoWriter is BGR-native — convert or + # the MP4 decodes color-swapped. + bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a + current_video_writers[k].write(bgr) + + rel_ts = frame_index / fps + current_frames.append( + { + "timestamp": rel_ts, + "frame_index": frame_index, + "episode_index": current_episode_index, + "index": global_index, + "task_index": tasks_index[default_task_label], + "obs": { + k: np.asarray(v) + for k, v in sample.observation.items() + if np.asarray(v).ndim < 3 + }, + "act": {k: np.asarray(v) for k, v in sample.action.items()}, + } ) - if name not in feature_shapes: - feature_shapes[name] = tuple(a.shape) - feature_dtypes[name] = "video" if is_image else str(a.dtype) - if is_image: - image_keys.add(k) - elif k not in state_keys: - state_keys.append(k) - stats.update(name, a) - - for k, arr in sample.action.items(): - a = np.asarray(arr) - single_action = len(sample.action) == 1 - name = _feature_name("action", k, is_image=False, single_action=single_action) - if name not in feature_shapes: - feature_shapes[name] = tuple(a.shape) - feature_dtypes[name] = str(a.dtype) - if k not in action_keys: - action_keys.append(k) - stats.update(name, a) - - # Video frame write + parquet row buffer. - frame_index = len(current_frames) - for k, arr in sample.observation.items(): - a = np.asarray(arr) - if a.ndim >= 3: - if k not in current_video_writers: - current_video_writers[k] = _open_video(k, current_episode_index, a) - # Frames are RGB; cv2.VideoWriter is BGR-native — convert or - # the MP4 decodes color-swapped. - bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a - current_video_writers[k].write(bgr) - - rel_ts = frame_index / fps - current_frames.append( - { - "timestamp": rel_ts, - "frame_index": frame_index, - "episode_index": current_episode_index, - "index": global_index, - "task_index": tasks_index[default_task_label], - "obs": { - k: np.asarray(v) - for k, v in sample.observation.items() - if np.asarray(v).ndim < 3 - }, - "act": {k: np.asarray(v) for k, v in sample.action.items()}, - } - ) - global_index += 1 + global_index += 1 - _flush_episode() + _flush_episode() + finally: + # If the drain raised mid-episode, release any writers still open so we + # don't leak file handles / leave half-written MP4s locked. + for vw in current_video_writers.values(): + vw.release() # ── meta files ─────────────────────────────────────────────────────────── total_episodes = len(episodes_meta) diff --git a/pyproject.toml b/pyproject.toml index be01eb012d..b62df6531e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,6 +193,12 @@ visualization = [ "dimos-viewer==0.32.0a1", ] +learning = [ + # dimos.learning.dataprep dataset writers (lazy-imported per format) + "pyarrow", # LeRobot v2 parquet writer + "h5py", # HDF5 writer +] + agents = [ "langchain>=1.2.3,<2", "langchain-chroma>=1,<2", From a7fc7a72b1745c760b5e5c82c88670b5720b85a1 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 19:56:30 -0700 Subject: [PATCH 24/45] feat: tests --- dimos/learning/dataprep/formats/test_hdf5.py | 85 ++++++++++++++ .../learning/dataprep/formats/test_lerobot.py | 111 ++++++++++++++++++ dimos/learning/dataprep/formats/test_stats.py | 85 ++++++++++++++ 3 files changed, 281 insertions(+) create mode 100644 dimos/learning/dataprep/formats/test_hdf5.py create mode 100644 dimos/learning/dataprep/formats/test_lerobot.py create mode 100644 dimos/learning/dataprep/formats/test_stats.py diff --git a/dimos/learning/dataprep/formats/test_hdf5.py b/dimos/learning/dataprep/formats/test_hdf5.py new file mode 100644 index 0000000000..54a0a2486f --- /dev/null +++ b/dimos/learning/dataprep/formats/test_hdf5.py @@ -0,0 +1,85 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Round-trip tests for the HDF5 dataset writer/reader. + +Builds a tiny in-memory Sample stream (no SqliteStore), writes it, then reads +it back through `inspect` and asserts the episode/frame counts, per-feature +shapes, and that stats landed. Skips cleanly if h5py isn't installed (it lives +in the `learning` optional-dependency group). +""" + +from __future__ import annotations + +from collections.abc import Iterator + +import numpy as np +import pytest + +h5py = pytest.importorskip("h5py") + +from dimos.learning.dataprep.core import OutputConfig, Sample +from dimos.learning.dataprep.formats.hdf5 import inspect, write + + +def _samples(n_episodes: int = 2, n_frames: int = 3) -> Iterator[Sample]: + """obs `state` (4-vec) + `action` (2-vec), `n_frames` per episode.""" + for ep in range(n_episodes): + for i in range(n_frames): + yield Sample( + ts=float(i), + episode_id=f"ep_{ep:06d}", + observation={"state": (np.arange(4, dtype=np.float32) + i)}, + action={"action": np.full(2, float(i), dtype=np.float32)}, + ) + + +def test_hdf5_roundtrip_counts_and_shapes(tmp_path): + out = OutputConfig( + format="hdf5", + path=tmp_path / "session", + metadata={"fps": 20.0, "robot": "xarm7"}, + ) + path = write(_samples(), out) + assert path.suffix == ".hdf5" + assert path.exists() + + info = inspect(path) + assert info["format"] == "hdf5" + assert info["episodes"] == 2 + assert info["frames"] == 6 + assert info["fps"] == 20.0 + assert info["robot"] == "xarm7" + assert info["observation"]["state"]["shape"] == [4] + assert info["action"]["action"]["shape"] == [2] + assert info["shapes_uniform"] is True + assert info["has_stats"] is True + assert info["episode_lengths"] == {"min": 3, "max": 3, "mean": 3.0, "uniform": True} + + +def test_hdf5_extension_appended_when_missing(tmp_path): + # path with no suffix → writer appends .hdf5 + out = OutputConfig(format="hdf5", path=tmp_path / "noext") + path = write(_samples(n_episodes=1, n_frames=2), out) + assert path.name == "noext.hdf5" + + +def test_hdf5_stats_values_match(tmp_path): + out = OutputConfig(format="hdf5", path=tmp_path / "s.hdf5") + path = write(_samples(n_episodes=1, n_frames=3), out) + with h5py.File(path, "r") as f: + assert "observation.state" in f["stats"] + # state = [0..3]+i for i in 0,1,2 → per-dim mean = base + mean(0,1,2)=base+1 + mean = f["stats"]["observation.state"]["mean"][:] + np.testing.assert_allclose(mean, np.arange(4) + 1.0) diff --git a/dimos/learning/dataprep/formats/test_lerobot.py b/dimos/learning/dataprep/formats/test_lerobot.py new file mode 100644 index 0000000000..f8b1fdd7a3 --- /dev/null +++ b/dimos/learning/dataprep/formats/test_lerobot.py @@ -0,0 +1,111 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Smoke tests for the LeRobot v2 writer/reader. + +The state-only test always runs (parquet + meta + stats, exercises the +try/finally cleanup with no writers open). The image test additionally checks +the MP4 path + canonical `observation.images.*` naming, and skips if no mp4v +codec is available in the environment. Skips entirely if pyarrow isn't +installed (`learning` optional-dependency group). +""" + +from __future__ import annotations + +from collections.abc import Iterator +import json + +import numpy as np +import pytest + +pytest.importorskip("pyarrow") +cv2 = pytest.importorskip("cv2") + +from dimos.learning.dataprep.core import OutputConfig, Sample +from dimos.learning.dataprep.formats.lerobot import inspect, write + + +def _state_samples(n: int = 4) -> Iterator[Sample]: + for i in range(n): + yield Sample( + ts=float(i), + episode_id="ep_000000", + observation={"state": np.arange(6, dtype=np.float32)}, + action={"action": np.full(6, float(i), dtype=np.float32)}, + ) + + +def _image_samples(n: int = 4) -> Iterator[Sample]: + for i in range(n): + yield Sample( + ts=float(i), + episode_id="ep_000000", + observation={ + "state": np.arange(6, dtype=np.float32), + "cam": np.full((16, 16, 3), i, dtype=np.uint8), + }, + action={"action": np.zeros(6, dtype=np.float32)}, + ) + + +def test_lerobot_state_only_layout_and_naming(tmp_path): + out = OutputConfig( + format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0, "robot": "xarm7"} + ) + root = write(_state_samples(), out) + + assert (root / "meta" / "info.json").exists() + assert (root / "meta" / "episodes.jsonl").exists() + assert (root / "meta" / "tasks.jsonl").exists() + assert (root / "meta" / "stats.json").exists() + assert (root / "data" / "chunk-000" / "episode_000000.parquet").exists() + + info = json.loads((root / "meta" / "info.json").read_text()) + assert info["total_episodes"] == 1 + assert info["total_frames"] == 4 + assert info["fps"] == 10.0 + # single low-dim state + single action → canonical names + assert "observation.state" in info["features"] + assert "action" in info["features"] + + +def test_lerobot_inspect_state_only(tmp_path): + out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) + root = write(_state_samples(), out) + info = inspect(root) + assert info["format"] == "lerobot" + assert info["episodes"] == 1 + assert info["frames"] == 4 + assert "observation.state" in info["observation"] + assert "action" in info["action"] + assert info["has_stats"] is True + + +def test_lerobot_with_images_writes_mp4_and_video_feature(tmp_path): + out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) + try: + root = write(_image_samples(), out) + except RuntimeError as e: + if "VideoWriter" in str(e): + pytest.skip(f"no mp4v encoder available in this environment: {e}") + raise + + mp4 = root / "videos" / "chunk-000" / "observation.images.cam" / "episode_000000.mp4" + assert mp4.exists() and mp4.stat().st_size > 0 + + info = json.loads((root / "meta" / "info.json").read_text()) + assert info["features"]["observation.images.cam"]["dtype"] == "video" + # image column is excluded from parquet; state/action remain + assert "observation.state" in info["features"] + assert info["total_frames"] == 4 diff --git a/dimos/learning/dataprep/formats/test_stats.py b/dimos/learning/dataprep/formats/test_stats.py new file mode 100644 index 0000000000..012d12dd9a --- /dev/null +++ b/dimos/learning/dataprep/formats/test_stats.py @@ -0,0 +1,85 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the streaming feature-stats aggregator (`_stats.py`). + +Pure numpy — no I/O, no optional deps. Verifies the Welford mean/std/min/max, +the low-dim quantile reservoir, and the image per-channel reduction + +subsampling. +""" + +from __future__ import annotations + +import numpy as np + +from dimos.learning.dataprep.formats._stats import StreamingStats + + +def test_scalar_mean_std_minmax_count(): + s = StreamingStats() + for v in ([1.0, 10.0], [2.0, 20.0], [3.0, 30.0]): + s.update("state", np.array(v)) + out = s.finalize()["state"] + # population variance (m2 / n): [1,2,3] → 2/3 ; [10,20,30] → 200/3 + np.testing.assert_allclose(out["mean"], [2.0, 20.0]) + np.testing.assert_allclose(out["std"], [np.sqrt(2 / 3), np.sqrt(200 / 3)]) + assert out["min"] == [1.0, 10.0] + assert out["max"] == [3.0, 30.0] + assert out["count"] == 3 + + +def test_single_sample_has_zero_std(): + s = StreamingStats() + s.update("x", np.array([5.0, 7.0])) + out = s.finalize()["x"] + assert out["std"] == [0.0, 0.0] + assert out["count"] == 1 + + +def test_lowdim_quantiles_present_and_bounded(): + s = StreamingStats() + for i in range(100): + s.update("x", np.array([float(i)])) + out = s.finalize()["x"] + assert "q01" in out and "q99" in out + assert out["min"][0] <= out["q01"][0] <= out["q99"][0] <= out["max"][0] + + +def test_image_reduced_to_per_channel_no_quantiles(): + # image_subsample=1 → every frame counts. Constant per-channel values. + s = StreamingStats(image_subsample=1) + img = np.zeros((4, 4, 3), dtype=np.uint8) + img[..., 0], img[..., 1], img[..., 2] = 10, 20, 30 + for _ in range(5): + s.update("cam", img) + out = s.finalize()["cam"] + np.testing.assert_allclose(out["mean"], [10.0, 20.0, 30.0]) + np.testing.assert_allclose(out["min"], [10.0, 20.0, 30.0]) + np.testing.assert_allclose(out["max"], [10.0, 20.0, 30.0]) + assert out["count"] == 5 + # Images skip the quantile reservoir (per-pixel stats would blow up memory). + assert "q01" not in out and "q99" not in out + + +def test_image_subsampling_counts_every_nth_frame(): + s = StreamingStats(image_subsample=10) + img = np.zeros((2, 2, 3), dtype=np.uint8) + for _ in range(25): + s.update("cam", img) + # frames 0, 10, 20 sampled → count 3 + assert s.finalize()["cam"]["count"] == 3 + + +def test_empty_aggregator_finalizes_empty(): + assert StreamingStats().finalize() == {} From e1d0134bc7951bc4d7a285b5a811855dce06aefd Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jun 2026 02:57:19 +0000 Subject: [PATCH 25/45] [autofix.ci] apply automated fixes --- uv.lock | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index e9e2d61c0f..1693240613 100644 --- a/uv.lock +++ b/uv.lock @@ -2038,6 +2038,10 @@ dds = [ drone = [ { name = "pymavlink" }, ] +learning = [ + { name = "h5py" }, + { name = "pyarrow" }, +] manipulation = [ { name = "a750-control", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "drake", version = "1.45.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and sys_platform == 'darwin'" }, @@ -2362,6 +2366,7 @@ requires-dist = [ { name = "gdown", marker = "extra == 'misc'", specifier = ">=5.2.2" }, { name = "googlemaps", marker = "extra == 'misc'", specifier = ">=4.10.0" }, { name = "gtsam-extended", marker = "extra == 'mapping'", specifier = ">=4.3a1.post1" }, + { name = "h5py", marker = "extra == 'learning'" }, { name = "hydra-core", marker = "extra == 'perception'", specifier = ">=1.3.0" }, { name = "ipykernel", marker = "extra == 'misc'" }, { name = "jinja2", marker = "extra == 'web'", specifier = ">=3.1.6" }, @@ -2407,6 +2412,7 @@ requires-dist = [ { name = "protobuf", specifier = ">=6.33.5,<7" }, { name = "psutil", specifier = ">=7.0.0" }, { name = "psycopg2-binary", marker = "extra == 'psql'", specifier = ">=2.9.11" }, + { name = "pyarrow", marker = "extra == 'learning'" }, { name = "pycollada", marker = "extra == 'manipulation'" }, { name = "pydantic" }, { name = "pydantic-settings", specifier = ">=2.11.0,<3" }, @@ -2453,7 +2459,7 @@ requires-dist = [ { name = "xformers", marker = "platform_machine == 'x86_64' and extra == 'cuda'", specifier = ">=0.0.20" }, { name = "yapf", marker = "extra == 'misc'", specifier = "==0.40.2" }, ] -provides-extras = ["misc", "visualization", "agents", "web", "perception", "unitree", "unitree-dds", "manipulation", "cpu", "cuda", "psql", "sim", "mapping", "drone", "dds", "base", "apriltag", "all"] +provides-extras = ["misc", "visualization", "learning", "agents", "web", "perception", "unitree", "unitree-dds", "manipulation", "cpu", "cuda", "psql", "sim", "mapping", "drone", "dds", "base", "apriltag", "all"] [package.metadata.requires-dev] autofix = [{ name = "ruff", specifier = "==0.14.3" }] @@ -3576,6 +3582,65 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, ] +[[package]] +name = "h5py" +version = "3.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/db/33/acd0ce6863b6c0d7735007df01815403f5589a21ff8c2e1ee2587a38f548/h5py-3.16.0.tar.gz", hash = "sha256:a0dbaad796840ccaa67a4c144a0d0c8080073c34c76d5a6941d6818678ef2738", size = 446526, upload-time = "2026-03-06T13:49:08.07Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/6b/231413e58a787a89b316bb0d1777da3c62257e4797e09afd8d17ad3549dc/h5py-3.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e06f864bedb2c8e7c1358e6c73af48519e317457c444d6f3d332bb4e8fa6d7d9", size = 3724137, upload-time = "2026-03-06T13:47:35.242Z" }, + { url = "https://files.pythonhosted.org/packages/74/f9/557ce3aad0fe8471fb5279bab0fc56ea473858a022c4ce8a0b8f303d64e9/h5py-3.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ec86d4fffd87a0f4cb3d5796ceb5a50123a2a6d99b43e616e5504e66a953eca3", size = 3090112, upload-time = "2026-03-06T13:47:37.634Z" }, + { url = "https://files.pythonhosted.org/packages/7a/f5/e15b3d0dc8a18e56409a839e6468d6fb589bc5207c917399c2e0706eeb44/h5py-3.16.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:86385ea895508220b8a7e45efa428aeafaa586bd737c7af9ee04661d8d84a10d", size = 4844847, upload-time = "2026-03-06T13:47:39.811Z" }, + { url = "https://files.pythonhosted.org/packages/cb/92/a8851d936547efe30cc0ce5245feac01f3ec6171f7899bc3f775c72030b3/h5py-3.16.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:8975273c2c5921c25700193b408e28d6bdd0111c37468b2d4e25dcec4cd1d84d", size = 5065352, upload-time = "2026-03-06T13:47:41.489Z" }, + { url = "https://files.pythonhosted.org/packages/2b/ae/f2adc5d0ca9626db3277a3d87516e124cbc5d0eea0bd79bc085702d04f2c/h5py-3.16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1677ad48b703f44efc9ea0c3ab284527f81bc4f318386aaaebc5fede6bbae56f", size = 4839173, upload-time = "2026-03-06T13:47:43.586Z" }, + { url = "https://files.pythonhosted.org/packages/64/0b/e0c8c69da1d8838da023a50cd3080eae5d475691f7636b35eff20bb6ef20/h5py-3.16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7c4dd4cf5f0a4e36083f73172f6cfc25a5710789269547f132a20975bfe2434c", size = 5076216, upload-time = "2026-03-06T13:47:45.315Z" }, + { url = "https://files.pythonhosted.org/packages/66/35/d88fd6718832133c885004c61ceeeb24dbd6397ef877dbed6b3a64d6a286/h5py-3.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:bdef06507725b455fccba9c16529121a5e1fbf56aa375f7d9713d9e8ff42454d", size = 3183639, upload-time = "2026-03-06T13:47:47.041Z" }, + { url = "https://files.pythonhosted.org/packages/ba/95/a825894f3e45cbac7554c4e97314ce886b233a20033787eda755ca8fecc7/h5py-3.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:719439d14b83f74eeb080e9650a6c7aa6d0d9ea0ca7f804347b05fac6fbf18af", size = 3721663, upload-time = "2026-03-06T13:47:49.599Z" }, + { url = "https://files.pythonhosted.org/packages/bf/3b/38ff88b347c3e346cda1d3fc1b65a7aa75d40632228d8b8a5d7b58508c24/h5py-3.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c3f0a0e136f2e95dd0b67146abb6668af4f1a69c81ef8651a2d316e8e01de447", size = 3087630, upload-time = "2026-03-06T13:47:51.249Z" }, + { url = "https://files.pythonhosted.org/packages/98/a8/2594cef906aee761601eff842c7dc598bea2b394a3e1c00966832b8eeb7c/h5py-3.16.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a6fbc5367d4046801f9b7db9191b31895f22f1c6df1f9987d667854cac493538", size = 4823472, upload-time = "2026-03-06T13:47:53.085Z" }, + { url = "https://files.pythonhosted.org/packages/52/a0/c1f604538ff6db22a0690be2dc44ab59178e115f63c917794e529356ab23/h5py-3.16.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fb1720028d99040792bb2fb31facb8da44a6f29df7697e0b84f0d79aff2e9bd3", size = 5027150, upload-time = "2026-03-06T13:47:55.043Z" }, + { url = "https://files.pythonhosted.org/packages/2e/fd/301739083c2fc4fd89950f9bcfce75d6e14b40b0ca3d40e48a8993d1722c/h5py-3.16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:314b6054fe0b1051c2b0cb2df5cbdab15622fb05e80f202e3b6a5eee0d6fe365", size = 4814544, upload-time = "2026-03-06T13:47:56.893Z" }, + { url = "https://files.pythonhosted.org/packages/4c/42/2193ed41ccee78baba8fcc0cff2c925b8b9ee3793305b23e1f22c20bf4c7/h5py-3.16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ffbab2fedd6581f6aa31cf1639ca2cb86e02779de525667892ebf4cc9fd26434", size = 5034013, upload-time = "2026-03-06T13:47:59.01Z" }, + { url = "https://files.pythonhosted.org/packages/f7/20/e6c0ff62ca2ad1a396a34f4380bafccaaf8791ff8fccf3d995a1fc12d417/h5py-3.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d1f1630f92ad74494a9a7392ab25982ce2b469fc62da6074c0ce48366a2999", size = 3191673, upload-time = "2026-03-06T13:48:00.626Z" }, + { url = "https://files.pythonhosted.org/packages/f2/48/239cbe352ac4f2b8243a8e620fa1a2034635f633731493a7ff1ed71e8658/h5py-3.16.0-cp311-cp311-win_arm64.whl", hash = "sha256:85b9c49dd58dc44cf70af944784e2c2038b6f799665d0dcbbc812a26e0faa859", size = 2673834, upload-time = "2026-03-06T13:48:02.579Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c0/5d4119dba94093bbafede500d3defd2f5eab7897732998c04b54021e530b/h5py-3.16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c5313566f4643121a78503a473f0fb1e6dcc541d5115c44f05e037609c565c4d", size = 3685604, upload-time = "2026-03-06T13:48:04.198Z" }, + { url = "https://files.pythonhosted.org/packages/b0/42/c84efcc1d4caebafb1ecd8be4643f39c85c47a80fe254d92b8b43b1eadaf/h5py-3.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:42b012933a83e1a558c673176676a10ce2fd3759976a0fedee1e672d1e04fc9d", size = 3061940, upload-time = "2026-03-06T13:48:05.783Z" }, + { url = "https://files.pythonhosted.org/packages/89/84/06281c82d4d1686fde1ac6b0f307c50918f1c0151062445ab3b6fa5a921d/h5py-3.16.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:ff24039e2573297787c3063df64b60aab0591980ac898329a08b0320e0cf2527", size = 5198852, upload-time = "2026-03-06T13:48:07.482Z" }, + { url = "https://files.pythonhosted.org/packages/9e/e9/1a19e42cd43cc1365e127db6aae85e1c671da1d9a5d746f4d34a50edb577/h5py-3.16.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:dfc21898ff025f1e8e67e194965a95a8d4754f452f83454538f98f8a3fcb207e", size = 5405250, upload-time = "2026-03-06T13:48:09.628Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8e/9790c1655eabeb85b92b1ecab7d7e62a2069e53baefd58c98f0909c7a948/h5py-3.16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:698dd69291272642ffda44a0ecd6cd3bda5faf9621452d255f57ce91487b9794", size = 5190108, upload-time = "2026-03-06T13:48:11.26Z" }, + { url = "https://files.pythonhosted.org/packages/51/d7/ab693274f1bd7e8c5f9fdd6c7003a88d59bedeaf8752716a55f532924fbb/h5py-3.16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2b2c02b0a160faed5fb33f1ba8a264a37ee240b22e049ecc827345d0d9043074", size = 5419216, upload-time = "2026-03-06T13:48:13.322Z" }, + { url = "https://files.pythonhosted.org/packages/03/c1/0976b235cf29ead553e22f2fb6385a8252b533715e00d0ae52ed7b900582/h5py-3.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:96b422019a1c8975c2d5dadcf61d4ba6f01c31f92bbde6e4649607885fe502d6", size = 3182868, upload-time = "2026-03-06T13:48:15.759Z" }, + { url = "https://files.pythonhosted.org/packages/14/d9/866b7e570b39070f92d47b0ff1800f0f8239b6f9e45f02363d7112336c1f/h5py-3.16.0-cp312-cp312-win_arm64.whl", hash = "sha256:39c2838fb1e8d97bcf1755e60ad1f3dd76a7b2a475928dc321672752678b96db", size = 2653286, upload-time = "2026-03-06T13:48:17.279Z" }, + { url = "https://files.pythonhosted.org/packages/0f/9e/6142ebfda0cb6e9349c091eae73c2e01a770b7659255248d637bec54a88b/h5py-3.16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:370a845f432c2c9619db8eed334d1e610c6015796122b0e57aa46312c22617d9", size = 3671808, upload-time = "2026-03-06T13:48:19.737Z" }, + { url = "https://files.pythonhosted.org/packages/b0/65/5e088a45d0f43cd814bc5bec521c051d42005a472e804b1a36c48dada09b/h5py-3.16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:42108e93326c50c2810025aade9eac9d6827524cdccc7d4b75a546e5ab308edb", size = 3045837, upload-time = "2026-03-06T13:48:21.854Z" }, + { url = "https://files.pythonhosted.org/packages/da/1e/6172269e18cc5a484e2913ced33339aad588e02ba407fafd00d369e22ef3/h5py-3.16.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:099f2525c9dcf28de366970a5fb34879aab20491589fa89ce2863a84218bb524", size = 5193860, upload-time = "2026-03-06T13:48:24.071Z" }, + { url = "https://files.pythonhosted.org/packages/bd/98/ef2b6fe2903e377cbe870c3b2800d62552f1e3dbe81ce49e1923c53d1c5c/h5py-3.16.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9300ad32dea9dfc5171f94d5f6948e159ed93e4701280b0f508773b3f582f402", size = 5400417, upload-time = "2026-03-06T13:48:25.728Z" }, + { url = "https://files.pythonhosted.org/packages/bc/81/5b62d760039eed64348c98129d17061fdfc7839fc9c04eaaad6dee1004e4/h5py-3.16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:171038f23bccddfc23f344cadabdfc9917ff554db6a0d417180d2747fe4c75a7", size = 5185214, upload-time = "2026-03-06T13:48:27.436Z" }, + { url = "https://files.pythonhosted.org/packages/28/c4/532123bcd9080e250696779c927f2cb906c8bf3447df98f5ceb8dcded539/h5py-3.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7e420b539fb6023a259a1b14d4c9f6df8cf50d7268f48e161169987a57b737ff", size = 5414598, upload-time = "2026-03-06T13:48:29.49Z" }, + { url = "https://files.pythonhosted.org/packages/c3/d9/a27997f84341fc0dfcdd1fe4179b6ba6c32a7aa880fdb8c514d4dad6fba3/h5py-3.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:18f2bbcd545e6991412253b98727374c356d67caa920e68dc79eab36bf5fedad", size = 3175509, upload-time = "2026-03-06T13:48:31.131Z" }, + { url = "https://files.pythonhosted.org/packages/a5/23/bb8647521d4fd770c30a76cfc6cb6a2f5495868904054e92f2394c5a78ff/h5py-3.16.0-cp313-cp313-win_arm64.whl", hash = "sha256:656f00e4d903199a1d58df06b711cf3ca632b874b4207b7dbec86185b5c8c7d4", size = 2647362, upload-time = "2026-03-06T13:48:33.411Z" }, + { url = "https://files.pythonhosted.org/packages/48/3c/7fcd9b4c9eed82e91fb15568992561019ae7a829d1f696b2c844355d95dd/h5py-3.16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9c9d307c0ef862d1cd5714f72ecfafe0a5d7529c44845afa8de9f46e5ba8bd65", size = 3678608, upload-time = "2026-03-06T13:48:35.183Z" }, + { url = "https://files.pythonhosted.org/packages/6a/b7/9366ed44ced9b7ef357ab48c94205280276db9d7f064aa3012a97227e966/h5py-3.16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8c1eff849cdd53cbc73c214c30ebdb6f1bb8b64790b4b4fc36acdb5e43570210", size = 3054773, upload-time = "2026-03-06T13:48:37.139Z" }, + { url = "https://files.pythonhosted.org/packages/58/a5/4964bc0e91e86340c2bbda83420225b2f770dcf1eb8a39464871ad769436/h5py-3.16.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:e2c04d129f180019e216ee5f9c40b78a418634091c8782e1f723a6ca3658b965", size = 5198886, upload-time = "2026-03-06T13:48:38.879Z" }, + { url = "https://files.pythonhosted.org/packages/f1/16/d905e7f53e661ce2c24686c38048d8e2b750ffc4350009d41c4e6c6c9826/h5py-3.16.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:e4360f15875a532bc7b98196c7592ed4fc92672a57c0a621355961cafb17a6dd", size = 5404883, upload-time = "2026-03-06T13:48:41.324Z" }, + { url = "https://files.pythonhosted.org/packages/4b/f2/58f34cb74af46d39f4cd18ea20909a8514960c5a3e5b92fd06a28161e0a8/h5py-3.16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:3fae9197390c325e62e0a1aa977f2f62d994aa87aab182abbea85479b791197c", size = 5192039, upload-time = "2026-03-06T13:48:43.117Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ca/934a39c24ce2e2db017268c08da0537c20fa0be7e1549be3e977313fc8f5/h5py-3.16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:43259303989ac8adacc9986695b31e35dba6fd1e297ff9c6a04b7da5542139cc", size = 5421526, upload-time = "2026-03-06T13:48:44.838Z" }, + { url = "https://files.pythonhosted.org/packages/3e/14/615a450205e1b56d16c6783f5ccd116cde05550faad70ae077c955654a75/h5py-3.16.0-cp314-cp314-win_amd64.whl", hash = "sha256:fa48993a0b799737ba7fd21e2350fa0a60701e58180fae9f2de834bc39a147ab", size = 3183263, upload-time = "2026-03-06T13:48:47.117Z" }, + { url = "https://files.pythonhosted.org/packages/7b/48/a6faef5ed632cae0c65ac6b214a6614a0b510c3183532c521bdb0055e117/h5py-3.16.0-cp314-cp314-win_arm64.whl", hash = "sha256:1897a771a7f40d05c262fc8f37376ec37873218544b70216872876c627640f63", size = 2663450, upload-time = "2026-03-06T13:48:48.707Z" }, + { url = "https://files.pythonhosted.org/packages/5d/32/0c8bb8aedb62c772cf7c1d427c7d1951477e8c2835f872bc0a13d1f85f86/h5py-3.16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:15922e485844f77c0b9d275396d435db3baa58292a9c2176a386e072e0cf2491", size = 3760693, upload-time = "2026-03-06T13:48:50.453Z" }, + { url = "https://files.pythonhosted.org/packages/1d/1f/fcc5977d32d6387c5c9a694afee716a5e20658ac08b3ff24fdec79fb05f2/h5py-3.16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:df02dd29bd247f98674634dfe41f89fd7c16ba3d7de8695ec958f58404a4e618", size = 3181305, upload-time = "2026-03-06T13:48:52.221Z" }, + { url = "https://files.pythonhosted.org/packages/f5/a1/af87f64b9f986889884243643621ebbd4ac72472ba8ec8cec891ac8e2ca1/h5py-3.16.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:0f456f556e4e2cebeebd9d66adf8dc321770a42593494a0b6f0af54a7567b242", size = 5074061, upload-time = "2026-03-06T13:48:54.089Z" }, + { url = "https://files.pythonhosted.org/packages/cc/d0/146f5eaff3dc246a9c7f6e5e4f42bd45cc613bce16693bcd4d1f7c958bf5/h5py-3.16.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:3e6cb3387c756de6a9492d601553dffea3fe11b5f22b443aac708c69f3f55e16", size = 5279216, upload-time = "2026-03-06T13:48:56.75Z" }, + { url = "https://files.pythonhosted.org/packages/a1/9d/12a13424f1e604fc7df9497b73c0356fb78c2fb206abd7465ce47226e8fd/h5py-3.16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8389e13a1fd745ad2856873e8187fd10268b2d9677877bb667b41aebd771d8b7", size = 5070068, upload-time = "2026-03-06T13:48:59.169Z" }, + { url = "https://files.pythonhosted.org/packages/41/8c/bbe98f813722b4873818a8db3e15aa3e625b59278566905ac439725e8070/h5py-3.16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:346df559a0f7dcb31cf8e44805319e2ab24b8957c45e7708ce503b2ec79ba725", size = 5300253, upload-time = "2026-03-06T13:49:02.033Z" }, + { url = "https://files.pythonhosted.org/packages/32/9e/87e6705b4d6890e7cecdf876e2a7d3e40654a2ae37482d79a6f1b87f7b92/h5py-3.16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:4c6ab014ab704b4feaa719ae783b86522ed0bf1f82184704ed3c9e4e3228796e", size = 3381671, upload-time = "2026-03-06T13:49:04.351Z" }, + { url = "https://files.pythonhosted.org/packages/96/91/9fad90cfc5f9b2489c7c26ad897157bce82f0e9534a986a221b99760b23b/h5py-3.16.0-cp314-cp314t-win_arm64.whl", hash = "sha256:faca8fb4e4319c09d83337adc80b2ca7d5c5a343c2d6f1b6388f32cfecca13c1", size = 2740706, upload-time = "2026-03-06T13:49:06.347Z" }, +] + [[package]] name = "hf-xet" version = "1.2.0" From 819febd915e9b24676c764ea047a6fce554b08eb Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 20:29:55 -0700 Subject: [PATCH 26/45] revert: jpeg debug --- dimos/memory2/codecs/jpeg.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py index 80f22696a2..3d854400b1 100644 --- a/dimos/memory2/codecs/jpeg.py +++ b/dimos/memory2/codecs/jpeg.py @@ -36,9 +36,4 @@ def encode(self, value: Image) -> bytes: def decode(self, data: bytes) -> Image: from dimos.msgs.sensor_msgs.Image import Image - # Some recordings include a 1-byte format tag before the LCM envelope - # (b'J' for JPEG-encoded Image). Strip it on read so old + new sessions - # both decode cleanly. - if data and data[0:1] == b"J": - data = data[1:] return Image.lcm_jpeg_decode(data) From b3e8d822ab94bc3accb8e4074a1cb8dd665d690c Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Wed, 17 Jun 2026 23:41:21 -0700 Subject: [PATCH 27/45] fix: mypy issues --- dimos/learning/dataprep/cli.py | 9 +++++++-- dimos/learning/dataprep/core.py | 9 ++++++--- dimos/learning/dataprep/formats/_stats.py | 1 + dimos/learning/dataprep/formats/hdf5.py | 4 ++-- dimos/learning/dataprep/formats/lerobot.py | 2 +- pyproject.toml | 4 ++++ 6 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dimos/learning/dataprep/cli.py b/dimos/learning/dataprep/cli.py index 434055a922..654e31fc56 100644 --- a/dimos/learning/dataprep/cli.py +++ b/dimos/learning/dataprep/cli.py @@ -27,7 +27,7 @@ import json from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, cast import typer @@ -53,8 +53,13 @@ def _load_config( if source is not None: updates["source"] = str(source) if output is not None or output_format is not None: + fmt = ( + cast("Literal['lerobot', 'hdf5']", output_format) + if output_format + else cfg.output.format + ) updates["output"] = OutputConfig( - format=output_format or cfg.output.format, + format=fmt, path=output or cfg.output.path, metadata=cfg.output.metadata, ) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 0a373c5770..544dff3b7e 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -40,6 +40,7 @@ if TYPE_CHECKING: from dimos.memory2.store.sqlite import SqliteStore + from dimos.memory2.stream import Stream # ───────────────────────────────────────────────────────────────────────────── @@ -90,7 +91,7 @@ class DataPrepConfig(BaseConfig): observation: dict[str, StreamField] = {} action: dict[str, StreamField] = {} sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) - output: OutputConfig = OutputConfig(format="lerobot", path="data/datasets/default") + output: OutputConfig = OutputConfig(format="lerobot", path=Path("data/datasets/default")) # ───────────────────────────────────────────────────────────────────────────── @@ -167,7 +168,7 @@ def extract_episodes(store: SqliteStore, cfg: EpisodeExtractor) -> list[Episode] ] # episode_status (default) - status_stream = store.stream(cfg.status_stream) + status_stream: Stream[Any, Any] = store.stream(cfg.status_stream) events = list(status_stream) # observations in storage order episodes: list[Episode] = [] @@ -248,7 +249,9 @@ def iter_episode_samples( # Materialize each stream's (timestamps, messages) once per episode. cached: dict[str, tuple[list[float], list[Any]]] = {} for key, ref in streams.items(): - sub = store.stream(ref.stream).time_range(episode.start_ts, episode.end_ts) + sub: Stream[Any, Any] = store.stream(ref.stream).time_range( + episode.start_ts, episode.end_ts + ) ts_list: list[float] = [] msg_list: list[Any] = [] for obs in sub: diff --git a/dimos/learning/dataprep/formats/_stats.py b/dimos/learning/dataprep/formats/_stats.py index 8d10ebdf35..0793b2f1cb 100644 --- a/dimos/learning/dataprep/formats/_stats.py +++ b/dimos/learning/dataprep/formats/_stats.py @@ -102,6 +102,7 @@ def finalize(self) -> dict[str, dict[str, Any]]: if agg.mean is None: continue n = max(1, agg.n) + assert agg.m2 is not None var = agg.m2 / n if agg.n > 1 else np.zeros_like(agg.mean) std = np.sqrt(var) entry: dict[str, Any] = { diff --git a/dimos/learning/dataprep/formats/hdf5.py b/dimos/learning/dataprep/formats/hdf5.py index e81de0fbd3..e752febdca 100644 --- a/dimos/learning/dataprep/formats/hdf5.py +++ b/dimos/learning/dataprep/formats/hdf5.py @@ -174,8 +174,8 @@ def inspect(path: Path) -> dict[str, Any]: # Are per-frame shapes consistent across every episode? shapes_uniform = True - for e in ep_names[1:]: - g = eps_g[e] + for ep_name in ep_names[1:]: + g = eps_g[ep_name] for grp, ref in (("observation", observation), ("action", action)): if grp in g: for k, d in g[grp].items(): diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 691a158502..983c5efd46 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -91,7 +91,7 @@ def write(samples: Iterator[Sample], output: OutputConfig) -> Path: (root / VIDEO_DIR / CHUNK).mkdir(parents=True, exist_ok=True) fps = float(output.metadata.get("fps", 30.0)) - fourcc = cv2.VideoWriter_fourcc(*"mp4v") + fourcc = cv2.VideoWriter.fourcc(*"mp4v") stats = StreamingStats( image_subsample=int(output.metadata.get("image_subsample", 10)), diff --git a/pyproject.toml b/pyproject.toml index b62df6531e..b669721907 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -524,6 +524,8 @@ module = [ "etils", "faster_whisper", "geometry_msgs.*", + "h5py", + "h5py.*", "lazy_loader", "mcap", "mcap.*", @@ -541,6 +543,8 @@ module = [ "pycuda.*", "pydrake", "pydrake.*", + "pyarrow", + "pyarrow.*", "pyzed", "pyzed.*", "rclpy.*", From b3f234f9148325567929d35077ebe567a1f1a16f Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 12:39:27 -0700 Subject: [PATCH 28/45] misc: test fixes + module list --- dimos/memory2/store/sqlite.py | 4 ---- dimos/robot/all_blueprints.py | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/dimos/memory2/store/sqlite.py b/dimos/memory2/store/sqlite.py index 229961a126..bb2b735c1c 100644 --- a/dimos/memory2/store/sqlite.py +++ b/dimos/memory2/store/sqlite.py @@ -54,10 +54,6 @@ def __init__(self, **kwargs: Any) -> None: raise FileNotFoundError( f"SQLite database not found: {os.path.abspath(self.config.path)}" ) - if not self.config.must_exist: - parent = os.path.dirname(self.config.path) - if parent: - os.makedirs(parent, exist_ok=True) self._registry_conn = self._open_connection() self._registry = RegistryStore(conn=self._registry_conn) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 0a0b7383a9..053082935e 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -136,6 +136,7 @@ "camera-module": "dimos.hardware.sensors.camera.module.CameraModule", "cartesian-motion-controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller.CartesianMotionController", "click-start-goal-router": "dimos.navigation.nav_stack.modules.click_start_goal_router.click_start_goal_router.ClickStartGoalRouter", + "collection-recorder": "dimos.learning.collection.recorder.CollectionRecorder", "control-coordinator": "dimos.control.coordinator.ControlCoordinator", "cost-mapper": "dimos.mapping.costmapper.CostMapper", "demo-calculator-skill": "dimos.agents.skills.demo_calculator_skill.DemoCalculatorSkill", From b4e84ab1d3d86d10ae307729f35da32fafde8607 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 13:07:38 -0700 Subject: [PATCH 29/45] fix: greptile comments --- dimos/learning/collection/blueprint.py | 9 +++-- dimos/learning/dataprep/core.py | 45 +++++++++++----------- dimos/learning/dataprep/formats/hdf5.py | 13 ++++--- dimos/learning/dataprep/formats/lerobot.py | 11 +++--- 4 files changed, 41 insertions(+), 37 deletions(-) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index cd1999abd4..e920bfbe2b 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -40,7 +40,10 @@ ) from dimos.teleop.quest.quest_types import Buttons -_SESSION_DB = f"data/recordings/session_{datetime.now():%Y%m%d_%H%M%S}.db" + +def _session_db(robot: str) -> str: + """Timestamped session DB path, namespaced by robot.""" + return f"data/recordings/session_{robot}_{datetime.now():%Y%m%d_%H%M%S}.db" def _camera_if_real() -> tuple[Blueprint, ...]: @@ -59,7 +62,7 @@ def _camera_if_real() -> tuple[Blueprint, ...]: teleop_quest_xarm7, *_camera_if_real(), EpisodeMonitorModule.blueprint(), # default button_map: toggle=B, discard=Y - CollectionRecorder.blueprint(db_path=_SESSION_DB), + CollectionRecorder.blueprint(db_path=_session_db("xarm7")), ).transports( { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), @@ -74,7 +77,7 @@ def _camera_if_real() -> tuple[Blueprint, ...]: teleop_quest_piper, *_camera_if_real(), EpisodeMonitorModule.blueprint(), # default button_map: toggle=B, discard=Y - CollectionRecorder.blueprint(db_path=_SESSION_DB), + CollectionRecorder.blueprint(db_path=_session_db("piper")), ).transports( { ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 544dff3b7e..6119ee489b 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -301,35 +301,34 @@ def _nearest(key: str, t: float) -> Any | None: return None return msg_list[best] - # Buffer the episode in order; the action shift below pairs obs[i] with - # action[i + shift]. - frames: list[Sample] = [] - for t in targets: - obs_dict: dict[str, np.ndarray] = {} - act_dict: dict[str, np.ndarray] = {} - skip = False - for key, ref in streams.items(): - msg = _nearest(key, t) - if msg is None: - skip = True - break - arr = resolve_field(msg, ref) - if arr.ndim < 3: - arr = arr.astype(np.float32, copy=False) - if key in action_keys: - act_dict[key] = arr - elif key in obs_keys: - obs_dict[key] = arr - if skip: - continue - frames.append(Sample(ts=t, episode_id=episode.id, observation=obs_dict, action=act_dict)) + def _build_frames() -> Iterator[Sample]: + for t in targets: + obs_dict: dict[str, np.ndarray] = {} + act_dict: dict[str, np.ndarray] = {} + skip = False + for key, ref in streams.items(): + msg = _nearest(key, t) + if msg is None: + skip = True + break + arr = resolve_field(msg, ref) + if arr.ndim < 3: + arr = arr.astype(np.float32, copy=False) + if key in action_keys: + act_dict[key] = arr + elif key in obs_keys: + obs_dict[key] = arr + if skip: + continue + yield Sample(ts=t, episode_id=episode.id, observation=obs_dict, action=act_dict) shift = max(0, sync.action_shift) if shift == 0 or not action_keys: - yield from frames + yield from _build_frames() return # frame i keeps its obs but takes frame i+shift's action; tail dropped. + frames = list(_build_frames()) for i in range(len(frames) - shift): cur = frames[i] nxt = frames[i + shift] diff --git a/dimos/learning/dataprep/formats/hdf5.py b/dimos/learning/dataprep/formats/hdf5.py index e752febdca..9b31b52a09 100644 --- a/dimos/learning/dataprep/formats/hdf5.py +++ b/dimos/learning/dataprep/formats/hdf5.py @@ -68,7 +68,7 @@ def write(samples: Iterator[Sample], output: OutputConfig) -> Path: # Per-episode buffers — flushed at episode boundary. cur_id: str | None = None - cur_idx = -1 + cur_idx = 0 cur_start_ts: float | None = None buf_ts: list[float] = [] buf_obs: dict[str, list[np.ndarray]] = {} @@ -79,9 +79,9 @@ def write(samples: Iterator[Sample], output: OutputConfig) -> Path: with h5py.File(out, "w") as h5: episodes_g = h5.create_group("episodes") - def _flush() -> None: - if cur_idx < 0 or not buf_ts: - return + def _flush() -> bool: + if not buf_ts: + return False ep = episodes_g.create_group(f"episode_{cur_idx:06d}") ep.attrs["length"] = len(buf_ts) ep.attrs["start_ts"] = float(cur_start_ts or 0.0) @@ -100,12 +100,13 @@ def _flush() -> None: buf_ts.clear() buf_obs.clear() buf_act.clear() + return True for sample in samples: if sample.episode_id != cur_id: - _flush() + if _flush(): + cur_idx += 1 cur_id = sample.episode_id - cur_idx += 1 cur_start_ts = float(sample.ts) if default_task_label not in tasks_index: tasks_index[default_task_label] = len(tasks_index) diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 983c5efd46..8cf4c49711 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -110,7 +110,7 @@ def write(samples: Iterator[Sample], output: OutputConfig) -> Path: default_task_label = output.metadata.get("default_task_label", "task") current_episode_id: str | None = None - current_episode_index = -1 + current_episode_index = 0 current_frames: list[dict[str, Any]] = [] current_video_writers: dict[str, Any] = {} global_index = 0 @@ -133,10 +133,10 @@ def _open_video(image_key: str, ep_idx: int, frame: np.ndarray) -> Any: raise RuntimeError(f"Failed to open VideoWriter for {path}") return writer - def _flush_episode() -> None: + def _flush_episode() -> bool: nonlocal current_frames, current_video_writers, current_episode_index if not current_frames: - return + return False for vw in current_video_writers.values(): vw.release() current_video_writers = {} @@ -173,13 +173,14 @@ def _flush_episode() -> None: } ) current_frames = [] + return True try: for sample in samples: if sample.episode_id != current_episode_id: - _flush_episode() + if _flush_episode(): + current_episode_index += 1 current_episode_id = sample.episode_id - current_episode_index += 1 label = default_task_label if label not in tasks_index: tasks_index[label] = len(tasks_index) From 621289a5b203282a7f3ad489947e00cad5504805 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 13:39:20 -0700 Subject: [PATCH 30/45] fix: dataprep fps sync + episode index/leak/lock fixes, monitor tests --- dimos/learning/collection/episode_monitor.py | 42 +++-- .../collection/test_episode_monitor.py | 145 ++++++++++++++++++ dimos/learning/dataprep/build.py | 17 +- dimos/learning/dataprep/core.py | 6 +- 4 files changed, 188 insertions(+), 22 deletions(-) create mode 100644 dimos/learning/collection/test_episode_monitor.py diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index fc8eebec3a..c2a96b3ef0 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -28,6 +28,7 @@ from typing import Any, Literal from pydantic import BaseModel +from reactivex.disposable import Disposable from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig @@ -98,11 +99,14 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self.buttons.subscribe(self._on_buttons) - self.keyboard.subscribe(self._on_keyboard) + # Registered so the base Module.stop() disposes them on shutdown. + self.register_disposable(Disposable(self.buttons.subscribe(self._on_buttons))) + self.register_disposable(Disposable(self.keyboard.subscribe(self._on_keyboard))) # Emit an initial idle status so subscribers (and recorders) have a # known starting point in the timeline. - self._publish("init") + with self._lock: + status = self._snapshot("init") + self._emit(status) @rpc def stop(self) -> None: @@ -116,7 +120,8 @@ def reset_counters(self) -> EpisodeStatus: self._discarded = 0 self._current_start_ts = None self._prev_bits = {} - return self._publish("init") + status = self._snapshot("init") + return self._emit(status) @rpc def get_status(self) -> EpisodeStatus: @@ -185,19 +190,24 @@ def _transition(self, event: Literal["start", "save", "discard", "toggle"], ts: self._discarded += 1 self._state = "idle" self._current_start_ts = None - self._publish(event) + # Snapshot under the mutation's lock so the event matches the state. + status = self._snapshot(event) + self._emit(status) + + def _snapshot(self, last_event: Literal["start", "save", "discard", "init"]) -> EpisodeStatus: + """Build a status from current state. Caller must hold `self._lock`.""" + self._last_event = last_event + return EpisodeStatus( + state=self._state, + episodes_saved=self._saved, + episodes_discarded=self._discarded, + current_episode_start_ts=self._current_start_ts, + last_event=last_event, + task_label=self.config.default_task_label, + ) - def _publish(self, last_event: Literal["start", "save", "discard", "init"]) -> EpisodeStatus: - with self._lock: - self._last_event = last_event - status = EpisodeStatus( - state=self._state, - episodes_saved=self._saved, - episodes_discarded=self._discarded, - current_episode_start_ts=self._current_start_ts, - last_event=last_event, - task_label=self.config.default_task_label, - ) + def _emit(self, status: EpisodeStatus) -> EpisodeStatus: + """Publish + log a snapshot. Must run outside the lock (does I/O).""" self.status.publish(status) self._log_status(status) return status diff --git a/dimos/learning/collection/test_episode_monitor.py b/dimos/learning/collection/test_episode_monitor.py new file mode 100644 index 0000000000..0a947ba311 --- /dev/null +++ b/dimos/learning/collection/test_episode_monitor.py @@ -0,0 +1,145 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the EpisodeMonitor state machine. + +Drives the button/keyboard handlers directly and captures published +EpisodeStatus events via a stubbed `status` Out port. The module is built with +`object.__new__` + the subclass' own state init so the test exercises just the +state machine, not the RPC/transport machinery a full `Module()` boots up. +Mirrors the offline `extract_episodes` state machine these events feed. +""" + +from __future__ import annotations + +import threading + +from dimos.learning.collection.episode_monitor import ( + EpisodeMonitorModule, + EpisodeMonitorModuleConfig, + EpisodeStatus, + KeyPress, +) +from dimos.teleop.quest.quest_types import Buttons + + +class _CaptureOut: + """Stand-in for the `status` Out port that records published events.""" + + def __init__(self) -> None: + self.events: list[EpisodeStatus] = [] + + def publish(self, status: EpisodeStatus) -> None: + self.events.append(status) + + +def _monitor(**config: object) -> tuple[EpisodeMonitorModule, _CaptureOut]: + m = EpisodeMonitorModule.__new__(EpisodeMonitorModule) + m.config = EpisodeMonitorModuleConfig(**config) # type: ignore[assignment] + m._state = "idle" + m._saved = 0 + m._discarded = 0 + m._current_start_ts = None + m._last_event = "init" + m._prev_bits = {} + m._lock = threading.Lock() + out = _CaptureOut() + m.status = out # type: ignore[assignment] + return m, out + + +def _press(monitor: EpisodeMonitorModule, alias: str, ts: float) -> None: + """Rising edge: release-then-press the given Quest button alias.""" + from dimos.learning.collection.episode_monitor import BUTTON_ALIASES + + attr = BUTTON_ALIASES[alias] + released = Buttons() + pressed = Buttons() + setattr(pressed, attr, True) + monitor._on_buttons(released) + monitor._on_buttons(pressed) + + +def test_toggle_starts_then_saves() -> None: + m, out = _monitor() # default map: toggle=B, discard=Y + _press(m, "B", ts=1.0) # idle → recording + _press(m, "B", ts=2.0) # recording → idle (saved) + + events = [e.last_event for e in out.events] + assert events == ["start", "save"] + assert out.events[-1].state == "idle" + assert out.events[-1].episodes_saved == 1 + assert out.events[-1].episodes_discarded == 0 + + +def test_discard_does_not_count_as_saved() -> None: + m, out = _monitor() + _press(m, "B", ts=1.0) # start + _press(m, "Y", ts=2.0) # discard + + assert out.events[-1].state == "idle" + assert out.events[-1].episodes_saved == 0 + assert out.events[-1].episodes_discarded == 1 + + +def test_start_while_recording_autocommits_previous() -> None: + # toggle (start), then an explicit start via keyboard while still recording: + # the in-progress episode auto-commits (matches the offline extractor). + m, out = _monitor(keyboard_map={"start": "r"}) + _press(m, "B", ts=1.0) # recording + m._on_keyboard(KeyPress(key="r", ts=2.0)) # start again → auto-commit prior + + assert out.events[-1].last_event == "start" + assert out.events[-1].state == "recording" + assert out.events[-1].episodes_saved == 1 # the auto-committed one + + +def test_no_event_without_rising_edge() -> None: + m, out = _monitor() + pressed = Buttons() + pressed.right_secondary = True # B held + m._on_buttons(pressed) + m._on_buttons(pressed) # still held — no new edge + assert [e.last_event for e in out.events] == ["start"] + + +def test_published_status_is_internally_consistent() -> None: + # Every published event's counters/state must match the event it carries — + # the snapshot is taken under the same lock as the mutation. + m, out = _monitor() + _press(m, "B", 1.0) # start + _press(m, "B", 2.0) # save (1) + _press(m, "B", 3.0) # start + _press(m, "B", 4.0) # save (2) + _press(m, "B", 5.0) # start + _press(m, "Y", 6.0) # discard (1) + + for e in out.events: + if e.last_event == "start": + assert e.state == "recording" + elif e.last_event in ("save", "discard"): + assert e.state == "idle" + assert out.events[-1].episodes_saved == 2 + assert out.events[-1].episodes_discarded == 1 + + +def test_reset_counters() -> None: + m, out = _monitor() + _press(m, "B", 1.0) + _press(m, "B", 2.0) + status = m.reset_counters() + assert status.episodes_saved == 0 + assert status.episodes_discarded == 0 + assert status.state == "idle" + assert status.last_event == "init" diff --git a/dimos/learning/dataprep/build.py b/dimos/learning/dataprep/build.py index 13dfc3e0b8..b673b3dd9b 100644 --- a/dimos/learning/dataprep/build.py +++ b/dimos/learning/dataprep/build.py @@ -90,7 +90,11 @@ def run_dataprep(config: DataPrepConfig) -> Path: try: logger.info("[dataprep] streams in source: %s", store.list_streams()) all_eps = extract_episodes(store, config.episodes) - episodes = [e for e in all_eps if e.success] + # Reindex survivors so sidecar ids match the writers' episode_index. + episodes = [ + e.model_copy(update={"id": f"ep_{i:06d}"}) + for i, e in enumerate(e for e in all_eps if e.success) + ] logger.info( "[dataprep] episodes extracted: %d total / %d successful", len(all_eps), @@ -116,7 +120,14 @@ def run_dataprep(config: DataPrepConfig) -> Path: config.sync.model_dump(), ) writer = get_writer(config.output.format) - logger.info("[dataprep] writing %s dataset to %s", config.output.format, config.output.path) + # fps drives written timestamps + video rate, so tie it to the resample + # rate; an explicit metadata.fps still wins. + output = config.output + if config.sync.rate_hz > 0 and "fps" not in output.metadata: + output = output.model_copy( + update={"metadata": {**output.metadata, "fps": config.sync.rate_hz}} + ) + logger.info("[dataprep] writing %s dataset to %s", config.output.format, output.path) samples_seen = 0 episodes_done = 0 @@ -145,7 +156,7 @@ def _all_samples() -> Iterator[Sample]: yield sample episodes_done += 1 - dataset_path = Path(writer(_all_samples(), config.output)) + dataset_path = Path(writer(_all_samples(), output)) _write_dimos_meta(dataset_path, config, episodes) logger.info( "[dataprep] succeeded — wrote %d samples across %d episodes to %s", diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 6119ee489b..ab26e16210 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -75,7 +75,7 @@ class SyncConfig(BaseConfig): class OutputConfig(BaseConfig): format: Literal["lerobot", "hdf5"] = "lerobot" path: Path - metadata: dict[str, Any] = {} + metadata: dict[str, Any] = Field(default_factory=dict) class DataPrepConfig(BaseConfig): @@ -88,8 +88,8 @@ class DataPrepConfig(BaseConfig): source: str = "" episodes: EpisodeExtractor = EpisodeExtractor() - observation: dict[str, StreamField] = {} - action: dict[str, StreamField] = {} + observation: dict[str, StreamField] = Field(default_factory=dict) + action: dict[str, StreamField] = Field(default_factory=dict) sync: SyncConfig = SyncConfig(anchor="image", rate_hz=30.0, tolerance_ms=50.0) output: OutputConfig = OutputConfig(format="lerobot", path=Path("data/datasets/default")) From 0b30ef1197f6203b67e0e91acff767cb6400b7c7 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 13:40:36 -0700 Subject: [PATCH 31/45] fix: add blueprints to self hosted list --- dimos/robot/test_all_blueprints.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dimos/robot/test_all_blueprints.py b/dimos/robot/test_all_blueprints.py index cdf72e9b6b..ec51e681f5 100644 --- a/dimos/robot/test_all_blueprints.py +++ b/dimos/robot/test_all_blueprints.py @@ -50,6 +50,8 @@ "coordinator-xarm6", "coordinator-xarm7", "dual-xarm6-planner", + "learning-collect-quest-piper", + "learning-collect-quest-xarm7", "teleop-hosted-go2", "teleop-hosted-xarm7", "teleop-quest-dual", From 2e613774cb9620d84d060ac812ccb93d9e60b13f Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 13:53:13 -0700 Subject: [PATCH 32/45] fix: redundant transport descriptions --- dimos/learning/collection/blueprint.py | 29 ++++---------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/dimos/learning/collection/blueprint.py b/dimos/learning/collection/blueprint.py index e920bfbe2b..27a96be586 100644 --- a/dimos/learning/collection/blueprint.py +++ b/dimos/learning/collection/blueprint.py @@ -25,20 +25,13 @@ from dimos.core.coordination.blueprints import Blueprint, autoconnect from dimos.core.global_config import global_config -from dimos.core.transport import LCMTransport, pLCMTransport from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera -from dimos.learning.collection.episode_monitor import ( - EpisodeMonitorModule, - EpisodeStatus, -) +from dimos.learning.collection.episode_monitor import EpisodeMonitorModule from dimos.learning.collection.recorder import CollectionRecorder -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.blueprints import ( teleop_quest_piper, teleop_quest_xarm7, ) -from dimos.teleop.quest.quest_types import Buttons def _session_db(robot: str) -> str: @@ -55,21 +48,14 @@ def _camera_if_real() -> tuple[Blueprint, ...]: return (RealSenseCamera.blueprint(enable_pointcloud=False),) -# Transports inline per blueprint so each recording config is self-contained. -# joint_state is declared explicitly (not left to autoconnect) so it keeps -# recording if the recorder moves to its own process. +# buttons / color_image / joint_state / status are left to autoconnect — each +# name is unique across the composed blueprint, so it resolves to a stable +# / topic shared by producer and recorder. learning_collect_quest_xarm7 = autoconnect( teleop_quest_xarm7, *_camera_if_real(), EpisodeMonitorModule.blueprint(), # default button_map: toggle=B, discard=Y CollectionRecorder.blueprint(db_path=_session_db("xarm7")), -).transports( - { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("status", EpisodeStatus): pLCMTransport("/learning/episode_status"), - } ) @@ -78,13 +64,6 @@ def _camera_if_real() -> tuple[Blueprint, ...]: *_camera_if_real(), EpisodeMonitorModule.blueprint(), # default button_map: toggle=B, discard=Y CollectionRecorder.blueprint(db_path=_session_db("piper")), -).transports( - { - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - ("color_image", Image): LCMTransport("/camera/color_image", Image), - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("status", EpisodeStatus): pLCMTransport("/learning/episode_status"), - } ) From da17bea4be3d392e509def6696c49b818dbf8560 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 14:36:12 -0700 Subject: [PATCH 33/45] feat: questaliases --- dimos/learning/collection/episode_monitor.py | 17 +---------------- .../collection/test_episode_monitor.py | 2 +- dimos/teleop/quest/quest_types.py | 19 ++++++++++++++++++- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index c2a96b3ef0..37ed1a624f 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -33,26 +33,11 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.teleop.quest.quest_types import Buttons +from dimos.teleop.quest.quest_types import BUTTON_ALIASES, Buttons from dimos.utils.logging_config import setup_logger logger = setup_logger() -# Friendly names → Quest Buttons attribute names. Override by supplying an -# attribute name directly in `button_map`. -BUTTON_ALIASES: dict[str, str] = { - "A": "right_primary", - "B": "right_secondary", - "X": "left_primary", - "Y": "left_secondary", - "LT": "left_trigger", - "RT": "right_trigger", - "LG": "left_grip", - "RG": "right_grip", - "MENU_L": "left_menu", - "MENU_R": "right_menu", -} - class EpisodeStatus(BaseModel): state: Literal["idle", "recording"] diff --git a/dimos/learning/collection/test_episode_monitor.py b/dimos/learning/collection/test_episode_monitor.py index 0a947ba311..4048fa29bf 100644 --- a/dimos/learning/collection/test_episode_monitor.py +++ b/dimos/learning/collection/test_episode_monitor.py @@ -61,7 +61,7 @@ def _monitor(**config: object) -> tuple[EpisodeMonitorModule, _CaptureOut]: def _press(monitor: EpisodeMonitorModule, alias: str, ts: float) -> None: """Rising edge: release-then-press the given Quest button alias.""" - from dimos.learning.collection.episode_monitor import BUTTON_ALIASES + from dimos.teleop.quest.quest_types import BUTTON_ALIASES attr = BUTTON_ALIASES[alias] released = Buttons() diff --git a/dimos/teleop/quest/quest_types.py b/dimos/teleop/quest/quest_types.py index 7e7cfc7620..7757a926b5 100644 --- a/dimos/teleop/quest/quest_types.py +++ b/dimos/teleop/quest/quest_types.py @@ -195,4 +195,21 @@ def from_controllers( return buttons -__all__ = ["Buttons", "QuestControllerState", "ThumbstickState"] +# Quest controller face-button labels → Buttons attribute names. Callers can +# also pass a raw attribute name (e.g. "right_grip") directly where an alias is +# accepted. +BUTTON_ALIASES: dict[str, str] = { + "A": "right_primary", + "B": "right_secondary", + "X": "left_primary", + "Y": "left_secondary", + "LT": "left_trigger", + "RT": "right_trigger", + "LG": "left_grip", + "RG": "right_grip", + "MENU_L": "left_menu", + "MENU_R": "right_menu", +} + + +__all__ = ["BUTTON_ALIASES", "Buttons", "QuestControllerState", "ThumbstickState"] From d006d623724a0bea86d71cd03697ef33a1e90f48 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 17:28:13 -0700 Subject: [PATCH 34/45] refactor: source-stamp EpisodeStatus.ts, drop redundant start_ts --- dimos/learning/collection/episode_monitor.py | 23 ++++++------- .../collection/test_episode_monitor.py | 1 - dimos/learning/dataprep/core.py | 6 ++-- dimos/learning/dataprep/test_core.py | 34 ++++++------------- 4 files changed, 24 insertions(+), 40 deletions(-) diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index 37ed1a624f..eb01fd6cc6 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -40,10 +40,10 @@ class EpisodeStatus(BaseModel): + ts: float state: Literal["idle", "recording"] episodes_saved: int episodes_discarded: int - current_episode_start_ts: float | None last_event: Literal["start", "save", "discard", "init"] = "init" task_label: str | None = None @@ -76,7 +76,6 @@ def __init__(self, **kwargs: Any) -> None: self._state: Literal["idle", "recording"] = "idle" self._saved: int = 0 self._discarded: int = 0 - self._current_start_ts: float | None = None self._last_event: Literal["start", "save", "discard", "init"] = "init" self._prev_bits: dict[str, bool] = {} # rising-edge detection for buttons self._lock = threading.Lock() @@ -90,7 +89,7 @@ def start(self) -> None: # Emit an initial idle status so subscribers (and recorders) have a # known starting point in the timeline. with self._lock: - status = self._snapshot("init") + status = self._snapshot("init", time.time()) self._emit(status) @rpc @@ -103,19 +102,18 @@ def reset_counters(self) -> EpisodeStatus: self._state = "idle" self._saved = 0 self._discarded = 0 - self._current_start_ts = None self._prev_bits = {} - status = self._snapshot("init") + status = self._snapshot("init", time.time()) return self._emit(status) @rpc def get_status(self) -> EpisodeStatus: with self._lock: return EpisodeStatus( + ts=time.time(), state=self._state, episodes_saved=self._saved, episodes_discarded=self._discarded, - current_episode_start_ts=self._current_start_ts, last_event=self._last_event, task_label=self.config.default_task_label, ) @@ -161,32 +159,31 @@ def _transition(self, event: Literal["start", "save", "discard", "toggle"], ts: event = "save" if self._state == "recording" else "start" if event == "start": # Auto-commit any in-progress episode (matches DataPrep extractor). - if self._state == "recording" and self._current_start_ts is not None: + if self._state == "recording": self._saved += 1 self._state = "recording" - self._current_start_ts = ts elif event == "save": if self._state == "recording": self._saved += 1 self._state = "idle" - self._current_start_ts = None elif event == "discard": if self._state == "recording": self._discarded += 1 self._state = "idle" - self._current_start_ts = None # Snapshot under the mutation's lock so the event matches the state. - status = self._snapshot(event) + status = self._snapshot(event, ts) self._emit(status) - def _snapshot(self, last_event: Literal["start", "save", "discard", "init"]) -> EpisodeStatus: + def _snapshot( + self, last_event: Literal["start", "save", "discard", "init"], ts: float + ) -> EpisodeStatus: """Build a status from current state. Caller must hold `self._lock`.""" self._last_event = last_event return EpisodeStatus( + ts=ts, state=self._state, episodes_saved=self._saved, episodes_discarded=self._discarded, - current_episode_start_ts=self._current_start_ts, last_event=last_event, task_label=self.config.default_task_label, ) diff --git a/dimos/learning/collection/test_episode_monitor.py b/dimos/learning/collection/test_episode_monitor.py index 4048fa29bf..e8abc3f807 100644 --- a/dimos/learning/collection/test_episode_monitor.py +++ b/dimos/learning/collection/test_episode_monitor.py @@ -50,7 +50,6 @@ def _monitor(**config: object) -> tuple[EpisodeMonitorModule, _CaptureOut]: m._state = "idle" m._saved = 0 m._discarded = 0 - m._current_start_ts = None m._last_event = "init" m._prev_bits = {} m._lock = threading.Lock() diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index ab26e16210..3bf0ec770e 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -202,9 +202,9 @@ def _commit(end_ts: float, success: bool, label: str | None) -> None: if last_event == "start": # Auto-commit any prior pending episode (success=True per state-machine spec). _commit(ts, success=True, label=pending_label) - # None check, not `or ts`: a start at absolute ts 0.0 is valid. - ep_start = getattr(ev, "current_episode_start_ts", None) - pending_start_ts = ts if ep_start is None else ep_start + # obs.ts is the press time — the recorder stamps EpisodeStatus from + # its own `.ts` field (set at the button press, not at record time). + pending_start_ts = ts pending_label = label elif last_event == "save": _commit(ts, success=True, label=pending_label or label) diff --git a/dimos/learning/dataprep/test_core.py b/dimos/learning/dataprep/test_core.py index 987e6c9d33..27119a5f7a 100644 --- a/dimos/learning/dataprep/test_core.py +++ b/dimos/learning/dataprep/test_core.py @@ -74,16 +74,12 @@ class _Status: """Mimics EpisodeStatus fields the extractor reads via getattr.""" last_event: str - current_episode_start_ts: float | None = None task_label: str | None = None -def _status(events: list[tuple[float, str, float | None, str | None]]) -> list[_Obs]: - """events = [(ts, last_event, start_ts, label), ...]""" - return [ - _Obs(ts=ts, data=_Status(last_event=ev, current_episode_start_ts=start, task_label=lbl)) - for ts, ev, start, lbl in events - ] +def _status(events: list[tuple[float, str, str | None]]) -> list[_Obs]: + """events = [(ts, last_event, label), ...]""" + return [_Obs(ts=ts, data=_Status(last_event=ev, task_label=lbl)) for ts, ev, lbl in events] # ── resolve_field ──────────────────────────────────────────────────────────── @@ -124,9 +120,7 @@ class Image: def test_extract_start_save(): - store = _FakeStore( - {"status": _status([(1.0, "start", 1.0, "pick"), (5.0, "save", None, None)])} - ) + store = _FakeStore({"status": _status([(1.0, "start", "pick"), (5.0, "save", None)])}) eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) assert len(eps) == 1 assert eps[0].start_ts == 1.0 and eps[0].end_ts == 5.0 @@ -135,9 +129,7 @@ def test_extract_start_save(): def test_extract_discard_marks_failure(): - store = _FakeStore( - {"status": _status([(1.0, "start", 1.0, None), (3.0, "discard", None, None)])} - ) + store = _FakeStore({"status": _status([(1.0, "start", None), (3.0, "discard", None)])}) eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) assert len(eps) == 1 assert eps[0].success is False @@ -149,9 +141,9 @@ def test_extract_auto_commit_on_restart(): { "status": _status( [ - (1.0, "start", 1.0, None), - (4.0, "start", 4.0, None), - (8.0, "save", None, None), + (1.0, "start", None), + (4.0, "start", None), + (8.0, "save", None), ] ) } @@ -163,25 +155,21 @@ def test_extract_auto_commit_on_restart(): def test_extract_pending_at_eof_dropped(): - store = _FakeStore({"status": _status([(1.0, "start", 1.0, None)])}) + store = _FakeStore({"status": _status([(1.0, "start", None)])}) eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) assert eps == [] def test_extract_init_and_unknown_are_noops(): store = _FakeStore( - { - "status": _status( - [(0.5, "init", None, None), (1.0, "start", 1.0, None), (5.0, "save", None, None)] - ) - } + {"status": _status([(0.5, "init", None), (1.0, "start", None), (5.0, "save", None)])} ) eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) assert len(eps) == 1 def test_extract_save_without_start_emits_nothing(): - store = _FakeStore({"status": _status([(2.0, "save", None, None)])}) + store = _FakeStore({"status": _status([(2.0, "save", None)])}) assert extract_episodes(store, EpisodeExtractor(status_stream="status")) == [] From 1a71102f93b5228a98353fbed1d74cb731266ee2 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 17:31:04 -0700 Subject: [PATCH 35/45] misc: todo for later --- dimos/learning/collection/episode_monitor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dimos/learning/collection/episode_monitor.py b/dimos/learning/collection/episode_monitor.py index eb01fd6cc6..0d194ea427 100644 --- a/dimos/learning/collection/episode_monitor.py +++ b/dimos/learning/collection/episode_monitor.py @@ -68,6 +68,8 @@ class EpisodeMonitorModule(Module): config: EpisodeMonitorModuleConfig buttons: In[Buttons] + # TODO: no KeyPress producer exists yet — add a pygame keyboard module that + # publishes KeyPress so this port is actually fed (today only buttons drive it). keyboard: In[KeyPress] status: Out[EpisodeStatus] From 8eca87300f9559d89dad322ee7b00ee96414528f Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 18:06:19 -0700 Subject: [PATCH 36/45] writer and inspector format validate --- dimos/learning/dataprep/core.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 3bf0ec770e..7e203309f3 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -36,12 +36,15 @@ from dimos.protocol.service.spec import BaseConfig -Writer = Callable[[Iterator["Sample"], "OutputConfig"], Path] - if TYPE_CHECKING: from dimos.memory2.store.sqlite import SqliteStore from dimos.memory2.stream import Stream +# A dataset format is a `formats/.py` module exposing `write` and +# `inspect` with these signatures, registered in `get_writer`/`get_inspector`. +Writer = Callable[[Iterator["Sample"], "OutputConfig"], Path] +Inspector = Callable[[Path], dict[str, Any]] + # ───────────────────────────────────────────────────────────────────────────── # Sub-configs @@ -351,7 +354,7 @@ def get_writer(format_name: str) -> Writer: return write -def get_inspector(format_name: str) -> Callable[[Path], dict[str, Any]]: +def get_inspector(format_name: str) -> Inspector: """Lazy-import the format reader's `inspect` function.""" if format_name == "lerobot": from dimos.learning.dataprep.formats.lerobot import inspect From 97c97262e0f3cc738acbb73f78050f4de9670363 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 18:25:49 -0700 Subject: [PATCH 37/45] misc: simplification nearest check --- dimos/learning/dataprep/core.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 7e203309f3..8310074233 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -68,10 +68,9 @@ class SyncConfig(BaseConfig): anchor: str rate_hz: float tolerance_ms: float - strategy: Literal["nearest", "interp"] = "nearest" - # Action = state this many frames ahead (default 1 = next-state BC). Set 0 - # for action==state. Use 0 for ACT — actions stay flat (one per frame) and - # ACT builds its own chunk at train time via delta_timestamps. + # TODO: add "interp" — do it per-stream (lerp low-dim vectors, force nearest + # for ndim>=3 images, can't blend frames). Only "nearest" is wired today. + strategy: Literal["nearest"] = "nearest" action_shift: int = 1 @@ -287,22 +286,15 @@ def _nearest(key: str, t: float) -> Any | None: ts_list, msg_list = cached[key] if not ts_list: return None + # Nearest is i (first sample ≥ t) or i-1 (last sample < t). i = bisect.bisect_left(ts_list, t) - candidates: list[int] = [] - if i < len(ts_list): - candidates.append(i) - if i > 0: - candidates.append(i - 1) - best: int | None = None - best_dt = float("inf") - for c in candidates: - dt = abs(ts_list[c] - t) - if dt < best_dt: - best = c - best_dt = dt - if best is None or best_dt > tolerance_s: - return None - return msg_list[best] + if i == 0: + best = 0 + elif i == len(ts_list): + best = i - 1 + else: + best = i if (ts_list[i] - t) < (t - ts_list[i - 1]) else i - 1 + return msg_list[best] if abs(ts_list[best] - t) <= tolerance_s else None def _build_frames() -> Iterator[Sample]: for t in targets: From 28f85abe6ed2b0305b2192969948d1acef0d1215 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 20:11:43 -0700 Subject: [PATCH 38/45] misc: comments instructions --- dimos/learning/dataprep/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 8310074233..202fd0d706 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -272,6 +272,8 @@ def iter_episode_samples( # Build the sequence of target timestamps for this episode. if sync.rate_hz > 0: + # Uniform 1/rate_hz grid, phase-locked to the first anchor sample — + # what LeRobot expects (it assumes contiguous fixed-fps frames). period = 1.0 / sync.rate_hz targets: list[float] = [] t = anchor_ts[0] @@ -280,6 +282,9 @@ def iter_episode_samples( targets.append(t) t += period else: + # rate_hz=0: follow the anchor's own timestamps (no image resampling). + # dt is irregular if the camera jitters — fine for hdf5/custom trainers, + # but not LeRobot-uniform. targets = list(anchor_ts) def _nearest(key: str, t: float) -> Any | None: From c69d4dcc3d1727dd2a337c36010461da0de0a459 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Thu, 18 Jun 2026 21:03:44 -0700 Subject: [PATCH 39/45] fix: None retun for tests --- .../collection/test_episode_monitor.py | 10 ++--- dimos/learning/dataprep/formats/test_hdf5.py | 7 +-- .../learning/dataprep/formats/test_lerobot.py | 7 +-- dimos/learning/dataprep/formats/test_stats.py | 12 ++--- dimos/learning/dataprep/test_core.py | 45 ++++++++++--------- 5 files changed, 41 insertions(+), 40 deletions(-) diff --git a/dimos/learning/collection/test_episode_monitor.py b/dimos/learning/collection/test_episode_monitor.py index e8abc3f807..328bb04c75 100644 --- a/dimos/learning/collection/test_episode_monitor.py +++ b/dimos/learning/collection/test_episode_monitor.py @@ -31,10 +31,10 @@ EpisodeStatus, KeyPress, ) -from dimos.teleop.quest.quest_types import Buttons +from dimos.teleop.quest.quest_types import BUTTON_ALIASES, Buttons -class _CaptureOut: +class FakeStatusOut: """Stand-in for the `status` Out port that records published events.""" def __init__(self) -> None: @@ -44,7 +44,7 @@ def publish(self, status: EpisodeStatus) -> None: self.events.append(status) -def _monitor(**config: object) -> tuple[EpisodeMonitorModule, _CaptureOut]: +def _monitor(**config: object) -> tuple[EpisodeMonitorModule, FakeStatusOut]: m = EpisodeMonitorModule.__new__(EpisodeMonitorModule) m.config = EpisodeMonitorModuleConfig(**config) # type: ignore[assignment] m._state = "idle" @@ -53,15 +53,13 @@ def _monitor(**config: object) -> tuple[EpisodeMonitorModule, _CaptureOut]: m._last_event = "init" m._prev_bits = {} m._lock = threading.Lock() - out = _CaptureOut() + out = FakeStatusOut() m.status = out # type: ignore[assignment] return m, out def _press(monitor: EpisodeMonitorModule, alias: str, ts: float) -> None: """Rising edge: release-then-press the given Quest button alias.""" - from dimos.teleop.quest.quest_types import BUTTON_ALIASES - attr = BUTTON_ALIASES[alias] released = Buttons() pressed = Buttons() diff --git a/dimos/learning/dataprep/formats/test_hdf5.py b/dimos/learning/dataprep/formats/test_hdf5.py index 54a0a2486f..336b7674be 100644 --- a/dimos/learning/dataprep/formats/test_hdf5.py +++ b/dimos/learning/dataprep/formats/test_hdf5.py @@ -23,6 +23,7 @@ from __future__ import annotations from collections.abc import Iterator +from pathlib import Path import numpy as np import pytest @@ -45,7 +46,7 @@ def _samples(n_episodes: int = 2, n_frames: int = 3) -> Iterator[Sample]: ) -def test_hdf5_roundtrip_counts_and_shapes(tmp_path): +def test_hdf5_roundtrip_counts_and_shapes(tmp_path: Path) -> None: out = OutputConfig( format="hdf5", path=tmp_path / "session", @@ -68,14 +69,14 @@ def test_hdf5_roundtrip_counts_and_shapes(tmp_path): assert info["episode_lengths"] == {"min": 3, "max": 3, "mean": 3.0, "uniform": True} -def test_hdf5_extension_appended_when_missing(tmp_path): +def test_hdf5_extension_appended_when_missing(tmp_path: Path) -> None: # path with no suffix → writer appends .hdf5 out = OutputConfig(format="hdf5", path=tmp_path / "noext") path = write(_samples(n_episodes=1, n_frames=2), out) assert path.name == "noext.hdf5" -def test_hdf5_stats_values_match(tmp_path): +def test_hdf5_stats_values_match(tmp_path: Path) -> None: out = OutputConfig(format="hdf5", path=tmp_path / "s.hdf5") path = write(_samples(n_episodes=1, n_frames=3), out) with h5py.File(path, "r") as f: diff --git a/dimos/learning/dataprep/formats/test_lerobot.py b/dimos/learning/dataprep/formats/test_lerobot.py index f8b1fdd7a3..cba373c8cb 100644 --- a/dimos/learning/dataprep/formats/test_lerobot.py +++ b/dimos/learning/dataprep/formats/test_lerobot.py @@ -25,6 +25,7 @@ from collections.abc import Iterator import json +from pathlib import Path import numpy as np import pytest @@ -59,7 +60,7 @@ def _image_samples(n: int = 4) -> Iterator[Sample]: ) -def test_lerobot_state_only_layout_and_naming(tmp_path): +def test_lerobot_state_only_layout_and_naming(tmp_path: Path) -> None: out = OutputConfig( format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0, "robot": "xarm7"} ) @@ -80,7 +81,7 @@ def test_lerobot_state_only_layout_and_naming(tmp_path): assert "action" in info["features"] -def test_lerobot_inspect_state_only(tmp_path): +def test_lerobot_inspect_state_only(tmp_path: Path) -> None: out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) root = write(_state_samples(), out) info = inspect(root) @@ -92,7 +93,7 @@ def test_lerobot_inspect_state_only(tmp_path): assert info["has_stats"] is True -def test_lerobot_with_images_writes_mp4_and_video_feature(tmp_path): +def test_lerobot_with_images_writes_mp4_and_video_feature(tmp_path: Path) -> None: out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) try: root = write(_image_samples(), out) diff --git a/dimos/learning/dataprep/formats/test_stats.py b/dimos/learning/dataprep/formats/test_stats.py index 012d12dd9a..d04cd6dc17 100644 --- a/dimos/learning/dataprep/formats/test_stats.py +++ b/dimos/learning/dataprep/formats/test_stats.py @@ -26,7 +26,7 @@ from dimos.learning.dataprep.formats._stats import StreamingStats -def test_scalar_mean_std_minmax_count(): +def test_scalar_mean_std_minmax_count() -> None: s = StreamingStats() for v in ([1.0, 10.0], [2.0, 20.0], [3.0, 30.0]): s.update("state", np.array(v)) @@ -39,7 +39,7 @@ def test_scalar_mean_std_minmax_count(): assert out["count"] == 3 -def test_single_sample_has_zero_std(): +def test_single_sample_has_zero_std() -> None: s = StreamingStats() s.update("x", np.array([5.0, 7.0])) out = s.finalize()["x"] @@ -47,7 +47,7 @@ def test_single_sample_has_zero_std(): assert out["count"] == 1 -def test_lowdim_quantiles_present_and_bounded(): +def test_lowdim_quantiles_present_and_bounded() -> None: s = StreamingStats() for i in range(100): s.update("x", np.array([float(i)])) @@ -56,7 +56,7 @@ def test_lowdim_quantiles_present_and_bounded(): assert out["min"][0] <= out["q01"][0] <= out["q99"][0] <= out["max"][0] -def test_image_reduced_to_per_channel_no_quantiles(): +def test_image_reduced_to_per_channel_no_quantiles() -> None: # image_subsample=1 → every frame counts. Constant per-channel values. s = StreamingStats(image_subsample=1) img = np.zeros((4, 4, 3), dtype=np.uint8) @@ -72,7 +72,7 @@ def test_image_reduced_to_per_channel_no_quantiles(): assert "q01" not in out and "q99" not in out -def test_image_subsampling_counts_every_nth_frame(): +def test_image_subsampling_counts_every_nth_frame() -> None: s = StreamingStats(image_subsample=10) img = np.zeros((2, 2, 3), dtype=np.uint8) for _ in range(25): @@ -81,5 +81,5 @@ def test_image_subsampling_counts_every_nth_frame(): assert s.finalize()["cam"]["count"] == 3 -def test_empty_aggregator_finalizes_empty(): +def test_empty_aggregator_finalizes_empty() -> None: assert StreamingStats().finalize() == {} diff --git a/dimos/learning/dataprep/test_core.py b/dimos/learning/dataprep/test_core.py index 27119a5f7a..c522529428 100644 --- a/dimos/learning/dataprep/test_core.py +++ b/dimos/learning/dataprep/test_core.py @@ -22,6 +22,7 @@ from __future__ import annotations from dataclasses import dataclass +from pathlib import Path from typing import Any import numpy as np @@ -85,7 +86,7 @@ def _status(events: list[tuple[float, str, str | None]]) -> list[_Obs]: # ── resolve_field ──────────────────────────────────────────────────────────── -def test_resolve_field_attribute(): +def test_resolve_field_attribute() -> None: @dataclass class Msg: position: list[float] @@ -95,18 +96,18 @@ class Msg: np.testing.assert_array_equal(arr, np.array([1.0, 2.0, 3.0])) -def test_resolve_field_dict_payload(): +def test_resolve_field_dict_payload() -> None: arr = resolve_field({"q": [4, 5]}, StreamField(stream="x", field="q")) np.testing.assert_array_equal(arr, np.array([4, 5])) -def test_resolve_field_none_passthrough_ndarray(): +def test_resolve_field_none_passthrough_ndarray() -> None: src = np.arange(6).reshape(2, 3) out = resolve_field(src, StreamField(stream="x", field=None)) assert out is src # ndarray passes straight through -def test_resolve_field_none_unwraps_data_attr(): +def test_resolve_field_none_unwraps_data_attr() -> None: @dataclass class Image: data: np.ndarray @@ -119,7 +120,7 @@ class Image: # ── extract_episodes: episode_status ───────────────────────────────────────── -def test_extract_start_save(): +def test_extract_start_save() -> None: store = _FakeStore({"status": _status([(1.0, "start", "pick"), (5.0, "save", None)])}) eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) assert len(eps) == 1 @@ -128,14 +129,14 @@ def test_extract_start_save(): assert eps[0].task_label == "pick" -def test_extract_discard_marks_failure(): +def test_extract_discard_marks_failure() -> None: store = _FakeStore({"status": _status([(1.0, "start", None), (3.0, "discard", None)])}) eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) assert len(eps) == 1 assert eps[0].success is False -def test_extract_auto_commit_on_restart(): +def test_extract_auto_commit_on_restart() -> None: # start, then another start without save → first auto-commits (success=True) store = _FakeStore( { @@ -154,13 +155,13 @@ def test_extract_auto_commit_on_restart(): assert eps[1].start_ts == 4.0 and eps[1].end_ts == 8.0 -def test_extract_pending_at_eof_dropped(): +def test_extract_pending_at_eof_dropped() -> None: store = _FakeStore({"status": _status([(1.0, "start", None)])}) eps = extract_episodes(store, EpisodeExtractor(status_stream="status")) assert eps == [] -def test_extract_init_and_unknown_are_noops(): +def test_extract_init_and_unknown_are_noops() -> None: store = _FakeStore( {"status": _status([(0.5, "init", None), (1.0, "start", None), (5.0, "save", None)])} ) @@ -168,7 +169,7 @@ def test_extract_init_and_unknown_are_noops(): assert len(eps) == 1 -def test_extract_save_without_start_emits_nothing(): +def test_extract_save_without_start_emits_nothing() -> None: store = _FakeStore({"status": _status([(2.0, "save", None)])}) assert extract_episodes(store, EpisodeExtractor(status_stream="status")) == [] @@ -176,13 +177,13 @@ def test_extract_save_without_start_emits_nothing(): # ── extract_episodes: ranges ───────────────────────────────────────────────── -def test_extract_ranges(): +def test_extract_ranges() -> None: cfg = EpisodeExtractor(extractor="ranges", ranges=[(0.0, 1.0), (2.0, 3.0)]) eps = extract_episodes(_FakeStore({}), cfg) assert [(e.start_ts, e.end_ts) for e in eps] == [(0.0, 1.0), (2.0, 3.0)] -def test_extract_ranges_empty(): +def test_extract_ranges_empty() -> None: cfg = EpisodeExtractor(extractor="ranges", ranges=None) assert extract_episodes(_FakeStore({}), cfg) == [] @@ -200,7 +201,7 @@ class S: return [_Obs(ts=ts, data=S(position=[v])) for ts, v in values] -def test_sync_basic_no_shift(): +def test_sync_basic_no_shift() -> None: # obs == action, shift disabled → one sample per anchor target store = _FakeStore( { @@ -221,7 +222,7 @@ def test_sync_basic_no_shift(): np.testing.assert_array_equal(samples[0].observation["state"], samples[0].action["act"]) -def test_sync_action_shift_next_state(): +def test_sync_action_shift_next_state() -> None: store = _FakeStore({"js": _scalar_stream([(0.0, 10.0), (1.0, 11.0), (2.0, 12.0)])}) ep = Episode(id="ep_0", start_ts=0.0, end_ts=2.0) streams = { @@ -241,7 +242,7 @@ def test_sync_action_shift_next_state(): np.testing.assert_array_equal(samples[1].action["act"], [12.0]) -def test_sync_tolerance_skips_unmatched_frame(): +def test_sync_tolerance_skips_unmatched_frame() -> None: # anchor ticks every 1s, but the second stream has a big gap around t=1 store = _FakeStore( { @@ -260,7 +261,7 @@ def test_sync_tolerance_skips_unmatched_frame(): assert [round(s.ts) for s in samples] == [0, 2] -def test_sync_missing_anchor_raises(): +def test_sync_missing_anchor_raises() -> None: ep = Episode(id="ep_0", start_ts=0.0, end_ts=1.0) streams = {"x": StreamField(stream="x", field="position")} sync = SyncConfig(anchor="not_there", rate_hz=1.0, tolerance_ms=10.0) @@ -268,7 +269,7 @@ def test_sync_missing_anchor_raises(): list(iter_episode_samples(_FakeStore({}), ep, streams, sync)) -def test_sync_empty_anchor_yields_nothing(): +def test_sync_empty_anchor_yields_nothing() -> None: store = _FakeStore({"a": []}) ep = Episode(id="ep_0", start_ts=0.0, end_ts=1.0) streams = {"a": StreamField(stream="a", field="position")} @@ -279,23 +280,23 @@ def test_sync_empty_anchor_yields_nothing(): # ── summarize_lengths ──────────────────────────────────────────────────────── -def test_summarize_lengths_uniform(): +def test_summarize_lengths_uniform() -> None: assert summarize_lengths([5, 5, 5]) == {"min": 5, "max": 5, "mean": 5.0, "uniform": True} -def test_summarize_lengths_varied(): +def test_summarize_lengths_varied() -> None: s = summarize_lengths([2, 4, 6]) assert s == {"min": 2, "max": 6, "mean": 4.0, "uniform": False} -def test_summarize_lengths_empty(): +def test_summarize_lengths_empty() -> None: assert summarize_lengths([]) == {"min": 0, "max": 0, "mean": 0.0, "uniform": True} # ── dimos_meta sidecar ─────────────────────────────────────────────────────── -def test_dimos_meta_records_sync_and_action_shift(tmp_path): +def test_dimos_meta_records_sync_and_action_shift(tmp_path: Path) -> None: import json from dimos.learning.dataprep.build import _write_dimos_meta @@ -315,7 +316,7 @@ def test_dimos_meta_records_sync_and_action_shift(tmp_path): assert meta["source"] == "s.db" -def test_dimos_meta_beside_file_for_hdf5(tmp_path): +def test_dimos_meta_beside_file_for_hdf5(tmp_path: Path) -> None: """hdf5 writer returns a FILE path; the sidecar must land beside it, not inside it (which would treat the .hdf5 file as a directory and crash).""" import json From a1497dc0b8f2296f429c711190360a55e78e7560 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Fri, 19 Jun 2026 14:10:15 -0700 Subject: [PATCH 40/45] feat: lerobot v3.0 --- dimos/learning/dataprep/formats/lerobot.py | 468 +++++++++++------- .../learning/dataprep/formats/test_lerobot.py | 75 ++- pyproject.toml | 3 +- 3 files changed, 349 insertions(+), 197 deletions(-) diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 8cf4c49711..73a21996d3 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -12,23 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LeRobot v2 dataset writer. +"""LeRobot v3.0 dataset writer. + +v3.0 differs structurally from v2.x: instead of one parquet + one MP4 *per +episode*, episodes are **concatenated** into shared chunked files, and all +per-episode bookkeeping (frame/byte ranges, video time offsets, per-episode +stats) moves into an episodes *parquet*. Layout:: / - meta/info.json schema, fps, total episodes/frames, features - meta/episodes.jsonl per-episode metadata - meta/tasks.jsonl task descriptions for language conditioning - meta/stats.json per-feature mean/std/min/max/q01/q99 - data/chunk-000/episode_NNNNNN.parquet - videos/chunk-000/observation.images./episode_NNNNNN.mp4 - -Single pass: streams samples to disk per-episode and accumulates stats in -parallel. Image frames go to MP4 (one per camera, per episode); their -columns are excluded from the parquet — lerobot loads them from MP4 at -``__getitem__`` time using the ``video_path`` template + episode_index + -timestamp. + meta/info.json schema, fps, totals, features + meta/tasks.parquet task strings (indexed by `task`) + meta/stats.json aggregated per-feature stats + meta/episodes/chunk-000/file-000.parquet one row per episode (+ stats) + data/chunk-000/file-000.parquet ALL episodes' frames concatenated + videos//chunk-000/file-000.mp4 ALL episodes for a camera, concatenated + +This writer emits a **single** data file and a single MP4 per camera (chunk +000 / file 000); LeRobot supports multi-file rolling at size limits, which we +don't need yet (logged if a soft limit is exceeded). A frame's `timestamp` is +relative to its episode; the episode's `videos//from_timestamp` gives its +offset inside the shared MP4, so `from_timestamp + timestamp` locates the frame. """ from __future__ import annotations @@ -42,17 +47,27 @@ from dimos.learning.dataprep.core import OutputConfig, Sample from dimos.learning.dataprep.formats._stats import StreamingStats +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() CHUNK = "chunk-000" +FILE = "file-000" DATA_DIR = "data" VIDEO_DIR = "videos" META_DIR = "meta" +EPISODES_DIR = "episodes" + +# LeRobot defaults; we write a single file but warn past these soft limits. +DATA_FILE_SIZE_MB = 100 +VIDEO_FILE_SIZE_MB = 200 +CHUNKS_SIZE = 1000 def _feature_name( prefix: str, key: str, is_image: bool, single_action: bool, single_state: bool = False ) -> str: - """Translate (prefix, key) into the LeRobot v2 feature name. + """Translate (prefix, key) into the LeRobot feature name. Canonical names lerobot policies (ACT, Diffusion, π₀) expect: observation.state single proprio vector @@ -71,10 +86,35 @@ def _feature_name( return f"action.{key}" -def write(samples: Iterator[Sample], output: OutputConfig) -> Path: - """Drain `samples`, write parquet+MP4+meta in LeRobot v2 layout. - Returns the dataset root path. +def _nest_image_stat(vals: list[float]) -> list[list[list[float]]]: + """Per-channel [c0,c1,c2] → shape (C,1,1) [[[c0]],[[c1]],[[c2]]] (lerobot image stats).""" + return [[[float(c)]] for c in vals] + + +def _flatten_episode_stats( + final: dict[str, dict[str, Any]], feature_dtypes: dict[str, str] +) -> dict[str, Any]: + """Flatten a per-episode StreamingStats result into ``stats//`` columns. + + Image features get the (C,1,1) nesting lerobot expects; low-dim stay flat. """ + out: dict[str, Any] = {} + for feat, entry in final.items(): + is_video = feature_dtypes.get(feat) == "video" + for k in ("mean", "std", "min", "max"): + v = entry.get(k) + if v is None: + continue + out[f"stats/{feat}/{k}"] = _nest_image_stat(v) if is_video else v + out[f"stats/{feat}/count"] = int(entry["count"]) + for q in ("q01", "q99"): + if q in entry: + out[f"stats/{feat}/{q}"] = _nest_image_stat(entry[q]) if is_video else entry[q] + return out + + +def write(samples: Iterator[Sample], output: OutputConfig) -> Path: + """Drain `samples`, write a LeRobot v3.0 dataset. Returns the dataset root path.""" try: import cv2 except ImportError as e: @@ -84,181 +124,236 @@ def write(samples: Iterator[Sample], output: OutputConfig) -> Path: import pyarrow.parquet as pq except ImportError as e: raise RuntimeError("LeRobot writer requires pyarrow for parquet writes") from e + try: + import pandas as pd + except ImportError as e: + raise RuntimeError("LeRobot writer requires pandas for tasks.parquet") from e root = Path(output.path) - (root / META_DIR).mkdir(parents=True, exist_ok=True) + (root / META_DIR / EPISODES_DIR / CHUNK).mkdir(parents=True, exist_ok=True) (root / DATA_DIR / CHUNK).mkdir(parents=True, exist_ok=True) - (root / VIDEO_DIR / CHUNK).mkdir(parents=True, exist_ok=True) fps = float(output.metadata.get("fps", 30.0)) fourcc = cv2.VideoWriter.fourcc(*"mp4v") + default_task_label = output.metadata.get("default_task_label", "task") - stats = StreamingStats( - image_subsample=int(output.metadata.get("image_subsample", 10)), - quantile_reservoir=int(output.metadata.get("quantile_reservoir", 10_000)), - seed=int(output.metadata.get("stats_seed", 0)), - ) + def _stats() -> StreamingStats: + return StreamingStats( + image_subsample=int(output.metadata.get("image_subsample", 10)), + quantile_reservoir=int(output.metadata.get("quantile_reservoir", 10_000)), + seed=int(output.metadata.get("stats_seed", 0)), + ) - image_keys: set[str] = set() + global_stats = _stats() # aggregated across all frames → meta/stats.json + + # Schema discovery (filled as samples flow). + image_keys: list[str] = [] state_keys: list[str] = [] action_keys: list[str] = [] feature_shapes: dict[str, tuple[int, ...]] = {} feature_dtypes: dict[str, str] = {} - episodes_meta: list[dict[str, Any]] = [] tasks_index: dict[str, int] = {} - default_task_label = output.metadata.get("default_task_label", "task") + episode_rows: list[dict[str, Any]] = [] + + # Single concatenated data file (opened on first flush). + data_path = root / DATA_DIR / CHUNK / f"{FILE}.parquet" + data_writer: Any = None + + # One MP4 per camera, persisting across episodes; from/to timestamps per episode. + video_writers: dict[str, Any] = {} + video_cum_frames: dict[str, int] = {} # frames written per camera so far - current_episode_id: str | None = None - current_episode_index = 0 - current_frames: list[dict[str, Any]] = [] - current_video_writers: dict[str, Any] = {} global_index = 0 + episode_index = -1 - def _episode_path_parquet(ep_idx: int) -> Path: - return root / DATA_DIR / CHUNK / f"episode_{ep_idx:06d}.parquet" + # Per-episode buffers. + cur_id: str | None = None + cur_rows: list[dict[str, Any]] = [] + cur_ep_stats = _stats() - def _episode_path_video(image_key: str, ep_idx: int) -> Path: - feat_name = _feature_name("observation", image_key, is_image=True, single_action=False) - d = root / VIDEO_DIR / CHUNK / feat_name + def _video_path(image_key: str) -> Path: + feat = _feature_name("observation", image_key, is_image=True, single_action=False) + d = root / VIDEO_DIR / feat / CHUNK d.mkdir(parents=True, exist_ok=True) - return d / f"episode_{ep_idx:06d}.mp4" + return d / f"{FILE}.mp4" - def _open_video(image_key: str, ep_idx: int, frame: np.ndarray) -> Any: - # Frames are RGB→BGR converted at write time (see the write loop below). + def _open_video(image_key: str, frame: np.ndarray) -> Any: h, w = frame.shape[:2] - path = _episode_path_video(image_key, ep_idx) - writer = cv2.VideoWriter(str(path), fourcc, fps, (w, h)) - if not writer.isOpened(): + path = _video_path(image_key) + vw = cv2.VideoWriter(str(path), fourcc, fps, (w, h)) + if not vw.isOpened(): raise RuntimeError(f"Failed to open VideoWriter for {path}") - return writer - - def _flush_episode() -> bool: - nonlocal current_frames, current_video_writers, current_episode_index - if not current_frames: - return False - for vw in current_video_writers.values(): - vw.release() - current_video_writers = {} - - cols: dict[str, list[Any]] = { - "timestamp": [f["timestamp"] for f in current_frames], - "frame_index": [f["frame_index"] for f in current_frames], - "episode_index": [f["episode_index"] for f in current_frames], - "index": [f["index"] for f in current_frames], - "task_index": [f["task_index"] for f in current_frames], + return vw + + def _flush_episode() -> None: + nonlocal data_writer + if not cur_rows: + return + length = len(cur_rows) + single_state = len(state_keys) == 1 + single_action = len(action_keys) == 1 + + cols: dict[str, Any] = { + "timestamp": pa.array([r["timestamp"] for r in cur_rows], pa.float32()), + "frame_index": pa.array([r["frame_index"] for r in cur_rows], pa.int64()), + "episode_index": pa.array([r["episode_index"] for r in cur_rows], pa.int64()), + "index": pa.array([r["index"] for r in cur_rows], pa.int64()), + "task_index": pa.array([r["task_index"] for r in cur_rows], pa.int64()), } f32_list = pa.list_(pa.float32()) - single_state = len(state_keys) == 1 for k in state_keys: - name = _feature_name( - "observation", k, is_image=False, single_action=False, single_state=single_state - ) - cols[name] = pa.array([f["obs"][k].tolist() for f in current_frames], type=f32_list) - single_action = len(action_keys) == 1 + name = _feature_name("observation", k, False, False, single_state=single_state) + cols[name] = pa.array([r["obs"][k].tolist() for r in cur_rows], type=f32_list) for k in action_keys: - name = _feature_name("action", k, is_image=False, single_action=single_action) - cols[name] = pa.array([f["act"][k].tolist() for f in current_frames], type=f32_list) - # Video columns intentionally omitted: lerobot's hf_features schema - # skips dtype="video" and reads frames from MP4 at __getitem__ time. - + name = _feature_name("action", k, False, single_action=single_action) + cols[name] = pa.array([r["act"][k].tolist() for r in cur_rows], type=f32_list) table = pa.Table.from_pydict(cols) - pq.write_table(table, _episode_path_parquet(current_episode_index)) - - episodes_meta.append( + if data_writer is None: + data_writer = pq.ParquetWriter(data_path, table.schema, compression="snappy") + data_writer.write_table(table) + + # Episode metadata row. + row: dict[str, Any] = { + "episode_index": episode_index, + "tasks": [list(tasks_index.keys())[cur_rows[0]["task_index"]]], + "length": length, + "data/chunk_index": 0, + "data/file_index": 0, + "dataset_from_index": global_index - length, + "dataset_to_index": global_index, + "meta/episodes/chunk_index": 0, + "meta/episodes/file_index": 0, + } + for k in image_keys: + feat = _feature_name("observation", k, is_image=True, single_action=False) + cum = video_cum_frames.get(k, 0) + row[f"videos/{feat}/chunk_index"] = 0 + row[f"videos/{feat}/file_index"] = 0 + row[f"videos/{feat}/from_timestamp"] = (cum - length) / fps + row[f"videos/{feat}/to_timestamp"] = cum / fps + row.update(_flatten_episode_stats(cur_ep_stats.finalize(), feature_dtypes)) + episode_rows.append(row) + cur_rows.clear() + + for sample in samples: + if sample.episode_id != cur_id: + _flush_episode() + cur_id = sample.episode_id + episode_index += 1 + cur_ep_stats = _stats() + if default_task_label not in tasks_index: + tasks_index[default_task_label] = len(tasks_index) + + # Schema discovery + stats (global + per-episode). + n_low_dim_obs = sum(1 for v in sample.observation.values() if np.asarray(v).ndim < 3) + single_state = n_low_dim_obs == 1 + for k, arr in sample.observation.items(): + a = np.asarray(arr) + is_image = a.ndim >= 3 + name = _feature_name("observation", k, is_image, False, single_state=single_state) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = "video" if is_image else str(a.dtype) + if is_image: + if k not in image_keys: + image_keys.append(k) + elif k not in state_keys: + state_keys.append(k) + global_stats.update(name, a) + cur_ep_stats.update(name, a) + single_action = len(sample.action) == 1 + for k, arr in sample.action.items(): + a = np.asarray(arr) + name = _feature_name("action", k, is_image=False, single_action=single_action) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = str(a.dtype) + if k not in action_keys: + action_keys.append(k) + global_stats.update(name, a) + cur_ep_stats.update(name, a) + + # Append image frames to the per-camera MP4 (RGB→BGR; cv2 is BGR-native). + for k, arr in sample.observation.items(): + a = np.asarray(arr) + if a.ndim >= 3: + if k not in video_writers: + video_writers[k] = _open_video(k, a) + bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a + video_writers[k].write(bgr) + video_cum_frames[k] = video_cum_frames.get(k, 0) + 1 + + frame_index = len(cur_rows) + cur_rows.append( { - "episode_index": current_episode_index, - "tasks": [list(tasks_index.keys())[current_frames[0]["task_index"]]], - "length": len(current_frames), + "timestamp": frame_index / fps, # relative to this episode + "frame_index": frame_index, + "episode_index": episode_index, + "index": global_index, + "task_index": tasks_index[default_task_label], + "obs": { + k: np.asarray(v) + for k, v in sample.observation.items() + if np.asarray(v).ndim < 3 + }, + "act": {k: np.asarray(v) for k, v in sample.action.items()}, } ) - current_frames = [] - return True + global_index += 1 - try: - for sample in samples: - if sample.episode_id != current_episode_id: - if _flush_episode(): - current_episode_index += 1 - current_episode_id = sample.episode_id - label = default_task_label - if label not in tasks_index: - tasks_index[label] = len(tasks_index) - - # Schema discovery + stats accumulation. - n_low_dim_obs = sum(1 for _, v in sample.observation.items() if np.asarray(v).ndim < 3) - single_state = n_low_dim_obs == 1 - for k, arr in sample.observation.items(): - a = np.asarray(arr) - is_image = a.ndim >= 3 - name = _feature_name( - "observation", - k, - is_image=is_image, - single_action=False, - single_state=single_state, - ) - if name not in feature_shapes: - feature_shapes[name] = tuple(a.shape) - feature_dtypes[name] = "video" if is_image else str(a.dtype) - if is_image: - image_keys.add(k) - elif k not in state_keys: - state_keys.append(k) - stats.update(name, a) - - for k, arr in sample.action.items(): - a = np.asarray(arr) - single_action = len(sample.action) == 1 - name = _feature_name("action", k, is_image=False, single_action=single_action) - if name not in feature_shapes: - feature_shapes[name] = tuple(a.shape) - feature_dtypes[name] = str(a.dtype) - if k not in action_keys: - action_keys.append(k) - stats.update(name, a) - - # Video frame write + parquet row buffer. - frame_index = len(current_frames) - for k, arr in sample.observation.items(): - a = np.asarray(arr) - if a.ndim >= 3: - if k not in current_video_writers: - current_video_writers[k] = _open_video(k, current_episode_index, a) - # Frames are RGB; cv2.VideoWriter is BGR-native — convert or - # the MP4 decodes color-swapped. - bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a - current_video_writers[k].write(bgr) - - rel_ts = frame_index / fps - current_frames.append( - { - "timestamp": rel_ts, - "frame_index": frame_index, - "episode_index": current_episode_index, - "index": global_index, - "task_index": tasks_index[default_task_label], - "obs": { - k: np.asarray(v) - for k, v in sample.observation.items() - if np.asarray(v).ndim < 3 - }, - "act": {k: np.asarray(v) for k, v in sample.action.items()}, - } - ) - global_index += 1 - - _flush_episode() - finally: - # If the drain raised mid-episode, release any writers still open so we - # don't leak file handles / leave half-written MP4s locked. - for vw in current_video_writers.values(): - vw.release() - - # ── meta files ─────────────────────────────────────────────────────────── - total_episodes = len(episodes_meta) + _flush_episode() + if data_writer is not None: + data_writer.close() + for vw in video_writers.values(): + vw.release() + + total_episodes = len(episode_rows) total_frames = global_index + if data_path.exists() and data_path.stat().st_size > DATA_FILE_SIZE_MB * 1e6: + logger.warning( + "[dataprep] data file exceeds %d MB (single-file writer, no rolling): %s", + DATA_FILE_SIZE_MB, + data_path, + ) + _write_meta( + root, + fps=fps, + total_episodes=total_episodes, + total_frames=total_frames, + feature_shapes=feature_shapes, + feature_dtypes=feature_dtypes, + image_keys=image_keys, + tasks_index=tasks_index, + episode_rows=episode_rows, + global_stats=global_stats, + robot=output.metadata.get("robot", "unknown"), + pa=pa, + pq=pq, + pd=pd, + ) + return root + + +def _write_meta( + root: Path, + *, + fps: float, + total_episodes: int, + total_frames: int, + feature_shapes: dict[str, tuple[int, ...]], + feature_dtypes: dict[str, str], + image_keys: list[str], + tasks_index: dict[str, int], + episode_rows: list[dict[str, Any]], + global_stats: StreamingStats, + robot: str, + pa: Any, + pq: Any, + pd: Any, +) -> None: + """Write info.json, tasks.parquet, episodes parquet, and aggregated stats.json.""" features: dict[str, Any] = {} for name, shape in feature_shapes.items(): if feature_dtypes[name] == "video": @@ -278,7 +373,6 @@ def _flush_episode() -> bool: }, } else: - # Per-dim names; downstream loaders only require len(names) == shape[0]. n = int(shape[0]) if shape else 0 base = name.split(".")[-1] features[name] = { @@ -296,45 +390,54 @@ def _flush_episode() -> bool: features[col] = {"dtype": dt, "shape": [1], "names": None} info = { - "codebase_version": "v2.0", - "robot_type": output.metadata.get("robot", "unknown"), + "codebase_version": "v3.0", + "robot_type": robot, "total_episodes": total_episodes, "total_frames": total_frames, "total_tasks": len(tasks_index), - "total_videos": total_episodes * len(image_keys), - "total_chunks": 1, - "chunks_size": max(1, total_episodes), + "chunks_size": CHUNKS_SIZE, + "data_files_size_in_mb": DATA_FILE_SIZE_MB, + "video_files_size_in_mb": VIDEO_FILE_SIZE_MB, "fps": fps, "splits": {"train": f"0:{total_episodes}"}, - "data_path": "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet", - "video_path": "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4", + "data_path": "data/chunk-{chunk_index:03d}/file-{file_index:03d}.parquet", + "video_path": "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4", "features": features, } with open(root / META_DIR / "info.json", "w") as f: json.dump(info, f, indent=2) - with open(root / META_DIR / "episodes.jsonl", "w") as f: - for ep in episodes_meta: - f.write(json.dumps(ep) + "\n") - with open(root / META_DIR / "tasks.jsonl", "w") as f: - for task, idx in tasks_index.items(): - f.write(json.dumps({"task_index": idx, "task": task}) + "\n") - final_stats = stats.finalize() + + # tasks.parquet — task strings as the (named) index + a task_index column. + tasks_df = pd.DataFrame( + {"task_index": list(tasks_index.values())}, + index=pd.Index(list(tasks_index.keys()), name="task"), + ) + tasks_df.to_parquet(root / META_DIR / "tasks.parquet") + + # episodes parquet — one row per episode (+ flattened per-episode stats). + ep_table = pa.Table.from_pylist(episode_rows) + pq.write_table( + ep_table, root / META_DIR / EPISODES_DIR / CHUNK / f"{FILE}.parquet", compression="snappy" + ) + + # Aggregated stats.json (image features nested to (C,1,1)). + final_stats = global_stats.finalize() for name, entry in final_stats.items(): if feature_dtypes.get(name) == "video": for k in ("mean", "std", "min", "max"): if entry.get(k) is not None: - entry[k] = [[[c]] for c in entry[k]] + entry[k] = _nest_image_stat(entry[k]) with open(root / META_DIR / "stats.json", "w") as f: json.dump(final_stats, f, indent=2) - return root - _META_COLS = {"timestamp", "frame_index", "episode_index", "index", "task_index"} def inspect(path: Path) -> dict[str, Any]: - """Summarize a LeRobot v2 dataset from its meta/ files (no parquet load).""" + """Summarize a LeRobot v3.0 dataset from meta/ (info.json + episodes parquet).""" + import pyarrow.parquet as pq + from dimos.learning.dataprep.core import summarize_lengths root = Path(path) @@ -353,14 +456,13 @@ def inspect(path: Path) -> dict[str, Any]: action[name] = entry lengths: list[int] = [] - ep_path = root / META_DIR / "episodes.jsonl" - if ep_path.exists(): - for line in ep_path.read_text().splitlines(): - if line.strip(): - lengths.append(int(json.loads(line).get("length", 0))) + ep_file = root / META_DIR / EPISODES_DIR / CHUNK / f"{FILE}.parquet" + if ep_file.exists(): + lengths = pq.read_table(ep_file, columns=["length"]).column("length").to_pylist() return { "format": "lerobot", + "version": info.get("codebase_version"), "path": str(root), "episodes": info.get("total_episodes"), "frames": info.get("total_frames"), diff --git a/dimos/learning/dataprep/formats/test_lerobot.py b/dimos/learning/dataprep/formats/test_lerobot.py index cba373c8cb..d9cb72daec 100644 --- a/dimos/learning/dataprep/formats/test_lerobot.py +++ b/dimos/learning/dataprep/formats/test_lerobot.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Smoke tests for the LeRobot v2 writer/reader. +"""Smoke tests for the LeRobot v3.0 writer/reader. -The state-only test always runs (parquet + meta + stats, exercises the -try/finally cleanup with no writers open). The image test additionally checks -the MP4 path + canonical `observation.images.*` naming, and skips if no mp4v -codec is available in the environment. Skips entirely if pyarrow isn't -installed (`learning` optional-dependency group). +Asserts the v3.0 layout: a single concatenated data parquet, parquet meta +(tasks + episodes, no jsonl), and one MP4 per camera under +`videos//chunk-000/`. The image test skips if no mp4v codec is available; +the whole module skips if pyarrow/pandas aren't installed (`learning` extra). """ from __future__ import annotations @@ -31,6 +30,7 @@ import pytest pytest.importorskip("pyarrow") +pytest.importorskip("pandas") cv2 = pytest.importorskip("cv2") from dimos.learning.dataprep.core import OutputConfig, Sample @@ -47,6 +47,17 @@ def _state_samples(n: int = 4) -> Iterator[Sample]: ) +def _two_episode_samples() -> Iterator[Sample]: + for ep in range(2): + for i in range(3): + yield Sample( + ts=float(ep * 3 + i), + episode_id=f"ep_{ep:06d}", + observation={"state": np.arange(6, dtype=np.float32) + ep}, + action={"action": np.full(6, float(i), dtype=np.float32)}, + ) + + def _image_samples(n: int = 4) -> Iterator[Sample]: for i in range(n): yield Sample( @@ -60,32 +71,66 @@ def _image_samples(n: int = 4) -> Iterator[Sample]: ) -def test_lerobot_state_only_layout_and_naming(tmp_path: Path) -> None: +def test_lerobot_v3_state_only_layout_and_naming(tmp_path: Path) -> None: out = OutputConfig( format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0, "robot": "xarm7"} ) root = write(_state_samples(), out) + # v3.0: concatenated single data file + parquet meta (no jsonl, no per-episode parquet) assert (root / "meta" / "info.json").exists() - assert (root / "meta" / "episodes.jsonl").exists() - assert (root / "meta" / "tasks.jsonl").exists() + assert (root / "meta" / "tasks.parquet").exists() assert (root / "meta" / "stats.json").exists() - assert (root / "data" / "chunk-000" / "episode_000000.parquet").exists() + assert (root / "meta" / "episodes" / "chunk-000" / "file-000.parquet").exists() + assert (root / "data" / "chunk-000" / "file-000.parquet").exists() + assert not (root / "meta" / "episodes.jsonl").exists() + assert not (root / "meta" / "tasks.jsonl").exists() info = json.loads((root / "meta" / "info.json").read_text()) + assert info["codebase_version"] == "v3.0" assert info["total_episodes"] == 1 assert info["total_frames"] == 4 assert info["fps"] == 10.0 + assert info["data_path"] == "data/chunk-{chunk_index:03d}/file-{file_index:03d}.parquet" # single low-dim state + single action → canonical names assert "observation.state" in info["features"] assert "action" in info["features"] -def test_lerobot_inspect_state_only(tmp_path: Path) -> None: +def test_lerobot_v3_episode_metadata_columns(tmp_path: Path) -> None: + import pyarrow.parquet as pq + + out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) + # two episodes so dataset_from/to_index advance + root = write(_two_episode_samples(), out) + ep = pq.read_table(root / "meta" / "episodes" / "chunk-000" / "file-000.parquet") + cols = set(ep.column_names) + for required in ( + "episode_index", + "tasks", + "length", + "dataset_from_index", + "dataset_to_index", + "data/chunk_index", + "data/file_index", + "meta/episodes/chunk_index", + "meta/episodes/file_index", + ): + assert required in cols, f"missing episode column {required}" + # per-episode stats are embedded (flattened) + assert any(c.startswith("stats/observation.state/") for c in cols) + rows = ep.to_pylist() + assert [r["episode_index"] for r in rows] == [0, 1] + assert rows[0]["dataset_from_index"] == 0 and rows[0]["dataset_to_index"] == 3 + assert rows[1]["dataset_from_index"] == 3 and rows[1]["dataset_to_index"] == 6 + + +def test_lerobot_v3_inspect_state_only(tmp_path: Path) -> None: out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) root = write(_state_samples(), out) info = inspect(root) assert info["format"] == "lerobot" + assert info["version"] == "v3.0" assert info["episodes"] == 1 assert info["frames"] == 4 assert "observation.state" in info["observation"] @@ -93,7 +138,7 @@ def test_lerobot_inspect_state_only(tmp_path: Path) -> None: assert info["has_stats"] is True -def test_lerobot_with_images_writes_mp4_and_video_feature(tmp_path: Path) -> None: +def test_lerobot_v3_with_images_writes_concatenated_mp4(tmp_path: Path) -> None: out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) try: root = write(_image_samples(), out) @@ -102,10 +147,14 @@ def test_lerobot_with_images_writes_mp4_and_video_feature(tmp_path: Path) -> Non pytest.skip(f"no mp4v encoder available in this environment: {e}") raise - mp4 = root / "videos" / "chunk-000" / "observation.images.cam" / "episode_000000.mp4" + # v3.0 video path: videos//chunk-000/file-000.mp4 (key before chunk, one per camera) + mp4 = root / "videos" / "observation.images.cam" / "chunk-000" / "file-000.mp4" assert mp4.exists() and mp4.stat().st_size > 0 info = json.loads((root / "meta" / "info.json").read_text()) + assert ( + info["video_path"] == "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4" + ) assert info["features"]["observation.images.cam"]["dtype"] == "video" # image column is excluded from parquet; state/action remain assert "observation.state" in info["features"] diff --git a/pyproject.toml b/pyproject.toml index ea76547b8c..6bc535d820 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,7 +201,8 @@ visualization = [ learning = [ # dimos.learning.dataprep dataset writers (lazy-imported per format) - "pyarrow", # LeRobot v2 parquet writer + "pyarrow", # LeRobot v3.0 data/episodes parquet + "pandas", # LeRobot v3.0 tasks.parquet (task-indexed) "h5py", # HDF5 writer ] From 6cf9c7e925b51472faeef86bcbe807b5047e6610 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jun 2026 21:52:30 +0000 Subject: [PATCH 41/45] [autofix.ci] apply automated fixes --- uv.lock | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/uv.lock b/uv.lock index 7870f34e43..3f1b7149cd 100644 --- a/uv.lock +++ b/uv.lock @@ -2042,6 +2042,8 @@ drone = [ ] learning = [ { name = "h5py" }, + { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "pandas", version = "3.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pyarrow" }, ] manipulation = [ @@ -2407,6 +2409,7 @@ requires-dist = [ { name = "openai", marker = "extra == 'agents'" }, { name = "opencv-contrib-python", marker = "extra == 'apriltag'", specifier = "==4.10.0.84" }, { name = "opencv-python" }, + { name = "pandas", marker = "extra == 'learning'" }, { name = "pillow", marker = "extra == 'perception'" }, { name = "pin", specifier = ">=3.3.0" }, { name = "pin-pink", marker = "extra == 'manipulation'", specifier = ">=4.2.0" }, @@ -7148,16 +7151,17 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version < '3.11' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'win32'", "python_full_version < '3.11' and platform_machine != 'x86_64' and sys_platform == 'win32'", "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform != 'darwin' and sys_platform != 'win32'", "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version < '3.11' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and sys_platform != 'linux')" }, - { name = "python-dateutil", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and sys_platform != 'linux')" }, - { name = "pytz", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and sys_platform != 'linux')" }, - { name = "tzdata", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and sys_platform != 'linux')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "python-dateutil", marker = "python_full_version < '3.11'" }, + { name = "pytz", marker = "python_full_version < '3.11'" }, + { name = "tzdata", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/33/01/d40b85317f86cf08d853a4f495195c73815fdf205eef3993821720274518/pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b", size = 4495223, upload-time = "2025-09-29T23:34:51.853Z" } wheels = [ @@ -7221,6 +7225,9 @@ resolution-markers = [ "python_full_version == '3.13.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version == '3.12.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'win32'", "python_full_version >= '3.14' and platform_machine != 'x86_64' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'win32'", @@ -7235,14 +7242,15 @@ resolution-markers = [ "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and platform_machine != 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'win32'", "python_full_version == '3.11.*' and platform_machine != 'x86_64' and sys_platform == 'win32'", "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform != 'darwin' and sys_platform != 'win32'", "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine != 'aarch64') or (python_full_version >= '3.11' and sys_platform != 'linux')" }, - { name = "python-dateutil", marker = "(python_full_version >= '3.11' and platform_machine != 'aarch64') or (python_full_version >= '3.11' and sys_platform != 'linux')" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "python-dateutil", marker = "python_full_version >= '3.11'" }, { name = "tzdata", marker = "(python_full_version >= '3.11' and sys_platform == 'emscripten') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/de/da/b1dc0481ab8d55d0f46e343cfe67d4551a0e14fcee52bd38ca1bd73258d8/pandas-3.0.0.tar.gz", hash = "sha256:0facf7e87d38f721f0af46fe70d97373a37701b1c09f7ed7aeeb292ade5c050f", size = 4633005, upload-time = "2026-01-21T15:52:04.726Z" } From 01c1a57d7188cc7c38aa34571b96793b0eddfdc4 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Fri, 19 Jun 2026 15:51:49 -0700 Subject: [PATCH 42/45] fix: greptile issues --- dimos/learning/dataprep/formats/lerobot.py | 147 +++++++++--------- .../learning/dataprep/formats/test_lerobot.py | 32 ++++ 2 files changed, 108 insertions(+), 71 deletions(-) diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index 73a21996d3..c4ea4d5243 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -236,77 +236,82 @@ def _flush_episode() -> None: episode_rows.append(row) cur_rows.clear() - for sample in samples: - if sample.episode_id != cur_id: - _flush_episode() - cur_id = sample.episode_id - episode_index += 1 - cur_ep_stats = _stats() - if default_task_label not in tasks_index: - tasks_index[default_task_label] = len(tasks_index) - - # Schema discovery + stats (global + per-episode). - n_low_dim_obs = sum(1 for v in sample.observation.values() if np.asarray(v).ndim < 3) - single_state = n_low_dim_obs == 1 - for k, arr in sample.observation.items(): - a = np.asarray(arr) - is_image = a.ndim >= 3 - name = _feature_name("observation", k, is_image, False, single_state=single_state) - if name not in feature_shapes: - feature_shapes[name] = tuple(a.shape) - feature_dtypes[name] = "video" if is_image else str(a.dtype) - if is_image: - if k not in image_keys: - image_keys.append(k) - elif k not in state_keys: - state_keys.append(k) - global_stats.update(name, a) - cur_ep_stats.update(name, a) - single_action = len(sample.action) == 1 - for k, arr in sample.action.items(): - a = np.asarray(arr) - name = _feature_name("action", k, is_image=False, single_action=single_action) - if name not in feature_shapes: - feature_shapes[name] = tuple(a.shape) - feature_dtypes[name] = str(a.dtype) - if k not in action_keys: - action_keys.append(k) - global_stats.update(name, a) - cur_ep_stats.update(name, a) - - # Append image frames to the per-camera MP4 (RGB→BGR; cv2 is BGR-native). - for k, arr in sample.observation.items(): - a = np.asarray(arr) - if a.ndim >= 3: - if k not in video_writers: - video_writers[k] = _open_video(k, a) - bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a - video_writers[k].write(bgr) - video_cum_frames[k] = video_cum_frames.get(k, 0) + 1 - - frame_index = len(cur_rows) - cur_rows.append( - { - "timestamp": frame_index / fps, # relative to this episode - "frame_index": frame_index, - "episode_index": episode_index, - "index": global_index, - "task_index": tasks_index[default_task_label], - "obs": { - k: np.asarray(v) - for k, v in sample.observation.items() - if np.asarray(v).ndim < 3 - }, - "act": {k: np.asarray(v) for k, v in sample.action.items()}, - } - ) - global_index += 1 - - _flush_episode() - if data_writer is not None: - data_writer.close() - for vw in video_writers.values(): - vw.release() + # try/finally so the parquet footer is written and MP4s are released even if + # the drain raises mid-stream — otherwise the data file is unreadable (no + # footer) and the videos lose their index. + try: + for sample in samples: + if sample.episode_id != cur_id: + _flush_episode() + cur_id = sample.episode_id + episode_index += 1 + cur_ep_stats = _stats() + if default_task_label not in tasks_index: + tasks_index[default_task_label] = len(tasks_index) + + # Schema discovery + stats (global + per-episode). + n_low_dim_obs = sum(1 for v in sample.observation.values() if np.asarray(v).ndim < 3) + single_state = n_low_dim_obs == 1 + for k, arr in sample.observation.items(): + a = np.asarray(arr) + is_image = a.ndim >= 3 + name = _feature_name("observation", k, is_image, False, single_state=single_state) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = "video" if is_image else str(a.dtype) + if is_image: + if k not in image_keys: + image_keys.append(k) + elif k not in state_keys: + state_keys.append(k) + global_stats.update(name, a) + cur_ep_stats.update(name, a) + single_action = len(sample.action) == 1 + for k, arr in sample.action.items(): + a = np.asarray(arr) + name = _feature_name("action", k, is_image=False, single_action=single_action) + if name not in feature_shapes: + feature_shapes[name] = tuple(a.shape) + feature_dtypes[name] = str(a.dtype) + if k not in action_keys: + action_keys.append(k) + global_stats.update(name, a) + cur_ep_stats.update(name, a) + + # Append image frames to the per-camera MP4 (RGB→BGR; cv2 is BGR-native). + for k, arr in sample.observation.items(): + a = np.asarray(arr) + if a.ndim >= 3: + if k not in video_writers: + video_writers[k] = _open_video(k, a) + bgr = cv2.cvtColor(a, cv2.COLOR_RGB2BGR) if a.shape[-1] == 3 else a + video_writers[k].write(bgr) + video_cum_frames[k] = video_cum_frames.get(k, 0) + 1 + + frame_index = len(cur_rows) + cur_rows.append( + { + "timestamp": frame_index / fps, # relative to this episode + "frame_index": frame_index, + "episode_index": episode_index, + "index": global_index, + "task_index": tasks_index[default_task_label], + "obs": { + k: np.asarray(v) + for k, v in sample.observation.items() + if np.asarray(v).ndim < 3 + }, + "act": {k: np.asarray(v) for k, v in sample.action.items()}, + } + ) + global_index += 1 + + _flush_episode() + finally: + if data_writer is not None: + data_writer.close() + for vw in video_writers.values(): + vw.release() total_episodes = len(episode_rows) total_frames = global_index diff --git a/dimos/learning/dataprep/formats/test_lerobot.py b/dimos/learning/dataprep/formats/test_lerobot.py index d9cb72daec..7fcca06e64 100644 --- a/dimos/learning/dataprep/formats/test_lerobot.py +++ b/dimos/learning/dataprep/formats/test_lerobot.py @@ -125,6 +125,38 @@ def test_lerobot_v3_episode_metadata_columns(tmp_path: Path) -> None: assert rows[1]["dataset_from_index"] == 3 and rows[1]["dataset_to_index"] == 6 +def test_lerobot_v3_writer_closed_on_midstream_error(tmp_path: Path) -> None: + """If the drain raises after an episode was flushed, the data parquet must + still be readable (footer written by the finally), not a headerless stub.""" + import pyarrow.parquet as pq + + def bad_samples() -> Iterator[Sample]: + for i in range(3): # episode 0 + yield Sample( + ts=float(i), + episode_id="ep_000000", + observation={"state": np.arange(6, dtype=np.float32)}, + action={"action": np.full(6, float(i), dtype=np.float32)}, + ) + # first frame of episode 1 flushes episode 0 (opens + writes the parquet)… + yield Sample( + ts=3.0, + episode_id="ep_000001", + observation={"state": np.arange(6, dtype=np.float32)}, + action={"action": np.zeros(6, dtype=np.float32)}, + ) + raise RuntimeError("boom mid-stream") # …then blow up before the final flush + + out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) + with pytest.raises(RuntimeError, match="boom"): + write(bad_samples(), out) + + # episode 0's 3 frames were flushed; the file must have a valid footer. + data = tmp_path / "ds" / "data" / "chunk-000" / "file-000.parquet" + assert data.exists() + assert pq.read_table(data).num_rows == 3 # raises ArrowInvalid if footer missing + + def test_lerobot_v3_inspect_state_only(tmp_path: Path) -> None: out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) root = write(_state_samples(), out) From d1f891669aef487b2bc136b62ca9436abee4243e Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Fri, 19 Jun 2026 20:15:17 -0700 Subject: [PATCH 43/45] =?UTF-8?q?fix:=20address=20greptile=20review=20?= =?UTF-8?q?=E2=80=94=20writer=20resource=20guard=20+=20per-episode=20task?= =?UTF-8?q?=20labels?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dimos/learning/dataprep/core.py | 10 +++++- dimos/learning/dataprep/formats/hdf5.py | 8 +++-- dimos/learning/dataprep/formats/lerobot.py | 9 +++-- .../learning/dataprep/formats/test_lerobot.py | 33 +++++++++++++++++++ 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/dimos/learning/dataprep/core.py b/dimos/learning/dataprep/core.py index 202fd0d706..add219b4ba 100644 --- a/dimos/learning/dataprep/core.py +++ b/dimos/learning/dataprep/core.py @@ -117,6 +117,7 @@ class Sample(BaseModel): episode_id: str observation: dict[str, np.ndarray] action: dict[str, np.ndarray] + task_label: str | None = None # carried from the episode for multi-task datasets # ───────────────────────────────────────────────────────────────────────────── @@ -320,7 +321,13 @@ def _build_frames() -> Iterator[Sample]: obs_dict[key] = arr if skip: continue - yield Sample(ts=t, episode_id=episode.id, observation=obs_dict, action=act_dict) + yield Sample( + ts=t, + episode_id=episode.id, + observation=obs_dict, + action=act_dict, + task_label=episode.task_label, + ) shift = max(0, sync.action_shift) if shift == 0 or not action_keys: @@ -337,6 +344,7 @@ def _build_frames() -> Iterator[Sample]: episode_id=cur.episode_id, observation=cur.observation, action=nxt.action, + task_label=cur.task_label, ) diff --git a/dimos/learning/dataprep/formats/hdf5.py b/dimos/learning/dataprep/formats/hdf5.py index 9b31b52a09..ba10924669 100644 --- a/dimos/learning/dataprep/formats/hdf5.py +++ b/dimos/learning/dataprep/formats/hdf5.py @@ -69,6 +69,7 @@ def write(samples: Iterator[Sample], output: OutputConfig) -> Path: # Per-episode buffers — flushed at episode boundary. cur_id: str | None = None cur_idx = 0 + cur_task = default_task_label # actual label for the in-progress episode cur_start_ts: float | None = None buf_ts: list[float] = [] buf_obs: dict[str, list[np.ndarray]] = {} @@ -85,7 +86,7 @@ def _flush() -> bool: ep = episodes_g.create_group(f"episode_{cur_idx:06d}") ep.attrs["length"] = len(buf_ts) ep.attrs["start_ts"] = float(cur_start_ts or 0.0) - ep.attrs["task_index"] = tasks_index[default_task_label] + ep.attrs["task_index"] = tasks_index[cur_task] ep.create_dataset("timestamp", data=np.asarray(buf_ts, dtype=np.float32)) for k, frames in buf_obs.items(): arr = np.stack(frames, axis=0) @@ -108,8 +109,9 @@ def _flush() -> bool: cur_idx += 1 cur_id = sample.episode_id cur_start_ts = float(sample.ts) - if default_task_label not in tasks_index: - tasks_index[default_task_label] = len(tasks_index) + cur_task = sample.task_label or default_task_label + if cur_task not in tasks_index: + tasks_index[cur_task] = len(tasks_index) buf_ts.append(float(sample.ts) - (cur_start_ts or 0.0)) for k, v in sample.observation.items(): diff --git a/dimos/learning/dataprep/formats/lerobot.py b/dimos/learning/dataprep/formats/lerobot.py index c4ea4d5243..db9d4eeb50 100644 --- a/dimos/learning/dataprep/formats/lerobot.py +++ b/dimos/learning/dataprep/formats/lerobot.py @@ -171,6 +171,7 @@ def _stats() -> StreamingStats: cur_id: str | None = None cur_rows: list[dict[str, Any]] = [] cur_ep_stats = _stats() + cur_task = default_task_label # actual label for the in-progress episode def _video_path(image_key: str) -> Path: feat = _feature_name("observation", image_key, is_image=True, single_action=False) @@ -246,8 +247,10 @@ def _flush_episode() -> None: cur_id = sample.episode_id episode_index += 1 cur_ep_stats = _stats() - if default_task_label not in tasks_index: - tasks_index[default_task_label] = len(tasks_index) + # Per-episode task label (falls back to the config default). + cur_task = sample.task_label or default_task_label + if cur_task not in tasks_index: + tasks_index[cur_task] = len(tasks_index) # Schema discovery + stats (global + per-episode). n_low_dim_obs = sum(1 for v in sample.observation.values() if np.asarray(v).ndim < 3) @@ -295,7 +298,7 @@ def _flush_episode() -> None: "frame_index": frame_index, "episode_index": episode_index, "index": global_index, - "task_index": tasks_index[default_task_label], + "task_index": tasks_index[cur_task], "obs": { k: np.asarray(v) for k, v in sample.observation.items() diff --git a/dimos/learning/dataprep/formats/test_lerobot.py b/dimos/learning/dataprep/formats/test_lerobot.py index 7fcca06e64..c6c2aef39f 100644 --- a/dimos/learning/dataprep/formats/test_lerobot.py +++ b/dimos/learning/dataprep/formats/test_lerobot.py @@ -157,6 +157,39 @@ def bad_samples() -> Iterator[Sample]: assert pq.read_table(data).num_rows == 3 # raises ArrowInvalid if footer missing +def test_lerobot_v3_per_episode_task_labels(tmp_path: Path) -> None: + """Episodes with distinct task_labels must produce distinct tasks + task_index + (multi-task recordings must not collapse to one task).""" + import pandas as pd + import pyarrow.parquet as pq + + def samples() -> Iterator[Sample]: + for ep, task in ((0, "pick"), (1, "place")): + for i in range(3): + yield Sample( + ts=float(ep * 3 + i), + episode_id=f"ep_{ep:06d}", + observation={"state": np.arange(6, dtype=np.float32)}, + action={"action": np.zeros(6, dtype=np.float32)}, + task_label=task, + ) + + out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) + root = write(samples(), out) + + tasks = pd.read_parquet(root / "meta" / "tasks.parquet") + assert set(tasks.index) == {"pick", "place"} + + ep = pq.read_table(root / "meta" / "episodes" / "chunk-000" / "file-000.parquet").to_pylist() + assert ep[0]["tasks"] == ["pick"] + assert ep[1]["tasks"] == ["place"] + + data = pq.read_table(root / "data" / "chunk-000" / "file-000.parquet") + ti = data.column("task_index").to_pylist() + assert ti[:3] == [0, 0, 0] # episode 0 → task 0 (pick) + assert ti[3:] == [1, 1, 1] # episode 1 → task 1 (place) + + def test_lerobot_v3_inspect_state_only(tmp_path: Path) -> None: out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) root = write(_state_samples(), out) From d5427675f72bc6ba5a863136190ccea03725b4ff Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Sat, 20 Jun 2026 00:18:27 -0700 Subject: [PATCH 44/45] fix(dataprep): reject shared obs/action feature keys instead of silently dropping the obs --- dimos/learning/dataprep/build.py | 9 ++++++++- dimos/learning/dataprep/test_core.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/dimos/learning/dataprep/build.py b/dimos/learning/dataprep/build.py index b673b3dd9b..5fc9f7620f 100644 --- a/dimos/learning/dataprep/build.py +++ b/dimos/learning/dataprep/build.py @@ -80,6 +80,13 @@ def run_dataprep(config: DataPrepConfig) -> Path: """ from dimos.memory2.store.sqlite import SqliteStore + shared = set(config.observation) & set(config.action) + if shared: + raise ValueError( + f"observation and action share feature name(s) {sorted(shared)}; " + f"give each a distinct key (they may still map to the same stream)." + ) + logger.info( "[dataprep] starting build source=%s extractor=%s output=%s", config.source, @@ -110,9 +117,9 @@ def run_dataprep(config: DataPrepConfig) -> Path: f"extractor='ranges' with explicit (start, end) tuples." ) - streams = {**config.observation, **config.action} obs_keys = set(config.observation) action_keys = set(config.action) + streams = {**config.observation, **config.action} logger.info( "[dataprep] obs streams=%s action streams=%s sync=%s", sorted(obs_keys), diff --git a/dimos/learning/dataprep/test_core.py b/dimos/learning/dataprep/test_core.py index c522529428..a6c4b9963f 100644 --- a/dimos/learning/dataprep/test_core.py +++ b/dimos/learning/dataprep/test_core.py @@ -333,3 +333,18 @@ def test_dimos_meta_beside_file_for_hdf5(tmp_path: Path) -> None: sidecar = tmp_path / "session.dimos_meta.json" assert sidecar.exists() # beside the file, not session.hdf5/dimos_meta.json assert json.loads(sidecar.read_text())["format"] == "hdf5" + + +def test_run_dataprep_rejects_shared_obs_action_key() -> None: + """A name in both obs and action would silently drop the obs feature when the + two maps merge; run_dataprep must reject it before opening the store.""" + from dimos.learning.dataprep.build import run_dataprep + from dimos.learning.dataprep.core import DataPrepConfig, StreamField + + cfg = DataPrepConfig( + source="nonexistent.db", # never reached — the check runs first + observation={"joints": StreamField(stream="joint_state", field="position")}, + action={"joints": StreamField(stream="joint_state", field="position")}, + ) + with pytest.raises(ValueError, match="share feature name"): + run_dataprep(cfg) From 132ac71332bcfc2291045244e590ef1b96002931 Mon Sep 17 00:00:00 2001 From: ruthwikdasyam Date: Sat, 20 Jun 2026 17:30:33 -0700 Subject: [PATCH 45/45] test(learning): drop __new__ shell for mocker-patched construction; hoist inner imports --- .../collection/test_episode_monitor.py | 129 +++++++++++------- .../learning/dataprep/formats/test_lerobot.py | 9 +- dimos/learning/dataprep/test_core.py | 17 +-- 3 files changed, 85 insertions(+), 70 deletions(-) diff --git a/dimos/learning/collection/test_episode_monitor.py b/dimos/learning/collection/test_episode_monitor.py index 328bb04c75..b0a7307714 100644 --- a/dimos/learning/collection/test_episode_monitor.py +++ b/dimos/learning/collection/test_episode_monitor.py @@ -14,48 +14,62 @@ """Unit tests for the EpisodeMonitor state machine. -Drives the button/keyboard handlers directly and captures published -EpisodeStatus events via a stubbed `status` Out port. The module is built with -`object.__new__` + the subclass' own state init so the test exercises just the -state machine, not the RPC/transport machinery a full `Module()` boots up. -Mirrors the offline `extract_episodes` state machine these events feed. +The module is constructed normally; only its boot side effects (the asyncio +loop + RPC transport that `Module.__init__` starts) are patched out, and its +`status` Out port is replaced with a mock so published EpisodeStatus events can +be inspected. Drives the button/keyboard handlers directly and asserts on the +state machine these events feed into `extract_episodes`. """ from __future__ import annotations -import threading +from collections.abc import Callable, Iterator +import pytest +import pytest_mock + +from dimos.core.module import LCMRPC from dimos.learning.collection.episode_monitor import ( EpisodeMonitorModule, - EpisodeMonitorModuleConfig, EpisodeStatus, KeyPress, ) from dimos.teleop.quest.quest_types import BUTTON_ALIASES, Buttons -class FakeStatusOut: - """Stand-in for the `status` Out port that records published events.""" +@pytest.fixture +def make_monitor( + mocker: pytest_mock.MockerFixture, +) -> Iterator[Callable[..., EpisodeMonitorModule]]: + """Factory for an EpisodeMonitorModule with its boot patched out. + + `Module.__init__` starts an asyncio loop + RPC transport; patch both so the + test exercises only the state machine. The `status` port is a mock whose + `publish` calls record the emitted EpisodeStatus. Every built module is + stopped on teardown. + """ + mocker.patch("dimos.core.module.get_loop", return_value=(mocker.MagicMock(), None)) + mocker.patch.object(LCMRPC, "__init__", return_value=None) + mocker.patch.object(LCMRPC, "serve_module_rpc", return_value=None) + mocker.patch.object(LCMRPC, "start", return_value=None) + mocker.patch.object(LCMRPC, "stop", return_value=None) + + built: list[EpisodeMonitorModule] = [] - def __init__(self) -> None: - self.events: list[EpisodeStatus] = [] + def _make(**config: object) -> EpisodeMonitorModule: + m = EpisodeMonitorModule(**config) + m.status = mocker.MagicMock() # type: ignore[assignment] + built.append(m) + return m - def publish(self, status: EpisodeStatus) -> None: - self.events.append(status) + yield _make + for m in built: + m.stop() -def _monitor(**config: object) -> tuple[EpisodeMonitorModule, FakeStatusOut]: - m = EpisodeMonitorModule.__new__(EpisodeMonitorModule) - m.config = EpisodeMonitorModuleConfig(**config) # type: ignore[assignment] - m._state = "idle" - m._saved = 0 - m._discarded = 0 - m._last_event = "init" - m._prev_bits = {} - m._lock = threading.Lock() - out = FakeStatusOut() - m.status = out # type: ignore[assignment] - return m, out +def _events(monitor: EpisodeMonitorModule) -> list[EpisodeStatus]: + """The EpisodeStatus objects published on the monitor's `status` port.""" + return [call.args[0] for call in monitor.status.publish.call_args_list] # type: ignore[attr-defined] def _press(monitor: EpisodeMonitorModule, alias: str, ts: float) -> None: @@ -68,53 +82,63 @@ def _press(monitor: EpisodeMonitorModule, alias: str, ts: float) -> None: monitor._on_buttons(pressed) -def test_toggle_starts_then_saves() -> None: - m, out = _monitor() # default map: toggle=B, discard=Y +def test_toggle_starts_then_saves(make_monitor: Callable[..., EpisodeMonitorModule]) -> None: + m = make_monitor() # default map: toggle=B, discard=Y _press(m, "B", ts=1.0) # idle → recording _press(m, "B", ts=2.0) # recording → idle (saved) - events = [e.last_event for e in out.events] - assert events == ["start", "save"] - assert out.events[-1].state == "idle" - assert out.events[-1].episodes_saved == 1 - assert out.events[-1].episodes_discarded == 0 + events = _events(m) + assert [e.last_event for e in events] == ["start", "save"] + assert events[-1].state == "idle" + assert events[-1].episodes_saved == 1 + assert events[-1].episodes_discarded == 0 -def test_discard_does_not_count_as_saved() -> None: - m, out = _monitor() +def test_discard_does_not_count_as_saved( + make_monitor: Callable[..., EpisodeMonitorModule], +) -> None: + m = make_monitor() _press(m, "B", ts=1.0) # start _press(m, "Y", ts=2.0) # discard - assert out.events[-1].state == "idle" - assert out.events[-1].episodes_saved == 0 - assert out.events[-1].episodes_discarded == 1 + last = _events(m)[-1] + assert last.state == "idle" + assert last.episodes_saved == 0 + assert last.episodes_discarded == 1 -def test_start_while_recording_autocommits_previous() -> None: +def test_start_while_recording_autocommits_previous( + make_monitor: Callable[..., EpisodeMonitorModule], +) -> None: # toggle (start), then an explicit start via keyboard while still recording: # the in-progress episode auto-commits (matches the offline extractor). - m, out = _monitor(keyboard_map={"start": "r"}) + m = make_monitor(keyboard_map={"start": "r"}) _press(m, "B", ts=1.0) # recording m._on_keyboard(KeyPress(key="r", ts=2.0)) # start again → auto-commit prior - assert out.events[-1].last_event == "start" - assert out.events[-1].state == "recording" - assert out.events[-1].episodes_saved == 1 # the auto-committed one + last = _events(m)[-1] + assert last.last_event == "start" + assert last.state == "recording" + assert last.episodes_saved == 1 # the auto-committed one -def test_no_event_without_rising_edge() -> None: - m, out = _monitor() +def test_no_event_without_rising_edge( + make_monitor: Callable[..., EpisodeMonitorModule], +) -> None: + m = make_monitor() pressed = Buttons() pressed.right_secondary = True # B held m._on_buttons(pressed) m._on_buttons(pressed) # still held — no new edge - assert [e.last_event for e in out.events] == ["start"] + assert [e.last_event for e in _events(m)] == ["start"] -def test_published_status_is_internally_consistent() -> None: +def test_published_status_is_internally_consistent( + make_monitor: Callable[..., EpisodeMonitorModule], +) -> None: # Every published event's counters/state must match the event it carries — # the snapshot is taken under the same lock as the mutation. - m, out = _monitor() + m = make_monitor() _press(m, "B", 1.0) # start _press(m, "B", 2.0) # save (1) _press(m, "B", 3.0) # start @@ -122,17 +146,18 @@ def test_published_status_is_internally_consistent() -> None: _press(m, "B", 5.0) # start _press(m, "Y", 6.0) # discard (1) - for e in out.events: + events = _events(m) + for e in events: if e.last_event == "start": assert e.state == "recording" elif e.last_event in ("save", "discard"): assert e.state == "idle" - assert out.events[-1].episodes_saved == 2 - assert out.events[-1].episodes_discarded == 1 + assert events[-1].episodes_saved == 2 + assert events[-1].episodes_discarded == 1 -def test_reset_counters() -> None: - m, out = _monitor() +def test_reset_counters(make_monitor: Callable[..., EpisodeMonitorModule]) -> None: + m = make_monitor() _press(m, "B", 1.0) _press(m, "B", 2.0) status = m.reset_counters() diff --git a/dimos/learning/dataprep/formats/test_lerobot.py b/dimos/learning/dataprep/formats/test_lerobot.py index c6c2aef39f..c01f6bebcb 100644 --- a/dimos/learning/dataprep/formats/test_lerobot.py +++ b/dimos/learning/dataprep/formats/test_lerobot.py @@ -33,6 +33,10 @@ pytest.importorskip("pandas") cv2 = pytest.importorskip("cv2") +# Below the importorskip guards above; used to read back v3.0 meta/parquet. +import pandas as pd +import pyarrow.parquet as pq + from dimos.learning.dataprep.core import OutputConfig, Sample from dimos.learning.dataprep.formats.lerobot import inspect, write @@ -98,8 +102,6 @@ def test_lerobot_v3_state_only_layout_and_naming(tmp_path: Path) -> None: def test_lerobot_v3_episode_metadata_columns(tmp_path: Path) -> None: - import pyarrow.parquet as pq - out = OutputConfig(format="lerobot", path=tmp_path / "ds", metadata={"fps": 10.0}) # two episodes so dataset_from/to_index advance root = write(_two_episode_samples(), out) @@ -128,7 +130,6 @@ def test_lerobot_v3_episode_metadata_columns(tmp_path: Path) -> None: def test_lerobot_v3_writer_closed_on_midstream_error(tmp_path: Path) -> None: """If the drain raises after an episode was flushed, the data parquet must still be readable (footer written by the finally), not a headerless stub.""" - import pyarrow.parquet as pq def bad_samples() -> Iterator[Sample]: for i in range(3): # episode 0 @@ -160,8 +161,6 @@ def bad_samples() -> Iterator[Sample]: def test_lerobot_v3_per_episode_task_labels(tmp_path: Path) -> None: """Episodes with distinct task_labels must produce distinct tasks + task_index (multi-task recordings must not collapse to one task).""" - import pandas as pd - import pyarrow.parquet as pq def samples() -> Iterator[Sample]: for ep, task in ((0, "pick"), (1, "place")): diff --git a/dimos/learning/dataprep/test_core.py b/dimos/learning/dataprep/test_core.py index a6c4b9963f..dbabe94f16 100644 --- a/dimos/learning/dataprep/test_core.py +++ b/dimos/learning/dataprep/test_core.py @@ -22,15 +22,19 @@ from __future__ import annotations from dataclasses import dataclass +import json from pathlib import Path from typing import Any import numpy as np import pytest +from dimos.learning.dataprep.build import _write_dimos_meta, run_dataprep from dimos.learning.dataprep.core import ( + DataPrepConfig, Episode, EpisodeExtractor, + OutputConfig, StreamField, SyncConfig, extract_episodes, @@ -297,11 +301,6 @@ def test_summarize_lengths_empty() -> None: def test_dimos_meta_records_sync_and_action_shift(tmp_path: Path) -> None: - import json - - from dimos.learning.dataprep.build import _write_dimos_meta - from dimos.learning.dataprep.core import DataPrepConfig, OutputConfig, StreamField - cfg = DataPrepConfig( source="s.db", observation={"state": StreamField(stream="js", field="position")}, @@ -319,11 +318,6 @@ def test_dimos_meta_records_sync_and_action_shift(tmp_path: Path) -> None: def test_dimos_meta_beside_file_for_hdf5(tmp_path: Path) -> None: """hdf5 writer returns a FILE path; the sidecar must land beside it, not inside it (which would treat the .hdf5 file as a directory and crash).""" - import json - - from dimos.learning.dataprep.build import _write_dimos_meta - from dimos.learning.dataprep.core import DataPrepConfig, OutputConfig - ds_file = tmp_path / "session.hdf5" ds_file.write_bytes(b"\x89HDF\r\n") # stand-in for a real .hdf5 cfg = DataPrepConfig(source="s.db", output=OutputConfig(format="hdf5", path=ds_file)) @@ -338,9 +332,6 @@ def test_dimos_meta_beside_file_for_hdf5(tmp_path: Path) -> None: def test_run_dataprep_rejects_shared_obs_action_key() -> None: """A name in both obs and action would silently drop the obs feature when the two maps merge; run_dataprep must reject it before opening the store.""" - from dimos.learning.dataprep.build import run_dataprep - from dimos.learning.dataprep.core import DataPrepConfig, StreamField - cfg = DataPrepConfig( source="nonexistent.db", # never reached — the check runs first observation={"joints": StreamField(stream="joint_state", field="position")},