From 99eef1fa00cb07205035eb207bdb3a3d46a30565 Mon Sep 17 00:00:00 2001 From: Reiase Date: Sun, 21 Jun 2026 11:43:20 +0800 Subject: [PATCH 1/5] Update dependencies, enhance documentation, and add new features for probing extensions - Added new dependencies including `probing-hccl-shim`, `probing-nccl-profiler`, and `serde_yaml` to `Cargo.lock` and `Cargo.toml`. - Updated the Makefile to include new targets for HCCL shim library and improved frontend build instructions. - Enhanced documentation to include a new section for the Federated Query Engine and updated existing design documents. - Refactored examples to replace deprecated scripts and improve clarity in tracing examples. - Improved the overall structure of the project by organizing new features and dependencies effectively. --- Cargo.lock | 14 + Cargo.toml | 1 + Makefile | 47 +- docs/mkdocs.yml | 2 + docs/src/api-reference.md | 53 +- docs/src/api-reference.zh.md | 50 +- docs/src/contributing.md | 2 +- docs/src/contributing.zh.md | 2 +- docs/src/design/distributed.zh.md | 2 + docs/src/design/federation.md | 28 + docs/src/design/federation.zh.md | 441 +++++++++++ docs/src/design/index.md | 1 + docs/src/design/index.zh.md | 1 + docs/src/design/modularity.md | 4 +- docs/src/design/modularity.zh.md | 4 +- docs/src/design/training-phase.zh.md | 94 +++ docs/src/guide/concepts.zh.md | 167 +--- docs/src/reference/sql-tables.md | 2 +- docs/src/reference/sql-tables.zh.md | 2 +- examples/README.md | 11 +- examples/events.py | 137 ---- examples/imagenet_with_span.py | 22 +- examples/tracing.py | 130 ++++ probing/core/Cargo.toml | 8 + probing/core/src/core/cluster.rs | 6 + probing/core/src/core/engine.rs | 8 +- .../src/core/federation/aggregate_pushdown.rs | 71 +- .../src/core/federation/cluster_executor.rs | 18 + probing/core/src/core/federation/convert.rs | 241 ++++-- .../core/src/core/federation/global_table.rs | 19 +- probing/core/src/core/federation/mod.rs | 14 +- probing/core/src/core/federation/rewrite.rs | 33 +- probing/core/src/core/federation/route.rs | 150 ++++ probing/core/src/core/metadata_rewrite.rs | 177 +++++ probing/core/src/core/mod.rs | 69 ++ probing/core/src/core/semantic_catalog.rs | 440 +++++++++++ probing/core/src/trace/mod.rs | 2 +- probing/core/src/trace/span.rs | 12 +- probing/core/src/trace/step.rs | 95 ++- probing/core/src/tracing.rs | 10 +- probing/extensions/hccl-shim/Cargo.toml | 24 + probing/extensions/hccl-shim/README.md | 82 ++ probing/extensions/hccl-shim/src/forward.rs | 250 ++++++ probing/extensions/hccl-shim/src/lib.rs | 199 +++++ probing/extensions/hccl-shim/src/msprof.rs | 278 +++++++ probing/extensions/hccl-shim/src/names.rs | 191 +++++ probing/extensions/hccl-shim/src/tables.rs | 184 +++++ probing/extensions/hccl-shim/src/writer.rs | 424 +++++++++++ probing/extensions/nccl-profiler/src/lib.rs | 2 +- .../extensions/nccl-profiler/src/tables.rs | 98 ++- .../python/src/extensions/python.rs | 2 +- .../python/src/extensions/python/exttbls.rs | 148 +++- .../extensions/python/src/features/tracing.rs | 94 ++- probing/memtable/src/discover.rs | 44 +- probing/memtable/src/docs.rs | 180 +++++ probing/memtable/src/lib.rs | 1 + probing/memtable/src/memtable.rs | 1 + probing/memtable/src/schema.rs | 36 +- probing/server/src/server/cluster_fanout.rs | 11 +- probing/server/src/server/training.rs | 19 +- pyproject.toml | 4 + python/probing/__init__.py | 18 +- .../bundled_skills/semantic/tables.yaml | 12 +- python/probing/core/table.py | 44 +- python/probing/ext/ray.py | 7 +- python/probing/ext/torch.py | 4 + python/probing/handlers/pythonext.py | 20 +- python/probing/hccl/__init__.py | 85 +++ python/probing/hccl/__main__.py | 81 ++ python/probing/parallel.py | 2 +- python/probing/profiling/collective/record.py | 35 +- python/probing/profiling/phase_tracker.py | 17 + python/probing/profiling/torch_probe.py | 95 +-- python/probing/skills/__main__.py | 3 +- python/probing/tracing.py | 716 ------------------ python/probing/tracing/__init__.py | 62 ++ python/probing/tracing/_bindings.py | 46 ++ python/probing/tracing/backends.py | 590 +++++++++++++++ python/probing/tracing/coordinates.py | 93 +++ python/probing/tracing/hooks.py | 138 ++++ python/probing/tracing/phases.py | 174 +++++ python/probing/tracing/span.py | 317 ++++++++ python/probing/tracing/table.py | 50 ++ python/probing/web_assets.py | 27 +- skills/semantic/tables.yaml | 380 ++++++++-- src/lib.rs | 12 +- tests/regression/core/test_table_docs.py | 83 ++ .../core/test_table_docs_integration.py | 198 +++++ tests/regression/ext/test_comm_collective.py | 1 - .../regression/ext/test_parallel_topology.py | 8 +- tests/regression/ext/test_phase_tracker.py | 128 ++++ tests/regression/ext/test_step_context.py | 66 +- tests/regression/ext/test_tracing_span.py | 217 ++++-- tests/regression/rust/Cargo.toml | 10 +- .../probing/core/federation_explain_tests.rs | 222 ++++++ .../rust/probing/core/federation_tests.rs | 250 +++++- .../probing/core/table_docs_integration.rs | 152 ++++ tests/regression/rust/src/test_helpers.rs | 10 + .../training_observability/conftest.py | 32 +- .../test_collective_recording.py | 3 +- .../test_collective_tracer_hook.py | 4 +- .../test_step_straggler_sql.py | 24 +- .../test_topology_context.py | 11 +- .../test_training_iteration_e2e.py | 51 +- tests/unit/probing/test_web_assets.py | 27 +- .../probing/tracing/test_phase_transitions.py | 374 +++++++++ tests/unit/probing/tracing/test_phases.py | 50 ++ .../probing/tracing/test_span_backends.py | 153 ++++ web/src/api/traces.rs | 36 +- web/src/components/sidebar/mod.rs | 2 +- web/src/hooks/mod.rs | 10 +- web/src/pages/traces.rs | 8 +- web/src/pages/training.rs | 6 +- 113 files changed, 8376 insertions(+), 1650 deletions(-) create mode 100644 docs/src/design/federation.md create mode 100644 docs/src/design/federation.zh.md create mode 100644 docs/src/design/training-phase.zh.md delete mode 100644 examples/events.py create mode 100644 examples/tracing.py create mode 100644 probing/core/src/core/federation/route.rs create mode 100644 probing/core/src/core/metadata_rewrite.rs create mode 100644 probing/core/src/core/semantic_catalog.rs create mode 100644 probing/extensions/hccl-shim/Cargo.toml create mode 100644 probing/extensions/hccl-shim/README.md create mode 100644 probing/extensions/hccl-shim/src/forward.rs create mode 100644 probing/extensions/hccl-shim/src/lib.rs create mode 100644 probing/extensions/hccl-shim/src/msprof.rs create mode 100644 probing/extensions/hccl-shim/src/names.rs create mode 100644 probing/extensions/hccl-shim/src/tables.rs create mode 100644 probing/extensions/hccl-shim/src/writer.rs create mode 100644 probing/memtable/src/docs.rs create mode 100644 python/probing/hccl/__init__.py create mode 100644 python/probing/hccl/__main__.py create mode 100644 python/probing/profiling/phase_tracker.py delete mode 100644 python/probing/tracing.py create mode 100644 python/probing/tracing/__init__.py create mode 100644 python/probing/tracing/_bindings.py create mode 100644 python/probing/tracing/backends.py create mode 100644 python/probing/tracing/coordinates.py create mode 100644 python/probing/tracing/hooks.py create mode 100644 python/probing/tracing/phases.py create mode 100644 python/probing/tracing/span.py create mode 100644 python/probing/tracing/table.py create mode 100644 tests/regression/core/test_table_docs.py create mode 100644 tests/regression/core/test_table_docs_integration.py create mode 100644 tests/regression/ext/test_phase_tracker.py create mode 100644 tests/regression/rust/probing/core/federation_explain_tests.rs create mode 100644 tests/regression/rust/probing/core/table_docs_integration.rs create mode 100644 tests/unit/probing/tracing/test_phase_transitions.py create mode 100644 tests/unit/probing/tracing/test_phases.py create mode 100644 tests/unit/probing/tracing/test_span_backends.py diff --git a/Cargo.lock b/Cargo.lock index 9dcd39ff..a21e2bf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3362,11 +3362,14 @@ dependencies = [ "libc", "log", "once_cell", + "probing-hccl-shim", "probing-macros", "probing-memtable", + "probing-nccl-profiler", "probing-proto", "serde", "serde_json", + "serde_yaml", "sled", "tempfile", "thiserror 2.0.12", @@ -3393,6 +3396,17 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "probing-hccl-shim" +version = "0.2.5" +dependencies = [ + "libc", + "once_cell", + "parking_lot 0.12.3", + "probing-memtable", + "tempfile", +] + [[package]] name = "probing-macros" version = "0.2.5" diff --git a/Cargo.toml b/Cargo.toml index 56715e1d..17f96177 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "probing/extensions/python", "probing/extensions/gpu", "probing/extensions/nccl-profiler", + "probing/extensions/hccl-shim", "probing/server", "probing/crates/store", ] diff --git a/Makefile b/Makefile index 42df75a2..95ffe384 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # Probing Makefile # # develop → maturin develop (Rust/Python daily loop) -# frontend → web/dist/ (manual, needs dx) +# frontend → web/dist/ + python/probing/bundled_web/ (manual, needs dx) # wheel → bundle skills + UI, then maturin build # frontend wheel → full release path # @@ -32,6 +32,7 @@ endif endif PYTHON ?= $(shell test -x .venv/bin/python && echo .venv/bin/python || echo python3) +VENV_PYTHON := $(abspath .venv/bin/python) BUILD_PY_DEPS := build wheel toml maturin DEV_PTH := python/probing/dev_pth.py DEV_PY_DEPS := pyyaml pytest pytest-cov coverage ipython ipykernel @@ -53,7 +54,7 @@ help: @echo "" @echo " develop / dev Bootstrap: _core, CLI, pytest, site hook" @echo " core Rebuild probing._core after Rust edits" - @echo " frontend Build web/dist/ (dx; manual)" + @echo " frontend Build web/dist/ + sync bundled_web (dx; manual)" @echo " wheel Build dist/*.whl (needs web/dist/; bundles skills + UI)" @echo " wheel-ci alias for wheel (native build; PyPI uses maturin-action + zig)" @echo " install-wheel pip install dist/probing-*.whl" @@ -83,7 +84,7 @@ install-dev-python-deps: fi # ============================================================================== -.PHONY: core develop dev check-dev frontend wheel wheel-ci install-wheel wheel-bundle nccl-profiler-lib venv venv-wheel install-build-deps install-wheel-test-deps +.PHONY: core develop dev check-dev frontend sync-bundled-web wheel wheel-ci install-wheel wheel-bundle nccl-profiler-lib hccl-shim-lib venv venv-wheel install-build-deps install-wheel-test-deps venv: @test -x .venv/bin/python || $(shell command -v python3 || echo python3) -m venv .venv @@ -98,7 +99,7 @@ install-build-deps: venv install-wheel-test-deps: venv $(PYTHON) -m pip install -q -U pip $(PYTEST_WHEEL_DEPS) -core: nccl-profiler-lib +core: nccl-profiler-lib hccl-shim-lib $(PYTHON) -m maturin develop $(MATURIN_FLAGS) develop: install-build-deps core install-dev-python-deps @@ -125,16 +126,24 @@ frontend: cp -R $(DX_PUBLIC)/. web/dist/ @mkdir -p web/dist/assets @cp -f web/assets/logo.svg web/dist/logo.svg 2>/dev/null || true + @cp -f web/assets/logo.svg web/dist/assets/logo.svg 2>/dev/null || true @cp -f web/assets/tailwind.css web/dist/assets/tailwind.css @echo "web/dist ($$(du -sh web/dist | cut -f1))" + $(MAKE) sync-bundled-web + +sync-bundled-web: + @test -f web/dist/index.html || { echo "error: web/dist missing — run make frontend first"; exit 1; } + rm -rf python/probing/bundled_web + cp -R web/dist python/probing/bundled_web + @echo "python/probing/bundled_web ($$(du -sh python/probing/bundled_web | cut -f1))" wheel-bundle: @test -f web/dist/index.html || { echo "error: run 'make frontend' first"; exit 1; } - rm -rf python/probing/bundled_skills python/probing/bundled_web + rm -rf python/probing/bundled_skills cp -R skills python/probing/bundled_skills - cp -R web/dist python/probing/bundled_web + $(MAKE) sync-bundled-web -wheel: install-build-deps wheel-bundle nccl-profiler-lib +wheel: install-build-deps wheel-bundle nccl-profiler-lib hccl-shim-lib $(PYTHON) -m maturin build $(MATURIN_FLAGS) --out dist wheel-ci: @@ -168,6 +177,22 @@ nccl-profiler-lib: @: endif +# Linux HCCL libprofapi.so shim → python/probing/shim/hccl/ +ifeq ($(UNAME_S),Linux) +ifdef DEBUG +HCCL_SHIM_OUT := target/debug/libprofapi.so +else +HCCL_SHIM_OUT := target/release/libprofapi.so +endif +hccl-shim-lib: + cargo build -p probing-hccl-shim $(CARGO_RELEASE) + mkdir -p python/probing/shim/hccl + cp $(HCCL_SHIM_OUT) python/probing/shim/hccl/ +else +hccl-shim-lib: + @: +endif + # ============================================================================== PYTEST_WHEEL_DEPS := pytest pytest-cov coverage pyyaml websockets pandas torch ipykernel # Installed wheel only — do not pass python/probing (conflicts with site-packages). @@ -184,8 +209,8 @@ test: test-rust test-python test-rust: test-rust-unit test-rust-regression test-rust-unit: - @if test -x .venv/bin/python; then \ - export PYTHON_SYS_EXECUTABLE=.venv/bin/python PYO3_PYTHON=.venv/bin/python; \ + @if test -x $(VENV_PYTHON); then \ + export PYTHON_SYS_EXECUTABLE=$(VENV_PYTHON) PYO3_PYTHON=$(VENV_PYTHON); \ elif command -v pyenv >/dev/null 2>&1; then \ P=$$(pyenv which python3 2>/dev/null); \ test -n "$$P" && export PYTHON_SYS_EXECUTABLE=$$P PYO3_PYTHON=$$P; \ @@ -193,8 +218,8 @@ test-rust-unit: cargo nextest run --lib --workspace --no-default-features --nff test-rust-regression: - @if test -x .venv/bin/python; then \ - export PYTHON_SYS_EXECUTABLE=.venv/bin/python PYO3_PYTHON=.venv/bin/python; \ + @if test -x $(VENV_PYTHON); then \ + export PYTHON_SYS_EXECUTABLE=$(VENV_PYTHON) PYO3_PYTHON=$(VENV_PYTHON); \ elif command -v pyenv >/dev/null 2>&1; then \ P=$$(pyenv which python3 2>/dev/null); \ test -n "$$P" && export PYTHON_SYS_EXECUTABLE=$$P PYO3_PYTHON=$$P; \ diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 8aef6972..a03af7a2 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -99,6 +99,7 @@ plugins: Data Layer: 数据层 Profiling: 性能分析 Distributed: 分布式 + Federated Query Engine: 联邦查询引擎 Cluster with Pulsing: 基于 Pulsing 的集群 Extensibility: 扩展机制 Modularity: 模块化与边界 @@ -142,6 +143,7 @@ nav: - Profiling: design/profiling.md - Debugging: design/debugging.md - Distributed: design/distributed.md + - Federated Query Engine: design/federation.md - NCCL Profiler: design/nccl-profiler.md - Cluster with Pulsing: design/cluster-pulsing.md - Extensibility: design/extensibility.md diff --git a/docs/src/api-reference.md b/docs/src/api-reference.md index 15df4bfb..19e69015 100644 --- a/docs/src/api-reference.md +++ b/docs/src/api-reference.md @@ -100,12 +100,49 @@ import probing df = probing.query("SELECT * FROM python.torch_trace LIMIT 10") ``` -### probing.span / probing.event +### probing.span / probing.event / probing.record_span / probing.step + +Four user-facing verbs: + +```python +import probing + +with probing.span("forward", phase=probing.FORWARD): + probing.event("batch.stats", attributes=[{"loss": 1.25}]) + +probing.record_span("all_reduce", duration_ns=1_000_000) + +probing.attach_training_phases(model, optimizer) # hook-driven forward/backward/optimizer + +probing.step.micro_step # finest counter +probing.step() # micro_step +1 +probing.step(42) # set micro_step +probing.step(micro_batches=10) # gradient-accumulation grouping +probing.step.local_step # micro_step // micro_batches +probing.step.global_step # = local_step +``` + +### probing.tracing primitives (integrators / plugins) + +| Category | Import | Purpose | +|----------|--------|---------| +| Span | `span`, `event`, `record_span`, `current_span` | Instrumentation | +| Step | `step`, `step_fields` | Training coordinates | +| Phase | `FORWARD`, `BACKWARD`, `OPTIMIZER`, `phases`, `attach_training_phases` | Training phase | +| Integrator | `phases.infer_from_stage()` | Torch stage → training phase | +| Context | `span_attrs`, `row_fields`, `step_fields` | Span and table row context fields | +| Backend | `register_backend`, `configure_backends`, `list_backends`, `reset_backends` | Export plugins; built-in: `memtable`, `logger`, `otel` | +| Table | `TraceEvent`, `SPANS_SQL` | SQL / skills | ```python -with probing.span("forward", kind="nn.forward"): +from probing.tracing import register_backend, configure_backends + +with probing.span("forward", phase=probing.FORWARD, source="my_trainer"): ... -probing.event("batch.stats", attributes=[{"loss": 1.25}]) + +register_backend("my_sink", factory) +configure_backends(["memtable", "logger"]) # terminal + memtable +configure_backends(["memtable", "my_sink"]) ``` ### @table (dataclass plugins) @@ -137,14 +174,10 @@ probing.current_role() probing.clear_role() ``` -### probing.tracing.step_snapshot - -```python -from probing.tracing import step_snapshot +### probing.step -snap = step_snapshot() -# snap.local_step, snap.global_step — use in SQL filters and custom tables -``` +Use ``probing.step()`` instead of the removed ``step_snapshot`` / ``sync_local_step`` helpers. +See **probing.span / probing.event / probing.record_span / probing.step** above. --- diff --git a/docs/src/api-reference.zh.md b/docs/src/api-reference.zh.md index 94c3b64e..f1afd293 100644 --- a/docs/src/api-reference.zh.md +++ b/docs/src/api-reference.zh.md @@ -100,12 +100,49 @@ import probing df = probing.query("SELECT * FROM python.torch_trace LIMIT 10") ``` -### probing.span / probing.event +### probing.span / probing.event / probing.record_span / probing.step + +用户面四个动词: ```python -with probing.span("forward", kind="nn.forward"): +import probing + +with probing.span("forward", phase=probing.FORWARD): + probing.event("batch.stats", attributes=[{"loss": 1.25}]) + +probing.record_span("all_reduce", duration_ns=1_000_000) + +probing.attach_training_phases(model, optimizer) # hook 驱动 forward/backward/optimizer + +probing.step.micro_step # 最细计数 +probing.step() # micro_step +1 +probing.step(42) # 设置 micro_step +probing.step(micro_batches=10) # 梯度累积分组 +probing.step.local_step # micro_step // micro_batches +probing.step.global_step # = local_step +``` + +### probing.tracing 原语(集成 / 插件) + +| 类别 | 导入 | 用途 | +|------|------|------| +| Span | `span`, `event`, `record_span`, `current_span` | 插桩 | +| Step | `step`, `step_fields` | 训练坐标 | +| Phase | `FORWARD`, `BACKWARD`, `OPTIMIZER`, `phases`, `attach_training_phases` | 训练阶段 | +| Integrator | `phases.infer_from_stage()` | Torch stage → 训练 phase | +| Context | `span_attrs`, `row_fields`, `step_fields` | span 与表行上下文字段 | +| Backend | `register_backend`, `configure_backends`, `list_backends`, `reset_backends` | 导出插件;内置:`memtable`、`logger`、`otel` | +| Table | `TraceEvent`, `SPANS_SQL` | SQL / skill | + +```python +from probing.tracing import register_backend, configure_backends + +with probing.span("forward", phase=probing.FORWARD, source="my_trainer"): ... -probing.event("batch.stats", attributes=[{"loss": 1.25}]) + +register_backend("my_sink", factory) +configure_backends(["memtable", "logger"]) # 终端 + memtable +configure_backends(["memtable", "my_sink"]) ``` ### @table(dataclass 插件) @@ -139,12 +176,7 @@ probing.clear_role() ### probing.tracing.step_snapshot -```python -from probing.tracing import step_snapshot - -snap = step_snapshot() -# snap.local_step, snap.global_step — 用于 SQL 过滤与自定义表 -``` +已合并为 ``probing.step()`` — 见上文 **probing.step**。 --- diff --git a/docs/src/contributing.md b/docs/src/contributing.md index d4cd7ceb..5a89bcc2 100644 --- a/docs/src/contributing.md +++ b/docs/src/contributing.md @@ -165,7 +165,7 @@ make test-python-wheel # tests against installed wheel + checkout pure Python ```bash uv pip install torch torchvision # or pip install … -PROBING=1 python examples/events.py +PROBING=1 python examples/tracing.py ``` See [examples/README.md](https://github.com/DeepLink-org/probing/blob/main/examples/README.md). diff --git a/docs/src/contributing.zh.md b/docs/src/contributing.zh.md index 8dfd5dbc..99e39039 100644 --- a/docs/src/contributing.zh.md +++ b/docs/src/contributing.zh.md @@ -165,7 +165,7 @@ make test-python-wheel ```bash uv pip install torch torchvision -PROBING=1 python examples/events.py +PROBING=1 python examples/tracing.py ``` 见 [examples/README.md](https://github.com/DeepLink-org/probing/blob/main/examples/README.md)。 diff --git a/docs/src/design/distributed.zh.md b/docs/src/design/distributed.zh.md index f2633747..837ca318 100644 --- a/docs/src/design/distributed.zh.md +++ b/docs/src/design/distributed.zh.md @@ -101,6 +101,8 @@ ORDER BY avg_ms DESC" 通过 torchrun(`setup_torchrun_cluster`)或 `PUT /apis/nodes` 注册节点,`_rank` / `_role` 才能正确解析。训练脚本中可用 `probing.set_role(...)` 运行时覆盖 role。 +引擎实现与正确性测试要求见 **[联邦查询引擎](federation.zh.md)**。 + ## 同步调试 ### 捕获所有堆栈 diff --git a/docs/src/design/federation.md b/docs/src/design/federation.md new file mode 100644 index 00000000..8e4e65b5 --- /dev/null +++ b/docs/src/design/federation.md @@ -0,0 +1,28 @@ +# Federated Query Engine + +Product design for cross-rank SQL in probing: ask the whole training cluster from a +coordinator with one query — who is slow, on which step, compute vs network, which machine. + +!!! note "Language" + Full design (scenarios, SQL, execution paths, acceptance bar) is in + **[中文版 / Chinese](/zh/design/federation/)**. + +## Summary + +| Path | When | Module | +|------|------|--------| +| Local `probe` | Single-process query | DataFusion | +| **Aggregate pushdown (A)** | Single-table `global.*` + `GROUP BY` / safe aggregates | `aggregate_pushdown.rs` | +| **Federated scan (B)** | Single-table scan, filters, raw rows | `FederatedScanExec` | +| **Broadcast (C)** | JOIN / CTE / multi-table on each rank | `cluster_fanout.rs` | + +Federation tags (fixed): `_host`, `_addr`, `_rank`, `_node_rank`, `_local_rank`, `_role`. + +Chinese doc covers: diagnostic scenarios (straggler, heatmap, slowdown, topology, hang), +engine behavior spec (routing, rewrite, tags, paths A/B/C), and the wan-scale «five SQL» bar. + +## Related + +- **[Federation (中文)](/zh/design/federation/)** +- [Distributed](distributed.md) +- [Modularity](modularity.md) diff --git a/docs/src/design/federation.zh.md b/docs/src/design/federation.zh.md new file mode 100644 index 00000000..41c8dd9e --- /dev/null +++ b/docs/src/design/federation.zh.md @@ -0,0 +1,441 @@ +# 联邦查询引擎 + +**产品定位:** 在 coordinator(通常是 rank 0 / master 探针)上用 **一条 SQL** 问整个训练集群——「谁慢、慢在哪一步、是算力还是网络、是哪台机器」——而不必 SSH 逐 rank 拉日志。 + +**设计契约:** 本文定义用户可见语义;实现以 `probing/core/src/core/federation/` 为准,不一致时以本文推进对齐。 + +术语:[核心概念](../guide/concepts.zh.md) · 用法:[分布式](distributed.zh.md) · [SQL 分析](../guide/sql-analytics.zh.md) + +--- + +## 1. 问题与原则 + +每个 rank 的探针 **只写本地 memtable**(`python.comm_collective`、`nccl.proxy_ops`、`python.trace_event` 等)。跨 rank 分析 = coordinator 把查询拆成: + +1. 各 peer 本地执行 `probe.*` +2. HTTP fan-out 拉结果 +3. 合并行并注入 **联邦标签**(标明行来自哪台机器、哪个 rank) + +**原则** + +| 原则 | 含义 | +|------|------| +| 本地写、按需读 | 训练路径零额外中心存储;只有显式 `cluster query` / `global.*` 才 fan-out | +| SQL 统一入口 | CLI、Web、Skill、进程内 `probing.query()` 最终都走同一套 Engine | +| 部分失败可接受 | 单个 peer 超时/不可达:丢弃该分片,返回 `nodes_failed`,不拖垮整查 | +| 不做跨 rank JOIN | 两张表要在 **同一进程** 内 join;不支持 `global.a JOIN global.b` | + +**入口** + +| 谁 | 怎么用 | 是否 fan-out | +|----|--------|--------------| +| 单 rank 调试 | `probing -t query "…"` | 否 | +| 集群诊断 | `probing -t rank0:8080 cluster query "…"` | 是 | +| Web Training 热力图 | `GET /apis/training/step_matrix?cluster=true` | 是 | +| 进程内 | rank 0 上 `probing.query("… global.…")` | 视 SQL | + +--- + +## 2. 两个 Catalog 与联邦标签 + +### 2.1 `probe` 与 `global` + +| Catalog | 含义 | +|---------|------| +| **`probe.*`** | 当前进程本地表 | +| **`global.*`** | 同一张表的 **联邦镜像**:本地 scan + 各 peer lazy fetch,合并后返回 | + +用户写 `FROM python.comm_collective` 且 `cluster=true` 时,引擎 rewrite 为 `global.python.comm_collective` 再路由。 + +Peer 上 **永远** 执行 `probe.*`,避免递归联邦。 + +### 2.2 六列联邦标签(固定) + +与表内 `rank`、`role` 等采集列严格区分——标签表示 **「这行是哪个探针返回的」**。 + +| 标签 | 含义 | 来源 | +|------|------|------| +| `_host` | 源 hostname | `cluster.nodes.host` | +| `_addr` | 探针 `host:port` | `cluster.nodes.addr` | +| `_rank` | 全局 torch rank | `RANK` | +| `_node_rank` | 节点/worker group rank | `GROUP_RANK` | +| `_local_rank` | 节点内 GPU 序号 | `LOCAL_RANK` | +| `_role` | 并行 key,如 `dp=2,pp=1,tp=0` | 注册 / `set_role` | + +整型缺失 → `-1`;`_role` 缺失 → `""`。 + +**投影规则** + +- `SELECT * FROM global.t` → 自动展开六列标签(`EXCLUDE` rewrite) +- `SELECT rank, avg_ms …` → **不**自动带标签;需要则显式写 `_rank` 等 +- 标签在 **coordinator 合并时** inject,不写入 peer memtable + +--- + +## 3. 诊断场景(用户要什么 → 用什么 SQL) + +复杂问题用 **诊断链** 分步收窄,而不是一条 SQL 扫全集群 raw 行。下列场景对标生产 LLM 集群实践([OSDI'25 Straggler / SMon](https://www.usenix.org/conference/osdi25/presentation/lin-jinkun)、[MegaScale](https://arxiv.org/pdf/2402.15627)、NCCL collective 归因),语义对齐即可,非复刻某套系统。 + +### 3.1 Straggler:从 rank 到机器到热力图 + +**链:** rank 榜 → 慢节点 → step×rank 热力图 → 按 op 分解 → NCCL culprit/victim + +**① 各 rank collective 谁最慢** + +```sql +SELECT _role, _rank, rank, + avg(duration_ms) AS avg_ms, max(duration_ms) AS max_ms +FROM global.python.comm_collective +WHERE global_step >= (SELECT max(global_step) - 50 FROM global.python.comm_collective) +GROUP BY _role, _rank, rank +ORDER BY avg_ms DESC; +``` + +**② 慢的是整台机器还是单卡** — 按 `_host` 聚合 + +```sql +SELECT _host, _node_rank, + count(DISTINCT _rank) AS ranks_on_host, + avg(duration_ms) AS avg_comm_ms +FROM global.python.comm_collective +WHERE global_step >= (SELECT max(global_step) - 100 FROM global.python.comm_collective) +GROUP BY _host, _node_rank +ORDER BY avg_comm_ms DESC; +``` + +**③ Step × Rank 热力图** — Training 页矩阵;底层需同 rank 内 span JOIN + +```sql +SELECT s.attributes, s.time AS start_time, + CAST((e.time - s.time) / 1000 AS DOUBLE) AS duration_us +FROM python.trace_event s +JOIN python.trace_event e + ON s.span_id = e.span_id AND e.record_type = 'span_end' +WHERE s.record_type = 'span_start' AND s.name = 'train.step'; +``` + +Coordinator 聚成 `(rank, step) → duration_ms` 供 UI 着色。HTTP:`GET /apis/training/step_matrix?cluster=true`。 + +**④ 仅用 collective 的 step×rank long format**(前端 pivot) + +```sql +SELECT global_step, _rank, _host, op, sum(duration_ms) AS comm_ms +FROM global.python.comm_collective +WHERE global_step >= (SELECT max(global_step) - 120 FROM global.python.comm_collective) +GROUP BY global_step, _rank, _host, op; +``` + +**⑤ NCCL:算力慢还是等网络** + +```sql +SELECT seq, coll_func, _rank, + sum(send_gpu_wait_ns) AS gpu_wait, + sum(recv_wait_ns) AS recv_wait +FROM global.nccl.proxy_ops +WHERE seq >= (SELECT max(seq) - 50 FROM global.nccl.proxy_ops) +GROUP BY seq, coll_func, _rank; +``` + +### 3.2 Slowdown:job 有多慢、是持久还是偶发 + +对标 Byte / SMon 的 what-if:**理想 step 时间** vs **实际** 的比值。SQL 侧用 barrier proxy(非完整离散事件模拟)。 + +**Per-step slowdown ratio** + +```sql +WITH per_rank_step AS ( + SELECT global_step, _rank, max(duration_ms) AS max_op_ms + FROM global.python.comm_collective + WHERE global_step >= (SELECT max(global_step) - 500 FROM global.python.comm_collective) + GROUP BY global_step, _rank +), +step_cp AS ( + SELECT global_step, + avg(max_op_ms) / NULLIF(min(max_op_ms), 0) AS step_slowdown_ratio + FROM per_rank_step + GROUP BY global_step +) +SELECT avg(step_slowdown_ratio) AS job_slowdown_proxy, + count(*) FILTER (WHERE step_slowdown_ratio > 1.1) AS slow_steps +FROM step_cp; +``` + +**某 rank 是否「长期」最慢** — `worst_fraction`:在多少 step 上是最慢 rank + +```sql +-- 含窗口函数;万卡规模宜拆成两步:先 A 出 per_step,再 coordinator 算 +WITH per_step AS ( + SELECT global_step, _rank, _host, sum(duration_ms) AS step_ms + FROM global.python.comm_collective + WHERE global_step >= (SELECT max(global_step) - 300 FROM global.python.comm_collective) + GROUP BY global_step, _rank, _host +) +SELECT _rank, _host, + sum(CASE WHEN step_ms = max(step_ms) OVER (PARTITION BY global_step) THEN 1 ELSE 0 END) + * 1.0 / count(*) AS worst_fraction +FROM per_step +GROUP BY _rank, _host +ORDER BY worst_fraction DESC; +``` + +### 3.3 并行拓扑:PP / TP / DP 与 NCCL seq + +```sql +-- 按 _role 看各 step 通信量 +SELECT _role, _rank, global_step, sum(duration_ms) AS comm_ms +FROM global.python.comm_collective +WHERE global_step >= (SELECT max(global_step) - 80 FROM global.python.comm_collective) +GROUP BY _role, _rank, global_step; + +-- NCCL 细粒度:pp/tp/dp × seq(需 NCCL profiler 插件) +SELECT pp_rank, tp_rank, dp_rank, rank, seq, coll_func, + sum(send_gpu_wait_ns) AS gpu_wait, sum(recv_wait_ns) AS net_wait +FROM global.nccl.proxy_ops +WHERE seq >= (SELECT max(seq) - 100 FROM global.nccl.proxy_ops) +GROUP BY pp_rank, tp_rank, dp_rank, rank, seq, coll_func; +``` + +### 3.4 算力 vs 通信 vs 机器资源 + +**Byte 结论:straggler 常来自计算而非 collective 墙钟。** 必须在 **同 rank 内** join: + +```sql +SELECT c.global_step, c.rank, c.role, + sum(c.duration_ms) AS comm_ms, + sum(CASE WHEN t.stage LIKE 'post forward' THEN t.duration ELSE 0 END) AS compute_sec +FROM python.comm_collective c +JOIN python.torch_trace t + ON c.global_step = t.global_step AND c.rank = t.rank AND c.role = t.role +WHERE c.global_step >= (SELECT max(global_step) - 50 FROM python.comm_collective) +GROUP BY c.global_step, c.rank, c.role; +``` + +Coordinator 收到各 peer 结果后,再 `GROUP BY _rank` 汇总。`comm + gpu` 对齐需 coordinator 侧 merge 两路聚合结果(v2);暂不支持 `global.comm JOIN global.gpu`。 + +### 3.5 Hang 与掉队 + +```sql +-- 哪些 rank 的 global_step 明显落后 +SELECT _host, _rank, _local_rank, max(global_step) AS last_step +FROM global.python.comm_collective +GROUP BY _host, _rank, _local_rank +HAVING max(global_step) < (SELECT max(global_step) - 5 FROM global.python.comm_collective); + +-- 栈是否卡在 collective +SELECT _host, _rank, func, file, lineno +FROM global.python.backtrace +WHERE func LIKE '%collective%' OR func LIKE '%nccl%'; +``` + +--- + +## 4. 引擎行为规格 + +本节从 §3 诊断需求反推 **引擎必须保证的用户可见语义**。实现入口:`probing/core/src/core/federation/`(执行)、`probing/server/src/server/cluster_fanout.rs`(`cluster=true` 路由)。 + +### 4.1 处理流水线 + +```mermaid +flowchart LR + IN[用户 SQL] --> CF{cluster?} + CF -->|否| L0[本地 DataFusion / probe.*] + CF -->|是| RT[路径选型 A/B/C] + RT --> RW[Catalog 改写] + RW --> EX[各分片执行] + EX --> TG[注入联邦标签] + TG --> MG[合并 / 二次聚合] + MG --> OUT[DataFrame + FanoutMeta] +``` + +**统一约定** + +| 项 | 语义 | +|----|------| +| 响应体 | `dataframe` + `meta.nodes_queried` + `meta.nodes_failed` | +| Peer 集合 | `cluster.nodes` 快照,**排除** coordinator 自身 listen addr | +| Peer 执行 | 永远 `probe.*`;禁止 peer 再 fan-out(防递归) | +| 并发 | 各 peer 并行请求;总延迟 ≈ 最慢 peer + coordinator 合并 | +| 超时 | 单 peer 超时记 `nodes_failed`,不拖垮整查(默认 2s,可 `PROBING_REMOTE_QUERY_TIMEOUT_SECS` 覆盖) | + +### 4.2 路径选型 + +`cluster=false` → 仅本地,不走联邦。 + +`cluster=true` 时按 **AST 解析**(非 substring)依次判断: + +```mermaid +flowchart TD + Q[SQL] --> P{单语句 SELECT?} + P -->|否| C[路径 C] + P -->|是| M{单表且无 JOIN/CTE/UNION/子查询?} + M -->|否| C + M -->|是| A{单表 global.* + 可下推 GROUP BY/聚合?} + A -->|是| PA[路径 A] + A -->|否| PB[路径 B] +``` + +| 路径 | 进入条件 | §3 典型场景 | +|------|----------|-------------| +| **A 聚合下推** | 单表 `global.*`;含 `GROUP BY` 和/或 `count/sum/min/max/avg`;聚合可分布式 merge | ①②④⑤、3.3 role 视图 | +| **B 联邦 scan** | 单表 `global.*`;不满足 A(无聚合、或聚合不可 merge、或含 `ORDER BY`/`LIMIT` 且 A 不接) | 拉 raw 行、带 filter 的单表探查 | +| **C broadcast** | JOIN、逗号 join、UNION、CTE、标量/IN/EXISTS 子查询、多语句、解析失败 | ③ span 热力图、3.4 compute join、3.2 含窗口的 CTE | + +解析失败时 **保守走 C**(正确性优先于性能)。 + +### 4.3 Catalog 改写 + +| 阶段 | 输入 | 输出 | +|------|------|------| +| 用户 → coordinator | `python.t` / `probe.t` | `global.{schema}.t`(已知 schema:`cluster/process/files/python/memtable/gpu/rdma`) | +| 用户已写 `global.*` | 不变 | — | +| coordinator → peer | 任意 `global.*` | `probe.*`(字符串替换) | +| `SELECT *` + `global.*` | `SELECT * FROM global.t` | `SELECT * EXCLUDE (六标签), 六标签 FROM global.t` | +| 显式列清单 | `SELECT rank, avg_ms …` | **不**自动追加标签;用户需显式写 `_rank` 等 | + +路径 A 生成 per-node SQL 时:**去掉** SELECT/GROUP BY 中的标签列(peer 表无标签);coordinator merge 后再 inject。 + +### 4.4 联邦标签 + +六列固定、顺序稳定:`_host`, `_addr`, `_rank`, `_node_rank`, `_local_rank`, `_role`。 + +| 规则 | 说明 | +|------|------| +| 注入时机 | coordinator 收到每个分片后,**合并前**按分片来源填同一值 | +| 缺失值 | 整型标签 → `-1`;`_role` → `""` | +| 与数据列区分 | 表内 `rank`/`role` 是采集值;`_rank`/`_role` 是探针 endpoint 身份 | +| 路径一致性 | A / B / C 注入语义相同;`SELECT *` 展开顺序相同 | +| 仅按标签 GROUP BY | 如 `GROUP BY _host`:per-node 不做该 GROUP BY;coordinator 对 inject 后的 partial 再聚合 | + +### 4.5 路径 A — 聚合下推 + +**目标:** §3 中 80% 集群诊断(榜、热力图 long format、NCCL sum)走此路径,避免万卡 raw 行上收。 + +**Per-node SQL 生成** + +1. `FROM probe.{schema}.{table}`(由 `global.*` 改写) +2. 保留原 WHERE(整句下推) +3. SELECT 投影:去掉标签列;保留聚合表达式与非标签 GROUP BY 列 +4. `GROUP BY`:仅 **数据列**(非 `_host` 等) + +**Coordinator merge** + +| 原聚合 | merge 函数 | +|--------|------------| +| `count`, `sum` | `sum` | +| `min` | `min` | +| `max` | `max` | +| `avg` | 不支持精确分布式 merge → **不走 A** | +| `count(distinct)` 且 GROUP BY 含数据列 | 不支持 → **不走 A** | + +merge 后再 `GROUP BY` 数据列 + 用户请求的标签列(若有)。 + +**标签 inject:** merge 前给每个 partial 打六列;若 SELECT/GROUP BY 引用标签,coordinator merge SQL 保留这些列作 group key。 + +**ORDER BY / LIMIT(目标语义)** + +| 子句 | 行为 | +|------|------| +| `ORDER BY` | 在 coordinator **merge 完成之后**排序;不下推到 peer | +| `LIMIT` | 在 coordinator **merge + ORDER BY 之后**截断;全局 top-K | +| 仅 `ORDER BY`/`LIMIT`、无聚合 | 不满足 A → 走路径 B 或 C | + +**部分失败:** 成功 peer 的 partial 仍参与 merge;失败 addr 写入 `nodes_failed`。 + +### 4.6 路径 B — 联邦 scan + +**目标:** 单表探查、带 filter 的 raw 采样;lazy 拉取控制 coordinator 峰值内存。 + +**执行模型** + +| 分区 | 来源 | +|------|------| +| partition 0 | 本地 `probe.*` scan | +| partition 1..N | 每个 peer 一条 lazy partition,首次 poll 才 HTTP 拉取 | + +**下推** + +| 算子 | 规则 | +|------|------| +| Filter | 仅当可 **精确** 翻译为 peer SQL 的谓词才下推(schema 可解析) | +| Projection | 用户选什么列就下推什么列;标签列只在 coordinator 侧 append | +| `LIMIT` | **全局 top-K**:仅 coordinator 在合并流上截断;**不下推**到各 peer(避免每 rank 各取 K 条导致结果错误) | + +**Scan 后算子:** DataFusion 在 federated scan 之上继续执行剩余计划(filter、sort、limit 等),语义以 coordinator 全局为准。 + +### 4.7 路径 C — broadcast + +**目标:** 必须在 **同一进程 memtable** 内完成的 SQL(JOIN、CTE、窗口)。 + +**行为** + +1. coordinator 与每个 peer 执行 **同一句** 用户 SQL(peer 侧若含 `global.*` 先改 `probe.*`) +2. 各分片独立出结果集 +3. coordinator **行拼接**(列对齐,缺列填空) +4. 每个分片 inject 六列标签 + +**不做二次关系代数:** broadcast 结果是 concat,不在 coordinator 对两表做 join。 + +**万卡:** 复杂 CTE + 窗口(§3.2 `worst_fraction`)允许走 C,但产品建议 **拆诊断链**(先 A 出 per_step,再 coordinator 聚合)以控延迟与内存。 + +### 4.8 退化与边界 + +| 情况 | 行为 | +|------|------| +| 无 peer 注册 | 等同单节点;`nodes_queried=1` | +| 全部 peer 失败 | 仅本地分片(若有);或空结果 + 正确 schema | +| 部分 peer 失败 | 合并成功分片;`nodes_failed` 列出 addr | +| 空表 | 空 DataFrame;schema 含用户请求的列 + 标签列(若适用) | +| `cluster.nodes` 与训练 rank 不一致 | 以探针注册为准;标签反映 endpoint 而非理想 torch rank | + +### 4.9 明确不做 + +| 不做 | 原因 | +|------|------| +| `global.a JOIN global.b` | 跨 rank 无共址 join key;改用路径 C 同进程 join | +| 不可精确翻译的 filter 下推 | 避免 silent wrong results | +| 分布式 `count(distinct)` merge | 语义不可分解;fallback B/C | +| 全集群 raw + 多层 CTE 一次跑完 | 万卡内存/延迟不可控;拆 §3 诊断链 | +| Peer 递归 fan-out | HTTP 环路与放大 | + +### 4.10 需求 → 路径速查 + +| §3 需求 | 路径 | 引擎关键点 | +|---------|------|------------| +| rank / 慢节点榜 | A | `GROUP BY _rank` / `_host`;merge + inject | +| step×rank comm 热力图 long | A | `GROUP BY global_step, _rank, _host, op` | +| span 热力图 | C | 同 rank `trace_event` self-join;concat + tag | +| job slowdown / worst_fraction | C 或 A 两步 | 窗口函数 → C;拆链后第二步 coordinator 聚合 | +| compute vs comm | C | `probe.comm JOIN probe.torch_trace`;再 `GROUP BY _rank` | +| hang / backtrace | A 或 B | 单表 `GROUP BY _rank` 或 filter raw | + +--- + +## 5. 万卡验收(最小五连) + +增强引擎时,**先保证这五条**在 mock 多节点与真 cluster 上数值正确: + +| # | 场景 | 核心 SQL | 路径 | +|---|------|----------|------| +| 1 | step×rank comm 热力图 | §3.1 ④ | A | +| 2 | 慢节点榜 | §3.1 ② | A | +| 3 | job slowdown proxy | §3.2 首段 | C 或 A 两步 | +| 4 | compute vs comm | §3.4 | C + coordinator 再聚合 | +| 5 | NCCL compute/network | §3.1 ⑤ | A | + +**标签:** 六列在路径 A/B/C 一致;`SELECT *` 展开顺序稳定;broadcast JOIN 与 scan 标签一致(§4.4)。 + +**联邦:** mock ≥2 peer 时合并行数、`_addr` 来源、`nodes_failed` 行为符合 §4.8;路径 A 的 `sum` 与 naive 全量 `GROUP BY` 一致;路径 B 的 `LIMIT` 为全局 top-K(§4.6)。 + +--- + +## 6. 相关文档 + +| 文档 | 内容 | +|------|------| +| [分布式架构](distributed.zh.md) | `cluster query` 用法 | +| [核心概念](../guide/concepts.zh.md) | 用户向联邦说明 | +| [SQL 表目录](../reference/sql-tables.zh.md) | 表列与 `cluster.nodes` | +| [NCCL Profiler](nccl-profiler.zh.md) | §3.1 ⑤、`nccl.proxy_ops` | +| [API — step_matrix](../api-reference.zh.md) | §3.1 ③ 热力图 | + +实现模块:`probing/core/src/core/federation/` · 控制面:`probing/server/src/server/cluster_fanout.rs` diff --git a/docs/src/design/index.md b/docs/src/design/index.md index 5464b9b7..416b20e1 100644 --- a/docs/src/design/index.md +++ b/docs/src/design/index.md @@ -49,6 +49,7 @@ Shared vocabulary (endpoint, steps, role, federation): **[Core Concepts](../guid | [Profiling](profiling.md) | Performance data collection | | [Debugging](debugging.md) | Debugging capabilities | | [Distributed](distributed.md) | Multi-node support | +| [Federated query engine](federation.md) | Cross-rank SQL: scenarios, execution paths, acceptance bar | | [NCCL Profiler](nccl-profiler.md) | NCCL plugin, culprit/victim, `nccl.proxy_ops` | | [Cluster with Pulsing](cluster-pulsing.md) | Using Pulsing for membership and failure detection | | [Extensibility](extensibility.md) | Custom tables and metrics | diff --git a/docs/src/design/index.zh.md b/docs/src/design/index.zh.md index 3cd32319..24e8e91d 100644 --- a/docs/src/design/index.zh.md +++ b/docs/src/design/index.zh.md @@ -49,6 +49,7 @@ Probing 的核心使命很简单:**让分布式系统重新变得 Pythonic** | [性能分析](profiling.zh.md) | 性能数据收集 | | [调试](debugging.zh.md) | 调试能力 | | [分布式](distributed.zh.md) | 多节点支持 | +| [联邦查询引擎](federation.zh.md) | 跨 rank SQL:诊断场景、三条执行路径、万卡验收 | | [NCCL Profiler](nccl-profiler.zh.md) | NCCL 插件、culprit/victim、`nccl.proxy_ops` | | [基于 Pulsing 的集群](cluster-pulsing.zh.md) | 使用 Pulsing 做成员发现与故障检测 | | [扩展机制](extensibility.zh.md) | 自定义表和指标 | diff --git a/docs/src/design/modularity.md b/docs/src/design/modularity.md index ff15f77e..bed4bb91 100644 --- a/docs/src/design/modularity.md +++ b/docs/src/design/modularity.md @@ -84,7 +84,7 @@ Key core submodules: | Submodule | Path | Contract | |-----------|------|----------| | Engine | `core/engine.rs` | `async_query`, `enable(ProbeDataSource)` | -| Federation | `core/federation/` | `global.*` catalog, tags `_host/_addr/_rank/_role` | +| Federation | `core/federation/` | `global.*` catalog, tags `_host/_addr/_rank/_role` — see [Federated query engine](federation.md) | | Memtable SQL | `core/memtable_sql.rs` | mmap files → `TableProvider` | | Config | `config.rs` | `get` / `set` / `write` KV + extension options | @@ -105,7 +105,7 @@ Python-side collectors (same layer, different language): | Unit | Path | Tables | |------|------|--------| | Torch tracing | `python/probing/profiling/` | `python.torch_trace`, `python.comm_collective` | -| Tracing spans | `python/probing/tracing.py` | `python.trace_event` | +| Tracing spans | `python/probing/tracing/` | `python.trace_event` | | Parallel role | `python/probing/parallel.py` | stamps `role` on rows | | User plugins | `python/probing/ext/` | `python.` via `@table` | diff --git a/docs/src/design/modularity.zh.md b/docs/src/design/modularity.zh.md index d73cd1cd..9aea1678 100644 --- a/docs/src/design/modularity.zh.md +++ b/docs/src/design/modularity.zh.md @@ -80,7 +80,7 @@ flowchart TB | 子模块 | 路径 | 契约 | |--------|------|------| | Engine | `core/engine.rs` | `async_query`、`enable(ProbeDataSource)` | -| Federation | `core/federation/` | `global.*`、标签 `_host/_addr/_rank/_role` | +| Federation | `core/federation/` | `global.*`、六列联邦标签 — [联邦查询引擎](federation.zh.md) | | Memtable SQL | `core/memtable_sql.rs` | mmap → `TableProvider` | | Config | `config.rs` | KV + extension options | @@ -100,7 +100,7 @@ Python 侧采集(同层,不同语言): | 单元 | 路径 | 表 | |------|------|-----| | Torch | `python/probing/profiling/` | `torch_trace`、`comm_collective` | -| Span | `python/probing/tracing.py` | `trace_event` | +| Span | `python/probing/tracing/` | `trace_event` | | 并行 role | `python/probing/parallel.py` | 写入 `role` 列 | | 用户插件 | `python/probing/ext/` | `@table` → `python.*` | diff --git a/docs/src/design/training-phase.zh.md b/docs/src/design/training-phase.zh.md new file mode 100644 index 00000000..207a29ca --- /dev/null +++ b/docs/src/design/training-phase.zh.md @@ -0,0 +1,94 @@ +# 训练 Phase 语义(Tracing) + +本文定义 `probing.phase()`、`train.step` 与 hook/span 协作的 **不变量**。实现见 `python/probing/tracing/phases.py`。 + +## 核心对象 + +| 概念 | 含义 | +|------|------| +| **phase** | 训练阶段枚举:`forward` / `backward` / `optimizer`(span 字段) | +| **`probing.phase()`** | 当前 span 栈上**最内层**带 training phase 的 span;无则为 `idle` | +| **`train.step`** | 分析用 span **名称**(不是 phase);表示一次 logical iteration 的 wall time | +| **`probing.step()`** | 坐标计数器;在 **OPTIMIZER span 退出** 时 +1 `micro_step` | + +## Span 命名(API spec) + +```python +# 规范形式:phase 给定则 name 默认为 phase +with probing.span(phase=probing.FORWARD): + ... + +# 分析用名称(非 training phase) +with probing.span("epoch"): + ... + +# 显式 display name + phase +with probing.span("compute", phase=probing.BACKWARD): + ... +``` + +`resolve_span(name, phase)` 规则: + +1. 仅 `phase` → `(name=phase, phase=phase)` +2. 仅 `name` → `(name, infer(name))` +3. 两者皆有 → `(name, resolve(name, phase))`,至少其一必填 + +## 不变量 + +1. **`phase()` 来自 span 栈**,不是独立全局变量;batch 结束后显示 `idle` 是预期行为。 +2. **`train.step` 起止**:从本 logical iteration 的**第一次 forward**(hook 进入)到 **optimizer hook 退出**;中间梯度累积的 forward/backward **不重置**计时器。 +3. **每个 optimizer 退出**最多写一条 `train.step`(需先出现过 forward);无 forward 的 optimizer 不写。 +4. **同一 phase 同时只有一个活跃 span**:`phase_hook` 在已有同 phase span(manual / torch_probe)时不重复开 span。 +5. **`micro_step`**:每次 OPTIMIZER span 退出 +1;**`local_step = micro_step // micro_batches`**(设置 `probing.step(micro_batches=k)` 对应梯度累积因子)。 + +## TorchProbe × phase hook(ownership) + +| 能力 | phase hook | TorchProbe | +|------|------------|------------| +| iteration phase span(forward/backward/optimizer) | **拥有** | 当 `owns_training_phases(module=…)` 为真时**跳过** | +| `train.step` closed span | **拥有** | 不写 | +| 模块级 `torch_trace` 表(timing / mem) | — | **拥有** | +| 非 training 模块 span(如 init) | — | **拥有** | + +检测 API:`probing.owns_training_phases(model=…)` / `optimizer=…` / `module=…`。 + +典型组合: + +```python +probing.attach_training_phases(model, optimizer) # iteration phase + train.step +configure("on") # TorchProbe 仅写 torch_trace,不再开 training phase span +``` + +仅 TorchProbe、未 attach phase hook 时:TorchProbe 仍会开 training phase span(legacy 路径)。 + +## 组合规则(source) + +| source | 用途 | +|--------|------| +| `manual` | 用户 `probing.span(..., phase=...)` | +| `phase_hook` | `attach_training_phases` hook | +| `torch_probe` | 模块级 TorchProbe span(training phase 可被 hook 抑制) | + +同 phase 已存在活跃 span 时,hook **不再**开同名 phase span。 + +## 梯度累积示例 + +```python +probing.step(micro_batches=4) +probing.attach_training_phases(model, optimizer) + +for i, batch in enumerate(loader): + loss = model(batch) / 4 + loss.backward() + if (i + 1) % 4 == 0: + optimizer.step() + optimizer.zero_grad() +``` + +- 每个 micro-batch:forward/backward phase span 各一对。 +- 仅第 4、8、… 次 micro-batch 触发 optimizer 与 `train.step`。 +- `train.step` attrs 含 `accum_index`、`micro_step`、`local_step`。 + +## 性能:`inspect.stack()` + +自动 `location` **默认关闭**。仅在 `PROBING_SPAN_LOCATION=1` 或显式 `location=` 时,`span.py` 的 `_caller_location()` 会遍历 `inspect.stack()`。TorchProbe 变量追踪在 `torch_probe.py` 另有独立 stack walk。 diff --git a/docs/src/guide/concepts.zh.md b/docs/src/guide/concepts.zh.md index 27fd53fe..c558234e 100644 --- a/docs/src/guide/concepts.zh.md +++ b/docs/src/guide/concepts.zh.md @@ -1,162 +1,23 @@ -# 核心概念 - -本页是教程、指南与设计文档共用的**术语锚点**。建议先读这里,再进入 -[SQL 分析](sql-analytics.zh.md) 或 [分布式](../design/distributed.zh.md)。 - -## 1. Endpoint(端点) - -**CLI** 通过 **endpoint** 连接目标进程上的 probing 服务: - -| 形式 | 示例 | 说明 | -|------|------|------| -| 本地 PID | `12345` | `probing -t 12345 query "…"` | -| host:port | `node-a:8080` | 远程 TCP;训练启动时设 `PROBING_PORT` | - -```bash -export ENDPOINT=12345 # 或 host:8080 -probing $ENDPOINT query "SELECT 1" -``` - -**进程内**(训练脚本):`PROBING=1`(或 Linux 上 inject)后 `import probing`,直接用 -`probing.query()`,无需 endpoint 字符串。 - -**没有** `probing.connect()`;访问远程进程一律用 CLI `-t `。 - ---- - -## 2. 三种 CLI 命令 - -| 命令 | 用法 | 数据 | -|------|------|------| -| **query** | `probing $ENDPOINT query ""` | 已采集的表数据 | -| **eval** | `probing $ENDPOINT eval ""` | 在目标进程执行一次性 Python | -| **backtrace** | `probing $ENDPOINT backtrace` | 瞬时栈 → `python.backtrace` | - -它们是进程外的**主要 CLI 入口**,不等于 Probing 的全部能力(持续采集、联邦、`global.*`、 -skill 等见下文各节)。典型组合:`backtrace` 抓现场 → `eval` 看 live 对象 → `query` 做历史分析。 - ---- - -## 3. 数据表(`python.*`) - -探针数据在 **`python` schema** 下的 **append-only** SQL 表中(另有 `cpu.utilization`、 -`cluster.nodes`、`nccl.proxy_ops` 等内置/扩展表)。 - -| 表 | 记录内容 | -|----|----------| -| `python.torch_trace` | 模块 hook 耗时 + GPU 显存 | -| `python.comm_collective` | `torch.distributed` collective 墙钟时间 | -| `python.trace_event` | Span 起止与自定义事件 | -| `python.backtrace` | 最近一次捕获的栈(非全历史) | -| `python.variables` | 监视变量快照(启用时) | - -自定义插件同样:`@table` dataclass + `.save()` → `python.<表名>`。 -列说明见 **[SQL 表目录](../reference/sql-tables.zh.md)**。 - -表**不是**懒加载快照——事件发生时推送行(hook、collective、span 结束等)。 - ---- - ## 4. Step 坐标 -训练分析需要统一的 **step 索引**。权威来源是 Rust `step_snapshot()`(不是单独的 Python 计数器)。 - -| 字段 | 含义 | -|------|------| -| `local_step` | 每 rank 本地步(与 optimizer step 对齐) | -| `global_step` | 集群全局步(协调时) | - -数据行上: - -- `python.torch_trace.step` → 本地步 -- `python.torch_trace.global_step`、`python.comm_collective` 的 `local_step` / `global_step` +训练分析使用三级 step 索引,权威来源是 Rust 坐标(通过 ``probing.step`` 访问)。 -进程内: +| 字段 | API | 含义 | +|------|-----|------| +| `micro_step` | `probing.step.micro_step` | 最细计数;每次 ``probing.step()`` 或 ``train.step`` span 结束 +1 | +| `local_step` | `probing.step.local_step` | 训练步(每 rank):``micro_step // micro_batches`` | +| `global_step` | `probing.step.global_step` | 与 ``local_step`` 相同(rank 对齐时即集群训练步) | +| `micro_batches` | `probing.step(micro_batches=k)` | 梯度累积倍数:每 k 个 micro_step 合成 1 个 local/global step | ```python -from probing.tracing import step_snapshot -s = step_snapshot() -print(s.local_step, s.global_step, s.rank) -``` - -SQL 与 skill 请用上述字段,**不要**用 `trainer.current_step`。 - ---- - -## 5. Parallel role(并行角色) - -分布式训练里每个进程有并行拓扑(TP / PP / DP / EP / …)。Probing 用可扩展字符串 **`role`** -表示,而不是每种维度一列。 - -**格式:** 按名字排序的 `name=value`,如 `dp=2,pp=1,tp=0`;未设置时为 `""`。 - -| 来源 | 方式 | -|------|------| -| 环境变量 | Megatron 风格 `*_PARALLEL_RANK`,或 `PROBING_ROLE_=` | -| 运行时 | `probing.set_role("dp=2,pp=1,tp=0")` 或 `set_role(dp=2, pp=1)` | -| 读取 | `probing.current_role()`;`clear_role()` 恢复 env 推导 | - -`role` 写入 **`python.torch_trace`**、**`python.comm_collective`**,便于同 rank 上按 role -JOIN / GROUP BY。 - -与 torchrun 的 **`role_name`** / `role_rank`(`cluster.nodes` 上)不同——那是 Elastic 作业 -字段;Probing 的 `role` 是**并行放置 key**,用于分析对齐。 - ---- - -## 6. 联邦查询(`global.*` 与标签) +import probing -跨 rank SQL 使用 **`global` catalog**,例如 `global.python.comm_collective`:向已注册节点 -fan-out 后合并。 - -每行附加**联邦标签**,标识数据来源: - -| 标签 | 含义 | -|------|------| -| `_host` | 来源主机名 | -| `_addr` | 来源 `host:port` | -| `_rank` | `torch.distributed` rank(节点注册表) | -| `_role` | 并行角色 key(注册表 / `set_role`) | - -示例: - -```sql -SELECT _role, _rank, avg(duration_ms) AS avg_ms -FROM global.python.comm_collective -WHERE global_step > 100 -GROUP BY _role, _rank -ORDER BY avg_ms DESC; +probing.step(micro_batches=10) # 10 个 micro-batch → 1 个 training step +probing.step() # micro_step +1 +probing.step(42) # 设置 micro_step +print(probing.step.micro_step, probing.step.local_step, probing.step.global_step) ``` -节点通过 torchrun(`setup_torchrun_cluster`)或 `PUT /apis/nodes` 注册。CLI: -`probing -t cluster nodes` / `cluster query "…"`。详见 -[分布式](../design/distributed.zh.md)。 - -行内列 **`role`** = 该 rank **写入时**的值;标签 **`_role`** = **联邦查询时**节点注册表 -的值(`set_role` 后会 best-effort 重注册以保持一致)。 - ---- - -## 7. 表插件 vs 诊断 skill - -| | **表插件**(路径 1) | **诊断 skill**(路径 2) | -|--|---------------------|-------------------------| -| 贡献 | dataclass 表 + 写行 | `SKILL.md` + 可选 `steps.yaml` | -| 产出 | `python.my_table` | 结论 + SQL 步骤 / Agent 指引 | -| 使用 | `SELECT …` | `probing skill run ` | -| 适用 | 新**指标/事件**要落库 | 新**排查配方** | - -可选**路径 3**:NCCL profiler cdylib → `nccl.proxy_ops`(culprit/victim)。见 -[扩展机制](../design/extensibility.zh.md)。 - ---- - -## 下一步 +SQL 表(``python.comm_collective``、``python.torch_trace``、span attributes)统一使用上述字段名。 -| 目标 | 文档 | -|------|------| -| SQL 模式 | [SQL 分析](sql-analytics.zh.md) | -| 表结构 | [SQL 表目录](../reference/sql-tables.zh.md) | -| 多机 | [分布式](../design/distributed.zh.md) | -| 写插件 | [扩展机制](../design/extensibility.zh.md) | -| CLI / API | [API 参考](../api-reference.zh.md) | +SQL 与 skill 请用 ``local_step`` / ``global_step`` 做训练步过滤,**不要**用 ``trainer.current_step``。 diff --git a/docs/src/reference/sql-tables.md b/docs/src/reference/sql-tables.md index 4b6b887b..695f16eb 100644 --- a/docs/src/reference/sql-tables.md +++ b/docs/src/reference/sql-tables.md @@ -118,7 +118,7 @@ Span start/end and custom events (distributed tracing). | `trace_id` | Trace id shared by related spans | | `span_id` | Unique span id | | `name` | Span or event name | -| `kind` | Semantic kind (e.g. `train.step`, `comm.all_reduce`) | +| `phase` | Training phase (`forward`, `backward`, `optimizer`) or empty | | `time` | Timestamp (nanoseconds since epoch) | | `attributes` | JSON metadata (rank, local_step, …) | diff --git a/docs/src/reference/sql-tables.zh.md b/docs/src/reference/sql-tables.zh.md index 9e5537a5..a38a6077 100644 --- a/docs/src/reference/sql-tables.zh.md +++ b/docs/src/reference/sql-tables.zh.md @@ -114,7 +114,7 @@ Span 起止与自定义事件(分布式 tracing)。 | `trace_id` | 同一 trace 内共享 | | `span_id` | Span 唯一 id | | `name` | Span / 事件名 | -| `kind` | 语义类型(如 `train.step`、`comm.all_reduce`) | +| `phase` | 训练阶段(`forward`、`backward`、`optimizer`)或空 | | `time` | 时间戳(纳秒) | | `attributes` | JSON 元数据(rank、local_step 等) | diff --git a/examples/README.md b/examples/README.md index 90bb716e..d1b114d2 100644 --- a/examples/README.md +++ b/examples/README.md @@ -6,17 +6,18 @@ Runnable scripts under `examples/`. They are **not** installed with `pip install | Script | Extra packages | Notes | |--------|----------------|-------| -| `events.py`, `hooks.py`, `test_probing.py` | none (beyond probing) | Good smoke tests | +| **`tracing.py`** | `torch` | **Tracing 入门**(hook 驱动 phase,~80 行) | +| `hooks.py`, `test_probing.py` | none (beyond probing) | Good smoke tests | | `imagenet.py`, `imagenet_with_span.py` | `torch`, `torchvision` | Needs ImageNet data path | | `ray_tracing_example.py` | `ray` | Optional Ray integration | | `bench_profiler.py` | varies | See script header | -Install ML stack into your dev venv: +Install PyTorch into your dev venv (tracing 示例只需 torch;ImageNet 脚本还需 torchvision): ```bash source .venv/bin/activate -uv pip install torch torchvision -# or: pip install torch torchvision +uv pip install torch +# ImageNet 示例: uv pip install torch torchvision ``` ## Running with probing @@ -25,7 +26,7 @@ Use the project venv after `make develop` (see [Contributing](../docs/src/contri ```bash source .venv/bin/activate -PROBING=1 python examples/events.py +PROBING=1 python examples/tracing.py # tracing 入门(推荐) PROBING=1 python examples/test_probing.py --depth 2 ``` diff --git a/examples/events.py b/examples/events.py deleted file mode 100644 index c9755648..00000000 --- a/examples/events.py +++ /dev/null @@ -1,137 +0,0 @@ -import ast -from dataclasses import dataclass -from typing import Any - -import torch - - -def _get_fullname(m): - return f"{m.__module__}.{m.__class__.__name__}" - - -@dataclass -class TenserDef: - shape: tuple = () - dtype: Any = None - - def __repr__(self): - return f"TenserDef({self.shape}, {self.dtype})" - - -class Event(dict): - """ - Examples - -------- - >>> event = Event() - >>> event.name = "event name" - >>> event.name - 'event name' - """ - - def __getattr__(self, name): - if name in self: - return self[name] - else: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - def __setattr__(self, name: str, value: Any) -> None: - return self.__setitem__(name, value) - - -class TorchEvent(Event): - def __init__(self, name, module, inputs, params={}): - self.name = name - self.module = _get_fullname(module) - try: - if isinstance(inputs, torch.Tensor): - self.inputs = [TenserDef(tuple(inputs.shape), inputs.dtype)] - else: - self.inputs = [TenserDef(tuple(x.shape), x.dtype) for x in inputs] - except: - self.inputs = None - self.params = params - - def __repr__(self) -> str: - return f'{self.name}("{self.module}", {self.inputs}, {self.params})' - - -class ForwardStartEvent(TorchEvent): - def __init__(self, m, inputs, params={}) -> None: - super().__init__("ForwardStartEvent", m, inputs, params=params) - - -class ForwardEndEvent(TorchEvent): - def __init__(self, m, outputs, params={}) -> None: - super().__init__("ForwardEndEvent", m, outputs, params=params) - - -class BackwardStartEvent(TorchEvent): - def __init__(self, m, inputs, params={}) -> None: - super().__init__("BackwardStartEvent", m, inputs, params=params) - - -class BackwardEndEvent(TorchEvent): - def __init__(self, m, outputs, params={}) -> None: - super().__init__("BackwardEndEvent", m, outputs, params=params) - - -def parse(line): - expr = ast.parse(line, mode="eval") - name = expr.body.func.id - args = [ - eval(compile(ast.Expression(body=arg), filename="", mode="eval")) - for arg in expr.body.args - ] - event = Event() - event.name = name - event.module = args[0] - if name == "SpanStartEvent": - event.id = args[1] - elif name == "SpanEndEvent": - event.id = args[1] - event.duration = args[2] - else: - event.inputs = args[1] - event.params = args[2] if len(args) > 2 else {} - return event - - -def parse_tree(lines): - tree = [] - - span_stack = [] - curr_span = None - - event_stack = [] - for line in lines: - event = parse(line) - if event.name == "SpanStartEvent": - new_span = Event() - if curr_span is not None: - curr_span.children.append(new_span) - curr_span = new_span - curr_span.name = "Span" - curr_span.module = event.module - curr_span.id = event.id - curr_span.duration = None - curr_span.children = [] - tree.append(curr_span) - span_stack.append(curr_span) - if event.name == "SpanEndEvent": - curr_span.duration = event.duration - curr_span.children = curr_span.children - curr_span = span_stack.pop() - - if event.name == "ForwardStartEvent": - if curr_span is not None: - curr_span.children.append(event) - event_stack.append(event) - - if event.name == "ForwardEndEvent": - if len(event_stack) > 0: - curr_event = event_stack.pop() - curr_event.outputs = event.inputs[0] - - return tree diff --git a/examples/imagenet_with_span.py b/examples/imagenet_with_span.py index dfe93a10..ae78130d 100644 --- a/examples/imagenet_with_span.py +++ b/examples/imagenet_with_span.py @@ -251,7 +251,7 @@ def main_worker(gpu, ngpus_per_node, args): except Exception: pass # create model - with probing.span("model.init", kind="setup"): + with probing.span("model.init"): if args.pretrained: print(f"=> using pre-trained model '{args.arch}'") model = models.__dict__[args.arch](pretrained=True) @@ -331,7 +331,7 @@ def main_worker(gpu, ngpus_per_node, args): print(f"=> no checkpoint found at '{args.resume}'") # Data loading code - with probing.span("data.load", kind="io"): + with probing.span("data.load"): if args.dummy: print("=> Dummy data is used!") train_dataset = datasets.FakeData( @@ -411,14 +411,14 @@ def main_worker(gpu, ngpus_per_node, args): return for epoch in range(args.start_epoch, args.epochs): - with probing.span("epoch", kind="train"): + with probing.span("epoch"): probing.event("epoch.start", attributes=[{"epoch": epoch}]) if args.distributed: train_sampler.set_epoch(epoch) - with probing.span("train", kind="loop"): + with probing.span("train"): train(train_loader, model, criterion, optimizer, epoch, device, args) - with probing.span("validate", kind="loop"): + with probing.span("validate"): acc1 = validate(val_loader, model, criterion, args) scheduler.step() probing.event("epoch.metrics", attributes=[{"acc1": float(acc1)}]) @@ -428,7 +428,7 @@ def main_worker(gpu, ngpus_per_node, args): if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0 ): - with probing.span("checkpoint.save", kind="io"): + with probing.span("checkpoint.save"): save_checkpoint( { "epoch": epoch + 1, @@ -468,23 +468,23 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): end = time.time() for i, (images, target) in enumerate(train_loader): time.sleep(1) - with probing.span("batch", kind="train.step"): + with probing.span("batch"): # measure data loading time data_time.update(time.time() - end) images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - with probing.span("forward", kind="nn.forward"): + with probing.span("forward"): output = model(images) - with probing.span("loss", kind="compute"): + with probing.span("loss"): loss = compute_loss(criterion, output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) - with probing.span("backward", kind="nn.backward"): + with probing.span("backward"): optimizer.zero_grad() loss.backward() - with probing.span("step", kind="optim.step"): + with probing.span("step"): optimizer.step() batch_time.update(time.time() - end) end = time.time() diff --git a/examples/tracing.py b/examples/tracing.py new file mode 100644 index 00000000..14cff683 --- /dev/null +++ b/examples/tracing.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +""" +probing tracing 入门 +=================== + +两层 API,分工明确: + + ① ``attach_training_phases(model, optimizer)`` — **推荐,零侵入** + model / optimizer 上的 hook 自动追踪每个 batch 的 + forward → backward → optimizer,并记录 ``train.step`` 整步耗时。 + 训练循环里 **不需要** ``with probing.span("forward")``。 + + 梯度累积:先 ``probing.step(micro_batches=N)``,``train.step`` 覆盖 + N 个 micro-batch 的 wall time;详见 ``docs/src/design/training-phase.zh.md``。 + + ② ``probing.span`` / ``probing.event`` — **可选,粗粒度时间线** + 包住模型初始化、epoch 等;与 ① 的 phase span 互不冲突。 + +运行:: + + PROBING=1 python examples/tracing.py + +终端实时查看 span(与 memtable 同时生效):: + + PROBING=1 PROBING_SPAN_BACKENDS=memtable,logger python examples/tracing.py + +查看最近 span(另开终端):: + + probing -t query " + SELECT s.name, s.phase, + round((e.time - s.time) / 1e6, 2) AS ms + FROM python.trace_event s + JOIN python.trace_event e + ON s.span_id = e.span_id AND e.record_type = 'span_end' + WHERE s.record_type = 'span_start' + ORDER BY s.time DESC LIMIT 12" +""" + +from __future__ import annotations + +import os + +import probing +import torch +import torch.nn as nn +import torch.nn.functional as F + +# --- 尽量小:只依赖 torch,无需真实数据集 --------------------------------- + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +BATCHES = 5 +BATCH_SIZE = 16 +LR = 0.01 + + +class TinyNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(8, 16) + self.fc2 = nn.Linear(16, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(F.relu(self.fc1(x))) + + +def train_one_batch( + model: nn.Module, + optimizer: torch.optim.Optimizer, + batch_idx: int, +) -> float: + x = torch.randn(BATCH_SIZE, 8, device=DEVICE) + y = torch.randint(0, 4, (BATCH_SIZE,), device=DEVICE) + + logits = model(x) + loss = F.cross_entropy(logits, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() # hook 在此追踪 optimizer phase,并推进 probing.step() + + # 可选:在 hook 打开的 phase span 上挂 point event + if batch_idx == 0: + probing.event( + "batch.stats", + attributes=[ + {"loss": round(float(loss.item()), 4)}, + {"phase": probing.phase()}, + ], + ) + + return float(loss.item()) + + +def main() -> None: + pid = os.getpid() + print( + f"pid={pid} device={DEVICE} probing={'on' if probing.is_enabled() else 'off'}" + ) + print() + + # --- ② 粗粒度 span:初始化 ----------------------------------------------- + with probing.span("setup"): + model = TinyNet().to(DEVICE) + optimizer = torch.optim.SGD(model.parameters(), lr=LR) + + # --- ① 自动 phase span:一行挂载 ----------------------------------------- + probing.attach_training_phases(model, optimizer) + print("attach_training_phases ✓ (forward / backward / optimizer 由 hook 驱动)") + print() + + # --- 训练:循环内无需手写 phase span ------------------------------------- + with probing.span("epoch"): + for i in range(BATCHES): + loss = train_one_batch(model, optimizer, i) + print( + f" batch {i} loss={loss:.4f} phase={probing.phase()!r} step={probing.step.micro_step}" + ) + + print() + if probing.is_enabled(): + print("完成。hook 已写入 python.trace_event,可用上方 SQL 查询。") + print( + f' probing -t {pid} query "SELECT name, phase FROM python.trace_event LIMIT 12"' + ) + else: + print("完成(未落表)。请用 PROBING=1 重新运行。") + + +if __name__ == "__main__": + main() diff --git a/probing/core/Cargo.toml b/probing/core/Cargo.toml index 5ad17961..3f5baad0 100644 --- a/probing/core/Cargo.toml +++ b/probing/core/Cargo.toml @@ -22,6 +22,11 @@ missing_errors_doc = "allow" missing_panics_doc = "allow" similar_names = "allow" +[features] +test-utils = [] +default = ["builtin-schema-docs"] +builtin-schema-docs = ["dep:profapi", "dep:probing-nccl-profiler"] + [lib] crate-type = ["rlib"] @@ -29,6 +34,8 @@ crate-type = ["rlib"] probing-proto = { path = "../proto" } probing-macros = { path = "../macros" } probing-memtable = { path = "../memtable" } +profapi = { path = "../extensions/hccl-shim", optional = true, package = "probing-hccl-shim" } +probing-nccl-profiler = { path = "../extensions/nccl-profiler", optional = true } anyhow = { workspace = true } arrow = { workspace = true } @@ -38,6 +45,7 @@ once_cell = { workspace = true } tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +serde_yaml = "0.9" thiserror = { workspace = true } async-trait = "0.1.83" diff --git a/probing/core/src/core/cluster.rs b/probing/core/src/core/cluster.rs index f9aa1c5a..5f10600b 100644 --- a/probing/core/src/core/cluster.rs +++ b/probing/core/src/core/cluster.rs @@ -69,3 +69,9 @@ pub fn update_nodes(nodes: Vec) { pub fn get_nodes() -> Vec { CLUSTER.read().unwrap().list() } + +/// Clear in-memory cluster registration (tests only). +#[cfg(any(test, feature = "test-utils"))] +pub fn reset_cluster_for_tests() { + *CLUSTER.write().unwrap() = Cluster::default(); +} diff --git a/probing/core/src/core/engine.rs b/probing/core/src/core/engine.rs index 969d2b6f..c7129a0f 100644 --- a/probing/core/src/core/engine.rs +++ b/probing/core/src/core/engine.rs @@ -18,6 +18,8 @@ use super::probe_extension::ProbeExtensionManager; use super::data_source::{ProbeDataSource, ProbeDataSourceKind}; use super::federation; +use super::metadata_rewrite; +use super::semantic_catalog; /// Core query engine for the Probing system /// @@ -108,7 +110,10 @@ impl Engine { if let Some(df) = federation::try_execute_aggregate_pushdown(self, &original).await? { return Ok(Some(df)); } - let query: String = federation::prepare_global_query(&original); + let default_schema = self.default_namespace(); + let query: String = metadata_rewrite::prepare_metadata_query(&original, &default_schema) + .unwrap_or(original); + let query: String = federation::prepare_global_query(&query); let df = self.sql(query.as_str()).await?; let schema = df.schema().clone(); let batches = df.collect().await?; @@ -271,6 +276,7 @@ impl EngineBuilder { for data_source in self.data_sources { engine.enable(data_source).await?; } + semantic_catalog::install_semantic_catalog(&engine.context)?; federation::install_global_catalog(&engine.context)?; Ok(engine) diff --git a/probing/core/src/core/federation/aggregate_pushdown.rs b/probing/core/src/core/federation/aggregate_pushdown.rs index 8caedd77..ded9303f 100644 --- a/probing/core/src/core/federation/aggregate_pushdown.rs +++ b/probing/core/src/core/federation/aggregate_pushdown.rs @@ -35,10 +35,12 @@ struct PlannedAggregate { merge_fn: Option<&'static str>, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct FederatedAggregatePlan { pub per_node_sql: String, pub coordinator_sql: Option, + /// Suffix applied on the coordinator after merge, e.g. ` ORDER BY avg_ms DESC LIMIT 5`. + pub post_merge_tail: Option, pub inject_tags: bool, } @@ -121,6 +123,12 @@ pub async fn try_execute_aggregate_pushdown( merge_proto_dataframes(&proto_parts)? }; + let result = if let Some(tail) = plan.post_merge_tail { + apply_post_merge_tail(&engine.context, &tail, result).await? + } else { + result + }; + Ok(Some(result)) } @@ -142,9 +150,7 @@ fn parse_single_statement(sql: &str) -> Option { } fn plan_from_query(query: &Query) -> Option { - if query.order_by.is_some() || query.limit_clause.is_some() || query.fetch.is_some() { - return None; - } + let post_merge_tail = build_post_merge_tail(query); let SetExpr::Select(select) = query.body.as_ref() else { return None; }; @@ -189,6 +195,7 @@ fn plan_from_query(query: &Query) -> Option { Some(FederatedAggregatePlan { per_node_sql, coordinator_sql, + post_merge_tail, inject_tags, }) } @@ -429,6 +436,50 @@ fn quote_ident(name: &str) -> String { } } +fn build_post_merge_tail(query: &Query) -> Option { + let order_by = format_order_by(query); + let limit = format_limit_clause(query); + if order_by.is_empty() && limit.is_empty() { + None + } else { + Some(format!("{order_by}{limit}")) + } +} + +fn format_order_by(query: &Query) -> String { + query + .order_by + .as_ref() + .map(|order_by| format!(" {order_by}")) + .unwrap_or_default() +} + +fn format_limit_clause(query: &Query) -> String { + let mut out = query + .limit_clause + .as_ref() + .map(|limit| format!(" {limit}")) + .unwrap_or_default(); + if let Some(fetch) = &query.fetch { + out.push(' '); + out.push_str(&fetch.to_string()); + } + out +} + +async fn apply_post_merge_tail( + ctx: &SessionContext, + tail: &str, + df: probing_proto::prelude::DataFrame, +) -> Result { + let batch = proto_dataframe_to_record_batch(&df)?; + if batch.num_rows() == 0 { + return Ok(df); + } + let sql = format!("SELECT * FROM partials{tail}"); + merge_on_coordinator(ctx, &sql, vec![batch]).await +} + async fn merge_on_coordinator( ctx: &SessionContext, merge_sql: &str, @@ -553,4 +604,16 @@ mod tests { let sql = "SELECT name, count(distinct value) AS n FROM global.process.envs GROUP BY name"; assert!(plan_federated_aggregate_pushdown(sql).is_none()); } + + #[test] + fn plans_order_by_limit_as_post_merge_tail() { + let sql = "SELECT name, count(*) AS n FROM global.process.envs GROUP BY name ORDER BY n DESC LIMIT 3"; + let plan = plan_federated_aggregate_pushdown(sql).expect("plan"); + assert!(plan.coordinator_sql.is_some()); + let tail = plan.post_merge_tail.as_deref().unwrap(); + assert!(tail.contains("ORDER BY n DESC")); + assert!(tail.contains("LIMIT 3")); + assert!(!plan.per_node_sql.to_uppercase().contains("ORDER BY")); + assert!(!plan.per_node_sql.to_uppercase().contains("LIMIT")); + } } diff --git a/probing/core/src/core/federation/cluster_executor.rs b/probing/core/src/core/federation/cluster_executor.rs index ab118930..4f1be781 100644 --- a/probing/core/src/core/federation/cluster_executor.rs +++ b/probing/core/src/core/federation/cluster_executor.rs @@ -6,6 +6,19 @@ use probing_proto::prelude::{DataFrame, Message, Node, Query, QueryDataFormat}; use crate::core::cluster::get_nodes; +#[cfg(any(test, feature = "test-utils"))] +type RemoteQueryHook = Box Result + Send + Sync>; + +#[cfg(any(test, feature = "test-utils"))] +static REMOTE_QUERY_HOOK: LazyLock>> = + LazyLock::new(|| Mutex::new(None)); + +/// Install an in-process remote query handler for federation integration tests. +#[cfg(any(test, feature = "test-utils"))] +pub fn set_remote_query_hook(hook: Option) { + *REMOTE_QUERY_HOOK.lock().unwrap() = hook; +} + /// Default per-node timeout for remote federated queries (seconds). const DEFAULT_REMOTE_QUERY_TIMEOUT_SECS: u64 = 2; /// Env var to override the per-node remote query timeout (seconds). @@ -155,6 +168,11 @@ impl ProbeClusterExecutor { } fn execute_remote(addr: &str, sql: &str) -> Result { + #[cfg(any(test, feature = "test-utils"))] + if let Some(hook) = REMOTE_QUERY_HOOK.lock().unwrap().as_ref() { + return hook(addr, sql); + } + let url = format!("http://{addr}/query"); let request = Message::new(Query { expr: sql.to_string(), diff --git a/probing/core/src/core/federation/convert.rs b/probing/core/src/core/federation/convert.rs index 1b19d8fa..40690152 100644 --- a/probing/core/src/core/federation/convert.rs +++ b/probing/core/src/core/federation/convert.rs @@ -9,15 +9,39 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion::error::{DataFusionError, Result}; use probing_proto::prelude::{DataFrame, Seq}; -/// Primary node identity column on `global.*` query results (hostname, or addr if host is empty). +/// Legacy alias; prefer the `_host` / `_addr` pair. pub const PROBE_NODE_COL: &str = "_probe_node"; pub const PROBE_HOST_COL: &str = "_host"; pub const PROBE_ADDR_COL: &str = "_addr"; /// Cluster `rank` from `cluster.nodes` for the row's source probing endpoint. pub const PROBE_RANK_COL: &str = "_rank"; +/// Node/worker group rank (`GROUP_RANK` / `group_rank` on the endpoint). +pub const PROBE_NODE_RANK_COL: &str = "_node_rank"; +/// Intra-node GPU index (`LOCAL_RANK` / `local_rank` on the endpoint). +pub const PROBE_LOCAL_RANK_COL: &str = "_local_rank"; /// Parallel-role key (e.g. "dp=2,pp=1,tp=0") for the row's source endpoint. pub const PROBE_ROLE_COL: &str = "_role"; +/// Fixed federation tag columns appended to `global.*` results (stable order). +pub const FEDERATION_TAG_COLUMNS: &[&str] = &[ + PROBE_HOST_COL, + PROBE_ADDR_COL, + PROBE_RANK_COL, + PROBE_NODE_RANK_COL, + PROBE_LOCAL_RANK_COL, + PROBE_ROLE_COL, +]; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FederationEndpointTags { + pub host: String, + pub addr: String, + pub rank: i32, + pub node_rank: i32, + pub local_rank: i32, + pub role: String, +} + #[cfg_attr(not(test), allow(dead_code))] pub fn node_label(host: &str, addr: &str) -> String { if host.is_empty() { @@ -29,17 +53,34 @@ pub fn node_label(host: &str, addr: &str) -> String { /// Resolve cluster rank for a probing endpoint (`host` + `addr` key in CLUSTER). pub fn cluster_rank_for_endpoint(host: &str, addr: &str) -> Option { + cluster_node_field(host, addr, |n| n.rank) +} + +/// Resolve node/worker group rank for a probing endpoint. +pub fn cluster_node_rank_for_endpoint(host: &str, addr: &str) -> Option { + cluster_node_field(host, addr, |n| n.group_rank) +} + +/// Resolve intra-node local rank for a probing endpoint. +pub fn cluster_local_rank_for_endpoint(host: &str, addr: &str) -> Option { + cluster_node_field(host, addr, |n| n.local_rank) +} + +fn cluster_node_field(host: &str, addr: &str, field: F) -> Option +where + F: Fn(&probing_proto::prelude::Node) -> Option, +{ use crate::core::cluster::CLUSTER; CLUSTER .read() .ok() - .and_then(|c| c.get_by_addr(host, addr).and_then(|n| n.rank)) + .and_then(|c| c.get_by_addr(host, addr).and_then(&field)) .or_else(|| { crate::core::cluster::get_nodes() .into_iter() .find(|n| n.addr == addr) - .and_then(|n| n.rank) + .and_then(|n| field(&n)) }) } @@ -60,12 +101,25 @@ pub fn cluster_role_for_endpoint(host: &str, addr: &str) -> Option { .filter(|r| !r.is_empty()) } +pub fn federation_tags_for_endpoint(host: &str, addr: &str) -> FederationEndpointTags { + FederationEndpointTags { + host: host.to_string(), + addr: addr.to_string(), + rank: cluster_rank_for_endpoint(host, addr).unwrap_or(-1), + node_rank: cluster_node_rank_for_endpoint(host, addr).unwrap_or(-1), + local_rank: cluster_local_rank_for_endpoint(host, addr).unwrap_or(-1), + role: cluster_role_for_endpoint(host, addr).unwrap_or_default(), + } +} + pub fn federated_output_schema(local: SchemaRef) -> SchemaRef { let mut fields = local.fields().to_vec(); for (name, dtype, nullable) in [ (PROBE_HOST_COL, DataType::Utf8, false), (PROBE_ADDR_COL, DataType::Utf8, false), (PROBE_RANK_COL, DataType::Int32, true), + (PROBE_NODE_RANK_COL, DataType::Int32, true), + (PROBE_LOCAL_RANK_COL, DataType::Int32, true), (PROBE_ROLE_COL, DataType::Utf8, true), ] { if !fields.iter().any(|f| f.name() == name) { @@ -78,7 +132,13 @@ pub fn federated_output_schema(local: SchemaRef) -> SchemaRef { pub fn is_federation_tag_column(name: &str) -> bool { matches!( name, - PROBE_NODE_COL | PROBE_HOST_COL | PROBE_ADDR_COL | PROBE_RANK_COL | PROBE_ROLE_COL + PROBE_NODE_COL + | PROBE_HOST_COL + | PROBE_ADDR_COL + | PROBE_RANK_COL + | PROBE_NODE_RANK_COL + | PROBE_LOCAL_RANK_COL + | PROBE_ROLE_COL ) } @@ -87,16 +147,34 @@ pub fn tag_proto_dataframe(df: &mut DataFrame, host: &str, addr: &str, rank: Opt if df.is_empty() { return; } + let mut tags = federation_tags_for_endpoint(host, addr); + if let Some(rank) = rank { + tags.rank = rank; + } + tag_proto_dataframe_with_tags(df, &tags); +} + +pub(crate) fn tag_proto_dataframe_with_tags(df: &mut DataFrame, tags: &FederationEndpointTags) { + if df.is_empty() { + return; + } + append_proto_tags(df, tags); +} + +fn append_proto_tags(df: &mut DataFrame, tags: &FederationEndpointTags) { let rows = df.len(); - let role = cluster_role_for_endpoint(host, addr).unwrap_or_default(); df.names.push(PROBE_HOST_COL.to_string()); df.names.push(PROBE_ADDR_COL.to_string()); df.names.push(PROBE_RANK_COL.to_string()); + df.names.push(PROBE_NODE_RANK_COL.to_string()); + df.names.push(PROBE_LOCAL_RANK_COL.to_string()); df.names.push(PROBE_ROLE_COL.to_string()); - df.cols.push(Seq::SeqText(vec![host.to_string(); rows])); - df.cols.push(Seq::SeqText(vec![addr.to_string(); rows])); - df.cols.push(Seq::SeqI32(vec![rank.unwrap_or(-1); rows])); - df.cols.push(Seq::SeqText(vec![role; rows])); + df.cols.push(Seq::SeqText(vec![tags.host.clone(); rows])); + df.cols.push(Seq::SeqText(vec![tags.addr.clone(); rows])); + df.cols.push(Seq::SeqI32(vec![tags.rank; rows])); + df.cols.push(Seq::SeqI32(vec![tags.node_rank; rows])); + df.cols.push(Seq::SeqI32(vec![tags.local_rank; rows])); + df.cols.push(Seq::SeqText(vec![tags.role.clone(); rows])); df.size = df.len() as u64; } @@ -148,10 +226,12 @@ pub fn dataframe_to_record_batch( return Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))); } - let rank = rank.or_else(|| cluster_rank_for_endpoint(host, addr)); - let role = cluster_role_for_endpoint(host, addr).unwrap_or_default(); - let mut columns = Vec::with_capacity(df.cols.len() + 4); - let mut fields = Vec::with_capacity(df.names.len() + 4); + let mut tags = federation_tags_for_endpoint(host, addr); + if let Some(rank) = rank { + tags.rank = rank; + } + let mut columns = Vec::with_capacity(df.cols.len() + FEDERATION_TAG_COLUMNS.len()); + let mut fields = Vec::with_capacity(df.names.len() + FEDERATION_TAG_COLUMNS.len()); for (name, col) in df.names.iter().zip(df.cols.iter()) { fields.push(Field::new(name, array_data_type(col), true)); @@ -162,11 +242,15 @@ pub fn dataframe_to_record_batch( fields.push(Field::new(PROBE_HOST_COL, DataType::Utf8, false)); fields.push(Field::new(PROBE_ADDR_COL, DataType::Utf8, false)); fields.push(Field::new(PROBE_RANK_COL, DataType::Int32, true)); + fields.push(Field::new(PROBE_NODE_RANK_COL, DataType::Int32, true)); + fields.push(Field::new(PROBE_LOCAL_RANK_COL, DataType::Int32, true)); fields.push(Field::new(PROBE_ROLE_COL, DataType::Utf8, true)); - columns.push(Arc::new(StringArray::from(vec![host.to_string(); rows]))); - columns.push(Arc::new(StringArray::from(vec![addr.to_string(); rows]))); - columns.push(Arc::new(Int32Array::from(vec![rank; rows]))); - columns.push(Arc::new(StringArray::from(vec![role; rows]))); + columns.push(Arc::new(StringArray::from(vec![tags.host; rows]))); + columns.push(Arc::new(StringArray::from(vec![tags.addr; rows]))); + columns.push(Arc::new(Int32Array::from(vec![tags.rank; rows]))); + columns.push(Arc::new(Int32Array::from(vec![tags.node_rank; rows]))); + columns.push(Arc::new(Int32Array::from(vec![tags.local_rank; rows]))); + columns.push(Arc::new(StringArray::from(vec![tags.role; rows]))); RecordBatch::try_new(Arc::new(Schema::new(fields)), columns) .map_err(|e| DataFusionError::Execution(format!("dataframe conversion failed: {e}"))) @@ -182,31 +266,59 @@ pub fn tag_record_batch( return Ok(batch); } - let rank = rank.or_else(|| cluster_rank_for_endpoint(host, addr)); + let mut tags = federation_tags_for_endpoint(host, addr); + if let Some(rank) = rank { + tags.rank = rank; + } let rows = batch.num_rows(); let mut fields = batch.schema().fields().to_vec(); let mut columns = batch.columns().to_vec(); + append_batch_tags(&mut fields, &mut columns, rows, &tags)?; + + RecordBatch::try_new(Arc::new(Schema::new(fields)), columns) + .map_err(|e| DataFusionError::Execution(format!("tagging batch failed: {e}"))) +} + +fn append_batch_tags( + fields: &mut Vec>, + columns: &mut Vec, + rows: usize, + tags: &FederationEndpointTags, +) -> Result<()> { if !fields.iter().any(|f| f.name() == PROBE_HOST_COL) { fields.push(Arc::new(Field::new(PROBE_HOST_COL, DataType::Utf8, false))); - columns.push(Arc::new(StringArray::from(vec![host.to_string(); rows]))); + columns.push(Arc::new(StringArray::from(vec![tags.host.as_str(); rows]))); } if !fields.iter().any(|f| f.name() == PROBE_ADDR_COL) { fields.push(Arc::new(Field::new(PROBE_ADDR_COL, DataType::Utf8, false))); - columns.push(Arc::new(StringArray::from(vec![addr.to_string(); rows]))); + columns.push(Arc::new(StringArray::from(vec![tags.addr.as_str(); rows]))); } if !fields.iter().any(|f| f.name() == PROBE_RANK_COL) { fields.push(Arc::new(Field::new(PROBE_RANK_COL, DataType::Int32, true))); - columns.push(Arc::new(Int32Array::from(vec![rank; rows]))); + columns.push(Arc::new(Int32Array::from(vec![tags.rank; rows]))); + } + if !fields.iter().any(|f| f.name() == PROBE_NODE_RANK_COL) { + fields.push(Arc::new(Field::new( + PROBE_NODE_RANK_COL, + DataType::Int32, + true, + ))); + columns.push(Arc::new(Int32Array::from(vec![tags.node_rank; rows]))); + } + if !fields.iter().any(|f| f.name() == PROBE_LOCAL_RANK_COL) { + fields.push(Arc::new(Field::new( + PROBE_LOCAL_RANK_COL, + DataType::Int32, + true, + ))); + columns.push(Arc::new(Int32Array::from(vec![tags.local_rank; rows]))); } if !fields.iter().any(|f| f.name() == PROBE_ROLE_COL) { - let role = cluster_role_for_endpoint(host, addr).unwrap_or_default(); fields.push(Arc::new(Field::new(PROBE_ROLE_COL, DataType::Utf8, true))); - columns.push(Arc::new(StringArray::from(vec![role; rows]))); + columns.push(Arc::new(StringArray::from(vec![tags.role.as_str(); rows]))); } - - RecordBatch::try_new(Arc::new(Schema::new(fields)), columns) - .map_err(|e| DataFusionError::Execution(format!("tagging batch failed: {e}"))) + Ok(()) } pub fn align_batch_to_schema(batch: RecordBatch, schema: &Schema) -> Result { @@ -277,8 +389,10 @@ mod tests { use arrow::array::{Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use probing_proto::prelude::Node; use super::*; + use crate::core::cluster::{reset_cluster_for_tests, update_node}; #[test] fn node_label_prefers_host() { @@ -291,22 +405,40 @@ mod tests { } #[test] - fn federated_schema_includes_tag_columns() { + fn federated_schema_includes_six_tag_columns() { let local = Arc::new(Schema::new(vec![Field::new( "rank", DataType::Int32, false, )])); let schema = federated_output_schema(local); - assert!(schema.index_of(PROBE_HOST_COL).is_ok()); - assert!(schema.index_of(PROBE_ADDR_COL).is_ok()); - assert!(schema.index_of(PROBE_RANK_COL).is_ok()); - assert!(schema.index_of(PROBE_ROLE_COL).is_ok()); + for col in FEDERATION_TAG_COLUMNS { + assert!(schema.index_of(col).is_ok(), "missing tag column {col}"); + } assert!(schema.index_of(PROBE_NODE_COL).is_err()); } #[test] - fn tag_record_batch_adds_probe_columns() { + fn federation_tags_resolve_from_cluster_node() { + reset_cluster_for_tests(); + update_node(Node { + host: "host-a".into(), + addr: "10.0.0.1:8080".into(), + rank: Some(3), + group_rank: Some(1), + local_rank: Some(2), + role: Some("dp=0".into()), + ..Default::default() + }); + let tags = federation_tags_for_endpoint("host-a", "10.0.0.1:8080"); + assert_eq!(tags.rank, 3); + assert_eq!(tags.node_rank, 1); + assert_eq!(tags.local_rank, 2); + assert_eq!(tags.role, "dp=0"); + } + + #[test] + fn tag_record_batch_adds_six_probe_columns() { let local = Arc::new(Schema::new(vec![Field::new( "rank", DataType::Int32, @@ -314,34 +446,10 @@ mod tests { )])); let batch = RecordBatch::try_new(local, vec![Arc::new(Int32Array::from(vec![7]))]).unwrap(); let tagged = tag_record_batch(batch, "host-a", "10.0.0.1:8080", Some(3)).unwrap(); - assert_eq!(tagged.num_columns(), 5); - assert_eq!( - tagged - .column(tagged.schema().index_of(PROBE_HOST_COL).unwrap()) - .as_any() - .downcast_ref::() - .unwrap() - .value(0), - "host-a" - ); - assert_eq!( - tagged - .column(tagged.schema().index_of(PROBE_ADDR_COL).unwrap()) - .as_any() - .downcast_ref::() - .unwrap() - .value(0), - "10.0.0.1:8080" - ); - assert_eq!( - tagged - .column(tagged.schema().index_of(PROBE_RANK_COL).unwrap()) - .as_any() - .downcast_ref::() - .unwrap() - .value(0), - 3 - ); + assert_eq!(tagged.num_columns(), 7); + for col in FEDERATION_TAG_COLUMNS { + assert!(tagged.schema().index_of(col).is_ok()); + } } #[test] @@ -356,19 +464,6 @@ mod tests { assert_eq!(extended, vec![0]); } - #[test] - fn extend_projection_honors_tag_only_selection() { - let local = Arc::new(Schema::new(vec![Field::new( - "rank", - DataType::Int32, - false, - )])); - let schema = federated_output_schema(local); - let rank_idx = schema.index_of(PROBE_RANK_COL).unwrap(); - let extended = extend_projection_with_probe_tags(Some(&vec![rank_idx]), &schema).unwrap(); - assert_eq!(extended, vec![rank_idx]); - } - #[test] fn align_batch_fills_timestamp_column_for_empty_rows() { let batch = RecordBatch::try_new( diff --git a/probing/core/src/core/federation/global_table.rs b/probing/core/src/core/federation/global_table.rs index f6edc8eb..8df153a5 100644 --- a/probing/core/src/core/federation/global_table.rs +++ b/probing/core/src/core/federation/global_table.rs @@ -95,28 +95,31 @@ impl TableProvider for GlobalFederatedTable { let local_schema = self.local.schema(); let local_projection = local_table_projection(projection, &output_schema, &local_schema); + let host = ProbeClusterExecutor::local_host_label(); + let addr = ProbeClusterExecutor::local_addr_label(); + let local_rank = cluster_rank_for_endpoint(&host, &addr); + + reset_fanout_stats(); + let remote_nodes = ProbeClusterExecutor::remote_nodes(); + // With peers registered, LIMIT is global top-K at the coordinator only. + let scan_limit = if remote_nodes.is_empty() { limit } else { None }; + // Local scan stays lazy; coalesce to a single partition so the federated // plan can expose it as partition 0 without losing rows from sub-partitions. let local_plan = self .local - .scan(state, local_projection.as_ref(), filters, limit) + .scan(state, local_projection.as_ref(), filters, scan_limit) .await?; let local_plan: Arc = Arc::new(CoalescePartitionsExec::new(local_plan)); - let host = ProbeClusterExecutor::local_host_label(); - let addr = ProbeClusterExecutor::local_addr_label(); - let local_rank = cluster_rank_for_endpoint(&host, &addr); - - reset_fanout_stats(); let remote_sql = build_remote_table_sql( &self.schema_name, &self.table_name, &local_schema, local_projection.as_ref(), filters, - limit, + scan_limit, ); - let remote_nodes = ProbeClusterExecutor::remote_nodes(); let scan_projection = federated_scan_projection(projection, &output_schema) .unwrap_or_else(|| (0..output_schema.fields().len()).collect()); diff --git a/probing/core/src/core/federation/mod.rs b/probing/core/src/core/federation/mod.rs index 55a3457d..2f4bb859 100644 --- a/probing/core/src/core/federation/mod.rs +++ b/probing/core/src/core/federation/mod.rs @@ -5,21 +5,31 @@ mod federated_scan_exec; mod global_catalog; mod global_table; mod rewrite; +mod route; mod sql_gen; pub use aggregate_pushdown::{ plan_federated_aggregate_pushdown, try_execute_aggregate_pushdown, FederatedAggregatePlan, }; +#[cfg(any(test, feature = "test-utils"))] +pub use cluster_executor::set_remote_query_hook; pub use cluster_executor::{ remote_query_timeout, reset_fanout_stats, set_fanout_stats, take_fanout_stats, FanoutStats, ProbeClusterExecutor, RemoteFanoutResult, }; pub use convert::{ - cluster_rank_for_endpoint, cluster_role_for_endpoint, PROBE_ADDR_COL, PROBE_HOST_COL, - PROBE_NODE_COL, PROBE_RANK_COL, PROBE_ROLE_COL, + cluster_local_rank_for_endpoint, cluster_node_rank_for_endpoint, cluster_rank_for_endpoint, + cluster_role_for_endpoint, federated_output_schema, federation_tags_for_endpoint, + is_federation_tag_column, tag_proto_dataframe, FederationEndpointTags, FEDERATION_TAG_COLUMNS, + PROBE_ADDR_COL, PROBE_HOST_COL, PROBE_LOCAL_RANK_COL, PROBE_NODE_COL, PROBE_NODE_RANK_COL, + PROBE_RANK_COL, PROBE_ROLE_COL, }; pub use global_catalog::{install_global_catalog, GLOBAL_CATALOG}; pub use rewrite::{ can_fanout_via_global_catalog, ensure_global_node_columns, prepare_global_query, rewrite_global_catalog_to_probe, rewrite_sql_for_global_fanout, }; +pub use route::{ + classify_cluster_sql, classify_federated_sql, explain_federation, explain_physical_plan, + FederatedQueryPath, FederationExplainReport, +}; diff --git a/probing/core/src/core/federation/rewrite.rs b/probing/core/src/core/federation/rewrite.rs index 9622e888..8ff32d3a 100644 --- a/probing/core/src/core/federation/rewrite.rs +++ b/probing/core/src/core/federation/rewrite.rs @@ -4,7 +4,7 @@ use datafusion::sql::sqlparser::ast::{Query, SetExpr, Statement}; use datafusion::sql::sqlparser::dialect::GenericDialect; use datafusion::sql::sqlparser::parser::Parser; -use super::convert::{PROBE_ADDR_COL, PROBE_HOST_COL, PROBE_RANK_COL, PROBE_ROLE_COL}; +use super::convert::FEDERATION_TAG_COLUMNS; const KNOWN_SCHEMAS: &[&str] = &[ "cluster", "process", "files", "python", "memtable", "gpu", "rdma", @@ -134,14 +134,22 @@ fn select_list_includes_wildcard(sql: &str) -> bool { } } +fn federation_tags_already_expanded(lower: &str) -> bool { + FEDERATION_TAG_COLUMNS.iter().all(|col| lower.contains(col)) +} + +fn federation_tag_exclude_list() -> String { + FEDERATION_TAG_COLUMNS + .iter() + .map(|col| col.to_string()) + .collect::>() + .join(", ") +} + fn expand_global_select_star(sql: &str) -> String { let trimmed = sql.trim(); let lower = trimmed.to_lowercase(); - if lower.contains(PROBE_HOST_COL) - && lower.contains(PROBE_ADDR_COL) - && lower.contains(PROBE_RANK_COL) - && lower.contains(PROBE_ROLE_COL) - { + if federation_tags_already_expanded(&lower) { return sql.to_string(); } let Some(from_idx) = find_top_level_from(trimmed) else { @@ -154,7 +162,8 @@ fn expand_global_select_star(sql: &str) -> String { } let exclude = format!( - " EXCLUDE ({PROBE_HOST_COL}, {PROBE_ADDR_COL}, {PROBE_RANK_COL}, {PROBE_ROLE_COL}), {PROBE_HOST_COL}, {PROBE_ADDR_COL}, {PROBE_RANK_COL}, {PROBE_ROLE_COL}" + " EXCLUDE ({tags}), {tags}", + tags = federation_tag_exclude_list() ); let new_select = if let Some(dot_star) = select_part.rfind(".*") { let before = &select_part[..dot_star + 2]; @@ -192,6 +201,7 @@ pub fn prepare_global_query(sql: &str) -> String { #[cfg(test)] mod tests { + use super::super::convert::{PROBE_ADDR_COL, PROBE_RANK_COL}; use super::*; #[test] @@ -301,7 +311,7 @@ mod tests { let sql = "SELECT * FROM global.process.envs"; assert_eq!( ensure_global_node_columns(sql), - "SELECT * EXCLUDE (_host, _addr, _rank, _role), _host, _addr, _rank, _role FROM global.process.envs" + "SELECT * EXCLUDE (_host, _addr, _rank, _node_rank, _local_rank, _role), _host, _addr, _rank, _node_rank, _local_rank, _role FROM global.process.envs" ); } @@ -310,20 +320,19 @@ mod tests { let sql = "SELECT e.* FROM global.process.envs e"; assert_eq!( ensure_global_node_columns(sql), - "SELECT e.* EXCLUDE (_host, _addr, _rank, _role), _host, _addr, _rank, _role FROM global.process.envs e" + "SELECT e.* EXCLUDE (_host, _addr, _rank, _node_rank, _local_rank, _role), _host, _addr, _rank, _node_rank, _local_rank, _role FROM global.process.envs e" ); } #[test] fn skips_select_star_wildcard_when_tags_already_present() { - let sql = - "SELECT * EXCLUDE (_host, _addr, _rank, _role), _host, _addr, _rank, _role FROM global.process.envs"; + let sql = "SELECT * EXCLUDE (_host, _addr, _rank, _node_rank, _local_rank, _role), _host, _addr, _rank, _node_rank, _local_rank, _role FROM global.process.envs"; assert_eq!(ensure_global_node_columns(sql), sql); } #[test] fn skips_qualified_select_star_when_already_expanded() { - let sql = "SELECT e.* EXCLUDE (_host, _addr, _rank, _role), _host, _addr, _rank, _role FROM global.process.envs e"; + let sql = "SELECT e.* EXCLUDE (_host, _addr, _rank, _node_rank, _local_rank, _role), _host, _addr, _rank, _node_rank, _local_rank, _role FROM global.process.envs e"; assert_eq!(ensure_global_node_columns(sql), sql); } diff --git a/probing/core/src/core/federation/route.rs b/probing/core/src/core/federation/route.rs new file mode 100644 index 00000000..923f095b --- /dev/null +++ b/probing/core/src/core/federation/route.rs @@ -0,0 +1,150 @@ +//! Federated query routing classification and EXPLAIN helpers. +//! +//! Mirrors the path selection in `docs/src/design/federation.zh.md` §4.2: +//! - **AggregatePushdown** (A): single-table `global.*` + merge-safe aggregates +//! - **FederatedScan** (B): single-table `global.*` scan via `FederatedScanExec` +//! - **Broadcast** (C): JOIN / CTE / subquery — cluster fan-out only +//! - **Local**: `probe.*` or no federation catalog + +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::Result; + +use crate::core::Engine; + +use super::aggregate_pushdown::{plan_federated_aggregate_pushdown, FederatedAggregatePlan}; +use super::rewrite::{ + can_fanout_via_global_catalog, prepare_global_query, rewrite_sql_for_global_fanout, +}; + +/// Execution path for a federated SQL statement (coordinator view). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FederatedQueryPath { + /// Single-process `probe.*` (no `global.*` / known schema fan-out). + Local, + /// Path A — partial aggregates on each peer, merge on coordinator. + AggregatePushdown, + /// Path B — lazy `FederatedScanExec` over local + peers. + FederatedScan, + /// Path C — broadcast full SQL to each rank (JOIN / CTE / …). + Broadcast, +} + +/// Snapshot returned by [`explain_federation`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FederationExplainReport { + pub user_sql: String, + pub global_sql: String, + pub execution_path: FederatedQueryPath, + pub aggregate_plan: Option, + /// DataFusion `EXPLAIN` text for the prepared `global.*` statement (path B plan shape). + pub physical_plan: String, +} + +/// Classify a SQL string that already references `global.*` (or will after rewrite). +pub fn classify_federated_sql(sql: &str) -> FederatedQueryPath { + let lower = sql.to_lowercase(); + if !lower.contains("global.") { + return FederatedQueryPath::Local; + } + if !can_fanout_via_global_catalog(sql) { + return FederatedQueryPath::Broadcast; + } + if plan_federated_aggregate_pushdown(sql).is_some() { + return FederatedQueryPath::AggregatePushdown; + } + FederatedQueryPath::FederatedScan +} + +/// Classify user/cluster SQL (`python.t` → `global.*` rewrite applied first). +pub fn classify_cluster_sql(user_sql: &str) -> FederatedQueryPath { + classify_federated_sql(&rewrite_sql_for_global_fanout(user_sql)) +} + +/// Build a full federation explain report: route + optional pushdown plan + physical EXPLAIN. +pub async fn explain_federation( + engine: &Engine, + user_sql: &str, +) -> Result { + let global_sql = prepare_global_query(&rewrite_sql_for_global_fanout(user_sql)); + let execution_path = classify_federated_sql(&global_sql); + let aggregate_plan = plan_federated_aggregate_pushdown(&global_sql); + let physical_plan = explain_physical_plan(engine, &global_sql).await?; + Ok(FederationExplainReport { + user_sql: user_sql.to_string(), + global_sql, + execution_path, + aggregate_plan, + physical_plan, + }) +} + +/// Run `EXPLAIN` on a prepared SQL string and return the plan text. +pub async fn explain_physical_plan(engine: &Engine, sql: &str) -> Result { + let df = engine.context.sql(&format!("EXPLAIN {sql}")).await?; + let batches = df.collect().await?; + Ok(format_explain_batches(&batches)) +} + +fn format_explain_batches(batches: &[RecordBatch]) -> String { + let mut lines = Vec::new(); + for batch in batches { + let schema = batch.schema(); + for row in 0..batch.num_rows() { + let mut parts = Vec::new(); + for col in 0..batch.num_columns() { + let name = schema.field(col).name(); + let array = batch.column(col); + let value = arrow::util::display::array_value_to_string(array, row) + .unwrap_or_else(|_| "?".to_string()); + if parts.is_empty() && schema.fields().len() == 1 { + lines.push(value); + } else { + parts.push(format!("{name}={value}")); + } + } + if !parts.is_empty() { + lines.push(parts.join(" ")); + } + } + } + lines.join("\n") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classify_aggregate_pushdown() { + let sql = "SELECT global_step, sum(duration_ms) AS ms \ + FROM global.python.comm_collective GROUP BY global_step"; + assert_eq!( + classify_federated_sql(sql), + FederatedQueryPath::AggregatePushdown + ); + } + + #[test] + fn classify_federated_scan() { + let sql = "SELECT rank FROM global.demo.metrics WHERE rank > 0"; + assert_eq!( + classify_federated_sql(sql), + FederatedQueryPath::FederatedScan + ); + } + + #[test] + fn classify_broadcast_join() { + let sql = "SELECT a.x FROM global.python.a JOIN global.python.b ON a.id = b.id"; + assert_eq!(classify_federated_sql(sql), FederatedQueryPath::Broadcast); + } + + #[test] + fn classify_cluster_rewrite_to_global() { + let sql = "SELECT rank, sum(duration_ms) FROM python.comm_collective GROUP BY rank"; + assert_eq!( + classify_cluster_sql(sql), + FederatedQueryPath::AggregatePushdown + ); + } +} diff --git a/probing/core/src/core/metadata_rewrite.rs b/probing/core/src/core/metadata_rewrite.rs new file mode 100644 index 00000000..e77758aa --- /dev/null +++ b/probing/core/src/core/metadata_rewrite.rs @@ -0,0 +1,177 @@ +//! Rewrite `DESCRIBE` / `SHOW CREATE TABLE` into documented catalog queries. + +use datafusion::sql::sqlparser::ast::{DescribeAlias, ObjectName, ShowCreateObject, Statement}; +use datafusion::sql::sqlparser::dialect::GenericDialect; +use datafusion::sql::sqlparser::parser::Parser; + +use super::semantic_catalog::{COLUMN_DOCS, DOCS_SCHEMA, TABLE_DOCS}; + +/// If `sql` is a table metadata statement, return an enriched SELECT; otherwise `None`. +pub fn prepare_metadata_query(sql: &str, default_schema: &str) -> Option { + let dialect = GenericDialect {}; + let mut stmts = Parser::parse_sql(&dialect, sql).ok()?; + if stmts.len() != 1 { + return None; + } + match stmts.remove(0) { + Statement::ExplainTable { + describe_alias: DescribeAlias::Describe | DescribeAlias::Desc, + table_name, + .. + } => { + let (schema, table) = object_name_to_ref(&table_name, default_schema); + Some(describe_table_sql(&schema, &table)) + } + Statement::ShowCreate { + obj_type: ShowCreateObject::Table, + obj_name, + } => { + let (schema, table) = object_name_to_ref(&obj_name, default_schema); + Some(show_create_table_sql(&schema, &table)) + } + _ => None, + } +} + +fn object_name_to_ref(name: &ObjectName, default_schema: &str) -> (String, String) { + let mut parts: Vec = name + .0 + .iter() + .filter_map(|part| part.as_ident().map(|ident| ident.value.clone())) + .collect(); + if matches!( + parts.first().map(|s| s.as_str()), + Some("probe") | Some("global") | Some("datafusion") + ) { + parts.remove(0); + } + match parts.as_slice() { + [] => (default_schema.to_string(), String::new()), + [table] => (default_schema.to_string(), table.clone()), + [schema, table] => (schema.clone(), table.clone()), + [schema, rest @ ..] => (schema.clone(), rest.join(".")), + } +} + +fn sql_literal(s: &str) -> String { + format!("'{}'", s.replace('\'', "''")) +} + +fn describe_table_sql(schema: &str, table: &str) -> String { + format!( + "SELECT \ + c.column_name, \ + c.data_type, \ + c.is_nullable, \ + cd.description AS comment, \ + td.description AS table_comment \ + FROM information_schema.columns c \ + LEFT JOIN probe.{docs}.{column_docs} cd \ + ON c.table_schema = cd.table_schema \ + AND c.table_name = cd.table_name \ + AND c.column_name = cd.column_name \ + LEFT JOIN probe.{docs}.{table_docs} td \ + ON c.table_schema = td.table_schema \ + AND c.table_name = td.table_name \ + WHERE c.table_schema = {schema} \ + AND c.table_name = {table} \ + ORDER BY c.ordinal_position", + docs = DOCS_SCHEMA, + column_docs = COLUMN_DOCS, + table_docs = TABLE_DOCS, + schema = sql_literal(schema), + table = sql_literal(table), + ) +} + +fn show_create_table_sql(schema: &str, table: &str) -> String { + format!( + "SELECT \ + c.table_schema, \ + c.table_name, \ + max(td.description) AS table_comment, \ + max(td.synonyms) AS synonyms, \ + max(td.notes) AS notes, \ + concat( \ + '-- ', coalesce(max(td.description), ''), '\n', \ + CASE WHEN max(td.notes) IS NOT NULL AND max(td.notes) != '' \ + THEN concat('-- ', replace(max(td.notes), '\n', '\n-- '), '\n') \ + ELSE '' END, \ + 'CREATE TABLE ', c.table_schema, '.', c.table_name, ' (\n', \ + string_agg( \ + concat( \ + ' ', c.column_name, ' ', c.data_type, \ + CASE WHEN cd.description IS NOT NULL AND cd.description != '' \ + THEN concat(' -- ', cd.description) \ + ELSE '' END \ + ), \ + ',\n' ORDER BY c.ordinal_position \ + ), \ + '\n);' \ + ) AS create_statement \ + FROM information_schema.columns c \ + LEFT JOIN probe.{docs}.{table_docs} td \ + ON c.table_schema = td.table_schema \ + AND c.table_name = td.table_name \ + LEFT JOIN probe.{docs}.{column_docs} cd \ + ON c.table_schema = cd.table_schema \ + AND c.table_name = cd.table_name \ + AND c.column_name = cd.column_name \ + WHERE c.table_schema = {schema} \ + AND c.table_name = {table} \ + GROUP BY c.table_schema, c.table_name", + docs = DOCS_SCHEMA, + column_docs = COLUMN_DOCS, + table_docs = TABLE_DOCS, + schema = sql_literal(schema), + table = sql_literal(table), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rewrite_describe_table() { + let sql = prepare_metadata_query("DESCRIBE hccl.tasks", "probe").unwrap(); + assert!(sql.contains("information_schema.columns")); + assert!(sql.contains("probing.column_docs")); + assert!(sql.contains("'hccl'")); + assert!(sql.contains("'tasks'")); + assert!(sql.contains("AS comment")); + } + + #[test] + fn rewrite_desc_short_form() { + let sql = prepare_metadata_query("DESC nccl.proxy_ops", "probe").unwrap(); + assert!(sql.contains("'nccl'")); + assert!(sql.contains("'proxy_ops'")); + } + + #[test] + fn rewrite_show_create_table() { + let sql = prepare_metadata_query("SHOW CREATE TABLE python.torch_trace", "probe").unwrap(); + assert!(sql.contains("create_statement")); + assert!(sql.contains("string_agg")); + assert!(sql.contains("'python'")); + assert!(sql.contains("'torch_trace'")); + } + + #[test] + fn skip_explain_query_plan() { + assert!(prepare_metadata_query("EXPLAIN SELECT 1", "probe").is_none()); + } + + #[test] + fn skip_regular_select() { + assert!(prepare_metadata_query("SELECT 1", "probe").is_none()); + } + + #[test] + fn qualified_with_probe_catalog() { + let sql = prepare_metadata_query("DESCRIBE probe.hccl.collectives", "probe").unwrap(); + assert!(sql.contains("'hccl'")); + assert!(sql.contains("'collectives'")); + } +} diff --git a/probing/core/src/core/mod.rs b/probing/core/src/core/mod.rs index 462c190f..0c15d98c 100644 --- a/probing/core/src/core/mod.rs +++ b/probing/core/src/core/mod.rs @@ -6,8 +6,10 @@ mod engine; mod error; pub mod federation; pub mod memtable_sql; +mod metadata_rewrite; mod plugin_advanced; pub mod probe_extension; +mod semantic_catalog; pub use data_source::ProbeDataSource; pub use data_source::ProbeDataSourceKind; @@ -101,4 +103,71 @@ mod tests { assert_eq!(df.names[0], "val", "Column name should match"); assert!(!df.cols.is_empty(), "Should have data columns"); } + + #[tokio::test] + async fn describe_rewrite_includes_comment() { + let engine = Engine::builder().build().await.unwrap(); + let df = engine + .async_query("DESCRIBE probing.column_docs") + .await + .unwrap() + .unwrap(); + assert!( + df.names.iter().any(|n| n == "comment"), + "DESCRIBE rewrite should expose comment column, got {:?}", + df.names + ); + assert!( + df.names.iter().any(|n| n == "table_comment"), + "DESCRIBE rewrite should expose table_comment column, got {:?}", + df.names + ); + } + + #[tokio::test] + async fn engine_column_docs_serves_code_first_hccl() { + use probing_proto::prelude::Seq; + + let engine = Engine::builder().build().await.unwrap(); + let df = engine + .async_query( + "SELECT description FROM probe.probing.column_docs \ + WHERE table_schema = 'hccl' AND table_name = 'tasks' AND column_name = 'task_name'", + ) + .await + .unwrap() + .expect("column_docs query should return rows"); + assert_eq!(df.names, vec!["description"]); + let desc = match &df.cols[0] { + Seq::SeqText(values) => values.first().cloned().expect("task_name description row"), + other => panic!("expected SeqText, got {other:?}"), + }; + assert!( + desc.contains("Memcpy"), + "expected code-first column doc, got {desc}" + ); + } + + #[tokio::test] + async fn engine_table_docs_serves_code_first_hccl() { + use probing_proto::prelude::Seq; + + let engine = Engine::builder().build().await.unwrap(); + let df = engine + .async_query( + "SELECT description FROM probe.probing.table_docs \ + WHERE table_schema = 'hccl' AND table_name = 'tasks'", + ) + .await + .unwrap() + .expect("table_docs query should return rows"); + let desc = match &df.cols[0] { + Seq::SeqText(values) => values.first().cloned().expect("hccl.tasks description row"), + other => panic!("expected SeqText, got {other:?}"), + }; + assert!( + desc.contains("MsprofHcclInfo"), + "expected code-first table doc, got {desc}" + ); + } } diff --git a/probing/core/src/core/semantic_catalog.rs b/probing/core/src/core/semantic_catalog.rs new file mode 100644 index 00000000..83676be3 --- /dev/null +++ b/probing/core/src/core/semantic_catalog.rs @@ -0,0 +1,440 @@ +//! Semantic table/column documentation for Engine `DESCRIBE` / `SHOW CREATE TABLE`. +//! +//! **Primary source:** in-code [`Schema`] docs registered via [`probing_memtable::docs`] +//! (HCCL/NCCL collectors, mmap `ExposedTable::create`, Python `@table`). +//! +//! **Overlay:** `skills/semantic/tables.yaml` supplies agent synonyms/notes/global_name +//! and fills gaps for tables not yet migrated to code-first docs. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::{RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, +}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::SessionContext; +use probing_memtable::docs; +use serde::Deserialize; + +use super::plugin_advanced::PluginAdvancedTable; + +const TABLES_YAML: &str = include_str!("../../../../skills/semantic/tables.yaml"); + +pub const DOCS_SCHEMA: &str = "probing"; +pub const TABLE_DOCS: &str = "table_docs"; +pub const COLUMN_DOCS: &str = "column_docs"; + +#[derive(Debug, Deserialize)] +struct SemanticCatalogFile { + tables: HashMap, +} + +#[derive(Debug, Deserialize)] +struct TableEntry { + description: String, + #[serde(default)] + synonyms: Vec, + #[serde(default)] + key_columns: HashMap, + #[serde(default)] + notes: Vec, + #[serde(default)] + global_name: Option, +} + +#[derive(Debug, Clone)] +pub struct ParsedSemanticCatalog { + pub table_rows: Vec, + pub column_rows: Vec, +} + +#[derive(Debug, Clone)] +pub struct TableDocRow { + pub table_schema: String, + pub table_name: String, + pub description: String, + pub synonyms: String, + pub notes: String, + pub global_name: String, +} + +#[derive(Debug, Clone)] +pub struct ColumnDocRow { + pub table_schema: String, + pub table_name: String, + pub column_name: String, + pub description: String, +} + +fn table_key(table_schema: &str, table_name: &str) -> (String, String) { + (table_schema.to_string(), table_name.to_string()) +} + +fn column_key(table_schema: &str, table_name: &str, column_name: &str) -> (String, String, String) { + ( + table_schema.to_string(), + table_name.to_string(), + column_name.to_string(), + ) +} + +/// Register compile-time known collector schemas (HCCL, NCCL, …). +pub fn register_builtin_schema_docs() { + #[cfg(feature = "builtin-schema-docs")] + { + profapi::register_docs(); + probing_nccl_profiler::register_docs(); + } +} + +pub fn parse_semantic_catalog_yaml(yaml: &str) -> Result { + let file: SemanticCatalogFile = serde_yaml::from_str(yaml).map_err(|e| { + DataFusionError::External(format!("failed to parse semantic tables.yaml: {e}").into()) + })?; + + let mut table_rows = Vec::new(); + let mut column_rows = Vec::new(); + + for (qualified, entry) in file.tables { + let Some((table_schema, table_name)) = qualified.split_once('.') else { + continue; + }; + table_rows.push(TableDocRow { + table_schema: table_schema.to_string(), + table_name: table_name.to_string(), + description: entry.description, + synonyms: entry.synonyms.join(", "), + notes: entry.notes.join("\n"), + global_name: entry.global_name.unwrap_or_default(), + }); + for (column_name, description) in entry.key_columns { + column_rows.push(ColumnDocRow { + table_schema: table_schema.to_string(), + table_name: table_name.to_string(), + column_name, + description, + }); + } + } + + sort_catalog_rows(&mut table_rows, &mut column_rows); + Ok(ParsedSemanticCatalog { + table_rows, + column_rows, + }) +} + +fn sort_catalog_rows(table_rows: &mut [TableDocRow], column_rows: &mut [ColumnDocRow]) { + table_rows + .sort_by(|a, b| (&a.table_schema, &a.table_name).cmp(&(&b.table_schema, &b.table_name))); + column_rows.sort_by(|a, b| { + (&a.table_schema, &a.table_name, &a.column_name).cmp(&( + &b.table_schema, + &b.table_name, + &b.column_name, + )) + }); +} + +/// Merge YAML overlay with the in-code doc registry (registry wins for descriptions). +pub fn build_semantic_catalog() -> Result { + register_builtin_schema_docs(); + + let yaml = parse_semantic_catalog_yaml(TABLES_YAML)?; + + let mut table_map: HashMap<(String, String), TableDocRow> = HashMap::new(); + for row in yaml.table_rows { + table_map.insert(table_key(&row.table_schema, &row.table_name), row); + } + + let mut column_map: HashMap<(String, String, String), ColumnDocRow> = HashMap::new(); + for row in yaml.column_rows { + column_map.insert( + column_key(&row.table_schema, &row.table_name, &row.column_name), + row, + ); + } + + for doc in docs::snapshot() { + let key = table_key(&doc.table_schema, &doc.table_name); + let entry = table_map.entry(key).or_insert_with(|| TableDocRow { + table_schema: doc.table_schema.clone(), + table_name: doc.table_name.clone(), + description: String::new(), + synonyms: String::new(), + notes: String::new(), + global_name: String::new(), + }); + if let Some(desc) = &doc.description { + entry.description = desc.clone(); + } + for (column_name, description) in &doc.columns { + column_map.insert( + column_key(&doc.table_schema, &doc.table_name, column_name), + ColumnDocRow { + table_schema: doc.table_schema.clone(), + table_name: doc.table_name.clone(), + column_name: column_name.clone(), + description: description.clone(), + }, + ); + } + } + + let mut table_rows: Vec = table_map.into_values().collect(); + let mut column_rows: Vec = column_map.into_values().collect(); + sort_catalog_rows(&mut table_rows, &mut column_rows); + + Ok(ParsedSemanticCatalog { + table_rows, + column_rows, + }) +} + +fn table_docs_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("table_schema", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("description", DataType::Utf8, false), + Field::new("synonyms", DataType::Utf8, false), + Field::new("notes", DataType::Utf8, false), + Field::new("global_name", DataType::Utf8, false), + ])) +} + +fn column_docs_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("table_schema", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("column_name", DataType::Utf8, false), + Field::new("description", DataType::Utf8, false), + ])) +} + +fn table_docs_batch(rows: &[TableDocRow]) -> Result { + let schema = table_docs_schema(); + let table_schema = StringArray::from( + rows.iter() + .map(|r| r.table_schema.as_str()) + .collect::>(), + ); + let table_name = StringArray::from( + rows.iter() + .map(|r| r.table_name.as_str()) + .collect::>(), + ); + let description = StringArray::from( + rows.iter() + .map(|r| r.description.as_str()) + .collect::>(), + ); + let synonyms = StringArray::from(rows.iter().map(|r| r.synonyms.as_str()).collect::>()); + let notes = StringArray::from(rows.iter().map(|r| r.notes.as_str()).collect::>()); + let global_name = StringArray::from( + rows.iter() + .map(|r| r.global_name.as_str()) + .collect::>(), + ); + RecordBatch::try_new( + schema, + vec![ + Arc::new(table_schema), + Arc::new(table_name), + Arc::new(description), + Arc::new(synonyms), + Arc::new(notes), + Arc::new(global_name), + ], + ) + .map_err(DataFusionError::from) +} + +fn column_docs_batch(rows: &[ColumnDocRow]) -> Result { + let schema = column_docs_schema(); + let table_schema = StringArray::from( + rows.iter() + .map(|r| r.table_schema.as_str()) + .collect::>(), + ); + let table_name = StringArray::from( + rows.iter() + .map(|r| r.table_name.as_str()) + .collect::>(), + ); + let column_name = StringArray::from( + rows.iter() + .map(|r| r.column_name.as_str()) + .collect::>(), + ); + let description = StringArray::from( + rows.iter() + .map(|r| r.description.as_str()) + .collect::>(), + ); + RecordBatch::try_new( + schema, + vec![ + Arc::new(table_schema), + Arc::new(table_name), + Arc::new(column_name), + Arc::new(description), + ], + ) + .map_err(DataFusionError::from) +} + +/// Register `probing.table_docs` and `probing.column_docs` on the `probe` catalog. +pub fn install_semantic_catalog(context: &SessionContext) -> Result<()> { + let parsed = build_semantic_catalog()?; + let catalog: Arc = if let Some(catalog) = context.catalog("probe") { + catalog + } else { + let c: Arc = Arc::new(MemoryCatalogProvider::new()); + context.register_catalog("probe", Arc::clone(&c)); + c + }; + + let schema: Arc = if let Some(schema) = catalog.schema(DOCS_SCHEMA) { + schema + } else { + let s: Arc = Arc::new(MemorySchemaProvider::new()); + catalog.register_schema(DOCS_SCHEMA, Arc::clone(&s))?; + s + }; + + let table_batch = table_docs_batch(&parsed.table_rows)?; + let column_batch = column_docs_batch(&parsed.column_rows)?; + + schema.register_table( + TABLE_DOCS.to_string(), + Arc::new(PluginAdvancedTable::try_new( + TABLE_DOCS, + table_docs_schema(), + vec![table_batch], + )?), + )?; + schema.register_table( + COLUMN_DOCS.to_string(), + Arc::new(PluginAdvancedTable::try_new( + COLUMN_DOCS, + column_docs_schema(), + vec![column_batch], + )?), + )?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + + #[test] + fn parse_embedded_yaml_has_python_tables() { + let parsed = parse_semantic_catalog_yaml(TABLES_YAML).unwrap(); + assert!(parsed + .table_rows + .iter() + .any(|r| r.table_schema == "python" && r.table_name == "torch_trace")); + } + + #[test] + fn build_catalog_prefers_code_docs_for_hccl() { + let parsed = build_semantic_catalog().unwrap(); + let host_ops = parsed + .table_rows + .iter() + .find(|r| r.table_schema == "hccl" && r.table_name == "host_ops") + .expect("hccl.host_ops"); + assert!(host_ops.description.contains("MSProf Host API")); + assert!(parsed.column_rows.iter().any(|r| { + r.table_schema == "hccl" + && r.table_name == "host_ops" + && r.column_name == "event_class" + && r.description.contains("host_hccl_op") + })); + } + + #[test] + fn build_catalog_keeps_yaml_synonyms_for_hccl() { + let parsed = build_semantic_catalog().unwrap(); + let host_ops = parsed + .table_rows + .iter() + .find(|r| r.table_schema == "hccl" && r.table_name == "host_ops") + .expect("hccl.host_ops"); + assert!( + host_ops.synonyms.contains("MSProf"), + "yaml synonyms should be preserved: {}", + host_ops.synonyms + ); + } + + #[test] + fn build_catalog_includes_registry_only_table() { + let table = format!("code_only_{}", std::process::id()); + docs::register_from_name( + &format!("unittest.{table}"), + &probing_memtable::Schema::new() + .table_doc("registry-only table") + .col_doc("id", probing_memtable::DType::I64, "primary id"), + ); + let parsed = build_semantic_catalog().unwrap(); + assert!(parsed.table_rows.iter().any(|r| { + r.table_schema == "unittest" + && r.table_name == table + && r.description.contains("registry-only") + })); + assert!(parsed.column_rows.iter().any(|r| { + r.table_schema == "unittest" && r.table_name == table && r.column_name == "id" + })); + } + + #[test] + fn build_catalog_nccl_culprit_column_from_code() { + let parsed = build_semantic_catalog().unwrap(); + let row = parsed + .column_rows + .iter() + .find(|r| { + r.table_schema == "nccl" + && r.table_name == "proxy_ops" + && r.column_name == "send_gpu_wait_ns" + }) + .expect("nccl.proxy_ops.send_gpu_wait_ns"); + assert!(row.description.contains("Culprit")); + } + + #[test] + fn build_catalog_yaml_only_python_table_still_present() { + let parsed = build_semantic_catalog().unwrap(); + assert!(parsed + .table_rows + .iter() + .any(|r| r.table_schema == "python" && r.table_name == "torch_trace")); + assert!(parsed.column_rows.iter().any(|r| { + r.table_schema == "python" && r.table_name == "torch_trace" && r.column_name == "module" + })); + } + + #[tokio::test] + async fn install_registers_docs_tables() { + let ctx = SessionContext::new(); + install_semantic_catalog(&ctx).unwrap(); + let df = ctx + .sql("SELECT count(*) AS n FROM probe.probing.column_docs") + .await + .unwrap(); + let batches = df.collect().await.unwrap(); + let col = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col.value(0) > 0); + } +} diff --git a/probing/core/src/trace/mod.rs b/probing/core/src/trace/mod.rs index dbfe32d2..fd26b574 100644 --- a/probing/core/src/trace/mod.rs +++ b/probing/core/src/trace/mod.rs @@ -3,7 +3,7 @@ mod step; pub use span::{attr, Attribute, Ele, Event, Location, Span, SpanStatus, Timestamp}; pub use step::{ - advance_local_step, current_local_step, set_step_bucket_size, step_snapshot, sync_local_step, + advance_micro_step, current_micro_step, set_micro_batches, step_snapshot, sync_micro_step, StepSnapshot, }; diff --git a/probing/core/src/trace/span.rs b/probing/core/src/trace/span.rs index 880fd431..1c9cda28 100644 --- a/probing/core/src/trace/span.rs +++ b/probing/core/src/trace/span.rs @@ -127,7 +127,7 @@ pub struct Span { pub end: Option, // === 元数据 === - pub kind: Option, + pub phase: Option, pub loc: Option, // === 扩展数据 === @@ -137,7 +137,7 @@ pub struct Span { impl Span { /// Creates a new root span (starts a new trace). - pub fn new_root>(name: N, kind: Option<&str>, location: Option<&str>) -> Self { + pub fn new_root>(name: N, phase: Option<&str>, location: Option<&str>) -> Self { let trace_id = NEXT_TRACE_ID.fetch_add(1, Ordering::Relaxed); let span_id = NEXT_SPAN_ID.fetch_add(1, Ordering::Relaxed); let location = location.map(|loc_val| Location::UnknownLocation(loc_val.into())); @@ -151,7 +151,7 @@ impl Span { name: name.into(), start: Timestamp::now(), end: None, - kind: kind.map(|k| k.to_string()), + phase: phase.map(|p| p.to_string()), loc: location, attrs: vec![], events: vec![], @@ -162,7 +162,7 @@ impl Span { pub fn new_child>( parent: &Span, name: N, - kind: Option<&str>, + phase: Option<&str>, location: Option<&str>, ) -> Self { let span_id = NEXT_SPAN_ID.fetch_add(1, Ordering::Relaxed); @@ -177,7 +177,7 @@ impl Span { name: name.into(), start: Timestamp::now(), end: None, - kind: kind.map(|k| k.to_string()), + phase: phase.map(|p| p.to_string()), loc: location, attrs: vec![], events: vec![], @@ -273,7 +273,7 @@ mod tests { ); assert_eq!(span.name, "process_incoming_request"); - assert_eq!(span.kind, Some("server_op".to_string())); + assert_eq!(span.phase, Some("server_op".to_string())); assert_eq!(span.parent_id, None, "Root span has no parent"); assert_eq!( span.status(), diff --git a/probing/core/src/trace/step.rs b/probing/core/src/trace/step.rs index d4ee66f9..64bae4f9 100644 --- a/probing/core/src/trace/step.rs +++ b/probing/core/src/trace/step.rs @@ -1,19 +1,24 @@ use std::cell::RefCell; -/// Snapshot of the training step coordinate system (scheme 2: local step buckets). +/// Training step coordinates. +/// +/// * ``micro_step`` — finest counter (advanced each ``train.step`` / ``probing.step()``). +/// * ``local_step = micro_step / micro_batches`` — per-rank training step. +/// * ``global_step = local_step`` — cluster training step (same value when ranks align). #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct StepSnapshot { + pub micro_step: u64, pub local_step: u64, pub global_step: u64, - pub bucket_size: u64, + pub micro_batches: u64, pub rank: i64, pub world_size: i64, } #[derive(Debug, Clone)] struct StepContext { - local_step: u64, - bucket_size: u64, + micro_step: u64, + micro_batches: u64, rank: i64, world_size: i64, } @@ -21,8 +26,8 @@ struct StepContext { impl Default for StepContext { fn default() -> Self { Self { - local_step: 0, - bucket_size: read_bucket_size(), + micro_step: 0, + micro_batches: read_micro_batches(), rank: read_rank(), world_size: read_world_size(), } @@ -31,27 +36,29 @@ impl Default for StepContext { impl StepContext { fn snapshot(&self) -> StepSnapshot { + let local = training_step_for(self.micro_step, self.micro_batches); StepSnapshot { - local_step: self.local_step, - global_step: global_step_for(self.local_step, self.bucket_size), - bucket_size: self.bucket_size, + micro_step: self.micro_step, + local_step: local, + global_step: local, + micro_batches: self.micro_batches, rank: self.rank, world_size: self.world_size, } } - fn sync_local_step(&mut self, step: u64) -> StepSnapshot { - self.local_step = step; + fn sync_micro_step(&mut self, step: u64) -> StepSnapshot { + self.micro_step = step; self.snapshot() } - fn advance_local_step(&mut self) -> StepSnapshot { - self.local_step = self.local_step.saturating_add(1); + fn advance_micro_step(&mut self) -> StepSnapshot { + self.micro_step = self.micro_step.saturating_add(1); self.snapshot() } - fn set_bucket_size(&mut self, bucket: u64) { - self.bucket_size = bucket.max(1); + fn set_micro_batches(&mut self, micro_batches: u64) { + self.micro_batches = micro_batches.max(1); } } @@ -59,9 +66,8 @@ thread_local! { static STEP_CTX: RefCell = RefCell::new(StepContext::default()); } -fn global_step_for(local_step: u64, bucket_size: u64) -> u64 { - let bucket = bucket_size.max(1); - local_step / bucket +fn training_step_for(micro_step: u64, micro_batches: u64) -> u64 { + micro_step / micro_batches.max(1) } fn read_env_u64(key: &str) -> Option { @@ -72,8 +78,9 @@ fn read_env_i64(key: &str) -> Option { std::env::var(key).ok().and_then(|v| v.trim().parse().ok()) } -fn read_bucket_size() -> u64 { - read_env_u64("PROBING_GLOBAL_STEP_BUCKET") +fn read_micro_batches() -> u64 { + read_env_u64("PROBING_MICRO_BATCHES") + .or_else(|| read_env_u64("PROBING_GLOBAL_STEP_BUCKET")) .or_else(|| read_env_u64("PROBING_STEP_BUCKET")) .unwrap_or(1) .max(1) @@ -95,20 +102,20 @@ pub fn step_snapshot() -> StepSnapshot { with_ctx(|ctx| ctx.snapshot()) } -pub fn sync_local_step(step: u64) -> StepSnapshot { - with_ctx(|ctx| ctx.sync_local_step(step)) +pub fn sync_micro_step(step: u64) -> StepSnapshot { + with_ctx(|ctx| ctx.sync_micro_step(step)) } -pub fn advance_local_step() -> StepSnapshot { - with_ctx(|ctx| ctx.advance_local_step()) +pub fn advance_micro_step() -> StepSnapshot { + with_ctx(|ctx| ctx.advance_micro_step()) } -pub fn set_step_bucket_size(bucket: u64) { - with_ctx(|ctx| ctx.set_bucket_size(bucket)); +pub fn set_micro_batches(micro_batches: u64) { + with_ctx(|ctx| ctx.set_micro_batches(micro_batches)); } -pub fn current_local_step() -> u64 { - step_snapshot().local_step +pub fn current_micro_step() -> u64 { + step_snapshot().micro_step } #[cfg(test)] @@ -116,22 +123,34 @@ mod tests { use super::*; #[test] - fn global_step_uses_bucket_size() { - assert_eq!(global_step_for(0, 1), 0); - assert_eq!(global_step_for(9, 10), 0); - assert_eq!(global_step_for(10, 10), 1); - assert_eq!(global_step_for(42, 10), 4); + fn local_step_uses_micro_batches() { + assert_eq!(training_step_for(0, 1), 0); + assert_eq!(training_step_for(9, 10), 0); + assert_eq!(training_step_for(10, 10), 1); + assert_eq!(training_step_for(42, 10), 4); } #[test] - fn advance_and_sync_local_step() { - let _ = sync_local_step(0); - assert_eq!(step_snapshot().local_step, 0); - let snap = advance_local_step(); + fn global_step_equals_local_step() { + let _ = sync_micro_step(0); + let snap = advance_micro_step(); + assert_eq!(snap.micro_step, 1); assert_eq!(snap.local_step, 1); assert_eq!(snap.global_step, 1); - let snap = sync_local_step(99); + let snap = sync_micro_step(99); + assert_eq!(snap.micro_step, 99); assert_eq!(snap.local_step, 99); assert_eq!(snap.global_step, 99); } + + #[test] + fn micro_batches_groups_training_steps() { + set_micro_batches(10); + let _ = sync_micro_step(0); + let snap = sync_micro_step(15); + assert_eq!(snap.micro_step, 15); + assert_eq!(snap.local_step, 1); + assert_eq!(snap.global_step, 1); + set_micro_batches(1); + } } diff --git a/probing/core/src/tracing.rs b/probing/core/src/tracing.rs index 15bd5807..5f55fa8c 100644 --- a/probing/core/src/tracing.rs +++ b/probing/core/src/tracing.rs @@ -89,7 +89,7 @@ pub struct Span { pub parent_span_id: Option, pub name: String, - pub kind: Option, + pub phase: Option, pub location: Option, pub start_time: Timestamp, @@ -139,12 +139,12 @@ impl LocalSpanManager { pub fn start_span>( &mut self, name: N, - kind: Option<&str>, + phase: Option<&str>, location: Option<&str>, initial_attributes: Option>, ) -> (SpanId, TraceId) { let name = name.into(); - let kind = kind.map(|k_val| k_val.into()); + let phase = phase.map(|p_val| p_val.into()); let start_time = Timestamp::now(); let current_span_sequence = self.next_span_seq; @@ -180,7 +180,7 @@ impl LocalSpanManager { span_id, parent_span_id: parent_span_id_to_store, name, - kind, + phase, location, start_time, end_time: None, @@ -418,7 +418,7 @@ mod tests { .get(&span_id) .expect("Span not found in tracer"); assert_eq!(span.name, "process_incoming_request"); - assert_eq!(span.kind, Some("server_op".to_string())); + assert_eq!(span.phase, Some("server_op".to_string())); assert_eq!(span.parent_span_id, None, "Root span has no parent"); assert_eq!( span.status, diff --git a/probing/extensions/hccl-shim/Cargo.toml b/probing/extensions/hccl-shim/Cargo.toml new file mode 100644 index 00000000..f5d36372 --- /dev/null +++ b/probing/extensions/hccl-shim/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "probing-hccl-shim" +description = "libprofapi.so shim for HCCL — captures Msprof events into hccl.* memtables" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +# Produces libprofapi.so on Linux (HCCL dlopen name). +name = "profapi" +crate-type = ["cdylib", "rlib"] + +[dependencies] +probing-memtable = { path = "../../memtable" } + +once_cell = { workspace = true } +parking_lot = "0.12" + +[target.'cfg(target_os = "linux")'.dependencies] +libc = "0.2" + +[dev-dependencies] +tempfile = "3" diff --git a/probing/extensions/hccl-shim/README.md b/probing/extensions/hccl-shim/README.md new file mode 100644 index 00000000..d9981d08 --- /dev/null +++ b/probing/extensions/hccl-shim/README.md @@ -0,0 +1,82 @@ +# probing-hccl-shim + +HCCL loads MSProf via `dlopen("libprofapi.so")`. This crate builds a **`libprofapi.so` shim** that: + +1. Exports the seven `Msprof*` symbols HCCL `dlsym`s +2. Records events into probing mmap tables (`hccl.*`) +3. Forwards to the real CANN library (`libprofapi.so.real` or `PROBING_HCCL_PROFAPI_REAL`) + +## Build + +```bash +make hccl-shim-lib +# → python/probing/shim/hccl/libprofapi.so +``` + +## Deploy on Ascend training + +```bash +# 1. Copy real MSProf API next to shim (once per CANN version) +python -m probing.hccl --install-real "$ASCEND_HOME/lib64/libprofapi.so" + +# 2. Prefer shim over CANN libprofapi.so +export LD_LIBRARY_PATH="$(python -m probing.hccl --shim-dir):${LD_LIBRARY_PATH:-}" + +# 3. probing memtable + optional debug +export PROBING=2 +export PROBING_HCCL_SHIM_LOG=1 # optional + +# 4. Enable CANN/HCCL profiling as usual, then train +torchrun ... train.py +``` + +Query (same process or after inject): + +```sql +SELECT count(*) FROM hccl.host_ops; +SELECT count(*) FROM hccl.collectives; +SELECT * FROM hccl.tasks LIMIT 20; +SELECT * FROM hccl.mc2_streams LIMIT 10; +SELECT * FROM global.hccl.tasks LIMIT 20; -- multi-rank +``` + +## Tables + +| SQL table | Source | Key columns | +|-----------|--------|-------------| +| `hccl.host_ops` | `MsprofReportApi` | `event_class`, `item_name`, `duration_ns`, timing | +| `hccl.collectives` | `MsprofReportCompactInfo` (HcclOpInfo) + host HCCL op API | `row_source` (`api`/`compact`), `count`, `group_hash`, `alg_hash` | +| `hccl.tasks` | `MsprofReportAdditionalInfo` → `MsprofHcclInfo` | `task_name`, `plane_index`, `rank_in_plane`, `data_size` | +| `hccl.mc2_streams` | MC2 comm AdditionalInfo | `comm_stream_ids`, `rank_id`, `aicpu_kfc_stream_id` | +| `hccl.context_ids` | ContextId AdditionalInfo | `ctx_id_min`, `ctx_id_max` | + +Join collective timing with metadata: + +```sql +SELECT a.op_name, a.duration_ns, c.count, c.group_hash, c.alg_hash +FROM hccl.collectives a +JOIN hccl.collectives c + ON a.thread_id = c.thread_id + AND a.row_source = 'api' + AND c.row_source = 'compact' + AND abs(a.ts - c.ts) < 1000000; +``` + +## Name resolution + +`MsprofRegTypeInfo` and `MsprofGetHashId` populate a hash→name cache. Known HCCL task/op strings are pre-seeded on first call so `item_name` / `event_class` decode without waiting for runtime registration. + +## Real library resolution + +1. `PROBING_HCCL_PROFAPI_REAL` +2. `/libprofapi.so.real` +3. `$ASCEND_HOME/lib64/libprofapi.so` or `$ASCEND_INSTALL_PATH/lib64/libprofapi.so` + +The shim never `dlopen("libprofapi.so")` by bare name (would reload itself). + +## Notes + +- Struct layouts (`MsprofHcclInfo`, `ProfilingDeviceCommResInfo`, etc.) are best-effort from open HCCL sources; pin CANN version and validate columns on first deploy. +- AdditionalInfo routing uses registered type names (`mc2_comm_info`, `context_id_info`) with data-length fallbacks. +- Profiling must be enabled (`GetIfProfile()` / MSProf subscribe) or tables stay empty. +- Non-Linux: crate builds for CI; plugin symbols are Linux-only. diff --git a/probing/extensions/hccl-shim/src/forward.rs b/probing/extensions/hccl-shim/src/forward.rs new file mode 100644 index 00000000..98a495d8 --- /dev/null +++ b/probing/extensions/hccl-shim/src/forward.rs @@ -0,0 +1,250 @@ +//! Lazy forward to the real CANN `libprofapi.so` (never dlopen the shim name). + +#![cfg(target_os = "linux")] + +use std::ffi::{CStr, CString}; +use std::os::raw::{c_char, c_void}; +use std::os::unix::ffi::OsStrExt; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; + +use once_cell::sync::Lazy; +use parking_lot::Mutex; + +type ProfCommandHandle = Option i32>; + +type FnRegisterCallback = unsafe extern "C" fn(u32, ProfCommandHandle) -> i32; +type FnRegTypeInfo = unsafe extern "C" fn(u16, u32, *const c_char) -> i32; +type FnReportApi = unsafe extern "C" fn(u32, *const c_void) -> i32; +type FnReportBlob = unsafe extern "C" fn(u32, *const c_void, u32) -> i32; +type FnGetHashId = unsafe extern "C" fn(*const c_char, u32) -> u64; +type FnSysCycleTime = unsafe extern "C" fn() -> u64; + +struct RealApi { + register_callback: FnRegisterCallback, + reg_type_info: FnRegTypeInfo, + report_api: FnReportApi, + report_compact: FnReportBlob, + report_additional: FnReportBlob, + get_hash_id: FnGetHashId, + sys_cycle_time: FnSysCycleTime, +} + +struct RealLib { + handle: *mut c_void, + api: RealApi, +} + +unsafe impl Send for RealLib {} + +static INIT: Lazy>> = Lazy::new(|| Mutex::new(None)); +static INIT_FAILED: AtomicBool = AtomicBool::new(false); +static LOGGED_INIT: AtomicBool = AtomicBool::new(false); + +const ENV_REAL: &str = "PROBING_HCCL_PROFAPI_REAL"; +const REAL_BASENAME: &str = "libprofapi.so.real"; +const ENV_ASCEND_HOME: &str = "ASCEND_HOME"; +const ENV_ASCEND_INSTALL: &str = "ASCEND_INSTALL_PATH"; + +fn log_once(msg: &str) { + if std::env::var_os("PROBING_HCCL_SHIM_LOG").is_some() + && LOGGED_INIT + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + eprintln!("[probing-hccl-shim] {msg}"); + } +} + +fn shim_directory() -> Option { + let maps = std::fs::read_to_string("/proc/self/maps").ok()?; + for line in maps.lines() { + if !line.contains("libprofapi.so") { + continue; + } + let path = line.split_whitespace().last()?; + let p = Path::new(path); + if p.is_absolute() { + return p.parent().map(|d| d.to_path_buf()); + } + } + None +} + +fn ascend_lib_dirs() -> Vec { + let mut out = Vec::new(); + for key in [ENV_ASCEND_HOME, ENV_ASCEND_INSTALL] { + if let Ok(v) = std::env::var(key) { + let base = PathBuf::from(v); + out.push(base.join("lib64")); + out.push(base.join("lib")); + } + } + out +} + +fn candidate_real_paths() -> Vec { + let mut out = Vec::new(); + if let Ok(p) = std::env::var(ENV_REAL) { + out.push(PathBuf::from(p)); + } + if let Some(dir) = shim_directory() { + out.push(dir.join(REAL_BASENAME)); + } + for libdir in ascend_lib_dirs() { + out.push(libdir.join("libprofapi.so")); + } + out +} + +unsafe fn load_sym(handle: *mut c_void, name: &CStr) -> Option { + let sym = libc::dlsym(handle, name.as_ptr()); + if sym.is_null() { + None + } else { + Some(std::mem::transmute_copy(&sym)) + } +} + +unsafe fn open_real() -> Option { + for path in candidate_real_paths() { + if !path.is_file() { + continue; + } + let cpath = CString::new(path.as_os_str().as_bytes()).ok()?; + let handle = libc::dlopen(cpath.as_ptr(), libc::RTLD_NOW | libc::RTLD_LOCAL); + if handle.is_null() { + continue; + } + let api = RealApi { + register_callback: load_sym(handle, c"MsprofRegisterCallback")?, + reg_type_info: load_sym(handle, c"MsprofRegTypeInfo")?, + report_api: load_sym(handle, c"MsprofReportApi")?, + report_compact: load_sym(handle, c"MsprofReportCompactInfo")?, + report_additional: load_sym(handle, c"MsprofReportAdditionalInfo")?, + get_hash_id: load_sym(handle, c"MsprofGetHashId")?, + sys_cycle_time: load_sym(handle, c"MsprofSysCycleTime")?, + }; + log_once(&format!("forwarding to {}", path.display())); + return Some(RealLib { handle, api }); + } + None +} + +fn real_lib() -> Option>> { + let mut guard = INIT.lock(); + if guard.is_none() && !INIT_FAILED.load(Ordering::Relaxed) { + unsafe { + *guard = open_real(); + } + if guard.is_none() { + INIT_FAILED.store(true, Ordering::Relaxed); + if LOGGED_INIT + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + eprintln!( + "[probing-hccl-shim] real libprofapi not found; MSProf forward disabled. \ + Set {ENV_REAL} or place {REAL_BASENAME} next to the shim." + ); + } + } + } + Some(guard) +} + +unsafe extern "C" fn stub_register(_: u32, _: ProfCommandHandle) -> i32 { + 0 +} +unsafe extern "C" fn stub_reg_type(_: u16, _: u32, _: *const c_char) -> i32 { + 0 +} +unsafe extern "C" fn stub_report_api(_: u32, _: *const c_void) -> i32 { + 0 +} +unsafe extern "C" fn stub_report_blob(_: u32, _: *const c_void, _: u32) -> i32 { + 0 +} +unsafe extern "C" fn stub_hash(_: *const c_char, _: u32) -> u64 { + 0 +} +unsafe extern "C" fn stub_time() -> u64 { + 0 +} + +pub fn forward_register(module_id: u32, handle: ProfCommandHandle) -> i32 { + if let Some(guard) = real_lib() { + if let Some(real) = guard.as_ref() { + return unsafe { (real.api.register_callback)(module_id, handle) }; + } + } + unsafe { stub_register(module_id, handle) } +} + +pub fn forward_reg_type_info(level: u16, type_id: u32, type_name: *const c_char) -> i32 { + if let Some(guard) = real_lib() { + if let Some(real) = guard.as_ref() { + return unsafe { (real.api.reg_type_info)(level, type_id, type_name) }; + } + } + unsafe { stub_reg_type(level, type_id, type_name) } +} + +pub fn forward_report_api(aging: u32, api: *const c_void) -> i32 { + if let Some(guard) = real_lib() { + if let Some(real) = guard.as_ref() { + return unsafe { (real.api.report_api)(aging, api) }; + } + } + unsafe { stub_report_api(aging, api) } +} + +pub fn forward_report_compact(aging: u32, data: *const c_void, len: u32) -> i32 { + if let Some(guard) = real_lib() { + if let Some(real) = guard.as_ref() { + return unsafe { (real.api.report_compact)(aging, data, len) }; + } + } + unsafe { stub_report_blob(aging, data, len) } +} + +pub fn forward_report_additional(aging: u32, data: *const c_void, len: u32) -> i32 { + if let Some(guard) = real_lib() { + if let Some(real) = guard.as_ref() { + return unsafe { (real.api.report_additional)(aging, data, len) }; + } + } + unsafe { stub_report_blob(aging, data, len) } +} + +pub fn forward_get_hash_id(hash_info: *const c_char, length: u32) -> u64 { + if hash_info.is_null() || length == 0 { + return 0; + } + if let Some(guard) = real_lib() { + if let Some(real) = guard.as_ref() { + return unsafe { (real.api.get_hash_id)(hash_info, length) }; + } + } + unsafe { stub_hash(hash_info, length) } +} + +pub fn forward_sys_cycle_time() -> u64 { + if let Some(guard) = real_lib() { + if let Some(real) = guard.as_ref() { + return unsafe { (real.api.sys_cycle_time)() }; + } + } + unsafe { stub_time() } +} + +pub fn shutdown() { + let mut guard = INIT.lock(); + if let Some(real) = guard.take() { + unsafe { + if !real.handle.is_null() { + libc::dlclose(real.handle); + } + } + } +} diff --git a/probing/extensions/hccl-shim/src/lib.rs b/probing/extensions/hccl-shim/src/lib.rs new file mode 100644 index 00000000..53a4de83 --- /dev/null +++ b/probing/extensions/hccl-shim/src/lib.rs @@ -0,0 +1,199 @@ +//! `libprofapi.so` shim — intercept MSProf, write `hccl.*` memtables, forward to CANN. + +#![allow(clippy::missing_safety_doc)] +#![cfg_attr(not(target_os = "linux"), allow(dead_code))] + +#[cfg(target_os = "linux")] +mod forward; +mod msprof; +mod names; +pub mod tables; +pub use tables::register_docs; +mod writer; + +#[cfg(not(target_os = "linux"))] +mod forward { + use std::os::raw::{c_char, c_void}; + type ProfCommandHandle = Option i32>; + pub fn forward_register(_: u32, _: ProfCommandHandle) -> i32 { + 0 + } + pub fn forward_reg_type_info(_: u16, _: u32, _: *const c_char) -> i32 { + 0 + } + pub fn forward_report_api(_: u32, _: *const c_void) -> i32 { + 0 + } + pub fn forward_report_compact(_: u32, _: *const c_void, _: u32) -> i32 { + 0 + } + pub fn forward_report_additional(_: u32, _: *const c_void, _: u32) -> i32 { + 0 + } + pub fn forward_get_hash_id(_: *const c_char, _: u32) -> u64 { + 0 + } + pub fn forward_sys_cycle_time() -> u64 { + 0 + } + pub fn shutdown() {} +} + +pub use tables::{ + collectives_schema, context_ids_schema, host_ops_schema, mc2_streams_schema, tasks_schema, + COLLECTIVES_FILE, CONTEXT_IDS_FILE, HOST_OPS_FILE, MC2_STREAMS_FILE, TASKS_FILE, +}; + +use std::os::raw::c_void; + +use once_cell::sync::Lazy; +use parking_lot::Mutex; + +use crate::msprof::{ + classify_additional, is_hccl_op_compact, read_additional_header, read_api, read_compact_header, + read_context_id_info, read_hccl_info, read_hccl_op_info, read_mc2_comm_info, AdditionalKind, + MSPROF_ADDITIONAL_HEADER, MSPROF_BLOB_HEADER, +}; +use crate::names::{lookup_type_id, preseed_hashes}; +use crate::writer::HcclWriter; + +static WRITER: Lazy> = Lazy::new(|| Mutex::new(HcclWriter::new())); + +type ProfCommandHandle = Option i32>; + +fn hash_fn(s: *const std::os::raw::c_char, l: u32) -> u64 { + forward::forward_get_hash_id(s, l) +} + +fn ensure_names() { + preseed_hashes(hash_fn); +} + +fn capture_api(aging: u32, ptr: *const c_void) { + if ptr.is_null() { + return; + } + ensure_names(); + if let Some(api) = read_api(ptr as *const u8, crate::msprof::MSPROF_API_SIZE as u32) { + WRITER.lock().record_api(aging, &api); + } +} + +fn capture_compact(_aging: u32, ptr: *const c_void, len: u32) { + if ptr.is_null() { + return; + } + ensure_names(); + let Some(header) = read_compact_header(ptr as *const u8, len) else { + return; + }; + let data_ptr = unsafe { (ptr as *const u8).add(MSPROF_BLOB_HEADER) }; + let type_name = lookup_type_id(header.type_id); + if is_hccl_op_compact(&type_name, header.data_len) { + if let Some(op) = read_hccl_op_info(data_ptr, header.data_len) { + WRITER.lock().record_compact_hccl_op(&header, &op); + } + } +} + +fn capture_additional(_aging: u32, ptr: *const c_void, len: u32) { + if ptr.is_null() { + return; + } + ensure_names(); + let Some(header) = read_additional_header(ptr as *const u8, len) else { + return; + }; + let data_ptr = unsafe { (ptr as *const u8).add(MSPROF_ADDITIONAL_HEADER) }; + let type_name = lookup_type_id(header.type_id); + match classify_additional(header.type_id, &type_name, header.data_len) { + AdditionalKind::HcclTask => { + if let Some(hccl) = read_hccl_info(data_ptr, header.data_len) { + WRITER + .lock() + .record_task(&header, &hccl, header.data_len as i32); + } + } + AdditionalKind::Mc2Comm => { + if let Some(mc2) = read_mc2_comm_info(data_ptr, header.data_len) { + WRITER.lock().record_mc2(&header, &mc2); + } + } + AdditionalKind::ContextId => { + if let Some(ctx) = read_context_id_info(data_ptr, header.data_len) { + WRITER.lock().record_context(&header, &ctx); + } + } + AdditionalKind::Unknown => {} + } +} + +#[cfg(target_os = "linux")] +mod export { + use std::os::raw::{c_char, c_void}; + + use super::*; + + #[no_mangle] + pub unsafe extern "C" fn MsprofRegisterCallback( + module_id: u32, + handle: ProfCommandHandle, + ) -> i32 { + forward::forward_register(module_id, handle) + } + + #[no_mangle] + pub unsafe extern "C" fn MsprofRegTypeInfo( + level: u16, + type_id: u32, + type_name: *const c_char, + ) -> i32 { + ensure_names(); + crate::names::register_type_info(type_id, type_name, hash_fn); + forward::forward_reg_type_info(level, type_id, type_name) + } + + #[no_mangle] + pub unsafe extern "C" fn MsprofReportApi(aging_flag: u32, api: *const c_void) -> i32 { + capture_api(aging_flag, api); + forward::forward_report_api(aging_flag, api) + } + + #[no_mangle] + pub unsafe extern "C" fn MsprofReportCompactInfo( + aging_flag: u32, + data: *const c_void, + length: u32, + ) -> i32 { + capture_compact(aging_flag, data, length); + forward::forward_report_compact(aging_flag, data, length) + } + + #[no_mangle] + pub unsafe extern "C" fn MsprofReportAdditionalInfo( + aging_flag: u32, + data: *const c_void, + length: u32, + ) -> i32 { + capture_additional(aging_flag, data, length); + forward::forward_report_additional(aging_flag, data, length) + } + + #[no_mangle] + pub unsafe extern "C" fn MsprofGetHashId(hash_info: *const c_char, length: u32) -> u64 { + ensure_names(); + let hash = forward::forward_get_hash_id(hash_info, length); + crate::names::register_hash_string(hash_info, length, hash); + hash + } + + #[no_mangle] + pub unsafe extern "C" fn MsprofSysCycleTime() -> u64 { + forward::forward_sys_cycle_time() + } +} + +#[cfg(not(target_os = "linux"))] +mod stub { + pub const BUILD_NOTE: &str = "probing-hccl-shim: libprofapi.so built on Linux only"; +} diff --git a/probing/extensions/hccl-shim/src/msprof.rs b/probing/extensions/hccl-shim/src/msprof.rs new file mode 100644 index 00000000..49ae7335 --- /dev/null +++ b/probing/extensions/hccl-shim/src/msprof.rs @@ -0,0 +1,278 @@ +//! Best-effort layouts for CANN MSProf structs (toolchain/prof_api.h). +//! +//! Field order follows open HCCL usage in task_profiling.cc / profiling_manager.cc. +//! Validate with `sizeof` checks at runtime; CANN version drift may require updates. + +use std::mem::size_of; + +/// HCCL passes `sizeof(MsprofHcclInfo)` bytes in AdditionalInfo.data for task reports. +#[repr(C)] +#[derive(Clone, Copy, Default)] +pub struct MsprofHcclInfo { + pub item_id: u64, + pub ccl_tag: u64, + pub group_name: u64, + pub local_rank: u32, + pub remote_rank: u32, + pub rank_size: u32, + pub workflow_mode: u32, + pub plane_id: u32, + pub ctx_id: u32, + pub notify_id: u64, + pub stage: u32, + pub role: u32, + pub _pad: u32, + pub duration_estimated: f64, + pub src_addr: u64, + pub dst_addr: u64, + pub data_size: u64, + pub op_type: u32, + pub data_type: u32, + pub link_type: u32, + pub transport_type: u32, + pub rdma_type: u32, +} + +pub const MSPROF_HCCL_INFO_MIN: usize = size_of::(); + +/// Observed HCCL MsprofApi initialization pattern. +#[repr(C)] +#[derive(Clone, Copy, Default)] +pub struct MsprofApi { + pub level: u16, + pub reserve: u16, + pub type_id: u32, + pub thread_id: u32, + pub reserve2: u32, + pub begin_time: u64, + pub end_time: u64, + pub item_id: u64, +} + +pub const MSPROF_API_SIZE: usize = size_of::(); + +/// Header shared by MsprofAdditionalInfo and MsprofCompactInfo. +#[repr(C)] +#[derive(Clone, Copy, Default)] +pub struct MsprofBlobHeader { + pub level: u32, + pub type_id: u32, + pub thread_id: u32, + pub data_len: u32, + pub time_stamp: u64, +} + +pub type MsprofAdditionalInfoHeader = MsprofBlobHeader; +pub type MsprofCompactInfoHeader = MsprofBlobHeader; + +pub const MSPROF_BLOB_HEADER: usize = size_of::(); +pub const MSPROF_ADDITIONAL_HEADER: usize = MSPROF_BLOB_HEADER; + +/// `CallMsprofReportHostHcclOpInfo` payload inside MsprofCompactInfo.data. +#[repr(C)] +#[derive(Clone, Copy, Default)] +pub struct MsprofHCCLOPInfo { + pub relay: u32, + pub retry: u32, + pub data_type: u32, + pub _pad: u32, + pub alg_type: u64, + pub count: u64, + pub group_name: u64, +} + +pub const MSPROF_HCCL_OP_INFO_MIN: usize = size_of::(); + +/// `CallMsprofReportContextIdInfo` payload. +#[repr(C)] +#[derive(Clone, Copy, Default)] +pub struct MsprofContextIdInfo { + pub ctx_id_num: u32, + pub ctx_ids: [u32; 2], +} + +pub const MSPROF_CONTEXT_ID_INFO: usize = size_of::(); + +/// Prefix of `ProfilingDeviceCommResInfo` from hccl_communicator_host.cc. +#[repr(C)] +#[derive(Clone, Copy, Default)] +pub struct ProfilingDeviceCommResInfoHeader { + pub group_name: u64, + pub rank_size: u32, + pub rank_id: u32, + pub usr_rank_id: u32, + pub aicpu_kfc_stream_id: u32, + pub reserve: u32, +} + +pub const MSPROF_MC2_HEADER: usize = size_of::(); + +/// ProfTaskType::TASK_HCCL_INFO +pub const PROF_TASK_HCCL_INFO: u32 = 0; + +pub fn decode_plane(plane_id: u32) -> (i32, i32, i32) { + let id = plane_id as u64; + let plane_index = ((id >> 28) & 0xF) as i32; + let rank_size_plane = ((id >> 16) & 0xFFF) as i32; + let rank_in_plane = (id & 0xFFFF) as i32; + (plane_index, rank_in_plane, rank_size_plane) +} + +pub fn read_api(ptr: *const u8, len: u32) -> Option { + if ptr.is_null() || (len as usize) < MSPROF_API_SIZE { + return None; + } + Some(unsafe { std::ptr::read_unaligned(ptr as *const MsprofApi) }) +} + +pub fn read_blob_header(ptr: *const u8, len: u32) -> Option { + if ptr.is_null() || (len as usize) < MSPROF_BLOB_HEADER { + return None; + } + Some(unsafe { std::ptr::read_unaligned(ptr as *const MsprofBlobHeader) }) +} + +pub fn read_additional_header(ptr: *const u8, len: u32) -> Option { + read_blob_header(ptr, len) +} + +pub fn read_compact_header(ptr: *const u8, len: u32) -> Option { + read_blob_header(ptr, len) +} + +pub fn read_hccl_info(data: *const u8, data_len: u32) -> Option { + if data.is_null() || (data_len as usize) < MSPROF_HCCL_INFO_MIN { + return None; + } + Some(unsafe { std::ptr::read_unaligned(data as *const MsprofHcclInfo) }) +} + +pub fn read_hccl_op_info(data: *const u8, data_len: u32) -> Option { + if data.is_null() || (data_len as usize) < MSPROF_HCCL_OP_INFO_MIN { + return None; + } + Some(unsafe { std::ptr::read_unaligned(data as *const MsprofHCCLOPInfo) }) +} + +pub fn read_context_id_info(data: *const u8, data_len: u32) -> Option { + if data.is_null() || (data_len as usize) < 8 { + return None; + } + Some(unsafe { std::ptr::read_unaligned(data as *const MsprofContextIdInfo) }) +} + +pub struct Mc2CommInfo { + pub header: ProfilingDeviceCommResInfoHeader, + pub comm_stream_size: u32, + pub comm_stream_ids: Vec, +} + +pub fn read_mc2_comm_info(data: *const u8, data_len: u32) -> Option { + if data.is_null() || (data_len as usize) < MSPROF_MC2_HEADER + 4 { + return None; + } + let header = + unsafe { std::ptr::read_unaligned(data as *const ProfilingDeviceCommResInfoHeader) }; + let tail = data_len as usize - MSPROF_MC2_HEADER; + if tail < 4 { + return None; + } + let tail_ptr = unsafe { data.add(MSPROF_MC2_HEADER) }; + // HCCL sets commStreamSize after filling ids; layout is header + ids[] + commStreamSize + // or header + commStreamSize + ids[]. Accept either by checking plausibility. + let comm_stream_size_tail = + unsafe { std::ptr::read_unaligned(tail_ptr.add(tail - 4) as *const u32) }; + let comm_stream_size_head = unsafe { std::ptr::read_unaligned(tail_ptr as *const u32) }; + + let (comm_stream_size, id_offset) = if comm_stream_size_head > 0 && comm_stream_size_head <= 512 + { + (comm_stream_size_head, 4usize) + } else if comm_stream_size_tail > 0 && comm_stream_size_tail <= 512 { + (comm_stream_size_tail, 0usize) + } else { + (0u32, 4usize) + }; + + let ids_bytes = tail.saturating_sub(4); + let ids_ptr = unsafe { tail_ptr.add(id_offset) }; + let max_ids = ids_bytes / 4; + let n = if comm_stream_size > 0 { + (comm_stream_size as usize).min(max_ids) + } else { + max_ids + }; + let mut comm_stream_ids = Vec::with_capacity(n); + for i in 0..n { + let id = unsafe { std::ptr::read_unaligned(ids_ptr.add(i * 4) as *const u32) }; + comm_stream_ids.push(id); + } + let comm_stream_size = if comm_stream_size > 0 { + comm_stream_size + } else { + n as u32 + }; + + Some(Mc2CommInfo { + header, + comm_stream_size, + comm_stream_ids, + }) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AdditionalKind { + HcclTask, + Mc2Comm, + ContextId, + Unknown, +} + +pub fn classify_additional(type_id: u32, type_name: &str, data_len: u32) -> AdditionalKind { + if type_name == "mc2_comm_info" { + return AdditionalKind::Mc2Comm; + } + if type_name == "context_id_info" { + return AdditionalKind::ContextId; + } + if type_id == PROF_TASK_HCCL_INFO || data_len as usize >= MSPROF_HCCL_INFO_MIN { + return AdditionalKind::HcclTask; + } + if data_len as usize >= MSPROF_MC2_HEADER + 8 && data_len <= 4096 { + return AdditionalKind::Mc2Comm; + } + if data_len as usize >= 8 && data_len <= 64 { + return AdditionalKind::ContextId; + } + AdditionalKind::Unknown +} + +pub fn is_hccl_op_compact(type_name: &str, data_len: u32) -> bool { + type_name.contains("hccl_op") || data_len as usize == MSPROF_HCCL_OP_INFO_MIN +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hccl_info_size_sane() { + const { + assert!(MSPROF_HCCL_INFO_MIN >= 96 && MSPROF_HCCL_INFO_MIN <= 256); + } + } + + #[test] + fn decode_plane_bits() { + // plane=3, rank_size=8, rank=5 -> (3<<28)|(8<<16)|5 + let plane_id = (3u64 << 28) | (8u64 << 16) | 5; + assert_eq!(decode_plane(plane_id as u32), (3, 5, 8)); + } + + #[test] + fn classify_task_by_len() { + assert_eq!( + classify_additional(99, "", MSPROF_HCCL_INFO_MIN as u32), + AdditionalKind::HcclTask + ); + } +} diff --git a/probing/extensions/hccl-shim/src/names.rs b/probing/extensions/hccl-shim/src/names.rs new file mode 100644 index 00000000..d92a8d0e --- /dev/null +++ b/probing/extensions/hccl-shim/src/names.rs @@ -0,0 +1,191 @@ +//! Resolve MSProf `item_id` hashes to human-readable names. + +use std::collections::HashMap; +use std::ffi::CStr; +use std::os::raw::c_char; +use std::sync::atomic::{AtomicBool, Ordering}; + +use once_cell::sync::Lazy; +use parking_lot::Mutex; + +static REGISTRY: Lazy> = Lazy::new(|| Mutex::new(NameRegistry::new())); +static PRESEEDED: AtomicBool = AtomicBool::new(false); + +/// ProfTaskType names from HCCL `PROF_TASK_OP_NAME`. +static KNOWN_TASK_NAMES: &[&str] = &[ + "hccl_info", + "Memcpy", + "RDMASend", + "Reduce_Inline", + "Reduce_TBE", + "Notify_Record", + "Notify_Wait", + "StageX_StepX", + "Flag", + "End", + "Multi_Thread", + "Launch_Ffts", + "AivKernel", + "Wait_Some", + "Coll_Recv_Lookup_Request", + "Coll_Recv_Update_Request", + "Isend_Update_Response", + "Isend_Lookup_Response", + "Update_Imrecv", + "Update_Global_Reduce", + "Lookup_Response_Memcpy", + "Lookup_Response_Isend", + "Share_Memory_Isend_Record", + "Abort_Self", + "Service_Cancel", + "Destroy_Resource", + "Event_Wait", + "unknown", +]; + +/// HcclCMDType names from HCCL `PROF_OP_NAME`. +static KNOWN_OP_NAMES: &[&str] = &[ + "hcom_invalid_", + "hcom_broadcast_", + "hcom_allReduce_", + "hcom_reduce_", + "hcom_send_", + "hcom_receive_", + "hcom_allGather_", + "hcom_reduceScatter_", + "hcom_scatter_", + "hcom_alltoall_", + "hcom_alltoallv_", + "hcom_allGatherv_", + "hcom_reduceScatterv_", + "hcom_alltoallvc_", + "hcom_batchSendRecv_", + "hccl_batchPut_", + "hccl_batchGet_", +]; + +struct NameRegistry { + hash_to_name: HashMap, + type_id_labels: HashMap, +} + +impl NameRegistry { + fn new() -> Self { + Self { + hash_to_name: HashMap::new(), + type_id_labels: HashMap::new(), + } + } + + fn insert_hash(&mut self, hash: u64, name: impl Into) { + if hash != 0 { + self.hash_to_name.entry(hash).or_insert(name.into()); + } + } + + fn lookup(&self, hash: u64) -> String { + self.hash_to_name.get(&hash).cloned().unwrap_or_default() + } + + fn label_for_type(&self, type_id: u32) -> String { + self.type_id_labels + .get(&type_id) + .cloned() + .unwrap_or_default() + } +} + +pub fn register_type_info( + type_id: u32, + type_name: *const c_char, + hash_fn: impl Fn(*const c_char, u32) -> u64, +) { + if type_name.is_null() { + return; + } + let Ok(name) = unsafe { CStr::from_ptr(type_name) }.to_str() else { + return; + }; + let mut reg = REGISTRY.lock(); + reg.type_id_labels + .entry(type_id) + .or_insert_with(|| name.to_string()); + let hash = hash_fn(type_name, name.len() as u32); + reg.insert_hash(hash, name); +} + +pub fn register_hash_string(hash_info: *const c_char, length: u32, hash: u64) { + if hash_info.is_null() || length == 0 || hash == 0 { + return; + } + let bytes = unsafe { std::slice::from_raw_parts(hash_info as *const u8, length as usize) }; + let Ok(name) = std::str::from_utf8(bytes) else { + return; + }; + REGISTRY.lock().insert_hash(hash, name); +} + +pub fn preseed_hashes(hash_fn: impl Fn(*const c_char, u32) -> u64) { + if PRESEEDED.swap(true, Ordering::Relaxed) { + return; + } + let mut reg = REGISTRY.lock(); + for name in KNOWN_TASK_NAMES.iter().chain(KNOWN_OP_NAMES.iter()) { + let c = std::ffi::CString::new(*name).expect("static name"); + let hash = hash_fn(c.as_ptr(), name.len() as u32); + reg.insert_hash(hash, *name); + } +} + +pub fn lookup_hash(hash: u64) -> String { + REGISTRY.lock().lookup(hash) +} + +pub fn lookup_type_id(type_id: u32) -> String { + REGISTRY.lock().label_for_type(type_id) +} + +/// Classify `MsprofReportApi` rows using level/type and resolved item name. +pub fn classify_api_event(level: u16, type_id: u32, item_id: u64) -> &'static str { + let name = lookup_hash(item_id); + if name.starts_with("hcom_") || name.starts_with("hccl_") { + return if level <= 1 { + "host_acl" + } else { + "host_hccl_op" + }; + } + if KNOWN_TASK_NAMES.contains(&name.as_str()) || name.ends_with("Kernel") { + return match type_id { + 2 => "task_slave", + _ => "task_master", + }; + } + match level { + 0 | 1 => "host_acl", + 2 => "node_launch", + 3 => match type_id { + 2 => "task_slave", + 1 => "task_master", + _ => "hccl_node", + }, + _ => "api_other", + } +} + +pub fn is_hccl_op_name(name: &str) -> bool { + name.starts_with("hcom_") || name.starts_with("hccl_") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classify_by_name_prefix() { + let mut reg = REGISTRY.lock(); + reg.insert_hash(42, "hcom_allReduce_"); + drop(reg); + assert_eq!(classify_api_event(3, 1, 42), "host_hccl_op"); + } +} diff --git a/probing/extensions/hccl-shim/src/tables.rs b/probing/extensions/hccl-shim/src/tables.rs new file mode 100644 index 00000000..46efaee1 --- /dev/null +++ b/probing/extensions/hccl-shim/src/tables.rs @@ -0,0 +1,184 @@ +//! Memtable schemas for HCCL MSProf intercept. + +use probing_memtable::docs; +use probing_memtable::{DType, Schema}; + +pub const HOST_OPS_FILE: &str = "hccl.host_ops"; +pub const TASKS_FILE: &str = "hccl.tasks"; +pub const COLLECTIVES_FILE: &str = "hccl.collectives"; +pub const MC2_STREAMS_FILE: &str = "hccl.mc2_streams"; +pub const CONTEXT_IDS_FILE: &str = "hccl.context_ids"; + +/// Register all HCCL table docs (safe to call from writer or Engine startup). +pub fn register_docs() { + docs::register_from_name(HOST_OPS_FILE, &host_ops_schema()); + docs::register_from_name(TASKS_FILE, &tasks_schema()); + docs::register_from_name(COLLECTIVES_FILE, &collectives_schema()); + docs::register_from_name(MC2_STREAMS_FILE, &mc2_streams_schema()); + docs::register_from_name(CONTEXT_IDS_FILE, &context_ids_schema()); +} + +pub fn host_ops_schema() -> Schema { + Schema::new() + .table_doc("HCCL MSProf Host API 时间线(集合通信 op、ACL、task master/slave)") + .col_doc("ts", DType::I64, "结束时间(CANN sys cycle)") + .col_doc("begin_ns", DType::I64, "开始时间") + .col_doc("end_ns", DType::I64, "结束时间") + .col_doc("duration_ns", DType::I64, "耗时 end - begin") + .col_doc("thread_id", DType::I32, "上报线程 id") + .col_doc("level", DType::I32, "MSProf level") + .col_doc("type_id", DType::I32, "MSProf type id") + .col_doc("item_id", DType::U64, "名称 hash") + .col_doc( + "item_name", + DType::Str, + "解码名称(hcom_allReduce_、Memcpy 等)", + ) + .col_doc( + "event_class", + DType::Str, + "host_hccl_op | task_master | task_slave | host_acl | node_launch", + ) + .col_doc("aging", DType::I32, "MSProf aging flag") +} + +pub fn tasks_schema() -> Schema { + Schema::new() + .table_doc("HCCL 设备侧 task 明细(MsprofHcclInfo L1)") + .col_doc("ts", DType::I64, "上报时间戳") + .col_doc("thread_id", DType::I32, "上报线程") + .col_doc("info_type", DType::I32, "AdditionalInfo type id") + .col_doc("info_level", DType::I32, "AdditionalInfo level") + .col_doc( + "info_type_name", + DType::Str, + "RegTypeInfo 注册名(如 hccl_info)", + ) + .col_doc("item_id", DType::U64, "task 类型 hash") + .col_doc("task_name", DType::Str, "task 类型名(Memcpy、RDMASend…)") + .col_doc("ccl_tag", DType::U64, "CCL tag hash") + .col_doc("group_name", DType::U64, "comm group hash") + .col_doc("local_rank", DType::I32, "本端 rank") + .col_doc("remote_rank", DType::I32, "对端 rank(-1 表示 N/A)") + .col_doc("rank_size", DType::I32, "通信组大小") + .col_doc("workflow_mode", DType::I32, "HCCL workflow mode enum") + .col_doc("plane_id", DType::I32, "原始 plane 编码") + .col_doc( + "plane_index", + DType::I32, + "plane 索引(plane_id bits 28-31)", + ) + .col_doc("rank_in_plane", DType::I32, "plane 内 rank(bits 0-15)") + .col_doc("rank_size_plane", DType::I32, "plane 宽度(bits 16-27)") + .col_doc("ctx_id", DType::I32, "FFTS context id") + .col_doc("notify_id", DType::U64, "notify 对象 id") + .col_doc("stage", DType::I32, "流水线 stage") + .col_doc("role", DType::I32, "task 角色 enum") + .col_doc("data_size", DType::U64, "传输字节数") + .col_doc("op_type", DType::I32, "操作类型 enum") + .col_doc("data_type", DType::I32, "数据类型 enum") + .col_doc("link_type", DType::I32, "链路类型") + .col_doc("transport_type", DType::I32, "传输类型") + .col_doc("rdma_type", DType::I32, "RDMA 类型") + .col_doc("duration_est_us", DType::F64, "估算耗时(微秒)") + .col_doc("payload_len", DType::I32, "AdditionalInfo payload 长度") +} + +pub fn collectives_schema() -> Schema { + Schema::new() + .table_doc("HCCL 集合通信元数据与 Host 耗时(row_source 区分 api/compact 行)") + .col_doc("ts", DType::I64, "事件时间") + .col_doc("thread_id", DType::I32, "上报线程") + .col_doc( + "row_source", + DType::Str, + "api=耗时行;compact=count/group/alg 参数行", + ) + .col_doc("begin_ns", DType::I64, "开始时间(api 行)") + .col_doc("end_ns", DType::I64, "结束时间(api 行)") + .col_doc("duration_ns", DType::I64, "耗时(api 行)") + .col_doc("op_hash", DType::U64, "算子名 hash(api 行)") + .col_doc("op_name", DType::Str, "算子名(api 行)") + .col_doc("group_hash", DType::U64, "comm group 名 hash(compact 行)") + .col_doc("alg_hash", DType::U64, "算法名 hash(compact 行)") + .col_doc("count", DType::U64, "元素个数(compact 行)") + .col_doc("data_type", DType::I32, "HcclDataType enum") + .col_doc("relay", DType::I32, "HCCL relay 标志") + .col_doc("retry", DType::I32, "HCCL retry 计数") + .col_doc( + "compact_type", + DType::I32, + "MsprofReportCompactInfo type id", + ) +} + +pub fn mc2_streams_schema() -> Schema { + Schema::new() + .table_doc("HCCL MC2 communicator stream 拓扑快照") + .col_doc("ts", DType::I64, "上报时间") + .col_doc("thread_id", DType::I32, "上报线程") + .col_doc("info_type", DType::I32, "MSProf AdditionalInfo type") + .col_doc("group_hash", DType::U64, "comm group 名 hash") + .col_doc("rank_size", DType::I32, "组内 rank 数") + .col_doc("rank_id", DType::I32, "rank id") + .col_doc("usr_rank_id", DType::I32, "用户可见 rank id") + .col_doc("aicpu_kfc_stream_id", DType::I32, "KFC stream id") + .col_doc("comm_stream_size", DType::I32, "comm stream 数量") + .col_doc("comm_stream_ids", DType::Str, "逗号分隔的 stream id 列表") +} + +pub fn context_ids_schema() -> Schema { + Schema::new() + .table_doc("HCCL FFTS context id 范围(dispatch 时上报)") + .col_doc("ts", DType::I64, "上报时间") + .col_doc("thread_id", DType::I32, "上报线程") + .col_doc("info_type", DType::I32, "MSProf AdditionalInfo type") + .col_doc("ctx_id_num", DType::I32, "context 数量(HCCL 固定报 2)") + .col_doc("ctx_id_min", DType::I32, "范围起点(通常 0)") + .col_doc("ctx_id_max", DType::I32, "范围终点(ctxIdMax)") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn register_docs_populates_registry() { + register_docs(); + let rows = docs::snapshot(); + assert!(rows + .iter() + .any(|r| r.table_schema == "hccl" && r.table_name == "tasks")); + assert!(rows + .iter() + .any(|r| r.table_schema == "hccl" && r.table_name == "host_ops")); + } + + #[test] + fn tasks_schema_every_column_documented() { + let schema = tasks_schema(); + assert!(schema.table_doc.is_some()); + assert_eq!(schema.cols.len(), 29); + assert!( + schema.cols.iter().all(|c| c.doc.is_some()), + "missing docs: {:?}", + schema + .cols + .iter() + .filter(|c| c.doc.is_none()) + .map(|c| c.name.as_str()) + .collect::>() + ); + } + + #[test] + fn host_ops_event_class_doc() { + let schema = host_ops_schema(); + let event_class = schema + .cols + .iter() + .find(|c| c.name == "event_class") + .expect("event_class column"); + assert!(event_class.doc.as_ref().unwrap().contains("host_hccl_op")); + } +} diff --git a/probing/extensions/hccl-shim/src/writer.rs b/probing/extensions/hccl-shim/src/writer.rs new file mode 100644 index 00000000..c841b610 --- /dev/null +++ b/probing/extensions/hccl-shim/src/writer.rs @@ -0,0 +1,424 @@ +//! Mmap writer for intercepted MSProf events. + +use std::sync::atomic::{AtomicBool, Ordering}; + +use probing_memtable::discover::ExposedTable; +use probing_memtable::Value; + +use crate::msprof::{ + Mc2CommInfo, MsprofAdditionalInfoHeader, MsprofApi, MsprofCompactInfoHeader, + MsprofContextIdInfo, MsprofHCCLOPInfo, MsprofHcclInfo, +}; +use crate::names::{classify_api_event, is_hccl_op_name, lookup_hash, lookup_type_id}; +use crate::tables::{ + collectives_schema, context_ids_schema, host_ops_schema, mc2_streams_schema, tasks_schema, + COLLECTIVES_FILE, CONTEXT_IDS_FILE, HOST_OPS_FILE, MC2_STREAMS_FILE, TASKS_FILE, +}; + +const CHUNK_SIZE: u32 = 16 * 1024; +const NUM_CHUNKS: u32 = 32; + +macro_rules! open_table { + ($self:ident, $field:ident, $failed:ident, $logged:ident, $file:expr, $schema:expr) => {{ + if $self.$field.is_none() && !$self.$failed.load(Ordering::Relaxed) { + match ExposedTable::create($file, &$schema(), CHUNK_SIZE, NUM_CHUNKS) { + Ok(t) => $self.$field = Some(t), + Err(e) => { + $self.$failed.store(true, Ordering::Relaxed); + if $self + .$logged + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + eprintln!("[probing-hccl-shim] failed to open {}: {e}", $file); + } + } + } + } + $self.$field.as_mut().ok_or(()) + }}; +} + +pub struct HcclWriter { + host_table: Option, + tasks_table: Option, + collectives_table: Option, + mc2_table: Option, + context_table: Option, + host_failed: AtomicBool, + tasks_failed: AtomicBool, + collectives_failed: AtomicBool, + mc2_failed: AtomicBool, + context_failed: AtomicBool, + logged_host: AtomicBool, + logged_tasks: AtomicBool, + logged_collectives: AtomicBool, + logged_mc2: AtomicBool, + logged_context: AtomicBool, +} + +impl HcclWriter { + pub fn new() -> Self { + Self { + host_table: None, + tasks_table: None, + collectives_table: None, + mc2_table: None, + context_table: None, + host_failed: AtomicBool::new(false), + tasks_failed: AtomicBool::new(false), + collectives_failed: AtomicBool::new(false), + mc2_failed: AtomicBool::new(false), + context_failed: AtomicBool::new(false), + logged_host: AtomicBool::new(false), + logged_tasks: AtomicBool::new(false), + logged_collectives: AtomicBool::new(false), + logged_mc2: AtomicBool::new(false), + logged_context: AtomicBool::new(false), + } + } + + fn open_host(&mut self) -> Result<&mut ExposedTable, ()> { + open_table!( + self, + host_table, + host_failed, + logged_host, + HOST_OPS_FILE, + host_ops_schema + ) + } + + fn open_tasks(&mut self) -> Result<&mut ExposedTable, ()> { + open_table!( + self, + tasks_table, + tasks_failed, + logged_tasks, + TASKS_FILE, + tasks_schema + ) + } + + fn open_collectives(&mut self) -> Result<&mut ExposedTable, ()> { + open_table!( + self, + collectives_table, + collectives_failed, + logged_collectives, + COLLECTIVES_FILE, + collectives_schema + ) + } + + fn open_mc2(&mut self) -> Result<&mut ExposedTable, ()> { + open_table!( + self, + mc2_table, + mc2_failed, + logged_mc2, + MC2_STREAMS_FILE, + mc2_streams_schema + ) + } + + fn open_context(&mut self) -> Result<&mut ExposedTable, ()> { + open_table!( + self, + context_table, + context_failed, + logged_context, + CONTEXT_IDS_FILE, + context_ids_schema + ) + } + + pub fn record_api(&mut self, aging: u32, api: &MsprofApi) { + let item_name = lookup_hash(api.item_id); + let event_class = classify_api_event(api.level, api.type_id, api.item_id); + let duration = api.end_time.saturating_sub(api.begin_time) as i64; + + if let Ok(table) = self.open_host() { + table.push_row(&[ + Value::I64(api.end_time as i64), + Value::I64(api.begin_time as i64), + Value::I64(api.end_time as i64), + Value::I64(duration), + Value::I32(api.thread_id as i32), + Value::I32(api.level as i32), + Value::I32(api.type_id as i32), + Value::U64(api.item_id), + Value::Str(&item_name), + Value::Str(event_class), + Value::I32(aging as i32), + ]); + } + + if event_class == "host_hccl_op" || is_hccl_op_name(&item_name) { + if let Ok(table) = self.open_collectives() { + table.push_row(&[ + Value::I64(api.end_time as i64), + Value::I32(api.thread_id as i32), + Value::Str("api"), + Value::I64(api.begin_time as i64), + Value::I64(api.end_time as i64), + Value::I64(duration), + Value::U64(api.item_id), + Value::Str(&item_name), + Value::U64(0), + Value::U64(0), + Value::U64(0), + Value::I32(-1), + Value::I32(0), + Value::I32(0), + Value::I32(0), + ]); + } + } + } + + pub fn record_compact_hccl_op( + &mut self, + header: &MsprofCompactInfoHeader, + op: &MsprofHCCLOPInfo, + ) { + let Ok(table) = self.open_collectives() else { + return; + }; + table.push_row(&[ + Value::I64(header.time_stamp as i64), + Value::I32(header.thread_id as i32), + Value::Str("compact"), + Value::I64(0), + Value::I64(0), + Value::I64(0), + Value::U64(0), + Value::Str(""), + Value::U64(op.group_name), + Value::U64(op.alg_type), + Value::U64(op.count), + Value::I32(op.data_type as i32), + Value::I32(op.relay as i32), + Value::I32(op.retry as i32), + Value::I32(header.type_id as i32), + ]); + } + + pub fn record_task( + &mut self, + header: &MsprofAdditionalInfoHeader, + hccl: &MsprofHcclInfo, + payload_len: i32, + ) { + let Ok(table) = self.open_tasks() else { + return; + }; + let type_name = lookup_type_id(header.type_id); + let task_name = lookup_hash(hccl.item_id); + let (plane_index, rank_in_plane, rank_size_plane) = + crate::msprof::decode_plane(hccl.plane_id); + table.push_row(&[ + Value::I64(header.time_stamp as i64), + Value::I32(header.thread_id as i32), + Value::I32(header.type_id as i32), + Value::I32(header.level as i32), + Value::Str(&type_name), + Value::U64(hccl.item_id), + Value::Str(&task_name), + Value::U64(hccl.ccl_tag), + Value::U64(hccl.group_name), + Value::I32(hccl.local_rank as i32), + Value::I32(hccl.remote_rank as i32), + Value::I32(hccl.rank_size as i32), + Value::I32(hccl.workflow_mode as i32), + Value::I32(hccl.plane_id as i32), + Value::I32(plane_index), + Value::I32(rank_in_plane), + Value::I32(rank_size_plane), + Value::I32(hccl.ctx_id as i32), + Value::U64(hccl.notify_id), + Value::I32(hccl.stage as i32), + Value::I32(hccl.role as i32), + Value::U64(hccl.data_size), + Value::I32(hccl.op_type as i32), + Value::I32(hccl.data_type as i32), + Value::I32(hccl.link_type as i32), + Value::I32(hccl.transport_type as i32), + Value::I32(hccl.rdma_type as i32), + Value::F64(hccl.duration_estimated), + Value::I32(payload_len), + ]); + } + + pub fn record_mc2(&mut self, header: &MsprofAdditionalInfoHeader, mc2: &Mc2CommInfo) { + let Ok(table) = self.open_mc2() else { + return; + }; + let ids = mc2 + .comm_stream_ids + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(","); + table.push_row(&[ + Value::I64(header.time_stamp as i64), + Value::I32(header.thread_id as i32), + Value::I32(header.type_id as i32), + Value::U64(mc2.header.group_name), + Value::I32(mc2.header.rank_size as i32), + Value::I32(mc2.header.rank_id as i32), + Value::I32(mc2.header.usr_rank_id as i32), + Value::I32(mc2.header.aicpu_kfc_stream_id as i32), + Value::I32(mc2.comm_stream_size as i32), + Value::Str(&ids), + ]); + } + + pub fn record_context( + &mut self, + header: &MsprofAdditionalInfoHeader, + ctx: &MsprofContextIdInfo, + ) { + let Ok(table) = self.open_context() else { + return; + }; + let ctx_min = ctx.ctx_ids.first().copied().unwrap_or(0); + let ctx_max = if ctx.ctx_id_num >= 2 { + ctx.ctx_ids[1] + } else { + ctx_min + }; + table.push_row(&[ + Value::I64(header.time_stamp as i64), + Value::I32(header.thread_id as i32), + Value::I32(header.type_id as i32), + Value::I32(ctx.ctx_id_num as i32), + Value::I32(ctx_min as i32), + Value::I32(ctx_max as i32), + ]); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::msprof::{MSPROF_HCCL_INFO_MIN, MSPROF_HCCL_OP_INFO_MIN}; + use probing_memtable::discover::discover_in; + use std::fs; + + fn test_dir() -> std::path::PathBuf { + std::env::temp_dir().join(format!("probing_hccl_shim_test_{}", std::process::id())) + } + + #[test] + fn tasks_mmap_roundtrip() { + let base = test_dir(); + let _ = fs::remove_dir_all(&base); + fs::create_dir_all(&base).unwrap(); + std::env::set_var("PROBING_DATA_DIR", &base); + + let mut w = HcclWriter::new(); + let header = MsprofAdditionalInfoHeader { + level: 1, + type_id: 0, + thread_id: 42, + data_len: MSPROF_HCCL_INFO_MIN as u32, + time_stamp: 999, + }; + let info = MsprofHcclInfo { + item_id: 1, + plane_id: (2u64 << 28 | 4u64 << 16 | 7) as u32, + data_size: 4096, + remote_rank: 3, + ..Default::default() + }; + w.record_task(&header, &info, header.data_len as i32); + + let found = discover_in(&base).unwrap(); + assert!(found.iter().any(|t| t.name() == TASKS_FILE)); + let _ = fs::remove_dir_all(&base); + } + + #[test] + fn collectives_compact_roundtrip() { + let base = test_dir(); + let _ = fs::remove_dir_all(&base); + fs::create_dir_all(&base).unwrap(); + std::env::set_var("PROBING_DATA_DIR", &base); + + let mut w = HcclWriter::new(); + let header = MsprofCompactInfoHeader { + level: 2, + type_id: 100, + thread_id: 7, + data_len: MSPROF_HCCL_OP_INFO_MIN as u32, + time_stamp: 1234, + }; + let op = MsprofHCCLOPInfo { + count: 1024, + data_type: 1, + group_name: 99, + alg_type: 88, + ..Default::default() + }; + w.record_compact_hccl_op(&header, &op); + + let found = discover_in(&base).unwrap(); + assert!(found.iter().any(|t| t.name() == COLLECTIVES_FILE)); + let _ = fs::remove_dir_all(&base); + } + + #[test] + fn host_ops_schema_columns() { + let base = test_dir(); + let _ = fs::remove_dir_all(&base); + fs::create_dir_all(&base).unwrap(); + std::env::set_var("PROBING_DATA_DIR", &base); + + let mut w = HcclWriter::new(); + w.record_api( + 1, + &MsprofApi { + level: 3, + type_id: 1, + thread_id: 1, + begin_time: 100, + end_time: 250, + item_id: 42, + ..Default::default() + }, + ); + + let found = discover_in(&base).unwrap(); + assert!(found.iter().any(|t| t.name() == HOST_OPS_FILE)); + let _ = fs::remove_dir_all(&base); + } + + #[test] + fn context_ids_roundtrip() { + let base = test_dir(); + let _ = fs::remove_dir_all(&base); + fs::create_dir_all(&base).unwrap(); + std::env::set_var("PROBING_DATA_DIR", &base); + + let mut w = HcclWriter::new(); + let header = MsprofAdditionalInfoHeader { + level: 2, + type_id: 200, + thread_id: 1, + data_len: 12, + time_stamp: 500, + }; + w.record_context( + &header, + &MsprofContextIdInfo { + ctx_id_num: 2, + ctx_ids: [0, 15], + }, + ); + + let found = discover_in(&base).unwrap(); + assert!(found.iter().any(|t| t.name() == CONTEXT_IDS_FILE)); + let _ = fs::remove_dir_all(&base); + } +} diff --git a/probing/extensions/nccl-profiler/src/lib.rs b/probing/extensions/nccl-profiler/src/lib.rs index 739ccaa5..d52f3cec 100644 --- a/probing/extensions/nccl-profiler/src/lib.rs +++ b/probing/extensions/nccl-profiler/src/lib.rs @@ -16,7 +16,7 @@ mod state; #[cfg(target_os = "linux")] mod plugin; -pub use tables::{net_qp_schema, proxy_ops_schema, NET_QP_FILE, PROXY_OPS_FILE}; +pub use tables::{net_qp_schema, proxy_ops_schema, register_docs, NET_QP_FILE, PROXY_OPS_FILE}; #[cfg(target_os = "linux")] mod export { diff --git a/probing/extensions/nccl-profiler/src/tables.rs b/probing/extensions/nccl-profiler/src/tables.rs index 06282783..37feec5a 100644 --- a/probing/extensions/nccl-profiler/src/tables.rs +++ b/probing/extensions/nccl-profiler/src/tables.rs @@ -1,39 +1,87 @@ //! Memtable schemas for `nccl.proxy_ops` and `nccl.net_qp`. +use probing_memtable::docs; use probing_memtable::{DType, Schema}; pub const PROXY_OPS_FILE: &str = "nccl.proxy_ops"; pub const NET_QP_FILE: &str = "nccl.net_qp"; +/// Register all NCCL table docs (safe to call from writer or Engine startup). +pub fn register_docs() { + docs::register_from_name(PROXY_OPS_FILE, &proxy_ops_schema()); + docs::register_from_name(NET_QP_FILE, &net_qp_schema()); +} + pub fn proxy_ops_schema() -> Schema { Schema::new() - .col("ts", DType::I64) - .col("rank", DType::I32) - .col("tp_rank", DType::I32) - .col("pp_rank", DType::I32) - .col("dp_rank", DType::I32) - .col("comm_hash", DType::U64) - .col("coll_func", DType::Str) - .col("seq", DType::U64) - .col("channel_id", DType::I32) - .col("peer", DType::I32) - .col("is_send", DType::I32) - .col("n_steps", DType::I32) - .col("trans_bytes", DType::U64) - .col("send_gpu_wait_ns", DType::I64) - .col("send_wait_ns", DType::I64) - .col("recv_wait_ns", DType::I64) - .col("recv_flush_wait_ns", DType::I64) + .table_doc("NCCL profiler plugin proxy-op wait 分解(culprit / victim 归因)") + .col_doc("ts", DType::I64, "事件时间戳(纳秒)") + .col_doc("rank", DType::I32, "torch.distributed rank") + .col_doc("tp_rank", DType::I32, "张量并行 rank(未知 -1)") + .col_doc("pp_rank", DType::I32, "流水线并行 rank(未知 -1)") + .col_doc("dp_rank", DType::I32, "数据并行 rank(未知 -1)") + .col_doc("comm_hash", DType::U64, "NCCL communicator hash") + .col_doc( + "coll_func", + DType::Str, + "集合通信名(AllReduce、AllGather…)", + ) + .col_doc("seq", DType::U64, "collective 序号") + .col_doc("channel_id", DType::I32, "NCCL channel id") + .col_doc("peer", DType::I32, "对端 rank") + .col_doc("is_send", DType::I32, "1=send proxy,0=recv proxy") + .col_doc("n_steps", DType::I32, "聚合的 ProxyStep 数") + .col_doc("trans_bytes", DType::U64, "传输字节数") + .col_doc( + "send_gpu_wait_ns", + DType::I64, + "Culprit 信号 — 本地 GPU 未就绪发送", + ) + .col_doc("send_wait_ns", DType::I64, "发送侧网络等待") + .col_doc("recv_wait_ns", DType::I64, "Victim 信号 — 等待对端数据") + .col_doc("recv_flush_wait_ns", DType::I64, "接收 flush 等待") } pub fn net_qp_schema() -> Schema { Schema::new() - .col("ts", DType::I64) - .col("rank", DType::I32) - .col("device", DType::I32) - .col("qp_num", DType::I32) - .col("wr_id", DType::U64) - .col("opcode", DType::I32) - .col("length", DType::U64) - .col("duration_ns", DType::I64) + .table_doc("NCCL NetPlugin IB QP 完成耗时(可选 mask bit 128)") + .col_doc("ts", DType::I64, "事件时间戳(纳秒)") + .col_doc("rank", DType::I32, "torch.distributed rank") + .col_doc("device", DType::I32, "IB 设备索引") + .col_doc("qp_num", DType::I32, "Queue Pair 号") + .col_doc("wr_id", DType::U64, "Work Request id") + .col_doc("opcode", DType::I32, "IB opcode") + .col_doc("length", DType::U64, "传输长度(字节)") + .col_doc("duration_ns", DType::I64, "QP 完成耗时(纳秒)") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn register_docs_populates_registry() { + register_docs(); + let rows = docs::snapshot(); + assert!(rows + .iter() + .any(|r| r.table_schema == "nccl" && r.table_name == "proxy_ops")); + assert!(rows + .iter() + .any(|r| r.table_schema == "nccl" && r.table_name == "net_qp")); + } + + #[test] + fn proxy_ops_culprit_columns_documented() { + let schema = proxy_ops_schema(); + assert!(schema.table_doc.is_some()); + for name in ["send_gpu_wait_ns", "recv_wait_ns"] { + let col = schema + .cols + .iter() + .find(|c| c.name == name) + .unwrap_or_else(|| panic!("missing column {name}")); + assert!(col.doc.is_some(), "{name} should have doc"); + } + } } diff --git a/probing/extensions/python/src/extensions/python.rs b/probing/extensions/python/src/extensions/python.rs index 45ea7b51..f3dd9d18 100644 --- a/probing/extensions/python/src/extensions/python.rs +++ b/probing/extensions/python/src/extensions/python.rs @@ -15,8 +15,8 @@ use pyo3::prelude::*; use pyo3::types::{PyAnyMethods, PyString}; use pyo3::Python; -pub use exttbls::ExternalTable; pub use exttbls::PyExternalTableConfig; +pub use exttbls::{register_table_docs, ExternalTable}; pub use tbls::PythonProbeDataSource; use crate::features::stack_tracer::{SignalTracer, StackTracer}; diff --git a/probing/extensions/python/src/extensions/python/exttbls.rs b/probing/extensions/python/src/extensions/python/exttbls.rs index e9b08040..3b55867e 100644 --- a/probing/extensions/python/src/extensions/python/exttbls.rs +++ b/probing/extensions/python/src/extensions/python/exttbls.rs @@ -18,6 +18,7 @@ use std::sync::{Arc, Mutex}; use crate::features::native_bridge::with_detached_native; use once_cell::sync::Lazy; use probing_memtable::discover::ExposedTable; +use probing_memtable::docs; use probing_memtable::{DType, Schema as MtSchema, Value}; use probing_proto::prelude::Ele; use pyo3::prelude::*; @@ -49,17 +50,47 @@ fn uses_timestamp_column(name: &str) -> bool { !name.contains('.') } -fn build_schema(name: &str, columns: &[String], dtypes: &[DType]) -> MtSchema { +fn build_schema_with_docs( + name: &str, + columns: &[String], + dtypes: &[DType], + table_doc: Option<&str>, + column_docs: &HashMap, +) -> MtSchema { let mut schema = MtSchema::new(); + if let Some(doc) = table_doc { + schema = schema.table_doc(doc); + } if uses_timestamp_column(name) { schema = schema.col("timestamp", DType::I64); } for (col, dt) in columns.iter().zip(dtypes.iter()) { - schema = schema.col(col, *dt); + schema = if let Some(doc) = column_docs.get(col) { + schema.col_doc(col, *dt, doc.as_str()) + } else { + schema.col(col, *dt) + }; } schema } +fn register_python_table_docs( + name: &str, + table_doc: Option<&str>, + column_docs: &HashMap, +) { + let (table_schema, table_name) = if let Some((schema, table)) = name.split_once('.') { + (schema.to_string(), table.to_string()) + } else { + (EXTERN_TABLE_SCHEMA.to_string(), name.to_string()) + }; + let pairs: Vec<(String, String)> = column_docs + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + docs::register_column_docs(&table_schema, &table_name, table_doc, &pairs); +} + /// Ring layout: fixed chunk count; chunk byte size derives from capacity. const NUM_CHUNKS: u32 = 8; const MIN_CHUNK_BYTES: usize = 4 * 1024; @@ -212,16 +243,29 @@ pub struct ExternBacking { capacity_bytes: usize, dtypes: Vec, table: Option, + table_doc: Option, + column_docs: HashMap, } impl ExternBacking { - fn new(name: &str, columns: Vec, capacity_bytes: usize) -> Self { + fn new( + name: &str, + columns: Vec, + capacity_bytes: usize, + table_doc: Option, + column_docs: HashMap, + ) -> Self { + if !column_docs.is_empty() || table_doc.is_some() { + register_python_table_docs(name, table_doc.as_deref(), &column_docs); + } Self { name: name.to_string(), columns, capacity_bytes, dtypes: vec![], table: None, + table_doc, + column_docs, } } @@ -230,7 +274,13 @@ impl ExternBacking { return Ok(()); } self.dtypes = vec![DType::Str; self.columns.len()]; - let schema = build_schema(&self.name, &self.columns, &self.dtypes); + let schema = build_schema_with_docs( + &self.name, + &self.columns, + &self.dtypes, + self.table_doc.as_deref(), + &self.column_docs, + ); let chunk_bytes = ring_chunk_bytes(self.capacity_bytes); let filename = mmap_basename(&self.name); let table = ExposedTable::create(&filename, &schema, chunk_bytes, NUM_CHUNKS) @@ -254,7 +304,13 @@ impl ExternBacking { self.dtypes.clear(); let dtypes: Vec = first_row.iter().map(ele_dtype).collect(); - let schema = build_schema(&self.name, &self.columns, &dtypes); + let schema = build_schema_with_docs( + &self.name, + &self.columns, + &dtypes, + self.table_doc.as_deref(), + &self.column_docs, + ); let chunk_bytes = ring_chunk_bytes(self.capacity_bytes); let filename = mmap_basename(&self.name); let table = ExposedTable::create(&filename, &schema, chunk_bytes, NUM_CHUNKS) @@ -370,9 +426,17 @@ impl ExternalTable { columns: Vec, discard_threshold: usize, discard_strategy: &str, + table_doc: Option, + column_docs: HashMap, ) -> Arc> { let capacity = ring_capacity_bytes(discard_threshold, discard_strategy); - let backing = Arc::new(Mutex::new(ExternBacking::new(name, columns, capacity))); + let backing = Arc::new(Mutex::new(ExternBacking::new( + name, + columns, + capacity, + table_doc, + column_docs, + ))); backing .lock() .expect("extern table lock") @@ -385,20 +449,28 @@ impl ExternalTable { #[pymethods] impl ExternalTable { #[new] - #[pyo3(signature = (name, columns, chunk_size = 10000, discard_threshold = 20_000_000, discard_strategy = "BaseMemorySize".to_string()))] + #[pyo3(signature = (name, columns, chunk_size = 10000, discard_threshold = 20_000_000, discard_strategy = "BaseMemorySize".to_string(), table_doc = None, column_docs = None))] fn new( name: &str, columns: Vec, chunk_size: usize, discard_threshold: usize, discard_strategy: String, + table_doc: Option, + column_docs: Option>, ) -> Self { let _ = chunk_size; // ring chunking is byte-based; kept for API compat let name = name.to_string(); with_detached_native(move || { let ncolumn = columns.len(); - let backing = - Self::create_backing(&name, columns, discard_threshold, &discard_strategy); + let backing = Self::create_backing( + &name, + columns, + discard_threshold, + &discard_strategy, + table_doc, + column_docs.unwrap_or_default(), + ); EXTERN_TABLES.lock().unwrap().insert(name, backing.clone()); ExternalTable(backing, ncolumn) }) @@ -421,7 +493,8 @@ impl ExternalTable { } #[classmethod] - #[pyo3(signature = (name, columns, chunk_size = 10000, discard_threshold = 20_000_000, discard_strategy = "BaseMemorySize".to_string()))] + #[pyo3(signature = (name, columns, chunk_size = 10000, discard_threshold = 20_000_000, discard_strategy = "BaseMemorySize".to_string(), table_doc = None, column_docs = None))] + #[allow(clippy::too_many_arguments)] fn get_or_create( _cls: &Bound<'_, PyType>, name: &str, @@ -429,6 +502,8 @@ impl ExternalTable { chunk_size: usize, discard_threshold: usize, discard_strategy: String, + table_doc: Option, + column_docs: Option>, ) -> PyResult { let _ = chunk_size; let name = name.to_string(); @@ -439,8 +514,14 @@ impl ExternalTable { Ok(ExternalTable(backing.clone(), ncolumn)) } else { let ncolumn = columns.len(); - let backing = - Self::create_backing(&name, columns, discard_threshold, &discard_strategy); + let backing = Self::create_backing( + &name, + columns, + discard_threshold, + &discard_strategy, + table_doc, + column_docs.unwrap_or_default(), + ); binding.insert(name, backing.clone()); Ok(ExternalTable(backing, ncolumn)) } @@ -525,6 +606,43 @@ impl ExternalTable { } } +/// Register table/column documentation for SQL `DESCRIBE` (without creating a table). +#[pyfunction] +#[pyo3(signature = (qualified_name, table_doc=None, column_docs=None))] +pub fn register_table_docs( + qualified_name: &str, + table_doc: Option<&str>, + column_docs: Option>, +) -> PyResult<()> { + register_python_table_docs(qualified_name, table_doc, &column_docs.unwrap_or_default()); + Ok(()) +} + +#[cfg(test)] +mod register_docs_tests { + use super::*; + use probing_memtable::docs; + + #[test] + fn register_table_docs_exposes_python_schema() { + let table = format!("py_doc_test_{}", std::process::id()); + let qualified = format!("python.{table}"); + let mut column_docs = HashMap::new(); + column_docs.insert("latency_ms".to_string(), "latency in ms".to_string()); + register_table_docs(&qualified, Some("Python doc test table"), Some(column_docs)).unwrap(); + let rows = docs::snapshot(); + let row = rows + .iter() + .find(|r| r.table_schema == "python" && r.table_name == table) + .expect("python table docs"); + assert_eq!(row.description.as_deref(), Some("Python doc test table")); + assert_eq!( + row.columns.get("latency_ms"), + Some(&"latency in ms".to_string()) + ); + } +} + #[cfg(test)] mod tests { use super::*; @@ -608,6 +726,8 @@ if not hasattr(probing, "_made_{name}"): 10000, 20000000, "BaseMemorySize".to_string(), + None, + None, ); assert_eq!(table.names(), vec!["a", "b"]); } @@ -662,6 +782,8 @@ probing.ExternalTable.drop("table_to_drop") 10000, 1_000_000, "BaseMemorySize".to_string(), + None, + None, ); Python::attach(|py| { let vals: Vec> = vec![ @@ -689,6 +811,8 @@ probing.ExternalTable.drop("table_to_drop") 10000, 1_000_000, "BaseMemorySize".to_string(), + None, + None, ); Python::attach(|py| { let vals: Vec> = vec![1i64.into_pyobject(py).unwrap().into_any().unbind()]; diff --git a/probing/extensions/python/src/features/tracing.rs b/probing/extensions/python/src/features/tracing.rs index caf42375..feb36397 100644 --- a/probing/extensions/python/src/features/tracing.rs +++ b/probing/extensions/python/src/features/tracing.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex}; use probing_core::trace::Span as RawSpan; use probing_core::trace::{ - advance_local_step, attr, set_step_bucket_size, step_snapshot, sync_local_step, Attribute, + advance_micro_step, attr, set_micro_batches, step_snapshot, sync_micro_step, Attribute, Event as RawEvent, SpanStatus, StepSnapshot, Timestamp, }; @@ -80,9 +80,9 @@ impl Span { impl Span { /// Creates a new root span (starts a new trace). #[new] - #[pyo3(signature = (name, *, kind=None, location=None))] - fn new(name: String, kind: Option, location: Option) -> Self { - let span = RawSpan::new_root(name, kind.as_deref(), location.as_deref()); + #[pyo3(signature = (name, *, phase=None, location=None))] + fn new(name: String, phase: Option, location: Option) -> Self { + let span = RawSpan::new_root(name, phase.as_deref(), location.as_deref()); Span { inner: Arc::new(Mutex::new(span)), } @@ -90,15 +90,15 @@ impl Span { /// Creates a new child span from a parent span. #[staticmethod] - #[pyo3(signature = (parent, name, *, kind=None, location=None))] + #[pyo3(signature = (parent, name, *, phase=None, location=None))] fn new_child( parent: &Bound<'_, Span>, name: String, - kind: Option, + phase: Option, location: Option, ) -> Self { let span = parent.borrow().with_inner(|parent_span| { - RawSpan::new_child(parent_span, name, kind.as_deref(), location.as_deref()) + RawSpan::new_child(parent_span, name, phase.as_deref(), location.as_deref()) }); Span { inner: Arc::new(Mutex::new(span)), @@ -135,10 +135,10 @@ impl Span { self.with_inner(|s| s.name.clone()) } - /// Gets the span kind. + /// Gets the span training phase (forward / backward / optimizer). #[getter] - fn kind(&self) -> Option { - self.with_inner(|s| s.kind.clone()) + fn phase(&self) -> Option { + self.with_inner(|s| s.phase.clone()) } /// Gets the span status. @@ -262,7 +262,7 @@ impl Span { "parent_id" => return optional_into_py(py, self.parent_id()), "thread_id" => return Ok(self.thread_id().into_bound_py_any(py)?.into()), "name" => return Ok(self.name().into_bound_py_any(py)?.into()), - "kind" => return optional_into_py(py, self.kind()), + "phase" => return optional_into_py(py, self.phase()), "status" => return Ok(self.status().into_bound_py_any(py)?.into()), "is_ended" => return Ok(self.is_ended().into_bound_py_any(py)?.into()), "duration" => return optional_into_py(py, self.duration()), @@ -366,15 +366,15 @@ fn active_span_for_events(py: Python) -> PyResult>> { }) } -/// Return the innermost active span whose kind matches ``kind`` (or None). +/// Return the innermost active span whose phase matches ``phase`` (or None). #[pyfunction] -fn active_span_by_kind(py: Python, kind: String) -> PyResult>> { +fn active_span_by_phase(py: Python, phase: String) -> PyResult>> { SPAN_STACK.with(|stack| { let stack = stack.borrow(); for obj in stack.iter().rev() { let bound = obj.bind(py); if let Ok(span) = bound.cast::() { - if span.borrow().kind().as_deref() == Some(kind.as_str()) { + if span.borrow().phase().as_deref() == Some(phase.as_str()) { return Ok(Some(obj.clone_ref(py))); } } @@ -383,15 +383,41 @@ fn active_span_by_kind(py: Python, kind: String) -> PyResult>> }) } +/// Innermost active training phase on the span stack (``forward`` / ``backward`` / ``optimizer``). +#[pyfunction] +fn active_training_phase(py: Python) -> PyResult> { + SPAN_STACK.with(|stack| { + let stack = stack.borrow(); + for obj in stack.iter().rev() { + let bound = obj.bind(py); + if let Ok(span) = bound.cast::() { + let borrowed = span.borrow(); + if borrowed.is_ended() { + continue; + } + if let Some(phase) = borrowed.phase() { + match phase.as_str() { + "forward" | "backward" | "optimizer" => return Ok(Some(phase)), + _ => {} + } + } + } + } + Ok(None) + }) +} + #[pyclass(from_py_object)] #[derive(Clone, Copy)] struct PyStepSnapshot { + #[pyo3(get)] + micro_step: u64, #[pyo3(get)] local_step: u64, #[pyo3(get)] global_step: u64, #[pyo3(get)] - bucket_size: u64, + micro_batches: u64, #[pyo3(get)] rank: i64, #[pyo3(get)] @@ -401,9 +427,10 @@ struct PyStepSnapshot { impl From for PyStepSnapshot { fn from(s: StepSnapshot) -> Self { Self { + micro_step: s.micro_step, local_step: s.local_step, global_step: s.global_step, - bucket_size: s.bucket_size, + micro_batches: s.micro_batches, rank: s.rank, world_size: s.world_size, } @@ -416,33 +443,33 @@ fn py_step_snapshot() -> PyStepSnapshot { } #[pyfunction] -fn py_sync_local_step(step: u64) -> PyStepSnapshot { - sync_local_step(step).into() +fn py_sync_micro_step(step: u64) -> PyStepSnapshot { + sync_micro_step(step).into() } #[pyfunction] -fn py_advance_local_step() -> PyStepSnapshot { - advance_local_step().into() +fn py_advance_micro_step() -> PyStepSnapshot { + advance_micro_step().into() } #[pyfunction] -fn py_set_step_bucket_size(bucket: u64) { - set_step_bucket_size(bucket); +fn py_set_micro_batches(micro_batches: u64) { + set_micro_batches(micro_batches); } #[pyfunction] -fn py_current_local_step() -> u64 { - probing_core::trace::current_local_step() +fn py_current_micro_step() -> u64 { + probing_core::trace::current_micro_step() } /// Internal function to create a span - called by Python wrapper. /// This is a low-level function that directly creates a span. #[pyfunction] -#[pyo3(signature = (name, *, kind=None, location=None))] +#[pyo3(signature = (name, *, phase=None, location=None))] fn _span_raw( py: Python, name: String, - kind: Option, + phase: Option, location: Option, ) -> PyResult { let parent = SPAN_STACK.with(|stack| { @@ -453,9 +480,9 @@ fn _span_raw( let span = if let Some(parent) = parent { let parent_obj = parent.bind(py); let parent_span = parent_obj.cast::()?; - Span::new_child(parent_span, name, kind, location) + Span::new_child(parent_span, name, phase, location) } else { - Span::new(name, kind, location) + Span::new(name, phase, location) }; Ok(span) @@ -521,12 +548,13 @@ pub fn register_tracing_functions(module: &Bound<'_, PyModule>) -> PyResult<()> module.add_function(wrap_pyfunction!(_span_raw, module)?)?; module.add_function(wrap_pyfunction!(current_span, module)?)?; module.add_function(wrap_pyfunction!(active_span_for_events, module)?)?; - module.add_function(wrap_pyfunction!(active_span_by_kind, module)?)?; + module.add_function(wrap_pyfunction!(active_span_by_phase, module)?)?; + module.add_function(wrap_pyfunction!(active_training_phase, module)?)?; module.add_function(wrap_pyfunction!(py_step_snapshot, module)?)?; - module.add_function(wrap_pyfunction!(py_sync_local_step, module)?)?; - module.add_function(wrap_pyfunction!(py_advance_local_step, module)?)?; - module.add_function(wrap_pyfunction!(py_set_step_bucket_size, module)?)?; - module.add_function(wrap_pyfunction!(py_current_local_step, module)?)?; + module.add_function(wrap_pyfunction!(py_sync_micro_step, module)?)?; + module.add_function(wrap_pyfunction!(py_advance_micro_step, module)?)?; + module.add_function(wrap_pyfunction!(py_set_micro_batches, module)?)?; + module.add_function(wrap_pyfunction!(py_current_micro_step, module)?)?; Ok(()) } diff --git a/probing/memtable/src/discover.rs b/probing/memtable/src/discover.rs index 4197eca0..4457eadb 100644 --- a/probing/memtable/src/discover.rs +++ b/probing/memtable/src/discover.rs @@ -131,9 +131,9 @@ impl ExposedTable { chunk_size: u32, num_chunks: u32, ) -> io::Result { - Ok(Self { - inner: MemTable::shared_in(base_dir, name, schema, chunk_size, num_chunks)?, - }) + let inner = MemTable::shared_in(base_dir, name, schema, chunk_size, num_chunks)?; + crate::docs::register_from_name(name, schema); + Ok(Self { inner }) } pub fn as_bytes(&self) -> &[u8] { @@ -475,6 +475,44 @@ mod tests { dir } + #[test] + fn exposed_table_create_registers_docs() { + let dir = test_dir(); + let table = format!("registry_create_{}", std::process::id()); + let qualified = format!("unittest.{table}"); + let schema = + Schema::new() + .table_doc("created via mmap") + .col_doc("x", DType::I32, "int column"); + ExposedTable::create_in(&dir, &qualified, &schema, 1024, 2).unwrap(); + let rows = crate::docs::snapshot(); + assert!(rows.iter().any(|r| { + r.table_schema == "unittest" + && r.table_name == table + && r.description.as_deref() == Some("created via mmap") + && r.columns.get("x") == Some(&"int column".to_string()) + })); + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn exposed_table_undotted_create_uses_memtable_schema() { + let dir = test_dir(); + let table = format!("undotted_{}", std::process::id()); + let schema = + Schema::new() + .table_doc("plain memtable file") + .col_doc("v", DType::F64, "value"); + ExposedTable::create_in(&dir, &table, &schema, 1024, 2).unwrap(); + let rows = crate::docs::snapshot(); + assert!(rows.iter().any(|r| { + r.table_schema == "memtable" + && r.table_name == table + && r.description.as_deref() == Some("plain memtable file") + })); + let _ = fs::remove_dir_all(&dir); + } + #[test] fn exposed_table_roundtrip() { let dir = test_dir(); diff --git a/probing/memtable/src/docs.rs b/probing/memtable/src/docs.rs new file mode 100644 index 00000000..a0e4a052 --- /dev/null +++ b/probing/memtable/src/docs.rs @@ -0,0 +1,180 @@ +//! In-process registry for table/column documentation attached to [`Schema`]. +//! +//! Docs are **not** persisted in mmap headers; they live only in Rust (or are +//! registered from Python) and are consumed by the probing Engine semantic catalog. + +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; + +use crate::schema::Schema; + +/// Documentation for one SQL table (`schema.table`). +#[derive(Debug, Clone, Default)] +pub struct TableDocs { + pub table_schema: String, + pub table_name: String, + pub description: Option, + pub columns: HashMap, +} + +static REGISTRY: OnceLock>> = OnceLock::new(); + +fn registry() -> &'static Mutex> { + REGISTRY.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn qualified_key(table_schema: &str, table_name: &str) -> String { + format!("{table_schema}.{table_name}") +} + +/// Register table/column docs for a qualified SQL name (`hccl.host_ops`, `python.foo`, …). +pub fn register_qualified(table_schema: &str, table_name: &str, schema: &Schema) { + let key = qualified_key(table_schema, table_name); + let mut entry = TableDocs { + table_schema: table_schema.to_string(), + table_name: table_name.to_string(), + description: schema.table_doc.clone(), + columns: HashMap::new(), + }; + for col in &schema.cols { + if let Some(doc) = &col.doc { + entry.columns.insert(col.name.clone(), doc.clone()); + } + } + + let mut reg = registry().lock().expect("table doc registry lock"); + reg.insert(key, entry); +} + +/// Register docs from an on-disk mmap basename (`hccl.host_ops` or undotted `metrics`). +pub fn register_from_name(name: &str, schema: &Schema) { + if let Some((table_schema, table_name)) = name.split_once('.') { + register_qualified(table_schema, table_name, schema); + } else { + register_qualified("memtable", name, schema); + } +} + +/// Snapshot all registered docs (sorted by qualified name). +pub fn snapshot() -> Vec { + let reg = registry().lock().expect("table doc registry lock"); + let mut rows: Vec = reg.values().cloned().collect(); + rows.sort_by(|a, b| (&a.table_schema, &a.table_name).cmp(&(&b.table_schema, &b.table_name))); + rows +} + +/// Register column docs without a full schema (e.g. Python `@table` before first append). +pub fn register_column_docs( + table_schema: &str, + table_name: &str, + table_doc: Option<&str>, + columns: &[(String, String)], +) { + let key = qualified_key(table_schema, table_name); + let mut reg = registry().lock().expect("table doc registry lock"); + let entry = reg.entry(key).or_insert_with(|| TableDocs { + table_schema: table_schema.to_string(), + table_name: table_name.to_string(), + description: None, + columns: HashMap::new(), + }); + if let Some(doc) = table_doc { + entry.description = Some(doc.to_string()); + } + for (col, doc) in columns { + entry.columns.insert(col.clone(), doc.clone()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{DType, Schema}; + + fn unique_table(prefix: &str) -> String { + format!("{prefix}_{}", std::process::id()) + } + + #[test] + fn register_from_schema_snapshot() { + let schema = + Schema::new() + .table_doc("demo table") + .col_doc("ts", DType::I64, "timestamp ns"); + register_from_name("demo.events", &schema); + let rows = snapshot(); + assert!(rows.iter().any(|r| { + r.table_schema == "demo" + && r.table_name == "events" + && r.description.as_deref() == Some("demo table") + && r.columns.get("ts") == Some(&"timestamp ns".to_string()) + })); + } + + #[test] + fn register_undotted_name_uses_memtable_schema() { + let name = unique_table("metrics_doc"); + let schema = + Schema::new() + .table_doc("undotted metrics") + .col_doc("v", DType::I64, "sample value"); + register_from_name(&name, &schema); + let rows = snapshot(); + assert!(rows.iter().any(|r| { + r.table_schema == "memtable" + && r.table_name == name + && r.description.as_deref() == Some("undotted metrics") + && r.columns.get("v") == Some(&"sample value".to_string()) + })); + } + + #[test] + fn register_column_docs_merges_into_existing_entry() { + let table = unique_table("merge_docs"); + register_column_docs( + "unittest", + &table, + Some("initial table doc"), + &[("a".to_string(), "column a".to_string())], + ); + register_column_docs( + "unittest", + &table, + Some("updated table doc"), + &[("b".to_string(), "column b".to_string())], + ); + let rows = snapshot(); + let row = rows + .iter() + .find(|r| r.table_schema == "unittest" && r.table_name == table) + .expect("merged docs row"); + assert_eq!(row.description.as_deref(), Some("updated table doc")); + assert_eq!(row.columns.get("a"), Some(&"column a".to_string())); + assert_eq!(row.columns.get("b"), Some(&"column b".to_string())); + } + + #[test] + fn register_from_schema_replaces_prior_entry() { + let table = unique_table("replace_docs"); + register_from_name( + &format!("unittest.{table}"), + &Schema::new() + .table_doc("old") + .col_doc("x", DType::I32, "old col"), + ); + register_from_name( + &format!("unittest.{table}"), + &Schema::new() + .table_doc("new") + .col_doc("y", DType::I32, "new col"), + ); + let rows = snapshot(); + let row = rows + .iter() + .find(|r| r.table_schema == "unittest" && r.table_name == table) + .expect("replaced docs row"); + assert_eq!(row.description.as_deref(), Some("new")); + assert!(!row.columns.contains_key("x")); + assert_eq!(row.columns.get("y"), Some(&"new col".to_string())); + } +} diff --git a/probing/memtable/src/lib.rs b/probing/memtable/src/lib.rs index c2b8bd8d..b530ddaf 100644 --- a/probing/memtable/src/lib.rs +++ b/probing/memtable/src/lib.rs @@ -100,6 +100,7 @@ mod cache; mod dedup; pub mod discover; +pub mod docs; mod layout; pub mod memc; pub mod memh; diff --git a/probing/memtable/src/memtable.rs b/probing/memtable/src/memtable.rs index f84a4eb9..d79077f8 100644 --- a/probing/memtable/src/memtable.rs +++ b/probing/memtable/src/memtable.rs @@ -156,6 +156,7 @@ macro_rules! impl_table_reader { name: cd.name_str().to_string(), dtype, elem_size: cd.elem_size as usize, + doc: None, }); } } diff --git a/probing/memtable/src/schema.rs b/probing/memtable/src/schema.rs index 04ed0fa9..66e95355 100644 --- a/probing/memtable/src/schema.rs +++ b/probing/memtable/src/schema.rs @@ -157,23 +157,44 @@ pub struct Col { pub name: String, pub dtype: DType, pub elem_size: usize, + /// Human-readable column description (not persisted in mmap). + pub doc: Option, } pub struct Schema { pub cols: Vec, + /// Human-readable table description (not persisted in mmap). + pub table_doc: Option, } impl Schema { pub fn new() -> Self { - Self { cols: vec![] } + Self { + cols: vec![], + table_doc: None, + } + } + + pub fn table_doc(mut self, doc: impl Into) -> Self { + self.table_doc = Some(doc.into()); + self + } + + pub fn col(self, name: &str, dtype: DType) -> Self { + self.push_col(name, dtype, None) } - pub fn col(mut self, name: &str, dtype: DType) -> Self { + pub fn col_doc(self, name: &str, dtype: DType, doc: impl Into) -> Self { + self.push_col(name, dtype, Some(doc.into())) + } + + fn push_col(mut self, name: &str, dtype: DType, doc: Option) -> Self { let elem_size = dtype.fixed_size().unwrap_or(0); self.cols.push(Col { name: name.into(), dtype, elem_size, + doc, }); self } @@ -207,4 +228,15 @@ mod tests { let schema = Schema::new().col("id", DType::I64).col("name", DType::Str); assert_eq!(format!("{schema:?}"), "Schema(id:i64, name:str)"); } + + #[test] + fn schema_table_and_column_docs() { + let schema = Schema::new() + .table_doc("events table") + .col("id", DType::I64) + .col_doc("name", DType::Str, "event name"); + assert_eq!(schema.table_doc.as_deref(), Some("events table")); + assert_eq!(schema.cols[0].doc, None); + assert_eq!(schema.cols[1].doc.as_deref(), Some("event name")); + } } diff --git a/probing/server/src/server/cluster_fanout.rs b/probing/server/src/server/cluster_fanout.rs index e442b014..23faeb96 100644 --- a/probing/server/src/server/cluster_fanout.rs +++ b/probing/server/src/server/cluster_fanout.rs @@ -206,14 +206,7 @@ fn tag_dataframe(mut df: DataFrame, host: &str, addr: &str, rank: Option) - if df.is_empty() { return df; } - let rows = df.len(); - df.names.push("_host".to_string()); - df.names.push("_addr".to_string()); - df.names.push("_rank".to_string()); - df.cols.push(Seq::SeqText(vec![host.to_string(); rows])); - df.cols.push(Seq::SeqText(vec![addr.to_string(); rows])); - df.cols.push(Seq::SeqI32(vec![rank.unwrap_or(-1); rows])); - df.size = df.len() as u64; + probing_core::core::federation::tag_proto_dataframe(&mut df, host, addr, rank); df } @@ -292,7 +285,7 @@ mod tests { ); let merged = merge_tagged_dataframes(&[local, remote]); assert_eq!(merged.len(), 2); - assert_eq!(merged.names.len(), 4); + assert_eq!(merged.names.len(), 7); let host_col = merged.names.iter().position(|n| n == "_host").unwrap(); assert_eq!(merged.cols[host_col].get_str(0).as_deref(), Some("host-a")); assert_eq!(merged.cols[host_col].get_str(1).as_deref(), Some("host-b")); diff --git a/probing/server/src/server/training.rs b/probing/server/src/server/training.rs index 336d3896..55ad96de 100644 --- a/probing/server/src/server/training.rs +++ b/probing/server/src/server/training.rs @@ -20,7 +20,7 @@ SELECT FROM python.trace_event s JOIN python.trace_event e ON s.span_id = e.span_id AND e.record_type = 'span_end' -WHERE s.record_type = 'span_start' AND s.kind = 'train.step' +WHERE s.record_type = 'span_start' AND s.name = 'train.step' ORDER BY s.time ASC "#; @@ -250,13 +250,26 @@ fn parse_attrs(raw: &str) -> (i32, i64, String) { return (-1, -1, String::new()); }; let rank = normalize_rank(value.get("rank").and_then(json_i64).unwrap_or(-1) as i32); - let local_step = value.get("local_step").and_then(json_i64).unwrap_or(-1); + let coord_step = value + .get("local_step") + .or_else(|| value.get("global_step")) + .and_then(json_i64) + .or_else(|| { + let micro = value.get("micro_step").and_then(json_i64)?; + let batches = value + .get("micro_batches") + .and_then(json_i64) + .unwrap_or(1) + .max(1); + Some(micro / batches) + }) + .unwrap_or(-1); let source = value .get("source") .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); - (rank, local_step, source) + (rank, coord_step, source) } fn json_i64(v: &serde_json::Value) -> Option { diff --git a/pyproject.toml b/pyproject.toml index dd4740d0..305eb063 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,10 @@ Issues = "https://github.com/reiase/probing/issues" probing = "probing.cli.__main__:main" [project.optional-dependencies] +otel = [ + "opentelemetry-api>=1.20", + "opentelemetry-sdk>=1.20", +] test = [ "pytest>=8.0; python_version >= '3.8'", "pytest<8.0; python_version < '3.8'", diff --git a/python/probing/__init__.py b/python/probing/__init__.py index f9c9e38e..d5d3db26 100644 --- a/python/probing/__init__.py +++ b/python/probing/__init__.py @@ -15,7 +15,7 @@ Public Interfaces: - Engine: `query`, `load_extension` - Control: `cli_main`, `enable_tracer`, `disable_tracer`, `is_enabled` -- Tracing: `span`, `event` +- Tracing: `span`, `event`, `record_span`, `step` - Engine: `query`, `load_extension` Pulsing integration is passive: when another runtime writes ``pulsing.*`` memtables @@ -52,11 +52,20 @@ def is_enabled(): # Internal Accessors _get_python_stacks = _core._get_python_stacks _get_python_frames = _core._get_python_frames + register_table_docs = _core.register_table_docs # Submodules with side effects (must be imported after Core Primitives) from probing.core.engine import load_extension, query from probing.parallel import clear_role, current_role, set_role - from probing.tracing import event, span + from probing.tracing import ( + attach_training_phases, + event, + owns_training_phases, + phase, + record_span, + span, + step, + ) try: from probing.nccl.mock import maybe_auto_seed @@ -83,6 +92,11 @@ def is_enabled(): "load_extension", "span", "event", + "record_span", + "step", + "phase", + "attach_training_phases", + "owns_training_phases", "set_role", "clear_role", "current_role", diff --git a/python/probing/bundled_skills/semantic/tables.yaml b/python/probing/bundled_skills/semantic/tables.yaml index 7ce954d6..76b92ce5 100644 --- a/python/probing/bundled_skills/semantic/tables.yaml +++ b/python/probing/bundled_skills/semantic/tables.yaml @@ -8,8 +8,10 @@ tables: description: "PyTorch module-level forward/step timings and GPU memory snapshots" synonyms: [torch trace, module timing, 模块耗时, 训练步 profiling] key_columns: - step: "Training step index (int)" - global_step: "Global training step (from Rust step_snapshot)" + micro_step: "Finest step counter (micro-batch index)" + local_step: "Training step on this rank (micro_step // micro_batches)" + global_step: "Same as local_step when ranks are aligned" + micro_batches: "Gradient accumulation factor" rank: "torch.distributed rank" role: "Parallel role key, e.g. 'dp=2,pp=1,tp=0' (align/join across ranks)" module: "Fully-qualified module name" @@ -26,8 +28,10 @@ tables: description: "torch.distributed collective calls (all_reduce, broadcast, …)" synonyms: [collective, communication, NCCL, 通信, all_reduce] key_columns: - global_step: "Global training step" - local_step: "Local step on this rank" + micro_step: "Finest step counter" + global_step: "Training step (same as local_step)" + local_step: "Training step on this rank" + micro_batches: "Gradient accumulation factor" rank: "torch.distributed rank" role: "Parallel role key, e.g. 'dp=2,pp=1,tp=0' (align/join across ranks)" op: "Collective operation name" diff --git a/python/probing/core/table.py b/python/probing/core/table.py index 03b8d749..856117f0 100644 --- a/python/probing/core/table.py +++ b/python/probing/core/table.py @@ -23,6 +23,26 @@ def camel_to_snake(name): return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() +def _table_doc_from_class(cls) -> Optional[str]: + doc = cls.__doc__ + if not doc: + return None + line = doc.strip().splitlines()[0].strip() + if not line or line.startswith(f"{cls.__name__}("): + return None + return line + + +def _column_docs_from_class(cls) -> dict[str, str]: + docs: dict[str, str] = {} + for field in dataclasses.fields(cls): + meta = field.metadata or {} + doc = meta.get("doc") + if doc: + docs[field.name] = str(doc) + return docs + + def table(name_or_class: Optional[Union[str, Type[Any]]] = None): """A decorator that converts a dataclass into a persistable table. @@ -30,6 +50,9 @@ def table(name_or_class: Optional[Union[str, Type[Any]]] = None): it creates or retrieves an ExternalTable with the dataclass name (or provided name) and adds methods for data persistence and retrieval operations. + Table documentation comes from the class ``__doc__`` (first line). Column docs use + ``field(metadata={"doc": "..."})`` and are registered for SQL ``DESCRIBE``. + Parameters ---------- name : Optional[str], default=None @@ -63,12 +86,13 @@ def table(name_or_class: Optional[Union[str, Type[Any]]] = None): Examples -------- - >>> from dataclasses import dataclass + >>> from dataclasses import dataclass, field >>> @table ... @dataclass ... class Point: - ... x: int - ... y: int + ... \"\"\"Demo points table.\"\"\" + ... x: int = field(metadata={"doc": "X coordinate"}) + ... y: int = field(metadata={"doc": "Y coordinate"}) >>> Point.append(Point(1, 2)) >>> Point.take(10)[0][1] [1, 2] @@ -91,6 +115,9 @@ def decorator(cls): table_name = name or camel_to_snake(cls.__name__) fields = [f.name for f in dataclasses.fields(cls)] + table_doc = _table_doc_from_class(cls) + column_docs = _column_docs_from_class(cls) + qualified_name = table_name if "." in table_name else f"python.{table_name}" @functools.wraps(cls.__init__) def init_table(): @@ -101,7 +128,16 @@ def init_table(): f"Table {table_name} already exists with different fields" ) except Exception: - table = probing.ExternalTable(table_name, fields) + table = probing.ExternalTable( + table_name, + fields, + table_doc=table_doc, + column_docs=column_docs or None, + ) + if column_docs or table_doc: + probing.register_table_docs( + qualified_name, table_doc, column_docs or None + ) cache[cls] = table return table diff --git a/python/probing/ext/ray.py b/python/probing/ext/ray.py index a9d23dd2..9df993d6 100644 --- a/python/probing/ext/ray.py +++ b/python/probing/ext/ray.py @@ -54,13 +54,14 @@ def on_start(self, span, parent_context=None): SpanKind.PRODUCER: "producer", SpanKind.CONSUMER: "consumer", } - span_kind = kind_map.get(span.kind) - attrs = {} if hasattr(span, "attributes") and span.attributes: attrs = {str(k): str(v) for k, v in span.attributes.items()} + span_kind = kind_map.get(span.kind) + if span_kind: + attrs["otel.span_kind"] = span_kind - probing_span = probing.span(span_name, kind=span_kind, **attrs) + probing_span = probing.span(span_name, **attrs) probing_span.__enter__() span_context = span.get_span_context() diff --git a/python/probing/ext/torch.py b/python/probing/ext/torch.py index 40772c94..445c08bf 100644 --- a/python/probing/ext/torch.py +++ b/python/probing/ext/torch.py @@ -13,6 +13,10 @@ def is_true(value): def optimizer_step_post_hook(optimizer, *args, **kwargs): global hooks + from probing.tracing.hooks import maybe_auto_attach + + maybe_auto_attach(optimizer) + if optimizer not in hooks: from probing.profiling.torch import install_hooks from probing.profiling.torch.module_utils import get_toplevel_module diff --git a/python/probing/handlers/pythonext.py b/python/probing/handlers/pythonext.py index 1e1d06f9..9cbcb1f5 100644 --- a/python/probing/handlers/pythonext.py +++ b/python/probing/handlers/pythonext.py @@ -130,7 +130,7 @@ def get_chrome_tracing(limit: int = 1000) -> str: name, time as timestamp, COALESCE(thread_id, 0) as thread_id, - kind, + phase, location, attributes, event_attributes @@ -167,13 +167,13 @@ def get_chrome_tracing(limit: int = 1000) -> str: thread_id = row.get("thread_id", 0) trace_id = row.get("trace_id", 0) name = row.get("name", "unknown") - kind = row.get("kind", "trace") + phase = row.get("phase", "") # Use (span_id, thread_id) as key to handle multiple threads key = (span_id, thread_id) span_start_lookup[key] = { "trace_id": trace_id, "name": name, - "kind": kind, + "phase": phase, "timestamp": row.get("timestamp", 0), } @@ -185,7 +185,7 @@ def get_chrome_tracing(limit: int = 1000) -> str: trace_id = row.get("trace_id", 0) span_id = row.get("span_id", 0) thread_id = row.get("thread_id", 0) - kind = row.get("kind", "trace") + phase = row.get("phase", "") # Convert nanoseconds to microseconds ts_micros = (timestamp - min_timestamp) // 1000 @@ -196,10 +196,10 @@ def get_chrome_tracing(limit: int = 1000) -> str: if record_type == "span_start": # Store span start information with trace_id for matching key = (span_id, thread_id) - span_starts[key] = (ts_micros, name, kind, pid) + span_starts[key] = (ts_micros, name, phase, pid) chrome_event = { "name": name, - "cat": kind if kind else "span", + "cat": phase if phase else "span", "ph": "B", "ts": ts_micros, "pid": pid, @@ -215,12 +215,12 @@ def get_chrome_tracing(limit: int = 1000) -> str: if start_info: # Found matching span_start that was already processed - start_ts, start_name, start_kind, start_pid = start_info + start_ts, start_name, start_phase, start_pid = start_info # Use the pid from span_start to ensure matching chrome_event = { "name": start_name, # Must match span_start name "cat": ( - start_kind if start_kind else "span" + start_phase if start_phase else "span" ), # Must match span_start cat "ph": "E", "ts": ts_micros, @@ -244,11 +244,11 @@ def get_chrome_tracing(limit: int = 1000) -> str: lookup_info["timestamp"] - min_timestamp ) // 1000 start_name = lookup_info["name"] - start_kind = lookup_info["kind"] + start_phase = lookup_info["phase"] chrome_event = { "name": start_name, # Must match span_start name "cat": ( - start_kind if start_kind else "span" + start_phase if start_phase else "span" ), # Must match span_start cat "ph": "E", "ts": ts_micros, diff --git a/python/probing/hccl/__init__.py b/python/probing/hccl/__init__.py new file mode 100644 index 00000000..656c027b --- /dev/null +++ b/python/probing/hccl/__init__.py @@ -0,0 +1,85 @@ +"""HCCL MSProf shim helpers (Linux, Ascend / CANN).""" + +from __future__ import annotations + +import os +import shutil +import sys +from pathlib import Path + +__all__ = [ + "shim_path", + "shim_dir", + "ld_library_path_prefix", + "install_real_copy", + "ENV_REAL", + "ENV_SHIM_LOG", +] + +_LIB_BASENAME = "libprofapi.so" +_REAL_BASENAME = "libprofapi.so.real" +_ENV_OVERRIDE = "PROBING_HCCL_SHIM" +ENV_REAL = "PROBING_HCCL_PROFAPI_REAL" +ENV_SHIM_LOG = "PROBING_HCCL_SHIM_LOG" + + +def _candidate_paths() -> list[Path]: + pkg_root = Path(__file__).resolve().parent.parent + name = _LIB_BASENAME + out = [ + pkg_root / "shim" / "hccl" / name, + ] + repo_root = Path(__file__).resolve().parents[3] + for profile in ("release", "debug"): + out.append(repo_root / "target" / profile / name) + return out + + +def shim_path() -> str: + """Absolute path to the probing ``libprofapi.so`` shim.""" + override = os.environ.get(_ENV_OVERRIDE) + if override: + path = Path(override).expanduser().resolve() + if not path.is_file(): + raise FileNotFoundError( + f"{_ENV_OVERRIDE}={override!r} does not point to an existing file" + ) + return str(path) + + if sys.platform != "linux": + raise OSError( + "HCCL MSProf shim is only available on Linux; " + f"set {_ENV_OVERRIDE} if you have a custom build" + ) + + for candidate in _candidate_paths(): + if candidate.is_file(): + return str(candidate.resolve()) + + searched = ", ".join(str(p) for p in _candidate_paths()) + raise FileNotFoundError( + "HCCL shim not found. Build with " + "`make hccl-shim-lib` or `cargo build -p probing-hccl-shim --release`, " + f"or set {_ENV_OVERRIDE}. Searched: {searched}" + ) + + +def shim_dir() -> str: + """Directory containing ``libprofapi.so`` (for ``LD_LIBRARY_PATH``).""" + return str(Path(shim_path()).parent) + + +def ld_library_path_prefix() -> str: + """Value to prepend to ``LD_LIBRARY_PATH`` (shim dir only).""" + return shim_dir() + + +def install_real_copy(cann_libprofapi: str | Path) -> Path: + """Copy CANN's real ``libprofapi.so`` to ``libprofapi.so.real`` beside the shim.""" + src = Path(cann_libprofapi).expanduser().resolve() + if not src.is_file(): + raise FileNotFoundError(f"CANN libprofapi not found: {src}") + dest = Path(shim_dir()) / _REAL_BASENAME + dest.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dest) + return dest diff --git a/python/probing/hccl/__main__.py b/python/probing/hccl/__main__.py new file mode 100644 index 00000000..5994701e --- /dev/null +++ b/python/probing/hccl/__main__.py @@ -0,0 +1,81 @@ +"""CLI: ``python -m probing.hccl --shim-dir``""" + +from __future__ import annotations + +import argparse +import sys + +from probing.hccl import ( + ENV_REAL, + ENV_SHIM_LOG, + install_real_copy, + ld_library_path_prefix, + shim_dir, + shim_path, +) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + prog="python -m probing.hccl", + description="HCCL libprofapi.so shim utilities", + ) + parser.add_argument( + "--shim-path", + action="store_true", + help="print absolute path to libprofapi.so (probing shim)", + ) + parser.add_argument( + "--shim-dir", + action="store_true", + help="print directory to prepend to LD_LIBRARY_PATH", + ) + parser.add_argument( + "--install-real", + metavar="PATH", + help="copy CANN libprofapi.so to libprofapi.so.real next to the shim", + ) + args = parser.parse_args(argv) + + if args.shim_path: + try: + print(shim_path()) + except (OSError, FileNotFoundError) as e: + print(e, file=sys.stderr) + return 1 + return 0 + + if args.shim_dir: + try: + print(shim_dir()) + except (OSError, FileNotFoundError) as e: + print(e, file=sys.stderr) + return 1 + return 0 + + if args.install_real: + try: + dest = install_real_copy(args.install_real) + except (OSError, FileNotFoundError) as e: + print(e, file=sys.stderr) + return 1 + print(dest, file=sys.stderr) + return 0 + + parser.print_help(sys.stderr) + print("\nExample:", file=sys.stderr) + print( + " export LD_LIBRARY_PATH=$(python -m probing.hccl --shim-dir):$LD_LIBRARY_PATH", + file=sys.stderr, + ) + print( + f" export {ENV_REAL}=/path/to/cann/lib64/libprofapi.so # optional if libprofapi.so.real present", + file=sys.stderr, + ) + print(" export PROBING=2", file=sys.stderr) + print(f" export {ENV_SHIM_LOG}=1 # optional shim debug log", file=sys.stderr) + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/python/probing/parallel.py b/python/probing/parallel.py index 6d10490c..094caec0 100644 --- a/python/probing/parallel.py +++ b/python/probing/parallel.py @@ -172,7 +172,7 @@ def current_role() -> str: return role_key() -def set_role(role=None, /, **dims) -> str: +def set_role(role=None, **dims) -> str: """Override this process's parallel role at runtime. Accepts a canonical string, a mapping, or keyword dimensions:: diff --git a/python/probing/profiling/collective/record.py b/python/probing/profiling/collective/record.py index d9b1e313..68bf7c69 100644 --- a/python/probing/profiling/collective/record.py +++ b/python/probing/profiling/collective/record.py @@ -15,13 +15,12 @@ from probing.core import table from probing.parallel import current_role -from probing.tracing import ( - _step_fields, - comm_kind, - record_closed_span, - recorded_span, - step_snapshot, -) +from probing.tracing import record_span, span, step +from probing.tracing.coordinates import row_fields + + +def _comm_label(op: str) -> str: + return op if op.startswith("comm.") else f"comm.{op}" class CommRecordMode(str, Enum): @@ -32,8 +31,10 @@ class CommRecordMode(str, Enum): @table("comm_collective") @dataclass class CommCollective: + micro_step: int = 0 local_step: int = 0 global_step: int = 0 + micro_batches: int = 1 rank: int = -1 world_size: int = -1 # Extensible parallel role (e.g. "dp=2,pp=1,tp=0"); see probing.parallel.role_key. @@ -54,16 +55,7 @@ def _role_row_fields() -> dict: def _step_row_fields() -> dict: - snap = step_snapshot() - if snap is None: - return {"local_step": 0, "global_step": 0, "rank": -1, "world_size": -1} - step = _step_fields(snap) - return { - "local_step": step.get("local_step", 0), - "global_step": step.get("global_step", 0), - "rank": step.get("rank", -1), - "world_size": step.get("world_size", -1), - } + return row_fields(step.snapshot()) def _context_fields( @@ -118,11 +110,10 @@ def record_comm_lite( ) CommCollective(duration_ms=duration_ms, **fields).save() if write_trace_event: - record_closed_span( + record_span( op, - kind=comm_kind(op), duration_ns=int(duration_ms * 1e6), - attrs={**fields, "duration_ms": duration_ms}, + attrs={**fields, "duration_ms": duration_ms, "comm": _comm_label(op)}, source="collective_tracer", ) @@ -149,8 +140,8 @@ def begin_comm_span( nbytes=nbytes, async_op=async_op, ) - span_attrs = {**meta, "source": "collective_tracer"} - cm = recorded_span(op, kind=comm_kind(op), **span_attrs) + span_attrs = {k: v for k, v in meta.items() if k != "source"} + cm = span(op, source="collective_tracer", comm=_comm_label(op), **span_attrs) cm.__enter__() return cm, meta diff --git a/python/probing/profiling/phase_tracker.py b/python/probing/profiling/phase_tracker.py new file mode 100644 index 00000000..42631516 --- /dev/null +++ b/python/probing/profiling/phase_tracker.py @@ -0,0 +1,17 @@ +"""Backward-compatible import path; see ``probing.tracing.hooks``.""" + +from probing.tracing.hooks import ( # noqa: F401 + PhaseTracker, + _REGISTRY, + attach_training_phases, + detach_training_phases, + maybe_auto_attach, +) + +__all__ = [ + "PhaseTracker", + "_REGISTRY", + "attach_training_phases", + "detach_training_phases", + "maybe_auto_attach", +] diff --git a/python/probing/profiling/torch_probe.py b/python/probing/profiling/torch_probe.py index 857e19ac..96434313 100644 --- a/python/probing/profiling/torch_probe.py +++ b/python/probing/profiling/torch_probe.py @@ -6,13 +6,9 @@ import probing from probing.core import table from probing.parallel import current_role -from probing.tracing import ( - TRAIN_STEP_KIND, - current_local_step, - module_stage_kind, - recorded_span, - step_snapshot, -) +from probing.tracing import span, step +from probing.tracing.coordinates import row_fields +from probing.tracing.phases import OPTIMIZER, infer_from_stage, is_training_phase from .torch.module_utils import module_name from .types import BaseTracer @@ -40,7 +36,10 @@ def _get_backend(): @table @dataclass class TorchTrace: - step: Optional[int] = None + micro_step: Optional[int] = None + local_step: int = -1 + global_step: int = -1 + micro_batches: int = 1 seq: Optional[int] = None module: Optional[str] = None stage: Optional[str] = None @@ -52,10 +51,7 @@ class TorchTrace: duration: float = 0.0 allocated_delta: float = 0.0 max_allocated_delta: float = 0.0 - # Step coordinate + parallel role, so module-level local work can be aligned - # across ranks by role and joined with comm_collective. ``role`` is an - # extensible KV string (e.g. "dp=2,pp=1,tp=0"); see probing.parallel.role_key. - global_step: int = -1 + # Step coordinate + parallel role (see probing.step). rank: int = -1 world_size: int = -1 role: str = "" @@ -64,7 +60,7 @@ class TorchTrace: @table @dataclass class Variables: - step: Optional[int] = None + micro_step: Optional[int] = None func: Optional[str] = None name: Optional[str] = None value: Optional[str] = None @@ -659,7 +655,7 @@ def __init__(self, config: Optional[TorchProbeConfig] = None): self.config = config self.enabled = config.enabled - self.curr_step = current_local_step() + self.curr_step = step.micro_step self.pending = [] self._open_spans = {} self._train_step_cm = None @@ -673,30 +669,27 @@ def __init__(self, config: Optional[TorchProbeConfig] = None): ) def _stamp_step_role(self, record) -> None: - """Fill step coordinate (local/global/rank) and parallel role on a record.""" - snap = step_snapshot() - if snap is not None: - record.step = int(snap.local_step) - record.global_step = int(snap.global_step) - record.rank = int(snap.rank) - record.world_size = int(snap.world_size) + """Fill step coordinate and parallel role on a torch_trace record.""" + for key, value in row_fields(step.snapshot()).items(): + setattr(record, key, value) record.role = current_role() - def _begin_train_step_span(self) -> None: + def _begin_train_step_span(self, optimizer=None) -> None: if self._train_step_cm is not None: return - # Do not sync_local_step here — curr_step is often stale and would reset - # the global coordinate back to 0/1 across batches. - self._train_step_cm = recorded_span( - "step", kind=TRAIN_STEP_KIND, source="torch_probe" - ) - self._train_step_cm.__enter__() + from probing.tracing.hooks import owns_training_phases + + if optimizer is not None and owns_training_phases(optimizer=optimizer): + return + handle = span(phase=OPTIMIZER, source="torch_probe") + handle.__enter__() + self._train_step_cm = handle def _end_train_step_span(self) -> None: if self._train_step_cm is None: return - # Reentrant: outer span (e.g. manual train.step) owns the lifecycle. - if getattr(self._train_step_cm, "_reentrant", False): + inner = getattr(self._train_step_cm, "_inner", None) + if inner is not None and getattr(inner, "_reentrant", False): self._train_step_cm = None return self._train_step_cm.__exit__(None, None, None) @@ -727,7 +720,8 @@ def _complete_post_stage(self, mod, post_stage: str) -> None: pre_max_allocated = entry[4] record.allocated_delta = record.allocated - pre_allocated record.max_allocated_delta = record.max_allocated - pre_max_allocated - span_cm.__exit__(None, None, None) + if span_cm is not None: + span_cm.__exit__(None, None, None) self.pending.append(DelayedRecord(record, events)) def _finish_open_stages(self) -> None: @@ -782,19 +776,30 @@ def log_module_stage(self, stage, mod, force=False) -> None: module_name_str = self._module_display_name(mod) record.module = module_name_str record.stage = stage - span_kind = module_stage_kind(stage) + span_phase = infer_from_stage(stage) + + emit_trace_span = True + if is_training_phase(span_phase): + from probing.tracing.hooks import owns_training_phases + + if owns_training_phases(module=mod): + emit_trace_span = False if stage.startswith("pre"): record.time_offset = self.begin_timing(mod, stage) - span_cm = recorded_span( - module_name_str, - kind=span_kind, - module=module_name_str, - stage=mapped_stage, - seq=record.seq, - source="torch_probe", - ) - span_cm.__enter__() + span_cm = None + if emit_trace_span: + span_kwargs = dict( + module=module_name_str, + stage=mapped_stage, + seq=record.seq, + source="torch_probe", + ) + if is_training_phase(span_phase): + span_cm = span(module_name_str, phase=span_phase, **span_kwargs) + else: + span_cm = span(module_name_str, **span_kwargs) + span_cm.__enter__() self._open_spans[span_key] = ( span_cm, mod, @@ -812,13 +817,13 @@ def post_step_hook(self, opt, args, kwargs): return if not self.finalized: self.finalize_discovery() - self.curr_step = current_local_step() - self._begin_train_step_span() + self.curr_step = step.micro_step + self._begin_train_step_span(optimizer=opt) else: self._end_train_step_span() self.next_mod() - self.curr_step = current_local_step() - self._begin_train_step_span() + self.curr_step = step.micro_step + self._begin_train_step_span(optimizer=opt) # Ensure backend operations are complete before processing traces if self.has_backend and self.pending: diff --git a/python/probing/skills/__main__.py b/python/probing/skills/__main__.py index 7e637488..a300eb67 100644 --- a/python/probing/skills/__main__.py +++ b/python/probing/skills/__main__.py @@ -31,7 +31,8 @@ def main() -> int: catalog = load_catalog() print(f"Catalog: {len(catalog.skills)} skills") print(f"Roots: {_roots_display()}") - if repo := repo_skills_dir(): + repo = repo_skills_dir() + if repo: print(f"Repo skills: {repo}") warnings = validate_all() if warnings: diff --git a/python/probing/tracing.py b/python/probing/tracing.py deleted file mode 100644 index 9a18a2b8..00000000 --- a/python/probing/tracing.py +++ /dev/null @@ -1,716 +0,0 @@ -"""Tracing facade (Python side). - -Provides a thin, explicit wrapper around the Rust implementation for creating spans -via a context manager or decorator, attaching immutable attributes at creation time, -and recording span lifecycle plus custom events into a single table. - -Notes ------ -* Attributes are fixed at span creation (no mutation API exposed). -* `TraceEvent` stores start/end/event rows; missing values use simple sentinels - (parent_id = -1, text fields = empty string) to avoid `None` persistence issues. -* The public surface stays minimal: `span`, `Span.with_`, `Span.decorator`, `add_event`, - and the `TraceEvent` dataclass table. - -Examples --------- -Context manager:: - - import probing - with probing.span("load_data", dataset="mnist") as s: - probing.event("read") - # do work - -Decorator:: - - import probing - @probing.span("predict") - def predict(x): - return model(x) - -Implicit name decorator:: - - import probing - @probing.span - def compute(): - return 42 -""" - -import functools -import inspect -from dataclasses import dataclass -from typing import Callable, Optional - -# Import from the internal Rust module -from probing import _core - -try: - Span = _core.Span - current_span = _core.current_span - active_span_for_events = _core.active_span_for_events - active_span_by_kind = _core.active_span_by_kind - step_snapshot = _core.py_step_snapshot - sync_local_step = _core.py_sync_local_step - advance_local_step = _core.py_advance_local_step - set_step_bucket_size = _core.py_set_step_bucket_size - current_local_step = _core.py_current_local_step -except AttributeError: - Span = None - - def current_span(): - return None - - def active_span_for_events(): - return None - - def active_span_by_kind(_kind: str): - return None - - def step_snapshot(): - return None - - def sync_local_step(_step: int): - return None - - def advance_local_step(): - return None - - def set_step_bucket_size(_bucket: int): - return None - - def current_local_step() -> int: - return 0 - - -from probing.core.table import table - -TRAIN_STEP_KIND = "train.step" - -# Materialized span rows derived from ``python.trace_event`` (start/end join). -# Use span ``time`` (ns since epoch), not the memtable ingestion ``timestamp``. -SPANS_SQL = """ -SELECT - s.trace_id, - s.span_id, - COALESCE(s.parent_id, -1) AS parent_span_id, - s.name, - s.kind, - CAST(s.time / 1000 AS BIGINT) AS start_us, - CAST(e.time / 1000 AS BIGINT) AS end_us, - CAST((e.time - s.time) / 1000 AS BIGINT) AS duration_us, - s.thread_id, - s.location, - s.attributes -FROM python.trace_event s -JOIN python.trace_event e - ON s.span_id = e.span_id AND e.record_type = 'span_end' -WHERE s.record_type = 'span_start' -""" - -STAGE_KIND_MAP = { - "forward": "nn.forward", - "backward": "nn.backward", - "step": "optim.step", -} - - -def _step_fields(snapshot) -> dict: - if snapshot is None: - return {} - return { - "local_step": int(snapshot.local_step), - "global_step": int(snapshot.global_step), - "bucket_size": int(snapshot.bucket_size), - "rank": int(snapshot.rank), - "world_size": int(snapshot.world_size), - } - - -def _merge_span_attributes(attrs: dict, *, source: str = "manual") -> dict: - """Merge user attrs with step coordinates, topology, and source label.""" - merged = dict(attrs) - merged.setdefault("source", source) - snap = step_snapshot() - if snap is not None: - merged.update(_step_fields(snap)) - from probing.parallel import parallel_fields - - merged.update(parallel_fields()) - return merged - - -def comm_kind(op: str) -> str: - """Span kind for a collective op, e.g. ``comm.all_reduce``.""" - if op.startswith("comm."): - return op - return f"comm.{op}" - - -def _create_span_object( - name: str, kind: Optional[str], location: Optional[str], attrs: dict -): - parent = current_span() - if parent: - span_obj = Span.new_child(parent, name, kind=kind, location=location) - else: - span_obj = Span(name, kind=kind, location=location) - if attrs and hasattr(span_obj, "_set_initial_attrs"): - try: - span_obj._set_initial_attrs(dict(attrs)) - except Exception as e: - import warnings - - warnings.warn(f"Failed to set initial attributes: {e}") - return span_obj - - -class _RecordedSpan: - """Internal context manager: span stack + TraceEvent persistence.""" - - def __init__( - self, - name: str, - kind: Optional[str] = None, - location: Optional[str] = None, - attrs: Optional[dict] = None, - *, - source: str = "manual", - ): - self.name = name - self.kind = kind - self.location = location - self.attrs = dict(attrs or {}) - self.source = source - self._span = None - self._reentrant = False - self._owns_step_advance = False - - def __enter__(self): - if self.kind == TRAIN_STEP_KIND: - existing = active_span_by_kind(TRAIN_STEP_KIND) - if existing is not None: - self._span = existing - self._reentrant = True - return existing - - loc = self.location or _get_location() - merged = _merge_span_attributes(self.attrs, source=self.source) - self._span = _create_span_object(self.name, self.kind, loc, merged) - self._span.__enter__() - _record_span_start(self._span, merged) - if self.kind == TRAIN_STEP_KIND: - self._owns_step_advance = True - return self._span - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._span is None: - return False - if self._reentrant: - return False - result = self._span.__exit__(exc_type, exc_val, exc_tb) - _record_span_end(self._span) - if self._owns_step_advance: - advance_local_step() - return result - - -def recorded_span(name: str, kind: Optional[str] = None, **attrs): - """Open a span that is always persisted to ``python.trace_event``.""" - return _RecordedSpan(name, kind=kind, attrs=attrs) - - -def module_stage_kind(stage: str) -> str: - """Map TorchProbe stage label to span kind.""" - for key, kind in STAGE_KIND_MAP.items(): - if key in stage: - return kind - return "torch.module" - - -def _get_location() -> Optional[str]: - """Get the current call location from the stack. - - Returns - ------- - Optional[str] - Location string in format "filename:function:lineno" or None if unavailable. - """ - try: - # Get the frame that called span() (skip this function and span() itself) - stack = inspect.stack() - # Find the first frame that's not in this module - for frame_info in stack[2:]: # Skip _get_location and span() - frame = frame_info.frame - filename = frame_info.filename - function = frame_info.function - lineno = frame_info.lineno - - # Skip frames from this module - if "probing/tracing.py" in filename or "probing\\tracing.py" in filename: - continue - - # Format: "filename:function:lineno" - return f"{filename}:{function}:{lineno}" - except Exception: - pass - return None - - -@table -@dataclass -class TraceEvent: - """Row model for trace records. - - Each saved instance is one of: span_start, span_end, event. - - Parameters - ---------- - record_type : str - One of ``'span_start'``, ``'span_end'`` or ``'event'``. - trace_id : int - Trace identifier (shared by related spans). - span_id : int - Unique span identifier. - name : str - Span or event name. - time : int - Nanoseconds since epoch. - parent_id : int, default -1 - Parent span id, -1 if root. - kind : str, default "" - Optional span kind label. - location : str, default "" - Code location automatically captured from call stack. - attributes : str, default "" - JSON string of span attributes (only in span rows). - event_attributes : str, default "" - JSON string of event attributes (only in event rows). - """ - - # Required fields - record_type: str - trace_id: int - span_id: int - name: str - time: int - thread_id: int = 0 - - # Optional fields - parent_id: Optional[int] = -1 - kind: Optional[str] = "" - location: Optional[str] = "" - attributes: Optional[str] = "" - event_attributes: Optional[str] = "" - - -def span(*args, **kwargs): - """Factory for span usage as context manager or decorator. - - Scenarios - --------- - 1. Context manager:: - - with span("work", user="alice") as s: - ... - - 2. Decorator with explicit name:: - - @span("inference") - def run(x): ... - - 3. Decorator with implicit function name:: - - @span - def train(): ... - - Parameters - ---------- - *args - Either empty (implicit decorator), a single callable, or a single string name. - **kwargs - Attributes to attach plus optional ``kind``. - - Note - ---- - The ``location`` is automatically captured from the call stack using - Python's ``inspect`` module. It is not passed as a parameter. - - Returns - ------- - object - A context manager / decorator hybrid or a pure decorator. - """ - # Extract special parameters - kind = kwargs.pop("kind", None) - # Location is automatically captured, not passed as parameter - location = _get_location() - - if len(args) == 0 and not kwargs: - - def decorator(func: Callable) -> Callable: - @functools.wraps(func) - def wrapper(*wargs, **wkwargs): - with _RecordedSpan(func.__name__, kind=kind) as _s: - return func(*wargs, **wkwargs) - - return wrapper - - return decorator - - # Handle @span(func) - first arg is a callable - if len(args) == 1 and callable(args[0]): - func = args[0] - - @functools.wraps(func) - def wrapper(*wargs, **wkwargs): - with _RecordedSpan(func.__name__, kind=kind) as _s: - return func(*wargs, **wkwargs) - - return wrapper - - # Handle @span("name") or with span("name") - if len(args) == 1 and isinstance(args[0], str): - name = args[0] - - # Create a wrapper that supports both decorator and context manager usage - class SpanWrapper: - def __init__( - self, - name: str, - kind: Optional[str], - location: Optional[str], - attrs: dict, - ): - self.name = name - self.kind = kind - self.location = location - self.attrs = attrs - self._inner = None - - def __call__(self, func: Callable) -> Callable: - """Enable decorator form when a name was provided.""" - - @functools.wraps(func) - def wrapper(*wargs, **wkwargs): - with _RecordedSpan( - self.name, - kind=self.kind, - location=self.location, - attrs=self.attrs, - ) as _s: - return func(*wargs, **wkwargs) - - return wrapper - - def __enter__(self): - self._inner = _RecordedSpan( - self.name, - kind=self.kind, - location=self.location, - attrs=self.attrs, - ) - return self._inner.__enter__() - - def __exit__(self, *args): - if self._inner: - return self._inner.__exit__(*args) - return False - - return SpanWrapper(name, kind, location, kwargs) - - if len(args) > 0: - name = args[0] - if not isinstance(name, str): - raise TypeError("span() requires a string name as the first argument") - - parent = current_span() - loc = location or _get_location() - - if parent: - span_obj = Span.new_child(parent, name, kind=kind, location=loc) - else: - span_obj = Span(name, kind=kind, location=loc) - - if kwargs: - attrs_dict = dict(kwargs) - if hasattr(span_obj, "_set_initial_attrs"): - span_obj._set_initial_attrs(attrs_dict) - - return span_obj - - raise TypeError("span() requires at least one argument") - - -def _record_span_start(span: Span, attrs: dict): - """Persist span start. - - Parameters - ---------- - span : Span - Span object. - attrs : dict - Creation-time attributes. - """ - import json - - # Convert attributes to JSON string - attrs_json = None - if attrs: - attrs_json = json.dumps(attrs) - # Sanitize None values to backend-friendly sentinels (tables reject Python None) - parent_id = span.parent_id if span.parent_id is not None else -1 - kind = span.kind if span.kind is not None else "" - location = ( - span.location if hasattr(span, "location") and span.location is not None else "" - ) - attributes = attrs_json if attrs_json is not None else "" - event = TraceEvent( - record_type="span_start", - trace_id=span.trace_id, - span_id=span.span_id, - name=span.name, - time=span.start_timestamp, - thread_id=getattr(span, "thread_id", 0), - parent_id=parent_id, - kind=kind, - location=location, - attributes=attributes, - event_attributes="", # not applicable - ) - event.save() - - -def _record_span_end(span: Span): - """Persist span end with minimal data (only end time + span id). - - Other fields are blanked to reduce duplication. - """ - import time - - end_ts = span.end_timestamp or int(time.time_ns()) - event = TraceEvent( - record_type="span_end", - trace_id=0, - span_id=span.span_id, - name="", - time=end_ts, - thread_id=getattr(span, "thread_id", 0), - parent_id=-1, - kind="", - location="", - attributes="", - event_attributes="", - ) - event.save() - - -def record_closed_span( - name: str, - *, - kind: Optional[str] = None, - duration_ns: int, - attrs: Optional[dict] = None, - source: str = "manual", -) -> None: - """Persist span_start + span_end without entering the span stack. - - Used for hot-path instrumentation where ``recorded_span`` stack/location - capture would add unnecessary overhead. - """ - import json - import time - - if duration_ns < 0: - duration_ns = 0 - - TraceEvent.init_table() - merged = _merge_span_attributes(dict(attrs or {}), source=source) - end_ns = int(time.time_ns()) - start_ns = end_ns - duration_ns - - parent = current_span() - if parent: - span_obj = Span.new_child(parent, name, kind=kind, location="") - else: - span_obj = Span(name, kind=kind, location="") - - attrs_json = json.dumps(merged) if merged else "" - parent_id = span_obj.parent_id if span_obj.parent_id is not None else -1 - kind_str = kind or "" - - TraceEvent( - record_type="span_start", - trace_id=span_obj.trace_id, - span_id=span_obj.span_id, - name=name, - time=start_ns, - thread_id=getattr(span_obj, "thread_id", 0), - parent_id=parent_id, - kind=kind_str, - location="", - attributes=attrs_json, - event_attributes="", - ).save() - - TraceEvent( - record_type="span_end", - trace_id=0, - span_id=span_obj.span_id, - name="", - time=end_ns, - thread_id=getattr(span_obj, "thread_id", 0), - parent_id=-1, - kind="", - location="", - attributes="", - event_attributes="", - ).save() - - -def _record_event(span: Span, event_name: str, event_attributes: Optional[list] = None): - """Persist an event row. - - Parameters - ---------- - span : Span - Active span. - event_name : str - Event name. - event_attributes : list, optional - List of dicts or (key, value) tuples. - """ - import json - import time - - # Get current timestamp (nanoseconds since epoch) - timestamp = int(time.time_ns()) - - # Convert event attributes to JSON string - event_attrs_json = None - if event_attributes: - # Convert list of dicts/tuples to a single dict - attrs_dict = {} - for attr_item in event_attributes: - if isinstance(attr_item, dict): - attrs_dict.update(attr_item) - elif isinstance(attr_item, (list, tuple)) and len(attr_item) == 2: - attrs_dict[attr_item[0]] = attr_item[1] - if attrs_dict: - event_attrs_json = json.dumps(attrs_dict) - - parent_id = span.parent_id if span.parent_id is not None else -1 - kind = span.kind if span.kind is not None else "" - location = ( - span.location if hasattr(span, "location") and span.location is not None else "" - ) - attrs = "" # span-level attributes not duplicated here - event_attrs = event_attrs_json if event_attrs_json is not None else "" - event = TraceEvent( - record_type="event", - trace_id=span.trace_id, - span_id=span.span_id, - name=event_name, - time=timestamp, - thread_id=getattr(span, "thread_id", 0), - parent_id=parent_id, - kind=kind, - location=location, - attributes=attrs, - event_attributes=event_attrs, - ) - event.save() - - -# Add convenience methods to Span class -def _span_with(name: str, kind: Optional[str] = None): - """Convenience context manager form. - - Parameters - ---------- - name : str - Span name. - kind : str, optional - Span kind label. - - Returns - ------- - Span - Newly created span (root or child). - """ - parent = current_span() - location = _get_location() - if parent: - return Span.new_child(parent, name, kind=kind, location=location) - else: - return Span(name, kind=kind, location=location) - - -def _span_decorator(name: Optional[str] = None, kind: Optional[str] = None): - """Return a decorator that wraps a function in a span. - - Parameters - ---------- - name : str, optional - Explicit span name, defaults to function name. - kind : str, optional - Kind label. - - Returns - ------- - Callable - Decorator applying tracing span. - """ - - def decorator(func: Callable) -> Callable: - @functools.wraps(func) - def wrapper(*wargs, **wkwargs): - span_name = name or func.__name__ - with _RecordedSpan(span_name, kind=kind) as _s: - return func(*wargs, **wkwargs) - - return wrapper - - return decorator - - -# Monkey-patch Span class with convenience methods -if Span: - Span.with_ = staticmethod(_span_with) - Span.decorator = staticmethod(_span_decorator) - - -def add_event(name: str, *, attributes: Optional[list] = None): - """Add an event to the current span. - - Parameters - ---------- - name : str - Event name. - attributes : list, optional - Each item is a dict or a (key, value) tuple. - - Raises - ------ - RuntimeError - If no span is active. - - Examples - -------- - >>> with span("op"): - ... add_event("phase") - ... add_event("kv", attributes=[{"x": 1}, ("y", 2)]) - """ - current = active_span_for_events() - if current is None: - current = current_span() - if current is None or getattr(current, "is_ended", False): - raise RuntimeError("No active span in current context. Cannot add event.") - - current.add_event(name, attributes=attributes) - - # Record event to table - _record_event(current, name, attributes) - - -# Alias for add_event to match the top-level export -event = add_event diff --git a/python/probing/tracing/__init__.py b/python/probing/tracing/__init__.py new file mode 100644 index 00000000..81e4206e --- /dev/null +++ b/python/probing/tracing/__init__.py @@ -0,0 +1,62 @@ +"""Tracing primitives for probing.""" + +from __future__ import annotations + +from probing.tracing._bindings import Span, current_span +from probing.tracing.backends import ( + bind_table, + configure as configure_backends, + list_backends, + register as register_backend, + reset as reset_backends, +) +from probing.tracing.coordinates import row_fields, span_attrs, step, step_fields +from probing.tracing.hooks import ( + attach_training_phases, + detach_training_phases, + owns_training_phases, +) +from probing.tracing.phases import ( + BACKWARD, + FORWARD, + IDLE, + OPTIMIZER, + SOURCE_MANUAL, + SOURCE_PHASE_HOOK, + SOURCE_TORCH_PROBE, + phase, + reset_phase, +) +from probing.tracing.span import event, record_span, span +from probing.tracing.table import SPANS_SQL, TraceEvent + +bind_table(TraceEvent) + +__all__ = [ + "span", + "event", + "record_span", + "current_span", + "step", + "step_fields", + "row_fields", + "span_attrs", + "phase", + "reset_phase", + "attach_training_phases", + "detach_training_phases", + "owns_training_phases", + "FORWARD", + "BACKWARD", + "OPTIMIZER", + "IDLE", + "SOURCE_MANUAL", + "SOURCE_PHASE_HOOK", + "SOURCE_TORCH_PROBE", + "register_backend", + "configure_backends", + "list_backends", + "reset_backends", + "TraceEvent", + "SPANS_SQL", +] diff --git a/python/probing/tracing/_bindings.py b/python/probing/tracing/_bindings.py new file mode 100644 index 00000000..188f9b41 --- /dev/null +++ b/python/probing/tracing/_bindings.py @@ -0,0 +1,46 @@ +"""Rust tracing bindings (internal).""" + +from __future__ import annotations + +from probing import _core + +try: + Span = _core.Span + current_span = _core.current_span + active_span_for_events = _core.active_span_for_events + active_span_by_phase = _core.active_span_by_phase + active_training_phase = _core.active_training_phase + step_snapshot = _core.py_step_snapshot + sync_micro_step = _core.py_sync_micro_step + advance_micro_step = _core.py_advance_micro_step + set_micro_batches = _core.py_set_micro_batches + current_micro_step = _core.py_current_micro_step +except AttributeError: + Span = None + + def current_span(): + return None + + def active_span_for_events(): + return None + + def active_span_by_phase(_phase: str): + return None + + def active_training_phase(): + return None + + def step_snapshot(): + return None + + def sync_micro_step(_step: int): + return None + + def advance_micro_step(): + return None + + def set_micro_batches(_micro_batches: int): + return None + + def current_micro_step() -> int: + return 0 diff --git a/python/probing/tracing/backends.py b/python/probing/tracing/backends.py new file mode 100644 index 00000000..8727aeb9 --- /dev/null +++ b/python/probing/tracing/backends.py @@ -0,0 +1,590 @@ +"""Pluggable span backends — fan-out from a single recorder. + +Default backend writes ``python.trace_event`` (memtable). Optional backends +include OpenTelemetry export and third-party entry points. + +Environment +----------- +``PROBING_SPAN_BACKENDS`` + Comma-separated backend names. Default: ``memtable``. + Built-in: ``memtable``, ``logger`` (terminal), ``otel`` (requires ``opentelemetry-sdk``). + +``PROBING_SPAN_LOG_LEVEL`` + Log level for the ``logger`` backend (default: ``INFO``). + +``OTEL_EXPORTER_OTLP_ENDPOINT`` / standard OTel env vars apply when ``otel`` is enabled. +""" + +from __future__ import annotations + +import json +import logging +import os +import sys +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Protocol, runtime_checkable + +logger = logging.getLogger(__name__) + +MEMTABLE_BACKEND = "memtable" +LOGGER_BACKEND = "logger" +OTEL_BACKEND = "otel" + +_trace_event_cls: Any = None +_recorder: Optional["SpanRecorder"] = None +_custom_factories: Dict[str, Callable[[], "SpanBackend"]] = {} +_programmatic_names: Optional[List[str]] = None + + +@dataclass(frozen=True) +class SpanStartRecord: + trace_id: int + span_id: int + parent_id: int + name: str + phase: str + time_ns: int + thread_id: int + location: str + attributes_json: str + + +@dataclass(frozen=True) +class SpanEndRecord: + span_id: int + time_ns: int + thread_id: int + + +@dataclass(frozen=True) +class SpanEventRecord: + trace_id: int + span_id: int + parent_id: int + phase: str + location: str + name: str + time_ns: int + thread_id: int + event_attributes_json: str + + +@runtime_checkable +class SpanBackend(Protocol): + name: str + + def on_span_start(self, record: SpanStartRecord) -> None: ... + + def on_span_end(self, record: SpanEndRecord) -> None: ... + + def on_event(self, record: SpanEventRecord) -> None: ... + + def shutdown(self) -> None: ... + + +def bind_table(trace_event_cls: Any) -> None: + """Bind the memtable row model (``TraceEvent`` dataclass).""" + global _trace_event_cls, _recorder + _trace_event_cls = trace_event_cls + _recorder = None + + +def configure(names: Optional[List[str]] = None) -> None: + """Select span backends by name (overrides ``PROBING_SPAN_BACKENDS`` until ``reset()``).""" + global _programmatic_names, _recorder + _programmatic_names = list(names) if names is not None else None + _recorder = None + + +def register(name: str, factory: Callable[[], SpanBackend]) -> None: + """Register a custom backend factory.""" + _custom_factories[name.strip().lower()] = factory + global _recorder + _recorder = None + + +def parse_backend_names(raw: Optional[str] = None) -> List[str]: + if _programmatic_names is not None: + return list(_programmatic_names) + value = ( + raw + if raw is not None + else os.environ.get("PROBING_SPAN_BACKENDS", MEMTABLE_BACKEND) + ) + names = [part.strip().lower() for part in value.split(",") if part.strip()] + return names or [MEMTABLE_BACKEND] + + +def _entry_point_backends() -> Dict[str, Callable[[], SpanBackend]]: + grouped: Dict[str, Callable[[], SpanBackend]] = {} + try: + try: + from importlib.metadata import entry_points as _eps + except ImportError: + from importlib_metadata import entry_points as _eps # type: ignore + + try: + eps = _eps(group="probing.span_backends") + except TypeError: + eps = _eps().get("probing.span_backends", []) + for ep in eps: + grouped[ep.name.strip().lower()] = ep.load + except Exception: + pass + return grouped + + +def _build_memtable_backend() -> SpanBackend: + if _trace_event_cls is None: + raise RuntimeError( + "probing.tracing.backends.bind_table(TraceEvent) was not called" + ) + return MemtableBackend(_trace_event_cls) + + +def _build_logger_backend() -> SpanBackend: + return LoggerBackend() + + +def _build_otel_backend() -> Optional[SpanBackend]: + try: + from opentelemetry import trace # noqa: F401 + except ImportError: + logger.warning( + "PROBING_SPAN_BACKENDS includes 'otel' but opentelemetry-sdk is not installed; skipping" + ) + return None + return OtelBackend() + + +def load_backends(names: Optional[List[str]] = None) -> List[SpanBackend]: + """Instantiate backends for *names* (deduplicated, stable order).""" + wanted = parse_backend_names(",".join(names) if names else None) + entry_map = _entry_point_backends() + out: List[SpanBackend] = [] + seen: set[str] = set() + + for name in wanted: + if name in seen: + continue + seen.add(name) + + backend: Optional[SpanBackend] = None + if name == MEMTABLE_BACKEND: + backend = _build_memtable_backend() + elif name == LOGGER_BACKEND: + backend = _build_logger_backend() + elif name == OTEL_BACKEND: + backend = _build_otel_backend() + elif name in _custom_factories: + backend = _custom_factories[name]() + elif name in entry_map: + backend = entry_map[name]() + else: + logger.warning("Unknown span backend %r — skipped", name) + continue + + if backend is not None: + out.append(backend) + + if not out: + out.append(_build_memtable_backend()) + return out + + +class MemtableBackend: + """Canonical store: ``python.trace_event`` mmap rows.""" + + name = MEMTABLE_BACKEND + + def __init__(self, trace_event_cls: Any) -> None: + self._TraceEvent = trace_event_cls + + def on_span_start(self, record: SpanStartRecord) -> None: + self._TraceEvent.init_table() + self._TraceEvent( + record_type="span_start", + trace_id=record.trace_id, + span_id=record.span_id, + name=record.name, + time=record.time_ns, + thread_id=record.thread_id, + parent_id=record.parent_id, + phase=record.phase, + location=record.location, + attributes=record.attributes_json, + event_attributes="", + ).save() + + def on_span_end(self, record: SpanEndRecord) -> None: + self._TraceEvent.init_table() + self._TraceEvent( + record_type="span_end", + trace_id=0, + span_id=record.span_id, + name="", + time=record.time_ns, + thread_id=record.thread_id, + parent_id=-1, + phase="", + location="", + attributes="", + event_attributes="", + ).save() + + def on_event(self, record: SpanEventRecord) -> None: + self._TraceEvent.init_table() + self._TraceEvent( + record_type="event", + trace_id=record.trace_id, + span_id=record.span_id, + name=record.name, + time=record.time_ns, + thread_id=record.thread_id, + parent_id=record.parent_id, + phase=record.phase, + location=record.location, + attributes="", + event_attributes=record.event_attributes_json, + ).save() + + def shutdown(self) -> None: + return None + + +def _terminal_logger() -> logging.Logger: + """Logger that prints span lines to stderr when no handler is configured.""" + log = logging.getLogger("probing.span") + if not log.handlers: + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter("%(message)s")) + log.addHandler(handler) + level = os.environ.get("PROBING_SPAN_LOG_LEVEL", "INFO").upper() + log.setLevel(getattr(logging, level, logging.INFO)) + log.propagate = False + return log + + +class LoggerBackend: + """Print span lifecycle to the terminal (works alongside other backends).""" + + name = LOGGER_BACKEND + + def __init__(self, log: Optional[logging.Logger] = None) -> None: + self._log = log or _terminal_logger() + self._depth = 0 + self._open: Dict[int, tuple[str, int]] = {} + + def _indent(self) -> str: + return " " * self._depth + + def on_span_start(self, record: SpanStartRecord) -> None: + self._open[record.span_id] = (record.name, record.time_ns) + parts = [f"→ {record.name}"] + if record.phase: + parts.append(f"phase={record.phase}") + source = _attr_from_json(record.attributes_json, "source") + if source: + parts.append(f"source={source}") + self._log.info("%s%s", self._indent(), " ".join(parts)) + self._depth += 1 + + def on_span_end(self, record: SpanEndRecord) -> None: + self._depth = max(0, self._depth - 1) + opened = self._open.pop(record.span_id, None) + if opened is not None: + name, start_ns = opened + dur_ms = max(0.0, (record.time_ns - start_ns) / 1e6) + self._log.info("%s← %s %.2fms", self._indent(), name, dur_ms) + else: + self._log.info("%s← span_id=%s", self._indent(), record.span_id) + + def on_event(self, record: SpanEventRecord) -> None: + suffix = "" + if record.event_attributes_json: + try: + parsed = json.loads(record.event_attributes_json) + if isinstance(parsed, dict) and parsed: + suffix = " " + json.dumps(parsed, ensure_ascii=False) + except json.JSONDecodeError: + suffix = f" {record.event_attributes_json}" + self._log.info("%s· %s%s", self._indent(), record.name, suffix) + + def shutdown(self) -> None: + self._open.clear() + self._depth = 0 + + +def _attr_from_json(raw: str, key: str) -> Optional[str]: + if not raw: + return None + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return None + if not isinstance(parsed, dict): + return None + value = parsed.get(key) + return str(value) if value is not None else None + + +class OtelBackend: + """Optional OpenTelemetry export (Jaeger/Grafana/OTLP via standard OTel env).""" + + name = OTEL_BACKEND + + def __init__(self) -> None: + from opentelemetry import trace + from opentelemetry.trace import SpanKind, set_span_in_context + + self._trace = trace + self._SpanKind = SpanKind + self._set_span_in_context = set_span_in_context + self._tracer = trace.get_tracer("probing") + self._spans: Dict[int, Any] = {} + self._parents: Dict[int, int] = {} + + def _kind(self, kind: str) -> Any: + mapping = { + "server": self._SpanKind.SERVER, + "client": self._SpanKind.CLIENT, + "producer": self._SpanKind.PRODUCER, + "consumer": self._SpanKind.CONSUMER, + } + return mapping.get(kind, self._SpanKind.INTERNAL) + + def on_span_start(self, record: SpanStartRecord) -> None: + parent_ctx = None + if record.parent_id not in (-1, None): + parent_otel = self._spans.get(record.parent_id) + if parent_otel is not None: + parent_ctx = self._set_span_in_context(parent_otel) + + otel_span = self._tracer.start_span( + record.name, + context=parent_ctx, + kind=self._kind(record.phase), + start_time=record.time_ns, + ) + if record.attributes_json: + try: + attrs = json.loads(record.attributes_json) + if isinstance(attrs, dict): + for key, value in attrs.items(): + otel_span.set_attribute(str(key), value) + else: + otel_span.set_attribute( + "probing.attributes", record.attributes_json + ) + except json.JSONDecodeError: + otel_span.set_attribute("probing.attributes", record.attributes_json) + if record.phase: + otel_span.set_attribute("probing.phase", record.phase) + if record.location: + otel_span.set_attribute("probing.location", record.location) + + self._spans[record.span_id] = otel_span + self._parents[record.span_id] = record.parent_id + + def on_span_end(self, record: SpanEndRecord) -> None: + otel_span = self._spans.pop(record.span_id, None) + self._parents.pop(record.span_id, None) + if otel_span is None: + return + otel_span.end(end_time=record.time_ns) + + def on_event(self, record: SpanEventRecord) -> None: + otel_span = self._spans.get(record.span_id) + if otel_span is None: + return + attrs: Dict[str, Any] = {} + if record.event_attributes_json: + try: + parsed = json.loads(record.event_attributes_json) + if isinstance(parsed, dict): + attrs = {str(k): v for k, v in parsed.items()} + except json.JSONDecodeError: + attrs = {"raw": record.event_attributes_json} + otel_span.add_event(record.name, attributes=attrs, timestamp=record.time_ns) + + def shutdown(self) -> None: + for span_id, otel_span in list(self._spans.items()): + try: + otel_span.end() + except Exception: + pass + self._spans.pop(span_id, None) + self._parents.clear() + + +class SpanRecorder: + """Fan-out span lifecycle records to all enabled backends.""" + + def __init__(self, backends: List[SpanBackend]) -> None: + self._backends = backends + + @property + def backend_names(self) -> List[str]: + return [b.name for b in self._backends] + + def record_span_start(self, span: Any, attrs: dict) -> None: + record = _span_start_record(span, attrs) + self._dispatch("on_span_start", record) + + def record_span_end(self, span: Any) -> None: + end_ts = span.end_timestamp or int(time.time_ns()) + record = SpanEndRecord( + span_id=int(span.span_id), + time_ns=int(end_ts), + thread_id=int(getattr(span, "thread_id", 0)), + ) + self._dispatch("on_span_end", record) + + def record_closed_span( + self, + span: Any, + *, + name: str, + phase: str, + start_ns: int, + end_ns: int, + attributes_json: str, + ) -> None: + start = SpanStartRecord( + trace_id=int(span.trace_id), + span_id=int(span.span_id), + parent_id=int(span.parent_id if span.parent_id is not None else -1), + name=name, + phase=phase, + time_ns=int(start_ns), + thread_id=int(getattr(span, "thread_id", 0)), + location="", + attributes_json=attributes_json, + ) + end = SpanEndRecord( + span_id=int(span.span_id), + time_ns=int(end_ns), + thread_id=int(getattr(span, "thread_id", 0)), + ) + self._dispatch("on_span_start", start) + self._dispatch("on_span_end", end) + + def record_event( + self, span: Any, event_name: str, event_attributes: Optional[list] = None + ) -> None: + record = _event_record(span, event_name, event_attributes) + self._dispatch("on_event", record) + + def shutdown(self) -> None: + for backend in self._backends: + try: + backend.shutdown() + except Exception as exc: + logger.debug("span backend %s.shutdown failed: %s", backend.name, exc) + + def _dispatch(self, method: str, record: Any) -> None: + for backend in self._backends: + _safe_call(backend, method, record) + + +def _safe_call(backend: SpanBackend, method: str, record: Any) -> None: + try: + getattr(backend, method)(record) + except Exception as exc: + logger.debug("span backend %s.%s failed: %s", backend.name, method, exc) + + +def _span_start_record(span: Any, attrs: dict) -> SpanStartRecord: + attrs_json = json.dumps(attrs) if attrs else "" + raw = getattr(span, "phase", None) or "" + phase = raw if raw is not None else "" + location = ( + span.location if hasattr(span, "location") and span.location is not None else "" + ) + return SpanStartRecord( + trace_id=int(span.trace_id), + span_id=int(span.span_id), + parent_id=int(span.parent_id if span.parent_id is not None else -1), + name=str(span.name), + phase=str(phase), + time_ns=int(span.start_timestamp), + thread_id=int(getattr(span, "thread_id", 0)), + location=str(location), + attributes_json=attrs_json, + ) + + +def _event_record( + span: Any, event_name: str, event_attributes: Optional[list] +) -> SpanEventRecord: + attrs_dict: Dict[str, Any] = {} + if event_attributes: + for attr_item in event_attributes: + if isinstance(attr_item, dict): + attrs_dict.update(attr_item) + elif isinstance(attr_item, (list, tuple)) and len(attr_item) == 2: + attrs_dict[attr_item[0]] = attr_item[1] + event_attrs_json = json.dumps(attrs_dict) if attrs_dict else "" + raw = getattr(span, "phase", None) or "" + phase = raw if raw is not None else "" + location = ( + span.location if hasattr(span, "location") and span.location is not None else "" + ) + return SpanEventRecord( + trace_id=int(span.trace_id), + span_id=int(span.span_id), + parent_id=int(span.parent_id if span.parent_id is not None else -1), + phase=str(phase), + location=str(location), + name=str(event_name), + time_ns=int(time.time_ns()), + thread_id=int(getattr(span, "thread_id", 0)), + event_attributes_json=event_attrs_json, + ) + + +def get_recorder(*, reset: bool = False) -> SpanRecorder: + global _recorder + if _recorder is None or reset: + _recorder = SpanRecorder(load_backends()) + return _recorder + + +def _reset_recorder() -> None: + """Drop cached recorder instance.""" + global _recorder + if _recorder is not None: + try: + _recorder.shutdown() + except Exception: + pass + _recorder = None + + +def reset(*, clear_registered: bool = False) -> None: + """Drop cached recorder; optionally clear ``register()`` factories and ``configure()`` override.""" + global _programmatic_names + if clear_registered: + _custom_factories.clear() + _programmatic_names = None + _reset_recorder() + + +def list_backends() -> List[str]: + """Return names of currently active span backends.""" + return get_recorder().backend_names + + +__all__ = [ + "SpanBackend", + "MemtableBackend", + "LoggerBackend", + "OtelBackend", + "SpanRecorder", + "register", + "configure", + "list_backends", + "reset", + "bind_table", +] diff --git a/python/probing/tracing/coordinates.py b/python/probing/tracing/coordinates.py new file mode 100644 index 00000000..fe12b6a3 --- /dev/null +++ b/python/probing/tracing/coordinates.py @@ -0,0 +1,93 @@ +"""Training step coordinates and span/table context fields.""" + +from __future__ import annotations + +from typing import Any, Optional + +from probing.tracing._bindings import ( + advance_micro_step, + set_micro_batches, + step_snapshot, + sync_micro_step, +) + +_ROW_DEFAULTS = { + "micro_step": 0, + "local_step": 0, + "global_step": 0, + "micro_batches": 1, + "rank": -1, + "world_size": -1, +} + + +class Step: + """Training step controller. + + * ``micro_step`` — finest counter; ``probing.step()`` advances by one. + * ``local_step = micro_step // micro_batches`` — per-rank training step. + * ``global_step = local_step``. + """ + + def __call__( + self, value: Optional[int] = None, *, micro_batches: Optional[int] = None + ) -> None: + if micro_batches is not None: + set_micro_batches(micro_batches) + if value is not None: + sync_micro_step(value) + return + if micro_batches is not None: + return + advance_micro_step() + + @property + def micro_step(self) -> int: + return int(step_snapshot().micro_step) + + @property + def local_step(self) -> int: + return int(step_snapshot().local_step) + + @property + def global_step(self) -> int: + return int(step_snapshot().global_step) + + def snapshot(self) -> Any: + return step_snapshot() + + +step = Step() + + +def step_fields(snapshot) -> dict: + """Step/topology fields from a snapshot.""" + if snapshot is None: + return {} + local = int(snapshot.local_step) + return { + "micro_step": int(snapshot.micro_step), + "local_step": local, + "global_step": int(snapshot.global_step), + "micro_batches": int(snapshot.micro_batches), + "rank": int(snapshot.rank), + "world_size": int(snapshot.world_size), + } + + +def row_fields(snapshot=None) -> dict: + """Step coordinates with memtable-friendly defaults.""" + snap = snapshot if snapshot is not None else step.snapshot() + fields = step_fields(snap) + return {key: fields.get(key, default) for key, default in _ROW_DEFAULTS.items()} + + +def span_attrs(user: dict, *, source: str = "manual") -> dict: + """Merge user attrs with step coordinates, topology, and source label.""" + merged = dict(user) + merged.setdefault("source", source) + merged.update(step_fields(step.snapshot())) + from probing.parallel import parallel_fields + + merged.update(parallel_fields()) + return merged diff --git a/python/probing/tracing/hooks.py b/python/probing/tracing/hooks.py new file mode 100644 index 00000000..7d1bcec7 --- /dev/null +++ b/python/probing/tracing/hooks.py @@ -0,0 +1,138 @@ +"""PyTorch model/optimizer hooks for automatic training phase spans.""" + +from __future__ import annotations + +import logging +from typing import Optional + +from probing.tracing.phases import BACKWARD, FORWARD, OPTIMIZER, hook_enter, hook_exit + +logger = logging.getLogger(__name__) + +_REGISTRY: dict[tuple[int, int], PhaseTracker] = {} + + +class PhaseTracker: + def __init__(self, model, optimizer) -> None: + self.model = model + self.optimizer = optimizer + self._handles: list = [] + + def install(self) -> None: + if self._handles: + return + m = self.model + opt = self.optimizer + self._handles = [ + m.register_forward_pre_hook(self._forward_pre), + m.register_forward_hook(self._forward_post), + m.register_full_backward_pre_hook(self._backward_pre), + m.register_full_backward_hook(self._backward_post), + opt.register_step_pre_hook(self._step_pre), + opt.register_step_post_hook(self._step_post), + ] + + def uninstall(self) -> None: + for h in self._handles: + try: + h.remove() + except Exception: + pass + self._handles.clear() + + def _forward_pre(self, module, _inputs) -> None: + if module.training: + hook_enter(FORWARD) + + def _forward_post(self, module, _inputs, _output) -> None: + if module.training: + hook_exit(FORWARD) + + def _backward_pre(self, _module, _grad_output) -> None: + hook_enter(BACKWARD) + + def _backward_post(self, _module, _inputs, _grad_output) -> None: + hook_exit(BACKWARD) + + def _step_pre(self, _optimizer, _args, _kwargs) -> None: + hook_enter(OPTIMIZER) + + def _step_post(self, _optimizer, _args, _kwargs) -> None: + hook_exit(OPTIMIZER) + + +def attach_training_phases(model, optimizer) -> PhaseTracker: + key = (id(model), id(optimizer)) + if key in _REGISTRY: + return _REGISTRY[key] + tracker = PhaseTracker(model, optimizer) + tracker.install() + _REGISTRY[key] = tracker + return tracker + + +def detach_training_phases(model, optimizer) -> None: + key = (id(model), id(optimizer)) + tracker = _REGISTRY.pop(key, None) + if tracker is not None: + tracker.uninstall() + + +def owns_training_phases(*, model=None, optimizer=None, module=None) -> bool: + """True when ``attach_training_phases`` owns iteration-level phase spans. + + * **optimizer** — same optimizer instance passed to ``attach_training_phases``. + * **model** — root model id match. + * **module** — *module* is the registered root or any of its submodules. + """ + if model is not None: + mid = id(model) + return any(k[0] == mid for k in _REGISTRY) + if optimizer is not None: + oid = id(optimizer) + return any(k[1] == oid for k in _REGISTRY) + if module is not None: + mid = id(module) + for tracker in _REGISTRY.values(): + root = tracker.model + if mid == id(root): + return True + for sub in root.modules(): + if id(sub) == mid: + return True + return False + return bool(_REGISTRY) + + +def maybe_auto_attach(optimizer) -> Optional[PhaseTracker]: + if not _phases_enabled(): + return None + for (_mid, oid), tracker in _REGISTRY.items(): + if oid == id(optimizer): + return tracker + try: + import probing + from probing.profiling.torch.module_utils import get_toplevel_module + except Exception: + return None + if not probing.is_enabled(): + return None + models = get_toplevel_module() + if not models: + return None + tracker = None + for model in models: + tracker = attach_training_phases(model, optimizer) + return tracker + + +def _phases_enabled() -> bool: + try: + import probing + + spec = probing.config.get_str("probing.torch.phases") + if spec is None or spec == "": + return True + return spec.lower() in ("1", "true", "on", "yes") + except Exception: + return True diff --git a/python/probing/tracing/phases.py b/python/probing/tracing/phases.py new file mode 100644 index 00000000..637b768f --- /dev/null +++ b/python/probing/tracing/phases.py @@ -0,0 +1,174 @@ +"""Training phase vocabulary, inference, and runtime coordination.""" + +from __future__ import annotations + +import time +from typing import Literal, Optional + +FORWARD = "forward" +BACKWARD = "backward" +OPTIMIZER = "optimizer" +IDLE = "idle" + +TrainingPhase = Literal["forward", "backward", "optimizer"] + +ALL = frozenset({FORWARD, BACKWARD, OPTIMIZER}) + +# Span name → phase when ``phase`` is omitted. +_NAME_PHASE: dict[str, str] = { + "forward": FORWARD, + "backward": BACKWARD, + "step": OPTIMIZER, + "optimizer": OPTIMIZER, +} + +# Span names that must not infer a training phase. +_NON_PHASE_NAMES = frozenset( + {"train.step", "model.init", "data.load", "checkpoint.save"} +) + +# --- Composability: who may emit a phase span (higher suppresses lower) --- +SOURCE_MANUAL = "manual" +SOURCE_PHASE_HOOK = "phase_hook" +SOURCE_TORCH_PROBE = "torch_probe" + + +def infer(name: str) -> Optional[str]: + if not name or name in _NON_PHASE_NAMES: + return None + if name in _NAME_PHASE: + return _NAME_PHASE[name] + base = name.rsplit(".", 1)[-1] + return _NAME_PHASE.get(base) + + +def infer_from_stage(stage: str) -> Optional[str]: + """Map TorchProbe hook stage label to training phase.""" + lowered = stage.lower() + for token in ("optimizer", "backward", "forward", "step"): + if token in lowered: + mapped = infer(token) + if mapped is not None: + return mapped + return None + + +def resolve(name: str, phase: Optional[str]) -> Optional[str]: + if phase is not None: + if phase not in ALL: + raise ValueError( + f"invalid training phase {phase!r}; use FORWARD, BACKWARD, or OPTIMIZER" + ) + return phase + return infer(name) + + +def resolve_span( + name: Optional[str] = None, + phase: Optional[str] = None, +) -> tuple[str, Optional[str]]: + """Return ``(span_name, training_phase)``. Requires at least one of *name* or *phase*. + + When *phase* is given and *name* is omitted, ``span_name == phase`` (canonical form). + """ + if phase is not None: + resolved = resolve(name or phase, phase) + display = name if name is not None else resolved + assert display is not None + return display, resolved + if name is not None: + return name, resolve(name, None) + raise TypeError("span() requires name and/or phase") + + +def is_training_phase(value: Optional[str]) -> bool: + return value in ALL + + +# --- Runtime coordination (hook + span stack) --- + +_hook_spans: dict[str, object] = {} +_iteration_start_ns: Optional[int] = None + + +def phase() -> str: + """Current training phase from the innermost active phase span, else ``idle``.""" + from probing.tracing._bindings import active_training_phase + + active = active_training_phase() + return active if active is not None else IDLE + + +def reset_phase() -> None: + """Reset coordinator state (tests).""" + global _iteration_start_ns + _hook_spans.clear() + _iteration_start_ns = None + + +def on_span_enter(name: str, span_phase: Optional[str], source: str) -> None: + """Reserved for span lifecycle hooks; phase is derived from the span stack.""" + del name, span_phase, source + + +def on_span_exit(name: str, span_phase: Optional[str], source: str) -> None: + """Reserved for span lifecycle hooks; phase is derived from the span stack.""" + del name, span_phase, source + + +def hook_enter(span_phase: str) -> None: + global _iteration_start_ns + if span_phase == FORWARD and _iteration_start_ns is None: + _iteration_start_ns = time.time_ns() + if _phase_tracked(span_phase, by_source=SOURCE_PHASE_HOOK): + return + if span_phase in _hook_spans: + return + import probing + + cm = probing.span(phase=span_phase, source=SOURCE_PHASE_HOOK) + cm.__enter__() + _hook_spans[span_phase] = cm + + +def hook_exit(span_phase: str) -> None: + if span_phase == OPTIMIZER: + _record_train_step() + cm = _hook_spans.pop(span_phase, None) + if cm is not None: + cm.__exit__(None, None, None) + + +def _record_train_step() -> None: + """Record one ``train.step`` closed span for the current logical iteration.""" + global _iteration_start_ns + if _iteration_start_ns is None: + return + import probing + + from probing.tracing.coordinates import step, step_fields + + snap = step.snapshot() + mb = max(int(snap.micro_batches), 1) + micro = int(snap.micro_step) + attrs = { + **step_fields(snap), + "accum_index": micro % mb, + "logical_step_pending": micro // mb, + } + duration_ns = int(time.time_ns()) - _iteration_start_ns + _iteration_start_ns = None + probing.record_span( + "train.step", + duration_ns=duration_ns, + attrs=attrs, + source=SOURCE_PHASE_HOOK, + ) + + +def _phase_tracked(span_phase: str, *, by_source: str) -> bool: + """True when an active span already carries this training phase.""" + del by_source + from probing.tracing._bindings import active_span_by_phase + + return active_span_by_phase(span_phase) is not None diff --git a/python/probing/tracing/span.py b/python/probing/tracing/span.py new file mode 100644 index 00000000..248739ee --- /dev/null +++ b/python/probing/tracing/span.py @@ -0,0 +1,317 @@ +"""Span lifecycle: open, event, record.""" + +from __future__ import annotations + +import functools +import inspect +import json +import os +import time +from typing import Callable, Optional + +from probing.tracing._bindings import ( + Span, + active_span_by_phase, + active_span_for_events, + current_span, +) +from probing.tracing.coordinates import span_attrs, step +from probing.tracing.phases import OPTIMIZER, resolve_span +from probing.tracing.table import TraceEvent + +_LOCATION_ENV = frozenset({"1", "true", "yes", "on"}) + + +class _RecordedSpan: + """Context manager: span stack + backend persistence.""" + + def __init__( + self, + name: str, + phase: Optional[str] = None, + location: Optional[str] = None, + attrs: Optional[dict] = None, + *, + source: str = "manual", + auto_location: bool = False, + ): + self.name = name + self.phase = phase + self.location = location + self.attrs = dict(attrs or {}) + self.source = source + self._auto_location = auto_location + self._span = None + self._reentrant = False + self._owns_step_advance = False + + def __enter__(self): + if self.phase == OPTIMIZER: + existing = active_span_by_phase(OPTIMIZER) + if existing is not None: + self._span = existing + self._reentrant = True + return existing + + loc = self.location + if loc is None and self._auto_location: + loc = _caller_location() + merged = span_attrs(self.attrs, source=self.source) + self._span = _create_span(self.name, self.phase, loc, merged) + self._span.__enter__() + _persist_span_start(self._span, merged) + if self.phase == OPTIMIZER: + self._owns_step_advance = True + return self._span + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._span is None or self._reentrant: + return False + result = self._span.__exit__(exc_type, exc_val, exc_tb) + _persist_span_end(self._span) + if self._owns_step_advance: + step() + return result + + +def _create_span(name: str, phase: Optional[str], location: Optional[str], attrs: dict): + parent = current_span() + if parent: + span_obj = Span.new_child(parent, name, phase=phase, location=location) + else: + span_obj = Span(name, phase=phase, location=location) + if attrs and hasattr(span_obj, "_set_initial_attrs"): + try: + span_obj._set_initial_attrs(dict(attrs)) + except Exception as e: + import warnings + + warnings.warn(f"Failed to set initial attributes: {e}") + return span_obj + + +def _caller_location() -> Optional[str]: + """Walk ``inspect.stack()`` for the first frame outside ``probing/tracing``.""" + try: + for frame_info in inspect.stack()[2:]: + path = frame_info.filename.replace("\\", "/") + if "probing/tracing" in path: + continue + return f"{frame_info.filename}:{frame_info.function}:{frame_info.lineno}" + except Exception: + pass + return None + + +def _location_enabled() -> bool: + return os.environ.get("PROBING_SPAN_LOCATION", "").lower() in _LOCATION_ENV + + +def _span_options( + kwargs: dict, +) -> tuple[str, Optional[str], str, Optional[str], dict, bool]: + phase = kwargs.pop("phase", None) + source = kwargs.pop("source", "manual") + location = kwargs.pop("location", None) + auto_location = location is None and _location_enabled() + return phase, source, location, kwargs, auto_location + + +def _make_handle( + name: str, + phase: Optional[str], + location: Optional[str], + attrs: dict, + source: str, + auto_location: bool, +): + class SpanHandle: + def __init__(self): + self.name = name + self.phase = phase + self.location = location + self.source = source + self.attrs = attrs + self._auto_location = auto_location + self._inner = None + + def __call__(self, func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*wargs, **wkwargs): + with _RecordedSpan( + self.name, + phase=self.phase, + location=self.location, + attrs=self.attrs, + source=self.source, + auto_location=self._auto_location, + ): + return func(*wargs, **wkwargs) + + return wrapper + + def __enter__(self): + self._inner = _RecordedSpan( + self.name, + phase=self.phase, + location=self.location, + attrs=self.attrs, + source=self.source, + auto_location=self._auto_location, + ) + return self._inner.__enter__() + + def __exit__(self, *exc): + if self._inner: + return self._inner.__exit__(*exc) + return False + + def __getattr__(self, attr): + if self._inner is not None: + return getattr(self._inner, attr) + raise AttributeError(attr) + + return SpanHandle() + + +def span(*args, **kwargs): + """Open a span (context manager, decorator, or manual enter/exit). + + Reserved kwargs: ``phase``, ``source``, ``location``. Training phases are + ``FORWARD``, ``BACKWARD``, ``OPTIMIZER`` (see ``probing.tracing.phases``). + + When ``phase`` is set and ``name`` is omitted, ``name`` defaults to ``phase``. + When only ``name`` is given, phase is inferred (e.g. ``"forward"`` → ``FORWARD``). + + Auto ``location`` via ``inspect.stack()`` is off by default; set + ``PROBING_SPAN_LOCATION=1`` or pass ``location=...`` explicitly. + """ + phase_kw, source, location, attrs, auto_location = _span_options(dict(kwargs)) + + if len(args) == 1 and isinstance(args[0], str): + name_kw = args[0] + name, phase = resolve_span(name_kw, phase_kw) + return _make_handle(name, phase, location, attrs, source, auto_location) + + if len(args) == 0: + if phase_kw is not None: + name, phase = resolve_span(None, phase_kw) + return _make_handle(name, phase, location, attrs, source, auto_location) + if not attrs: + + def decorator(func: Callable) -> Callable: + resolved_name, resolved_phase = resolve_span(func.__name__, None) + return _make_handle( + resolved_name, + resolved_phase, + location, + {}, + source, + auto_location, + )(func) + + return decorator + raise TypeError("span() requires name and/or phase") + + if len(args) == 1 and callable(args[0]): + func = args[0] + resolved_name, resolved_phase = resolve_span(func.__name__, phase_kw) + + @functools.wraps(func) + def wrapper(*wargs, **wkwargs): + with _RecordedSpan( + resolved_name, + phase=resolved_phase, + location=location, + attrs=attrs, + source=source, + auto_location=auto_location, + ): + return func(*wargs, **wkwargs) + + return wrapper + + if len(args) == 1: + raise TypeError( + f"span() first argument must be str or callable, got {type(args[0]).__name__}" + ) + if len(args) > 1: + raise TypeError("span() takes at most one positional argument") + + raise TypeError("span() requires at least one argument") + + +def event(name: str, *, attributes: Optional[list] = None): + """Add a point event on the active span.""" + current = active_span_for_events() or current_span() + if current is None or getattr(current, "is_ended", False): + raise RuntimeError("No active span in current context. Cannot add event.") + current.add_event(name, attributes=attributes) + + +def record_span( + name: str, + *, + phase: Optional[str] = None, + duration_ns: int, + attrs: Optional[dict] = None, + source: str = "manual", +) -> None: + """Record a completed span without entering the span stack (hot path).""" + if duration_ns < 0: + duration_ns = 0 + + TraceEvent.init_table() + merged = span_attrs(dict(attrs or {}), source=source) + end_ns = int(time.time_ns()) + start_ns = end_ns - duration_ns + resolved_name, resolved_phase = resolve_span(name, phase) + + parent = current_span() + if parent: + span_obj = Span.new_child( + parent, resolved_name, phase=resolved_phase, location="" + ) + else: + span_obj = Span(resolved_name, phase=resolved_phase, location="") + + from probing.tracing.backends import get_recorder + + get_recorder().record_closed_span( + span_obj, + name=resolved_name, + phase=resolved_phase or "", + start_ns=start_ns, + end_ns=end_ns, + attributes_json=json.dumps(merged) if merged else "", + ) + + +def _persist_span_start(span: Span, attrs: dict) -> None: + from probing.tracing.backends import get_recorder + + get_recorder().record_span_start(span, attrs) + + +def _persist_span_end(span: Span) -> None: + from probing.tracing.backends import get_recorder + + get_recorder().record_span_end(span) + + +def _persist_event( + span: Span, event_name: str, event_attributes: Optional[list] = None +) -> None: + from probing.tracing.backends import get_recorder + + get_recorder().record_event(span, event_name, event_attributes) + + +if Span: + _rust_add_event = Span.add_event + + def _add_event_persist(self, name, attributes=None): + _rust_add_event(self, name, attributes=attributes) + _persist_event(self, name, attributes) + + Span.add_event = _add_event_persist diff --git a/python/probing/tracing/table.py b/python/probing/tracing/table.py new file mode 100644 index 00000000..c54be71c --- /dev/null +++ b/python/probing/tracing/table.py @@ -0,0 +1,50 @@ +"""Trace event table schema (SQL / federation).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from probing.core.table import table + +# Materialized span rows derived from ``python.trace_event`` (start/end join). +# Use span ``time`` (ns since epoch), not the memtable ingestion ``timestamp``. +SPANS_SQL = """ +SELECT + s.trace_id, + s.span_id, + COALESCE(s.parent_id, -1) AS parent_span_id, + s.name, + s.phase, + CAST(s.time / 1000 AS BIGINT) AS start_us, + CAST(e.time / 1000 AS BIGINT) AS end_us, + CAST((e.time - s.time) / 1000 AS BIGINT) AS duration_us, + s.thread_id, + s.location, + s.attributes +FROM python.trace_event s +JOIN python.trace_event e + ON s.span_id = e.span_id AND e.record_type = 'span_end' +WHERE s.record_type = 'span_start' +""" + + +@table +@dataclass +class TraceEvent: + """Row model for trace records. + + Each saved instance is one of: span_start, span_end, event. + """ + + record_type: str + trace_id: int + span_id: int + name: str + time: int + thread_id: int = 0 + parent_id: Optional[int] = -1 + phase: Optional[str] = "" + location: Optional[str] = "" + attributes: Optional[str] = "" + event_attributes: Optional[str] = "" diff --git a/python/probing/web_assets.py b/python/probing/web_assets.py index 9c5ad7fe..c7abd9e7 100644 --- a/python/probing/web_assets.py +++ b/python/probing/web_assets.py @@ -53,6 +53,18 @@ def _resource_dir(name: str, marker: str) -> Path | None: return None +def _looks_like_built_ui(root: Path) -> bool: + """True when ``index.html`` is a Dioxus bundle, not the checkout placeholder.""" + index = root / "index.html" + if not index.is_file(): + return False + try: + body = index.read_text(encoding="utf-8", errors="ignore") + except OSError: + return False + return "web-dxh" in body or '
' in body + + def resolve_web_assets_root() -> Path | None: """Return the directory that contains ``index.html``, if any.""" override = os.environ.get(_ENV) @@ -61,8 +73,19 @@ def resolve_web_assets_root() -> Path | None: if (root / "index.html").is_file(): return root return None - if bundled := bundled_web_dir(): - return bundled + + # Editable checkout: prefer freshly built ``web/dist`` over stale ``bundled_web``. + if not _running_from_installed_wheel(): + dev = dev_web_dir() + if dev: + if _looks_like_built_ui(dev): + return dev + + bundled = bundled_web_dir() + if bundled: + if _looks_like_built_ui(bundled): + return bundled + return dev_web_dir() diff --git a/skills/semantic/tables.yaml b/skills/semantic/tables.yaml index 7ce954d6..06f4c7bc 100644 --- a/skills/semantic/tables.yaml +++ b/skills/semantic/tables.yaml @@ -1,139 +1,361 @@ # Semantic layer: maps user/agent language to probing SQL tables and columns. -# Used in system prompts and skill prerequisite checks. +# Used in system prompts, skill prerequisite checks, and Engine DESCRIBE rewrite +# (probe.probing.table_docs / column_docs). apiVersion: probing.dev/v1 kind: SemanticCatalog tables: python.torch_trace: - description: "PyTorch module-level forward/step timings and GPU memory snapshots" + description: "PyTorch 模块级 forward/step hook 耗时与 GPU 显存快照(长驻采样,非 torch.profiler 窗口)" synonyms: [torch trace, module timing, 模块耗时, 训练步 profiling] key_columns: - step: "Training step index (int)" - global_step: "Global training step (from Rust step_snapshot)" - rank: "torch.distributed rank" - role: "Parallel role key, e.g. 'dp=2,pp=1,tp=0' (align/join across ranks)" - module: "Fully-qualified module name" - stage: "One of: forward, step (post forward / post step hooks)" - duration: "Hook duration in seconds" - allocated: "GPU memory allocated (MB) at hook time" - allocated_delta: "Change in allocated since previous hook (MB)" - max_allocated: "Peak allocated (MB)" + micro_step: "最细粒度步计数(micro-batch 序号)" + local_step: "本 rank 训练步(micro_step // micro_batches)" + global_step: "全局训练步(各 rank 对齐时与 local_step 相同)" + micro_batches: "梯度累积因子" + seq: "本 step 内 hook 序号" + rank: "torch.distributed rank(-1 未知)" + world_size: "world size(-1 未知)" + role: "并行角色键,如 dp=2,pp=1,tp=0" + module: "模块全限定名" + stage: "pre forward | post forward | pre step | post step" + time_offset: "相对本 step 时间锚点的秒偏移" + duration: "hook 耗时(秒);post 行有意义" + allocated: "当前 GPU 已分配显存(MB)" + allocated_delta: "相对上一 hook 的 allocated 变化(MB)" + max_allocated: "峰值已分配显存(MB)" + max_allocated_delta: "相对上一 hook 的 max_allocated 变化(MB)" + cached: "当前 GPU reserved 显存(MB)" + max_cached: "峰值 reserved 显存(MB)" notes: - - "First complete step is discovery-only (may have no rows)" - - "Backward hooks are off by default" + - "第一个完整 step 为 discovery-only,可能无数据行" + - "默认仅 forward hook;backward hook 默认关闭" python.comm_collective: - description: "torch.distributed collective calls (all_reduce, broadcast, …)" + description: "torch.distributed 集合通信调用记录(all_reduce、broadcast 等)" synonyms: [collective, communication, NCCL, 通信, all_reduce] key_columns: - global_step: "Global training step" - local_step: "Local step on this rank" + micro_step: "最细粒度步计数" + local_step: "本 rank 训练步" + global_step: "全局训练步" + micro_batches: "梯度累积因子" rank: "torch.distributed rank" - role: "Parallel role key, e.g. 'dp=2,pp=1,tp=0' (align/join across ranks)" - op: "Collective operation name" - duration_ms: "Wall time in milliseconds" - bytes: "Tensor bytes communicated" + world_size: "world size" + role: "并行角色键" + op: "集合通信算子名" + group_rank: "进程组内 rank" + group_size: "进程组大小" + participate_ranks: "参与 rank 列表(JSON 数组字符串)" + tensor_shape: "张量 shape 字符串" + tensor_dtype: "张量 dtype" + bytes: "通信字节数" + duration_ms: "墙钟耗时(毫秒)" + async_op: "1=异步 collective,0=同步" global_name: global.python.comm_collective federation_columns: [_host, _addr, _rank, _role] python.trace_event: - description: "Span start/end and custom events (distributed tracing)" + description: "分布式 tracing:span 起止与自定义 event(python.tracing)" synonyms: [trace, span, timeline, 链路] key_columns: record_type: "span_start | span_end | event" - trace_id: "Trace id shared by related spans" - span_id: "Unique span id" - name: "Span or event name" - time: "Timestamp (nanoseconds since epoch)" + trace_id: "同一 trace 内共享的 trace id" + span_id: "span 唯一 id" + parent_id: "父 span id(-1 表示无)" + name: "span 或 event 名称" + phase: "训练阶段:forward | backward | optimizer(可为空)" + time: "时间戳(纳秒,epoch)" + thread_id: "记录线程 id" + location: "源码位置(file:line)" + attributes: "JSON 元数据(rank、local_step 等)" + event_attributes: "event 专用 JSON 属性" + notes: + - "与 span_end 按 span_id join 可算 duration" + - "物化视图见 python.tracing.table.SPANS_SQL" python.backtrace: - description: "Live mixed Python + native stack (point-in-time, not historical)" + description: "混合 Python + native 调用栈快照(即时采集,非历史表)" synonyms: [stack, backtrace, 调用栈, hang stack] key_columns: - func: "Function name" - file: "Source file" - lineno: "Line number" - depth: "Stack depth (0 = innermost)" - frame_type: "python | native" + ip: "native 帧指令指针(Python 帧为 NULL)" + file: "源文件路径" + func: "函数/符号名" + lineno: "行号" + depth: "栈深度(0=最内层)" + frame_type: "Python | Native | Rust" + notes: + - "通过 probing backtrace 或 inject 后 SELECT 填充" + + python.variables: + description: "TorchProbe 变量追踪快照(需启用 vars/watch)" + synonyms: [variable trace, watch, 变量快照] + key_columns: + micro_step: "训练 micro-step" + func: "函数名" + name: "变量名" + value: "变量值的字符串表示" + notes: + - "与 python.trace_variables 不同;后者记录调试器 watch 变更" + + python.trace_variables: + description: "调试器 watch 触发的变量变更记录" + synonyms: [trace variables, watch changes, 变量变更] + key_columns: + function_name: "发生变更的函数名" + filename: "源文件" + lineno: "行号" + variable_name: "变量名" + value: "变更后的值(字符串)" + value_type: "值类型名" cpu.utilization: - description: "Host CPU and RSS sampling (process and top threads)" - synonyms: [cpu, host memory, RSS, 主机内存] + description: "主机 CPU / RSS 周期采样(进程级 + Top-N 线程)" + synonyms: [cpu, host memory, RSS, 主机 CPU] key_columns: - ts: "Sample timestamp (microseconds)" + ts: "采样时间戳(微秒 since epoch)" scope: "process | thread" - rss_kb: "Resident set size (KB) — process scope only" - cpu_total_pct: "CPU utilization percentage" - comm: "Thread/process name" - wchan: "Kernel wait channel (Linux)" + platform: "采样后端(linux | macos)" + tid: "线程 id(process scope 时为 0)" + comm: "进程/线程名" + wall_ns: "距上次采样的墙钟间隔(纳秒)" + delta_user_ns: "用户态 CPU 增量(纳秒)" + delta_sys_ns: "内核态 CPU 增量(纳秒)" + delta_total_ns: "总 CPU 增量(纳秒)" + cpu_user_pct: "用户态 CPU 利用率(%)" + cpu_sys_pct: "内核态 CPU 利用率(%)" + cpu_total_pct: "总 CPU 利用率(%)" + cum_user_ns: "累计用户态 CPU 时间(纳秒)" + cum_sys_ns: "累计内核态 CPU 时间(纳秒)" + rss_kb: "常驻内存 RSS(KB);仅 process scope" + thread_count: "线程数;仅 process scope" + delta_vol_ctxt: "自愿上下文切换增量" + delta_invol_ctxt: "非自愿上下文切换增量" + state: "线程/进程状态(Linux)" + wchan: "内核 wait channel(Linux)" + + cpu.tasks: + description: "CPU Top-N 热点线程明细(与 cpu.utilization 同周期采样)" + synonyms: [cpu threads, top threads, 热点线程] + key_columns: + ts: "采样时间戳(微秒)" + platform: "linux | macos" + tid: "线程 id" + comm: "线程名" + state: "线程状态" + wchan: "内核 wait channel" + wall_ns: "采样间隔(纳秒)" + delta_user_ns: "用户态 CPU 增量(纳秒)" + delta_sys_ns: "内核态 CPU 增量(纳秒)" + delta_total_ns: "总 CPU 增量(纳秒)" gpu.utilization: - description: "GPU memory and utilization samples" + description: "GPU 显存与利用率周期采样" synonyms: [gpu memory, VRAM, 显存, GPU利用率] key_columns: - ts: "Sample timestamp" - used_bytes: "Device memory used" - total_bytes: "Device memory total" - mem_used_pct: "Memory used percentage" - gpu_util_pct: "GPU compute utilization (-1 if unavailable)" + ts: "采样时间戳(微秒)" + backend: "cuda | mps | …" + device_id: "设备序号" + name: "设备名称" + memory_model: "显存型号描述" + chip: "芯片/架构标识" + free_bytes: "空闲显存(字节)" + total_bytes: "总显存(字节)" + used_bytes: "已用显存(字节)" + mem_used_pct: "显存使用率(%)" + gpu_util_pct: "GPU 计算利用率(%;不可用为 -1)" + mem_controller_util_pct: "显存控制器利用率(NVIDIA)" + renderer_util_pct: "渲染引擎利用率(Apple MPS)" + tiler_util_pct: "Tiler 利用率(Apple MPS)" + driver_mem_bytes: "驱动保留显存(字节)" + wall_ns: "采样间隔(纳秒)" process.kmsg: - description: "Linux kernel ring buffer (dmesg) — OOM killer, GPU Xid, IB errors" + description: "Linux 内核 ring buffer(dmesg)— OOM、GPU Xid、IB 错误等" synonyms: [kernel log, dmesg, OOM, 内核日志] platform: linux key_columns: - timestamp: "Event time" - level: "Log level" - message: "Kernel message text" + timestamp: "内核日志时间(微秒)" + facility: "syslog facility" + level: "日志级别" + message: "内核消息正文" cluster.nodes: - description: "Registered distributed training peers" + description: "已注册的分布式训练 peer(torchrun / PUT /apis/nodes)" synonyms: [cluster, nodes, ranks, 集群节点] key_columns: - rank: "Global rank" - host: "Hostname" - addr: "probing HTTP address" - role: "Parallel role key, e.g. dp=2,pp=1,tp=0 (federation _role source)" - status: "Node status" + host: "主机名" + addr: "probing HTTP 地址" + local_rank: "节点内 local rank" + rank: "全局 rank" + world_size: "world size" + group_rank: "进程组 rank" + group_world_size: "进程组大小" + role_name: "Torchrun / Elastic role 名" + role_rank: "role 内 rank" + role_world_size: "role 内 world size" + role: "并行角色键(federation _role 来源)" + status: "节点状态" + timestamp: "最后更新时间(微秒)" + + rdma.mlx_hca: + description: "InfiniBand/RoCE HCA 端口计数器快照(/sys/class/infiniband)" + synonyms: [IB, RoCE, HCA, mlx, RDMA counters] + platform: linux + key_columns: + hca_name: "HCA 设备名(如 mlx5_0)" + port_rcv_packets: "端口接收包计数" + port_rcv_data: "端口接收字节计数" + port_xmit_packets: "端口发送包计数" + port_xmit_data: "端口发送字节计数" + link_downed: "链路 down 次数" + np_cnp_sent: "RoCE CNP 发送计数" + np_ecn_marked_roce_packets: "ECN 标记 RoCE 包计数" + rcv_pkts_rate: "接收包速率(包/秒,两次采样差分)" + snd_pkts_rate: "发送包速率(包/秒)" + notes: + - "需配置 rdma.hca.name 与 rdma.sample.rate" nccl.proxy_ops: - description: "NCCL profiler plugin proxy-op wait decomposition (culprit vs victim)" + description: "NCCL profiler plugin proxy-op wait 分解(culprit / victim 归因)" synonyms: [NCCL proxy, send_gpu_wait, recv_wait, culprit, victim, proxy wait] key_columns: - ts: "Event timestamp (nanoseconds)" + ts: "事件时间戳(纳秒)" rank: "torch.distributed rank" - tp_rank: "Tensor-parallel rank (-1 if unknown)" - pp_rank: "Pipeline-parallel rank (-1 if unknown)" - dp_rank: "Data-parallel rank (-1 if unknown)" + tp_rank: "张量并行 rank(未知 -1)" + pp_rank: "流水线并行 rank(未知 -1)" + dp_rank: "数据并行 rank(未知 -1)" comm_hash: "NCCL communicator hash" - coll_func: "Collective name (AllReduce, AllGather, …)" - seq: "Collective sequence number" + coll_func: "集合通信名(AllReduce、AllGather…)" + seq: "collective 序号" channel_id: "NCCL channel id" - peer: "Peer rank for this proxy op" - is_send: "1 if send proxy, 0 if recv" - n_steps: "ProxyStep count aggregated" - trans_bytes: "Bytes transferred" - send_gpu_wait_ns: "Culprit signal — local GPU not ready to send" - send_wait_ns: "Send-side network wait" - recv_wait_ns: "Victim signal — waiting on peer data" - recv_flush_wait_ns: "Recv flush wait" + peer: "对端 rank" + is_send: "1=send proxy,0=recv proxy" + n_steps: "聚合的 ProxyStep 数" + trans_bytes: "传输字节数" + send_gpu_wait_ns: "Culprit 信号 — 本地 GPU 未就绪发送" + send_wait_ns: "发送侧网络等待" + recv_wait_ns: "Victim 信号 — 等待对端数据" + recv_flush_wait_ns: "接收 flush 等待" global_name: global.nccl.proxy_ops federation_columns: [_host, _addr, _rank, _role] notes: - - "Requires NCCL_PROFILER_PLUGIN + NCCL ≥ 2.26 (mask 26: Coll|ProxyOp|ProxyStep)" - - "macOS dev: PROBING_NCCL_MOCK=1 seeds synthetic culprit/victim pattern" + - "需 NCCL_PROFILER_PLUGIN + NCCL ≥ 2.26(mask Coll|ProxyOp|ProxyStep)" + - "macOS 开发:PROBING_NCCL_MOCK=1 可注入合成数据" nccl.net_qp: - description: "NCCL NetPlugin IB QP completion timing (optional mask bit 128)" + description: "NCCL NetPlugin IB QP 完成耗时(可选 mask bit 128)" synonyms: [IB QP, RoCE, net plugin, qp duration] key_columns: - ts: "Event timestamp (nanoseconds)" + ts: "事件时间戳(纳秒)" rank: "torch.distributed rank" - device: "IB device index" - qp_num: "Queue pair number" - wr_id: "Work request id" + device: "IB 设备索引" + qp_num: "Queue Pair 号" + wr_id: "Work Request id" opcode: "IB opcode" - length: "Transfer length" - duration_ns: "QP completion duration" + length: "传输长度(字节)" + duration_ns: "QP 完成耗时(纳秒)" global_name: global.nccl.net_qp federation_columns: [_host, _addr, _rank, _role] + + hccl.host_ops: + description: "HCCL MSProf Host API 时间线(集合通信 op、ACL、task master/slave)" + synonyms: [HCCL host, MSProf API, host_hccl_op, task_master] + key_columns: + ts: "结束时间(CANN sys cycle)" + begin_ns: "开始时间" + end_ns: "结束时间" + duration_ns: "耗时 end - begin" + thread_id: "上报线程 id" + level: "MSProf level" + type_id: "MSProf type id" + item_id: "名称 hash" + item_name: "解码名称(hcom_allReduce_、Memcpy 等)" + event_class: "host_hccl_op | task_master | task_slave | host_acl | node_launch" + aging: "MSProf aging flag" + notes: + - "来源:libprofapi.so shim 拦截 MsprofReportApi" + - "需开启 CANN/HCCL profiling" + + hccl.collectives: + description: "HCCL 集合通信元数据与 Host 耗时(row_source 区分 api/compact 行)" + synonyms: [HCCL collective, hcom_allReduce, HcclOpInfo] + key_columns: + ts: "事件时间" + thread_id: "上报线程" + row_source: "api=耗时行;compact=count/group/alg 参数行" + begin_ns: "开始时间(api 行)" + end_ns: "结束时间(api 行)" + duration_ns: "耗时(api 行)" + op_hash: "算子名 hash(api 行)" + op_name: "算子名(api 行)" + group_hash: "comm group 名 hash(compact 行)" + alg_hash: "算法名 hash(compact 行)" + count: "元素个数(compact 行)" + data_type: "HcclDataType enum" + relay: "HCCL relay 标志" + retry: "HCCL retry 计数" + compact_type: "MsprofReportCompactInfo type id" + notes: + - "api 与 compact 行按 thread_id + 邻近 ts JOIN" + + hccl.tasks: + description: "HCCL 设备侧 task 明细(MsprofHcclInfo L1)" + synonyms: [HCCL task, RDMASend, Memcpy, Notify_Wait] + key_columns: + ts: "上报时间戳" + thread_id: "上报线程" + info_type: "AdditionalInfo type id" + info_level: "AdditionalInfo level" + info_type_name: "RegTypeInfo 注册名(如 hccl_info)" + item_id: "task 类型 hash" + task_name: "task 类型名(Memcpy、RDMASend…)" + ccl_tag: "CCL tag hash" + group_name: "comm group hash" + local_rank: "本端 rank" + remote_rank: "对端 rank(-1 表示 N/A)" + rank_size: "通信组大小" + workflow_mode: "HCCL workflow mode enum" + plane_id: "原始 plane 编码" + plane_index: "plane 索引(plane_id bits 28-31)" + rank_in_plane: "plane 内 rank(bits 0-15)" + rank_size_plane: "plane 宽度(bits 16-27)" + ctx_id: "FFTS context id" + notify_id: "notify 对象 id" + stage: "流水线 stage" + role: "task 角色 enum" + data_size: "传输字节数" + op_type: "操作类型 enum" + data_type: "数据类型 enum" + link_type: "链路类型" + transport_type: "传输类型" + rdma_type: "RDMA 类型" + duration_est_us: "HCCL 估算耗时(微秒)" + payload_len: "原始 AdditionalInfo payload 长度" + global_name: global.hccl.tasks + federation_columns: [_host, _addr, _rank, _role] + + hccl.mc2_streams: + description: "HCCL MC2 communicator stream 拓扑快照" + synonyms: [MC2, comm streams, aicpu stream] + key_columns: + ts: "上报时间" + thread_id: "上报线程" + info_type: "MSProf AdditionalInfo type" + group_hash: "comm group 名 hash" + rank_size: "组内 rank 数" + rank_id: "rank id" + usr_rank_id: "用户可见 rank id" + aicpu_kfc_stream_id: "KFC stream id" + comm_stream_size: "comm stream 数量" + comm_stream_ids: "逗号分隔的 stream id 列表" + + hccl.context_ids: + description: "HCCL FFTS context id 范围(dispatch 时上报)" + synonyms: [FFTS context, context id] + key_columns: + ts: "上报时间" + thread_id: "上报线程" + info_type: "MSProf AdditionalInfo type" + ctx_id_num: "context 数量(HCCL 固定报 2)" + ctx_id_min: "范围起点(通常 0)" + ctx_id_max: "范围终点(ctxIdMax)" diff --git a/src/lib.rs b/src/lib.rs index 9367759e..13e695b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ use anyhow::Result; use pyo3::prelude::*; use probing_core::{install_panic_hook, register_python_main_thread}; -use probing_python::extensions::python::ExternalTable; +use probing_python::extensions::python::{register_table_docs, ExternalTable}; use probing_python::features::config; use probing_python::features::python_api::{cli_main, query_json}; use probing_python::features::tracing; @@ -148,6 +148,14 @@ fn cleanup() { } } +/// Start the in-process engine and local query server (same as normal `PROBING=1` startup). +/// +/// Used when `PROBING_CLI_MODE=1` skipped the `#[ctor]` hook so docs can be registered first. +#[pyfunction] +fn start_local() { + probing_server::start_local(); +} + /// Python module entry point - exported as probing._core #[pymodule(gil_used = true)] fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -165,6 +173,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { // Register all classes m.add_class::()?; + m.add_function(wrap_pyfunction!(register_table_docs, m)?)?; m.add_class::()?; // Register all functions @@ -182,6 +191,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { use probing_python::features::python_api::{is_enabled, should_enable_probing}; m.add_function(wrap_pyfunction!(is_enabled, m)?)?; m.add_function(wrap_pyfunction!(should_enable_probing, m)?)?; + m.add_function(wrap_pyfunction!(start_local, m)?)?; // Register config functions directly to the module (flattened) config::register_config_functions(m)?; diff --git a/tests/regression/core/test_table_docs.py b/tests/regression/core/test_table_docs.py new file mode 100644 index 00000000..98af4268 --- /dev/null +++ b/tests/regression/core/test_table_docs.py @@ -0,0 +1,83 @@ +"""Tests for code-first table documentation (@table + register_table_docs).""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import probing +from probing.core.table import ( + _column_docs_from_class, + _table_doc_from_class, +) + + +def test_table_doc_from_class_first_line(): + @dataclass + class Demo: + """First line summary. + + More details ignored. + """ + + x: int + + assert _table_doc_from_class(Demo) == "First line summary." + + +def test_table_doc_from_class_missing(): + @dataclass + class NoDoc: + x: int + + assert _table_doc_from_class(NoDoc) is None + + +def test_column_docs_from_field_metadata(): + @dataclass + class Demo: + x: int = field(metadata={"doc": "X coordinate"}) + y: int = field(metadata={"other": "ignored"}) + + assert _column_docs_from_class(Demo) == {"x": "X coordinate"} + + +def test_table_decorator_registers_docs(monkeypatch): + import importlib + + table_mod = importlib.import_module("probing.core.table") + table_mod.cache.clear() + table_name = f"decorated_doc_{id(object())}" + captured: dict = {} + + def capture_register(qualified, table_doc, column_docs): + captured["qualified"] = qualified + captured["table_doc"] = table_doc + captured["column_docs"] = column_docs or {} + return probing._core.register_table_docs(qualified, table_doc, column_docs) + + monkeypatch.setattr(probing, "register_table_docs", capture_register) + + @table_mod.table(table_name) + @dataclass + class DecoratedMetrics: + """Decorated metrics table.""" + + latency_ms: float = field(metadata={"doc": "latency milliseconds"}) + + assert captured["qualified"] == f"python.{table_name}" + assert captured["table_doc"] == "Decorated metrics table." + assert captured["column_docs"]["latency_ms"] == "latency milliseconds" + + DecoratedMetrics.drop() + table_mod.cache.clear() + + +def test_builtin_hccl_docs_in_engine_catalog(): + """HCCL code-first docs are baked into the semantic catalog at engine build.""" + df = probing.query( + "SELECT description FROM probe.probing.column_docs " + "WHERE table_schema = 'hccl' AND table_name = 'tasks' " + "AND column_name = 'task_name'" + ) + assert len(df) == 1 + assert "Memcpy" in str(df["description"].iloc[0]) diff --git a/tests/regression/core/test_table_docs_integration.py b/tests/regression/core/test_table_docs_integration.py new file mode 100644 index 00000000..641ab31e --- /dev/null +++ b/tests/regression/core/test_table_docs_integration.py @@ -0,0 +1,198 @@ +"""Integration tests for code-first table docs through the live probing engine.""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +import probing + + +def _project_root() -> Path: + return Path(__file__).resolve().parents[3] + + +def _python_path_env(*, defer_engine_init: bool = False) -> dict[str, str]: + env = os.environ.copy() + python_dir = str(_project_root() / "python") + env["PYTHONPATH"] = ( + f"{python_dir}:{env['PYTHONPATH']}" if env.get("PYTHONPATH") else python_dir + ) + env["PROBING"] = "1" + if defer_engine_init: + env["PROBING_CLI_MODE"] = "1" + else: + env.pop("PROBING_CLI_MODE", None) + return env + + +def _run_fresh_probing_script( + body: str, *, defer_engine_init: bool = False +) -> subprocess.CompletedProcess[str]: + cli_mode_line = ( + 'os.environ["PROBING_CLI_MODE"] = "1"' + if defer_engine_init + else 'os.environ.pop("PROBING_CLI_MODE", None)' + ) + script = f""" +import json +import os +import sys +import tempfile + +sys.path.insert(0, {repr(str(_project_root() / "python"))}) +os.environ["PROBING"] = "1" +{cli_mode_line} +os.environ["PROBING_DATA_DIR"] = tempfile.mkdtemp(prefix="probing_doc_it_") + +{body} +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as handle: + handle.write(script) + script_path = handle.name + try: + return subprocess.run( + [sys.executable, script_path], + env=_python_path_env(defer_engine_init=defer_engine_init), + capture_output=True, + text=True, + timeout=60, + ) + finally: + os.unlink(script_path) + + +def test_register_table_docs_visible_in_catalog_subprocess(): + """Fresh process: register_table_docs before first query lands in column_docs.""" + table = f"manual_reg_{os.getpid()}" + result = _run_fresh_probing_script( + f""" +from probing import _core +table = "{table}" +_core.register_table_docs( + f"python.{{table}}", + "Manual registration table", + {{"payload": "json payload column"}}, +) +_core.start_local() + +import probing +ext = probing.ExternalTable.get_or_create(table, ["payload"]) +ext.append(["{{}}"]) + +df = probing.query( + "SELECT description FROM probe.probing.column_docs " + f"WHERE table_schema = 'python' AND table_name = '{{table}}' " + "AND column_name = 'payload'" +) +assert len(df) == 1 +assert "json payload" in str(df["description"].iloc[0]) +tbl = probing.query( + "SELECT description FROM probe.probing.table_docs " + f"WHERE table_schema = 'python' AND table_name = '{{table}}'" +) +assert len(tbl) == 1 +assert "Manual registration" in str(tbl["description"].iloc[0]) +print("OK") +""", + defer_engine_init=True, + ) + assert result.returncode == 0, result.stdout + result.stderr + assert "OK" in result.stdout + + +def test_table_decorator_docs_visible_in_catalog_subprocess(): + """@table metadata is registered before the engine catalog is built.""" + result = _run_fresh_probing_script( + """ +from dataclasses import dataclass, field +import importlib + +table_mod = importlib.import_module("probing.core.table") + +@table_mod.table("subproc_metrics") +@dataclass +class Metrics: + \"\"\"Subprocess metrics table.\"\"\" + score: float = field(metadata={"doc": "model score"}) + rank: int = field(metadata={"doc": "process rank"}) + +Metrics.append(Metrics(0.75, 2)) + +from probing import _core +_core.start_local() + +import probing +score = probing.query( + "SELECT description FROM probe.probing.column_docs " + "WHERE table_schema = 'python' AND table_name = 'subproc_metrics' " + "AND column_name = 'score'" +) +assert len(score) == 1 +assert "model score" in str(score["description"].iloc[0]) +tbl = probing.query( + "SELECT description FROM probe.probing.table_docs " + "WHERE table_schema = 'python' AND table_name = 'subproc_metrics'" +) +assert len(tbl) == 1 +assert "Subprocess metrics" in str(tbl["description"].iloc[0]) +print("OK") +""", + defer_engine_init=True, + ) + assert result.returncode == 0, result.stdout + result.stderr + assert "OK" in result.stdout + + +def test_describe_builtin_hccl_via_query(): + df = probing.query("DESCRIBE probe.probing.column_docs") + assert "comment" in df.columns + assert "table_comment" in df.columns + assert "column_name" in df.columns + assert "description" in df["column_name"].tolist() + + +def test_hccl_catalog_and_select_roundtrip(): + """Builtin HCCL docs are queryable; YAML synonyms remain on table_docs.""" + col = probing.query( + "SELECT description FROM probe.probing.column_docs " + "WHERE table_schema = 'hccl' AND table_name = 'tasks' " + "AND column_name = 'local_rank'" + ) + assert len(col) == 1 + assert "rank" in str(col["description"].iloc[0]).lower() + + meta = probing.query( + "SELECT description, synonyms FROM probe.probing.table_docs " + "WHERE table_schema = 'nccl' AND table_name = 'proxy_ops'" + ) + assert len(meta) == 1 + assert "culprit" in str(meta["description"].iloc[0]).lower() or "NCCL" in str( + meta["description"].iloc[0] + ) + assert "proxy" in str(meta["synonyms"].iloc[0]).lower() + + +def test_describe_json_shape_subprocess(): + """Sanity-check DESCRIBE rewrite columns in an isolated process.""" + result = _run_fresh_probing_script( + """ +import json +import probing +df = probing.query("DESCRIBE probe.probing.table_docs") +payload = { + "columns": list(df.columns), + "rows": len(df), +} +print(json.dumps(payload)) +""" + ) + assert result.returncode == 0, result.stdout + result.stderr + payload = json.loads(result.stdout.strip().splitlines()[-1]) + assert payload["rows"] > 0 + assert "comment" in payload["columns"] + assert "table_comment" in payload["columns"] diff --git a/tests/regression/ext/test_comm_collective.py b/tests/regression/ext/test_comm_collective.py index eba1a7e1..0e56e943 100644 --- a/tests/regression/ext/test_comm_collective.py +++ b/tests/regression/ext/test_comm_collective.py @@ -96,5 +96,4 @@ def test_comm_lite_writes_trace_event(): ] by_type = {row["record_type"]: row for row in rows} assert by_type["span_start"]["name"] == "all_reduce" - assert by_type["span_start"]["kind"] == "comm.all_reduce" assert by_type["span_end"]["span_id"] == by_type["span_start"]["span_id"] diff --git a/tests/regression/ext/test_parallel_topology.py b/tests/regression/ext/test_parallel_topology.py index 6eced3b7..213a2210 100644 --- a/tests/regression/ext/test_parallel_topology.py +++ b/tests/regression/ext/test_parallel_topology.py @@ -58,8 +58,8 @@ def test_span_includes_parallel_fields(monkeypatch): assert attrs["dp_rank"] == 5 -def test_comm_kind(): - from probing.tracing import comm_kind +def test_comm_label(): + from probing.profiling.collective.record import _comm_label - assert comm_kind("all_reduce") == "comm.all_reduce" - assert comm_kind("comm.broadcast") == "comm.broadcast" + assert _comm_label("all_reduce") == "comm.all_reduce" + assert _comm_label("comm.broadcast") == "comm.broadcast" diff --git a/tests/regression/ext/test_phase_tracker.py b/tests/regression/ext/test_phase_tracker.py new file mode 100644 index 00000000..84618d13 --- /dev/null +++ b/tests/regression/ext/test_phase_tracker.py @@ -0,0 +1,128 @@ +"""Phase tracker: model + optimizer hooks drive training spans.""" + +from __future__ import annotations + +import dataclasses + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +import probing + + +@pytest.fixture(autouse=True) +def _reset_trace_and_phases(): + from probing.profiling import phase_tracker + from probing.tracing import TraceEvent, bind_table, reset_backends + from probing.tracing.phases import reset_phase + + for (_mid, _oid), tracker in list(phase_tracker._REGISTRY.items()): + phase_tracker.detach_training_phases(tracker.model, tracker.optimizer) + phase_tracker._REGISTRY.clear() + reset_phase() + + try: + TraceEvent.drop() + except Exception: + pass + TraceEvent.init_table() + reset_backends(clear_registered=True) + bind_table(TraceEvent) + yield + reset_backends(clear_registered=True) + + +def _trace_rows(n: int = 100) -> list[dict]: + from probing.tracing import TraceEvent + + fields = [f.name for f in dataclasses.fields(TraceEvent)] + return [dict(zip(fields, data)) for _ts, data in TraceEvent.take(n)] + + +def _closed_span_names(rows: list[dict]) -> list[str]: + starts = { + r["span_id"]: r["name"] for r in rows if r.get("record_type") == "span_start" + } + ends = {r["span_id"] for r in rows if r.get("record_type") == "span_end"} + return [starts[sid] for sid in ends if sid in starts] + + +class TinyNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(4, 2) + + def forward(self, x): + return self.fc(x) + + +def test_phase_hooks_emit_training_spans(): + model = TinyNet() + opt = torch.optim.SGD(model.parameters(), lr=0.01) + probing.attach_training_phases(model, opt) + + x = torch.randn(2, 4) + y = torch.tensor([0, 1]) + out = model(x) + loss = F.cross_entropy(out, y) + loss.backward() + opt.step() + + names = _closed_span_names(_trace_rows()) + assert "forward" in names + assert "backward" in names + assert "optimizer" in names + assert "train.step" in names + + +def test_phase_advances_on_optimizer_step(): + probing.step(0) + model = TinyNet() + opt = torch.optim.SGD(model.parameters(), lr=0.01) + probing.attach_training_phases(model, opt) + + x = torch.randn(2, 4) + y = torch.tensor([0, 1]) + out = model(x) + loss = F.cross_entropy(out, y) + loss.backward() + opt.step() + + assert probing.step.micro_step == 1 + + +def test_phase_idle_outside_training(): + assert probing.phase() == "idle" + + +def test_span_phase_drives_state_without_hooks(): + from probing.tracing.phases import FORWARD + + with probing.span("forward", phase=FORWARD): + assert probing.phase() == "forward" + assert probing.phase() == "idle" + + +def test_manual_forward_span_suppresses_hook_duplicate(): + model = TinyNet() + opt = torch.optim.SGD(model.parameters(), lr=0.01) + probing.attach_training_phases(model, opt) + + with probing.span("batch"): + x = torch.randn(2, 4) + y = torch.tensor([0, 1]) + with probing.span("forward"): + out = model(x) + loss = F.cross_entropy(out, y) + loss.backward() + opt.step() + + rows = _trace_rows() + forward_starts = [ + r + for r in rows + if r.get("record_type") == "span_start" and r.get("name") == "forward" + ] + assert len(forward_starts) == 1 diff --git a/tests/regression/ext/test_step_context.py b/tests/regression/ext/test_step_context.py index 7b7374ec..0a02ff2a 100644 --- a/tests/regression/ext/test_step_context.py +++ b/tests/regression/ext/test_step_context.py @@ -1,57 +1,57 @@ import pytest import probing -from probing.tracing import ( - TRAIN_STEP_KIND, - advance_local_step, - set_step_bucket_size, - step_snapshot, - sync_local_step, -) +from probing.tracing.phases import OPTIMIZER, reset_phase @pytest.fixture(autouse=True) def reset_step_context(): - sync_local_step(0) - set_step_bucket_size(1) + reset_phase() + probing.step(0) + probing.step(micro_batches=1) yield - sync_local_step(0) - set_step_bucket_size(1) + probing.step(0) + probing.step(micro_batches=1) -def test_step_snapshot_and_bucket(): - set_step_bucket_size(10) - sync_local_step(0) - snap = step_snapshot() +def test_step_snapshot_and_micro_batches(): + probing.step(0) + probing.step(micro_batches=10) + snap = probing.step.snapshot() + assert snap.micro_step == 0 assert snap.local_step == 0 assert snap.global_step == 0 - assert snap.bucket_size == 10 + assert snap.micro_batches == 10 - sync_local_step(15) - snap = step_snapshot() - assert snap.local_step == 15 - assert snap.global_step == 1 + probing.step(15) + assert probing.step.micro_step == 15 + assert probing.step.local_step == 1 + assert probing.step.global_step == 1 -def test_advance_local_step(): - sync_local_step(0) - snap = advance_local_step() - assert snap.local_step == 1 - assert snap.global_step == 1 +def test_advance_micro_step(): + probing.step(0) + probing.step() + assert probing.step.micro_step == 1 + assert probing.step.local_step == 1 + assert probing.step.global_step == 1 -def test_train_step_span_injects_coordinates(): - with probing.span("batch", kind=TRAIN_STEP_KIND) as s: - assert s.kind == TRAIN_STEP_KIND +def test_optimizer_span_injects_coordinates(): + with probing.span("step", phase=OPTIMIZER) as s: + assert s.phase == OPTIMIZER attrs = dict(s.get_attributes()) + assert attrs["micro_step"] == 0 assert attrs["local_step"] == 0 assert attrs["global_step"] == 0 assert attrs["source"] == "manual" -def test_nested_train_step_is_reentrant(): - sync_local_step(3) - with probing.span("outer", kind=TRAIN_STEP_KIND) as outer: - with probing.span("inner", kind=TRAIN_STEP_KIND) as inner: +def test_nested_optimizer_span_is_reentrant(): + probing.step(3) + with probing.span("outer", phase=OPTIMIZER) as outer: + with probing.span("inner", phase=OPTIMIZER) as inner: assert inner.span_id == outer.span_id - assert step_snapshot().local_step == 4 + assert probing.step.micro_step == 4 + assert probing.step.local_step == 4 + assert probing.step.global_step == 4 diff --git a/tests/regression/ext/test_tracing_span.py b/tests/regression/ext/test_tracing_span.py index 543275b2..ac428306 100644 --- a/tests/regression/ext/test_tracing_span.py +++ b/tests/regression/ext/test_tracing_span.py @@ -1,3 +1,4 @@ +import dataclasses import time import pytest @@ -5,6 +6,42 @@ import probing +@pytest.fixture(autouse=True) +def _reset_trace_event_table(): + """Isolate memtable rows so persistence tests are deterministic.""" + from probing.tracing import TraceEvent, bind_table, reset_backends + from probing.tracing.phases import reset_phase + + reset_phase() + + try: + TraceEvent.drop() + except Exception: + pass + TraceEvent.init_table() + reset_backends(clear_registered=True) + bind_table(TraceEvent) + yield + reset_backends(clear_registered=True) + + +def _trace_event_rows(n: int = 50) -> list[dict]: + from probing.tracing import TraceEvent + + fields = [f.name for f in dataclasses.fields(TraceEvent)] + return [dict(zip(fields, data)) for _ts, data in TraceEvent.take(n)] + + +def _span_duration_ns(rows: list[dict], span_id: int) -> int | None: + starts = {r["span_id"]: r for r in rows if r.get("record_type") == "span_start"} + ends = {r["span_id"]: r for r in rows if r.get("record_type") == "span_end"} + start = starts.get(span_id) + end = ends.get(span_id) + if start is None or end is None: + return None + return int(end["time"]) - int(start["time"]) + + def test_context_manager_basic(): with probing.span("root") as s: assert s.name == "root" @@ -56,18 +93,17 @@ def test_current_span_stack_behavior(): def test_property_immutability(): - with probing.span("immutable", kind="op") as s: + with probing.span("immutable", phase="forward") as s: original_id = s.span_id with pytest.raises(AttributeError): - s.name = "changed" # should not allow reassignment + s.name = "changed" with pytest.raises(AttributeError): - s.kind = "other" + s.phase = "other" with pytest.raises(AttributeError): s.span_id = 123 - # ensure values unchanged assert s.span_id == original_id assert s.name == "immutable" - assert s.kind == "op" + assert s.phase == "forward" def test_events_recording(): @@ -80,6 +116,70 @@ def test_events_recording(): assert events[1]["name"] == "e2" assert events[1]["attributes"]["k"] == "v" + rows = _trace_event_rows() + event_rows = [r for r in rows if r.get("record_type") == "event"] + assert len(event_rows) == 2 + assert {r["name"] for r in event_rows} == {"e1", "e2"} + + +def test_span_persists_start_end_pair_to_trace_event(): + with probing.span("persist_me", phase="forward") as s: + span_id = s.span_id + trace_id = s.trace_id + + rows = _trace_event_rows() + starts = [r for r in rows if r.get("record_type") == "span_start"] + ends = [r for r in rows if r.get("record_type") == "span_end"] + assert len(starts) == 1 + assert len(ends) == 1 + assert starts[0]["span_id"] == span_id + assert starts[0]["trace_id"] == trace_id + assert starts[0]["name"] == "persist_me" + assert starts[0]["phase"] == "forward" + assert ends[0]["span_id"] == span_id + duration_ns = _span_duration_ns(rows, span_id) + assert duration_ns is not None + assert duration_ns >= 0 + + +def test_nested_spans_persist_parent_links(): + with probing.span("parent") as parent: + with probing.span("child") as child: + child_id = child.span_id + parent_id = parent.span_id + + rows = _trace_event_rows() + child_start = next( + r + for r in rows + if r.get("record_type") == "span_start" and r.get("span_id") == child_id + ) + assert child_start["parent_id"] == parent_id + + +def test_decorator_persists_trace_event_rows(): + @probing.span("decor_persist") + def work(): + return 7 + + assert work() == 7 + rows = _trace_event_rows() + assert any( + r.get("record_type") == "span_start" and r.get("name") == "decor_persist" + for r in rows + ) + + +def test_manual_span_without_recorded_wrapper_does_not_persist(): + """Low-level ``Span`` is stack-only; integrators should use ``probing.span``.""" + from probing.tracing import Span + + parent = Span("manual_parent") + child = Span.new_child(parent, "manual_child") + child.end() + rows = _trace_event_rows() + assert rows == [] + def test_status_and_duration(): with probing.span("timed") as s: @@ -119,22 +219,10 @@ def test_manual_construction_and_child(): assert child.is_ended -def test_access_nonexistent_attribute_raises(): - with probing.span("attr") as s: - with pytest.raises(AttributeError): - _ = s.not_exist_field - - -# Ensure add_attr isn't exposed (immutability guarantee) -def test_no_add_attr_method(): - with probing.span("no_add") as s: - assert not hasattr(s, "add_attr") - assert not hasattr(s, "add_attr") - - def test_add_event_module_function(): """Test add_event module-level function.""" with probing.span("test_add_event") as s: + span_id = s.span_id probing.event("event1") probing.event("event2", attributes=[{"key": "value"}]) @@ -144,6 +232,28 @@ def test_add_event_module_function(): assert events[1]["name"] == "event2" assert events[1]["attributes"]["key"] == "value" + rows = _trace_event_rows() + persisted = [ + r + for r in rows + if r.get("record_type") == "event" and r.get("span_id") == span_id + ] + assert len(persisted) == 2 + assert {r["name"] for r in persisted} == {"event1", "event2"} + + +def test_access_nonexistent_attribute_raises(): + with probing.span("attr") as s: + with pytest.raises(AttributeError): + _ = s.not_exist_field + + +# Ensure add_attr isn't exposed (immutability guarantee) +def test_no_add_attr_method(): + with probing.span("no_add") as s: + assert not hasattr(s, "add_attr") + assert not hasattr(s, "add_attr") + def test_add_event_no_active_span(): """Test add_event raises error when no active span.""" @@ -156,32 +266,54 @@ def test_add_event_no_active_span(): probing.event("should_fail") -def test_event_inside_train_step_with_nested_spans(): - """Regression: batch train.step + nested spans + event (imagenet pattern).""" - from probing.tracing import TRAIN_STEP_KIND +def test_phase_inferred_from_name(): + from probing.tracing.phases import BACKWARD, FORWARD, OPTIMIZER + + with probing.span("forward") as s: + assert s.phase == FORWARD + with probing.span("step") as s: + assert s.phase == OPTIMIZER + with probing.span("custom.op") as s: + assert not s.phase + + +def test_explicit_phase_on_span(): + from probing.tracing.phases import BACKWARD + + with probing.span("compute", phase=BACKWARD) as s: + assert s.phase == BACKWARD + + +def test_inferred_phase_persists_to_trace_event(): + with probing.span("forward"): + pass + rows = _trace_event_rows() + start = next(r for r in rows if r.get("record_type") == "span_start") + assert start["phase"] == "forward" + + +def test_record_span_without_training_phase(): + probing.record_span("train.step", duration_ns=1_000_000) + rows = _trace_event_rows() + start = next( + r + for r in rows + if r.get("record_type") == "span_start" and r.get("name") == "train.step" + ) + assert start["phase"] in ("", None) or not start["phase"] + - with probing.span("batch", kind=TRAIN_STEP_KIND): - with probing.span("forward", kind="nn.forward"): - pass - with probing.span("loss", kind="compute"): - pass - with probing.span("backward", kind="nn.backward"): - pass - with probing.span("step", kind="optim.step"): - pass - probing.event( - "batch.stats", - attributes=[{"i": 0}, {"loss": 1.0}], - ) +def test_event_inside_training_phases_with_nested_spans(): + with probing.span("optimizer", phase="optimizer"): + probing.event("batch.stats", attributes=[{"i": 0}, {"loss": 1.0}]) -def test_train_step_reentrant_torch_probe_does_not_close_manual_span(): - """TorchProbe reentrant train.step must not end the outer span on step hook.""" +def test_optimizer_span_reentrant_with_torch_probe(): from probing.profiling.torch_probe import TorchProbe, TorchProbeConfig - from probing.tracing import TRAIN_STEP_KIND + from probing.tracing.phases import OPTIMIZER tracer = TorchProbe(config=TorchProbeConfig(enabled=True)) - with probing.span("batch", kind=TRAIN_STEP_KIND) as outer: + with probing.span("outer", phase=OPTIMIZER) as outer: tracer._begin_train_step_span() assert not outer.is_ended tracer._end_train_step_span() @@ -191,11 +323,10 @@ def test_train_step_reentrant_torch_probe_does_not_close_manual_span(): def test_post_step_hook_does_not_reset_local_step_across_batches(): - """Regression: stale curr_step must not sync_local_step back to 0/1.""" from probing.profiling.torch_probe import TorchProbe, TorchProbeConfig - from probing.tracing import TRAIN_STEP_KIND, current_local_step, sync_local_step + from probing.tracing.phases import OPTIMIZER - sync_local_step(0) + probing.step(0) tracer = TorchProbe(config=TorchProbeConfig(enabled=True)) tracer.finalized = True @@ -205,6 +336,6 @@ class FakeOpt: opt = FakeOpt() for expected in range(1, 6): - with probing.span("batch", kind=TRAIN_STEP_KIND): + with probing.span("step", phase=OPTIMIZER): tracer.post_step_hook(opt, (), {}) - assert current_local_step() == expected + assert probing.step.micro_step == expected diff --git a/tests/regression/rust/Cargo.toml b/tests/regression/rust/Cargo.toml index 48c8201f..3d13dd66 100644 --- a/tests/regression/rust/Cargo.toml +++ b/tests/regression/rust/Cargo.toml @@ -13,7 +13,7 @@ path = "src/lib.rs" pyo3-build-config = "0.29.0" [dependencies] -probing-core = { path = "../../../probing/core" } +probing-core = { path = "../../../probing/core", features = ["test-utils"] } probing-server = { path = "../../../probing/server", default-features = false } probing-memtable = { path = "../../../probing/memtable" } probing-cli = { path = "../../../probing/cli" } @@ -45,6 +45,10 @@ tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } name = "core_federation" path = "probing/core/federation_tests.rs" +[[test]] +name = "core_federation_explain" +path = "probing/core/federation_explain_tests.rs" + [[test]] name = "core_engine_complex" path = "probing/core/engine_complex_tests.rs" @@ -57,6 +61,10 @@ path = "probing/core/extension_routing_spec.rs" name = "core_distributed_storage" path = "probing/core/distributed_storage_integration.rs" +[[test]] +name = "core_table_docs_integration" +path = "probing/core/table_docs_integration.rs" + [[test]] name = "server_training_observability" path = "probing/server/training_observability_tests.rs" diff --git a/tests/regression/rust/probing/core/federation_explain_tests.rs b/tests/regression/rust/probing/core/federation_explain_tests.rs new file mode 100644 index 00000000..9e1e6ec4 --- /dev/null +++ b/tests/regression/rust/probing/core/federation_explain_tests.rs @@ -0,0 +1,222 @@ +//! Integration tests: federated routing + EXPLAIN plan shape (§4.2 path A/B/C). + +use std::sync::Arc; + +use probing_core::core::cluster::{reset_cluster_for_tests, update_node}; +use probing_core::core::federation::{ + classify_cluster_sql, classify_federated_sql, explain_federation, explain_physical_plan, + plan_federated_aggregate_pushdown, prepare_global_query, FederatedQueryPath, +}; +use probing_core::core::{Engine, ProbeDataSource}; +use probing_proto::prelude::Node; +use probing_rust_regression::test_helpers::{federation_test_lock, GenericTableProbeDataSource}; + +async fn metrics_engine(values: Vec) -> Engine { + std::env::set_var("PROBING_ADDRESS", "127.0.0.1:19999"); + std::env::set_var("HOSTNAME", "explain-coord"); + reset_cluster_for_tests(); + update_node(Node { + host: "explain-coord".into(), + addr: "127.0.0.1:19999".into(), + rank: Some(0), + ..Default::default() + }); + update_node(Node { + host: "explain-peer".into(), + addr: "127.0.0.1:20001".into(), + rank: Some(1), + ..Default::default() + }); + + let table = GenericTableProbeDataSource::single_column_table("metrics", "demo", "v", values); + Engine::builder() + .with_data_source(Arc::new(table) as Arc) + .build() + .await + .expect("engine") +} + +fn assert_plan_contains(plan: &str, needle: &str, context: &str) { + assert!( + plan.contains(needle), + "{context}: expected EXPLAIN to contain `{needle}`\n--- plan ---\n{plan}" + ); +} + +// --- Route classification (design §4.2 matrix) --- + +#[test] +fn route_matrix_aggregate_pushdown_comm_heatmap() { + let sql = "SELECT global_step, _rank, sum(duration_ms) AS comm_ms \ + FROM python.comm_collective \ + WHERE global_step >= 0 \ + GROUP BY global_step, _rank"; + assert_eq!( + classify_cluster_sql(sql), + FederatedQueryPath::AggregatePushdown + ); +} + +#[test] +fn route_matrix_federated_scan_raw_rows() { + let sql = "SELECT rank FROM python.comm_collective WHERE rank > 0 LIMIT 100"; + assert_eq!(classify_cluster_sql(sql), FederatedQueryPath::FederatedScan); +} + +#[test] +fn route_matrix_broadcast_join_compute_vs_comm() { + let sql = "SELECT c.global_step, sum(c.duration_ms) AS comm_ms \ + FROM python.comm_collective c \ + JOIN python.torch_trace t \ + ON c.global_step = t.global_step AND c.rank = t.rank \ + GROUP BY c.global_step"; + assert_eq!(classify_cluster_sql(sql), FederatedQueryPath::Broadcast); +} + +#[test] +fn route_matrix_broadcast_cte_slowdown() { + let sql = "WITH per_rank AS ( \ + SELECT global_step, _rank, max(duration_ms) AS max_ms \ + FROM python.comm_collective GROUP BY global_step, _rank \ + ) SELECT avg(max_ms) FROM per_rank"; + assert_eq!(classify_cluster_sql(sql), FederatedQueryPath::Broadcast); +} + +#[test] +fn route_matrix_local_probe_catalog() { + assert_eq!( + classify_federated_sql("SELECT v FROM probe.demo.metrics"), + FederatedQueryPath::Local + ); +} + +// --- EXPLAIN physical plan shape (path B) --- + +#[tokio::test] +async fn explain_federated_scan_exec_for_single_table_scan() { + let _lock = federation_test_lock().await; + let engine = metrics_engine(vec![1, 2, 3]).await; + + let sql = "SELECT v FROM global.demo.metrics WHERE v > 0"; + let plan = explain_physical_plan(&engine, &prepare_global_query(sql)) + .await + .expect("explain"); + + assert_plan_contains(&plan, "FederatedScanExec", "path B single-table scan"); + assert_plan_contains(&plan, "global.demo.metrics", "logical table scan"); + assert_plan_contains( + &plan, + "remote_sql=SELECT \"v\" FROM probe.demo.metrics", + "peer SQL uses probe catalog", + ); +} + +#[tokio::test] +async fn explain_federated_scan_with_peers_shows_peer_count() { + let _lock = federation_test_lock().await; + let engine = metrics_engine(vec![1, 2]).await; + + let sql = "SELECT v FROM global.demo.metrics"; + let plan = explain_physical_plan(&engine, &prepare_global_query(sql)) + .await + .expect("explain"); + + // One registered peer → FederatedScanExec: peers=1 + assert_plan_contains(&plan, "FederatedScanExec: peers=1", "peer partition count"); +} + +#[tokio::test] +async fn explain_aggregate_query_still_plans_federated_scan_underneath() { + let _lock = federation_test_lock().await; + let engine = metrics_engine(vec![1, 2, 3]).await; + + let sql = "SELECT sum(v) AS total FROM global.demo.metrics"; + let plan = explain_physical_plan(&engine, sql).await.expect("explain"); + + // DataFusion EXPLAIN: scan then aggregate locally on coordinator. + assert_plan_contains(&plan, "FederatedScanExec", "scan under aggregate"); + assert_plan_contains(&plan, "AggregateExec", "partial/global aggregate in plan"); +} + +// --- Path A pushdown plan contract (execution differs from EXPLAIN) --- + +#[test] +fn pushdown_plan_per_node_uses_probe_and_strips_tags() { + let global_sql = prepare_global_query( + "SELECT _host, sum(v) AS total FROM global.demo.metrics GROUP BY _host ORDER BY total DESC LIMIT 5", + ); + let plan = plan_federated_aggregate_pushdown(&global_sql).expect("pushdown plan"); + + assert!(plan.per_node_sql.contains("probe.demo.metrics")); + assert!(!plan.per_node_sql.contains("global.")); + assert!(!plan.per_node_sql.to_uppercase().contains("_HOST")); + assert!(!plan.per_node_sql.to_uppercase().contains("ORDER BY")); + assert!(!plan.per_node_sql.to_uppercase().contains("LIMIT")); + let tail = plan.post_merge_tail.as_deref().unwrap_or(""); + assert!(tail.contains("ORDER BY total DESC")); + assert!(tail.contains("LIMIT 5")); +} + +#[tokio::test] +async fn explain_federation_report_matches_design_paths() { + let _lock = federation_test_lock().await; + let engine = metrics_engine(vec![10, 20]).await; + + // Path A: execution = pushdown; report carries plan + EXPLAIN scan/agg shape. + let heatmap = explain_federation( + &engine, + "SELECT v, sum(v) AS s FROM global.demo.metrics GROUP BY v ORDER BY s DESC LIMIT 3", + ) + .await + .expect("explain federation"); + assert_eq!( + heatmap.execution_path, + FederatedQueryPath::AggregatePushdown + ); + assert!(heatmap.aggregate_plan.is_some()); + assert!(heatmap.physical_plan.contains("FederatedScanExec")); + assert!(heatmap.global_sql.contains("global.demo.metrics")); + + // Path B: raw scan. + let scan = explain_federation(&engine, "SELECT v FROM global.demo.metrics WHERE v > 5") + .await + .expect("scan explain"); + assert_eq!(scan.execution_path, FederatedQueryPath::FederatedScan); + assert!(scan.aggregate_plan.is_none()); + assert!(scan.physical_plan.contains("FederatedScanExec")); + + // Path C: join. + let join = explain_federation( + &engine, + "SELECT a.v FROM global.demo.metrics a JOIN global.demo.metrics b ON a.v = b.v", + ) + .await + .expect("join explain"); + assert_eq!(join.execution_path, FederatedQueryPath::Broadcast); +} + +#[tokio::test] +async fn explain_select_star_rewrite_before_plan() { + let _lock = federation_test_lock().await; + let engine = metrics_engine(vec![1]).await; + + let report = explain_federation(&engine, "SELECT * FROM global.demo.metrics") + .await + .expect("report"); + + assert!(report.global_sql.contains("EXCLUDE")); + for col in [ + "_host", + "_addr", + "_rank", + "_node_rank", + "_local_rank", + "_role", + ] { + assert!( + report.global_sql.contains(col), + "SELECT * rewrite missing {col}" + ); + } + assert_eq!(report.execution_path, FederatedQueryPath::FederatedScan); +} diff --git a/tests/regression/rust/probing/core/federation_tests.rs b/tests/regression/rust/probing/core/federation_tests.rs index dab0f659..3ba1ff95 100644 --- a/tests/regression/rust/probing/core/federation_tests.rs +++ b/tests/regression/rust/probing/core/federation_tests.rs @@ -3,12 +3,15 @@ use std::sync::Arc; +use probing_core::core::cluster::{reset_cluster_for_tests, update_node}; use probing_core::core::federation::{ - GLOBAL_CATALOG, PROBE_ADDR_COL, PROBE_HOST_COL, PROBE_RANK_COL, PROBE_ROLE_COL, + set_remote_query_hook, take_fanout_stats, FEDERATION_TAG_COLUMNS, GLOBAL_CATALOG, + PROBE_ADDR_COL, PROBE_HOST_COL, PROBE_LOCAL_RANK_COL, PROBE_NODE_RANK_COL, PROBE_RANK_COL, + PROBE_ROLE_COL, }; use probing_core::core::{Engine, ProbeDataSource}; -use probing_proto::prelude::Seq; -use probing_rust_regression::test_helpers::GenericTableProbeDataSource; +use probing_proto::prelude::{Node, Seq}; +use probing_rust_regression::test_helpers::{federation_test_lock, GenericTableProbeDataSource}; fn df_col_i32(df: &probing_proto::prelude::DataFrame, name: &str) -> Vec { let idx = df @@ -22,6 +25,19 @@ fn df_col_i32(df: &probing_proto::prelude::DataFrame, name: &str) -> Vec { } } +fn df_col_i64(df: &probing_proto::prelude::DataFrame, name: &str) -> Vec { + let idx = df + .names + .iter() + .position(|n| n == name) + .unwrap_or_else(|| panic!("column {name} missing from {:?}", df.names)); + match &df.cols[idx] { + Seq::SeqI64(v) => v.clone(), + Seq::SeqI32(v) => v.iter().map(|&x| i64::from(x)).collect(), + other => panic!("column {name} expected integer column, got {other:?}"), + } +} + #[allow(dead_code)] fn df_col_str(df: &probing_proto::prelude::DataFrame, name: &str) -> Vec { let idx = df @@ -48,8 +64,94 @@ async fn build_demo_engine() -> Engine { .expect("engine build") } +fn register_local_node(rank: i32, addr: &str, host: &str) { + update_node(Node { + host: host.into(), + addr: addr.into(), + rank: Some(rank), + group_rank: Some(rank / 8), + local_rank: Some(rank % 8), + role: Some(format!("dp={rank}")), + ..Default::default() + }); +} + +struct FederatedTestCluster { + local_engine: Engine, + #[allow(dead_code)] + peer_engine: Engine, + #[allow(dead_code)] + peer_addr: String, +} + +impl FederatedTestCluster { + async fn setup(local_values: Vec, peer_values: Vec) -> Self { + reset_cluster_for_tests(); + set_remote_query_hook(None); + + let local_addr = "127.0.0.1:19999"; + let peer_addr = "127.0.0.1:20001".to_string(); + std::env::set_var("PROBING_ADDRESS", local_addr); + std::env::set_var("HOSTNAME", "coord-host"); + + register_local_node(0, local_addr, "coord-host"); + update_node(Node { + host: "peer-host".into(), + addr: peer_addr.clone(), + rank: Some(1), + group_rank: Some(0), + local_rank: Some(1), + role: Some("dp=1".into()), + ..Default::default() + }); + + let local_table = + GenericTableProbeDataSource::single_column_table("metrics", "demo", "v", local_values); + let local_engine = Engine::builder() + .with_data_source(Arc::new(local_table) as Arc) + .build() + .await + .expect("local engine"); + + let peer_table = + GenericTableProbeDataSource::single_column_table("metrics", "demo", "v", peer_values); + let peer_engine = Engine::builder() + .with_data_source(Arc::new(peer_table) as Arc) + .build() + .await + .expect("peer engine"); + + let peer_for_hook = peer_engine.clone(); + let peer_addr_for_hook = peer_addr.clone(); + set_remote_query_hook(Some(Box::new(move |addr, sql| { + if addr != peer_addr_for_hook { + return Err(datafusion::error::DataFusionError::Execution(format!( + "unexpected peer addr: {addr}" + ))); + } + futures::executor::block_on(async { + peer_for_hook.async_query(sql).await?.ok_or_else(|| { + datafusion::error::DataFusionError::Execution("peer query returned None".into()) + }) + }) + }))); + + Self { + local_engine, + peer_engine, + peer_addr, + } + } + + fn teardown(&self) { + set_remote_query_hook(None); + reset_cluster_for_tests(); + } +} + #[tokio::test] async fn global_catalog_discovers_probe_schema() { + let _lock = federation_test_lock().await; let engine = build_demo_engine().await; let global = engine .context @@ -62,6 +164,7 @@ async fn global_catalog_discovers_probe_schema() { #[tokio::test] async fn global_catalog_discovers_tables_registered_after_build() { + let _lock = federation_test_lock().await; std::env::set_var("PROBING_ADDRESS", "127.0.0.1:19999"); std::env::set_var("HOSTNAME", "federation-test-host"); @@ -90,6 +193,7 @@ async fn global_catalog_discovers_tables_registered_after_build() { #[tokio::test] async fn probe_query_has_no_probe_addr_column() { + let _lock = federation_test_lock().await; let engine = build_demo_engine().await; let df = engine .async_query("SELECT rank FROM probe.demo.metrics ORDER BY rank") @@ -102,6 +206,7 @@ async fn probe_query_has_no_probe_addr_column() { #[tokio::test] async fn global_explicit_column_select_omits_probe_tags() { + let _lock = federation_test_lock().await; let engine = build_demo_engine().await; let df = engine .async_query("SELECT rank FROM global.demo.metrics ORDER BY rank") @@ -114,6 +219,7 @@ async fn global_explicit_column_select_omits_probe_tags() { #[tokio::test] async fn global_query_filter_pushdown_preserves_explicit_projection() { + let _lock = federation_test_lock().await; let engine = build_demo_engine().await; let df = engine .async_query("SELECT rank FROM global.demo.metrics WHERE rank = 1") @@ -126,6 +232,7 @@ async fn global_query_filter_pushdown_preserves_explicit_projection() { #[tokio::test] async fn global_and_probe_return_same_ranks_without_peers() { + let _lock = federation_test_lock().await; let engine = build_demo_engine().await; let probe_df = engine .async_query("SELECT rank FROM probe.demo.metrics ORDER BY rank") @@ -145,6 +252,7 @@ async fn global_and_probe_return_same_ranks_without_peers() { #[tokio::test] async fn global_select_name_returns_only_name() { + let _lock = federation_test_lock().await; use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -181,6 +289,7 @@ async fn global_select_name_returns_only_name() { #[tokio::test] async fn global_empty_table_with_timestamp_explicit_select_preserves_schema() { + let _lock = federation_test_lock().await; use arrow::array::{StringArray, TimestampMicrosecondArray}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; @@ -222,6 +331,7 @@ async fn global_empty_table_with_timestamp_explicit_select_preserves_schema() { #[tokio::test] async fn global_empty_table_explicit_select_preserves_schema() { + let _lock = federation_test_lock().await; std::env::set_var("PROBING_ADDRESS", "127.0.0.1:19999"); std::env::set_var("HOSTNAME", "federation-test-host"); @@ -243,6 +353,7 @@ async fn global_empty_table_explicit_select_preserves_schema() { #[tokio::test] async fn global_select_star_includes_probe_addr_and_rank() { + let _lock = federation_test_lock().await; use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -282,13 +393,17 @@ async fn global_select_star_includes_probe_addr_and_rank() { PROBE_HOST_COL.to_string(), PROBE_ADDR_COL.to_string(), PROBE_RANK_COL.to_string(), + PROBE_NODE_RANK_COL.to_string(), + PROBE_LOCAL_RANK_COL.to_string(), PROBE_ROLE_COL.to_string(), ] ); + assert_eq!(df.names.len(), 2 + FEDERATION_TAG_COLUMNS.len()); } #[tokio::test] async fn explicit_probe_tags_not_duplicated() { + let _lock = federation_test_lock().await; let engine = build_demo_engine().await; let df = engine .async_query("SELECT rank, _addr, _rank FROM global.demo.metrics ORDER BY rank") @@ -327,6 +442,7 @@ fn cluster_fanout_join_uses_legacy_broadcast() { #[tokio::test] async fn global_select_star_exclude_rewrite_works() { + let _lock = federation_test_lock().await; use probing_core::core::federation::prepare_global_query; let sql = "SELECT * FROM global.process.envs"; @@ -338,6 +454,7 @@ async fn global_select_star_exclude_rewrite_works() { #[tokio::test] async fn global_select_probe_rank_only_returns_requested_column() { + let _lock = federation_test_lock().await; use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -374,6 +491,7 @@ async fn global_select_probe_rank_only_returns_requested_column() { #[tokio::test] async fn global_group_by_rank_with_count_distinct() { + let _lock = federation_test_lock().await; use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -410,3 +528,129 @@ async fn global_group_by_rank_with_count_distinct() { assert!(df.names.iter().any(|n| n == "_rank")); assert!(df.names.iter().any(|n| n == "n")); } + +#[tokio::test] +async fn aggregate_pushdown_merges_local_and_peer_sums() { + let _lock = federation_test_lock().await; + let cluster = FederatedTestCluster::setup(vec![1, 2, 3], vec![4, 5]).await; + + let df = cluster + .local_engine + .async_query("SELECT sum(v) AS total FROM global.demo.metrics") + .await + .expect("query") + .expect("dataframe"); + + assert_eq!(df_col_i64(&df, "total"), vec![15]); + let stats = take_fanout_stats(); + assert_eq!(stats.nodes_succeeded, 1); + assert!(stats.nodes_failed.is_empty()); + + cluster.teardown(); +} + +#[tokio::test] +async fn aggregate_pushdown_groups_by_host_with_six_tags() { + let _lock = federation_test_lock().await; + let cluster = FederatedTestCluster::setup(vec![10, 20], vec![100]).await; + + let df = cluster + .local_engine + .async_query( + "SELECT _host, sum(v) AS total FROM global.demo.metrics GROUP BY _host ORDER BY total DESC", + ) + .await + .expect("query") + .expect("dataframe"); + + assert!(df.names.iter().any(|n| n == "_host")); + assert!(df.names.iter().any(|n| n == "total")); + let addrs = df_col_str(&df, "_addr"); + assert_eq!(addrs, vec!["127.0.0.1:20001", "127.0.0.1:19999"]); + assert_eq!(df_col_i64(&df, "total"), vec![100, 30]); + + cluster.teardown(); +} + +#[tokio::test] +async fn federated_scan_concatenates_local_and_peer_rows_with_tags() { + let _lock = federation_test_lock().await; + let cluster = FederatedTestCluster::setup(vec![1, 2], vec![3]).await; + + let df = cluster + .local_engine + .async_query("SELECT v, _host FROM global.demo.metrics ORDER BY v") + .await + .expect("query") + .expect("dataframe"); + + assert_eq!(df_col_i32(&df, "v"), vec![1, 2, 3]); + assert!(df.names.iter().any(|n| n == "_host")); + let hosts = df_col_str(&df, "_host"); + assert_eq!(hosts, vec!["coord-host", "coord-host", "peer-host"]); + + let stats = take_fanout_stats(); + assert_eq!(stats.nodes_succeeded, 1); + + cluster.teardown(); +} + +#[tokio::test] +async fn federated_scan_global_limit_with_peer() { + let _lock = federation_test_lock().await; + let cluster = FederatedTestCluster::setup(vec![1, 2, 3], vec![4, 5, 6]).await; + + let df = cluster + .local_engine + .async_query("SELECT v FROM global.demo.metrics ORDER BY v LIMIT 4") + .await + .expect("query") + .expect("dataframe"); + + assert_eq!(df_col_i32(&df, "v"), vec![1, 2, 3, 4]); + + cluster.teardown(); +} + +#[tokio::test] +async fn aggregate_pushdown_order_by_limit_post_merge() { + let _lock = federation_test_lock().await; + use arrow::array::StringArray; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + reset_cluster_for_tests(); + set_remote_query_hook(None); + std::env::set_var("PROBING_ADDRESS", "127.0.0.1:19999"); + std::env::set_var("HOSTNAME", "coord-host"); + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["a", "a", "a", "b"])), + Arc::new(StringArray::from(vec!["1", "1", "1", "2"])), + ], + ) + .unwrap(); + let envs = GenericTableProbeDataSource::new("envs", "process", schema, vec![batch]); + let engine = Engine::builder() + .with_data_source(Arc::new(envs) as Arc) + .build() + .await + .expect("engine build"); + + let df = engine + .async_query( + "SELECT name, count(*) AS n FROM global.process.envs GROUP BY name ORDER BY n DESC LIMIT 2", + ) + .await + .expect("query") + .expect("dataframe"); + + assert_eq!(df_col_str(&df, "name"), vec!["a", "b"]); + assert_eq!(df_col_i64(&df, "n"), vec![3, 1]); +} diff --git a/tests/regression/rust/probing/core/table_docs_integration.rs b/tests/regression/rust/probing/core/table_docs_integration.rs new file mode 100644 index 00000000..8a019ae5 --- /dev/null +++ b/tests/regression/rust/probing/core/table_docs_integration.rs @@ -0,0 +1,152 @@ +//! Integration tests: Schema docs → registry → semantic catalog → Engine SQL. + +use std::sync::Arc; + +use anyhow::Result; +use probing_core::core::{Engine, UnifiedMemtableProbeDataSource}; +use probing_memtable::discover::ExposedTable; +use probing_memtable::{DType, Schema, Value}; +use probing_proto::prelude::{DataFrame, Seq}; + +fn df_col_str(df: &DataFrame, name: &str) -> Vec { + let idx = df + .names + .iter() + .position(|n| n == name) + .unwrap_or_else(|| panic!("column {name} missing from {:?}", df.names)); + match &df.cols[idx] { + Seq::SeqText(values) => values.clone(), + other => panic!("column {name}: expected SeqText, got {other:?}"), + } +} + +fn with_data_dir() -> tempfile::TempDir { + let dir = tempfile::tempdir().expect("tempdir"); + std::env::set_var("PROBING_DATA_DIR", dir.path()); + dir +} + +async fn engine_with_memtable() -> Result { + Engine::builder() + .with_data_source(Arc::new(UnifiedMemtableProbeDataSource)) + .build() + .await + .map_err(Into::into) +} + +#[tokio::test] +async fn mmap_schema_docs_visible_in_semantic_catalog() -> Result<()> { + let _dir = with_data_dir(); + let table = format!("metrics_{}", std::process::id()); + let qualified = format!("unittest.{table}"); + let schema = Schema::new() + .table_doc("Integration metrics table") + .col_doc("latency_ms", DType::F64, "wall-clock latency in ms") + .col_doc("rank", DType::I32, "torch rank"); + + { + let mut exposed = ExposedTable::create(&qualified, &schema, 4096, 4)?; + let mut writer = exposed.writer(); + writer.push_row(&[Value::F64(12.5), Value::I32(0)]); + } + + let engine = engine_with_memtable().await?; + + let col_df = engine + .async_query(format!( + "SELECT description FROM probe.probing.column_docs \ + WHERE table_schema = 'unittest' AND table_name = '{table}' \ + AND column_name = 'latency_ms'" + )) + .await? + .expect("column doc row"); + let desc = df_col_str(&col_df, "description") + .into_iter() + .next() + .unwrap_or_default(); + assert_eq!(desc, "wall-clock latency in ms"); + + let table_df = engine + .async_query(format!( + "SELECT description FROM probe.probing.table_docs \ + WHERE table_schema = 'unittest' AND table_name = '{table}'" + )) + .await? + .expect("table doc row"); + let table_desc = df_col_str(&table_df, "description") + .into_iter() + .next() + .unwrap_or_default(); + assert_eq!(table_desc, "Integration metrics table"); + Ok(()) +} + +#[tokio::test] +async fn describe_static_catalog_table_includes_comment_columns() -> Result<()> { + let engine = engine_with_memtable().await?; + let df = engine + .async_query("DESCRIBE probe.probing.table_docs") + .await? + .expect("DESCRIBE rows"); + + assert!( + df.names.iter().any(|n| n == "comment"), + "DESCRIBE rewrite missing comment: {:?}", + df.names + ); + assert!( + df.names.iter().any(|n| n == "table_comment"), + "DESCRIBE rewrite missing table_comment: {:?}", + df.names + ); + assert!( + df_col_str(&df, "column_name") + .iter() + .any(|n| n == "description"), + "expected static catalog columns" + ); + Ok(()) +} + +#[tokio::test] +async fn catalog_serves_builtin_hccl_and_yaml_synonyms() -> Result<()> { + let engine = engine_with_memtable().await?; + + let col_df = engine + .async_query( + "SELECT description FROM probe.probing.column_docs \ + WHERE table_schema = 'hccl' AND table_name = 'tasks' AND column_name = 'task_name'", + ) + .await? + .expect("hccl.tasks.task_name doc"); + let desc = df_col_str(&col_df, "description") + .into_iter() + .next() + .unwrap_or_default(); + assert!( + desc.contains("Memcpy"), + "expected code-first HCCL column doc, got {desc}" + ); + + let table_df = engine + .async_query( + "SELECT description, synonyms FROM probe.probing.table_docs \ + WHERE table_schema = 'hccl' AND table_name = 'host_ops'", + ) + .await? + .expect("hccl.host_ops table doc"); + let description = df_col_str(&table_df, "description") + .into_iter() + .next() + .unwrap_or_default(); + let synonyms = df_col_str(&table_df, "synonyms") + .into_iter() + .next() + .unwrap_or_default(); + assert!(description.contains("MSProf Host API")); + assert!( + synonyms.contains("MSProf"), + "yaml synonyms should remain available: {synonyms}" + ); + Ok(()) +} diff --git a/tests/regression/rust/src/test_helpers.rs b/tests/regression/rust/src/test_helpers.rs index 39ea8c94..bb5a17ff 100644 --- a/tests/regression/rust/src/test_helpers.rs +++ b/tests/regression/rust/src/test_helpers.rs @@ -11,6 +11,16 @@ use datafusion::logical_expr::Expr; use datafusion::physical_plan::ExecutionPlan; use probing_core::core::{ProbeDataSource, ProbeDataSourceKind}; use std::sync::Arc; +use std::sync::LazyLock; + +use tokio::sync::Mutex; + +static FEDERATION_TEST_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +/// Serialize federation integration tests that mutate global cluster state. +pub async fn federation_test_lock() -> tokio::sync::MutexGuard<'static, ()> { + FEDERATION_TEST_LOCK.lock().await +} /// Generic test table plugin implementation #[derive(Debug, Clone)] diff --git a/tests/regression/training_observability/conftest.py b/tests/regression/training_observability/conftest.py index cdae355b..631712d2 100644 --- a/tests/regression/training_observability/conftest.py +++ b/tests/regression/training_observability/conftest.py @@ -22,6 +22,8 @@ import pytest +import probing + # Same shape as probing/server/src/server/training.rs STEP_MATRIX_SQL (local window). STEP_MATRIX_SQL = """ SELECT @@ -32,7 +34,7 @@ FROM python.trace_event s JOIN python.trace_event e ON s.span_id = e.span_id AND e.record_type = 'span_end' -WHERE s.record_type = 'span_start' AND s.kind = 'train.step' +WHERE s.record_type = 'span_start' AND s.name = 'train.step' ORDER BY s.time ASC """ @@ -45,13 +47,9 @@ @pytest.fixture(autouse=True) def _reset_step_coordinates(): - from probing.tracing import set_step_bucket_size, sync_local_step - - sync_local_step(0) - set_step_bucket_size(1) + probing.step(0) yield - sync_local_step(0) - set_step_bucket_size(1) + probing.step(0) @pytest.fixture(autouse=True) @@ -80,25 +78,25 @@ def _apply( ) -> None: from unittest.mock import MagicMock - from probing.tracing import step_snapshot as real_step_snapshot + from probing.tracing.coordinates import step_snapshot as real_step_snapshot def _fake_snapshot(): base = real_step_snapshot() snap = MagicMock() - snap.local_step = local_step if local_step is not None else base.local_step - snap.global_step = base.global_step - snap.bucket_size = base.bucket_size + micro = local_step if local_step is not None else base.micro_step + batches = base.micro_batches + training = micro // max(batches, 1) + snap.micro_step = micro + snap.local_step = training + snap.global_step = training + snap.micro_batches = batches snap.rank = rank snap.world_size = world_size return snap monkeypatch.setenv("RANK", str(rank)) monkeypatch.setenv("WORLD_SIZE", str(world_size)) - for target in ( - "probing.tracing.step_snapshot", - "probing.profiling.collective.record.step_snapshot", - ): - monkeypatch.setattr(target, _fake_snapshot) + monkeypatch.setattr("probing.tracing.coordinates.step_snapshot", _fake_snapshot) return _apply @@ -159,7 +157,7 @@ def train_step_samples_from_memtable(limit: int = 500) -> list[dict[str, Any]]: starts = { e["span_id"]: e for e in events - if e.get("record_type") == "span_start" and e.get("kind") == "train.step" + if e.get("record_type") == "span_start" and e.get("name") == "train.step" } ends = {e["span_id"]: e for e in events if e.get("record_type") == "span_end"} out: list[dict[str, Any]] = [] diff --git a/tests/regression/training_observability/test_collective_recording.py b/tests/regression/training_observability/test_collective_recording.py index e3ce532e..601e176f 100644 --- a/tests/regression/training_observability/test_collective_recording.py +++ b/tests/regression/training_observability/test_collective_recording.py @@ -47,7 +47,7 @@ def test_full_mode_span_writes_row_and_trace(self): events = table_rows(TraceEvent, 10) assert len(events) >= 2 starts = [e for e in events if e["record_type"] == "span_start"] - assert any(e["kind"] == "comm.all_reduce" for e in starts) + assert any(e["name"] == "all_reduce" for e in starts) def test_lite_mode_writes_comm_row(self): record_comm_lite( @@ -76,7 +76,6 @@ def test_lite_mode_writes_closed_trace_pair_by_default(self): events = table_rows(TraceEvent, 10) by_type = {row["record_type"]: row for row in events} assert by_type["span_start"]["name"] == "all_reduce" - assert by_type["span_start"]["kind"] == "comm.all_reduce" assert by_type["span_end"]["span_id"] == by_type["span_start"]["span_id"] def test_lite_mode_can_skip_trace_event(self): diff --git a/tests/regression/training_observability/test_collective_tracer_hook.py b/tests/regression/training_observability/test_collective_tracer_hook.py index 665f05a6..cf9149c9 100644 --- a/tests/regression/training_observability/test_collective_tracer_hook.py +++ b/tests/regression/training_observability/test_collective_tracer_hook.py @@ -81,7 +81,7 @@ def fake_all_reduce(tensor, *args, **kwargs): span = current_span() seen["had_span"] = span is not None - seen["kind"] = getattr(span, "kind", None) if span else None + seen["name"] = getattr(span, "name", None) if span else None return None wrapper = tracer._trace_wrapper("all_reduce", fake_all_reduce) @@ -90,6 +90,6 @@ def fake_all_reduce(tensor, *args, **kwargs): wrapper(mock_tensor) assert seen.get("had_span") is True - assert seen.get("kind") == "comm.all_reduce" + assert seen.get("name") == "all_reduce" rows = table_rows(CommCollective, 5) assert len(rows) == 1 diff --git a/tests/regression/training_observability/test_step_straggler_sql.py b/tests/regression/training_observability/test_step_straggler_sql.py index 5f1dfc16..7aaaa2b4 100644 --- a/tests/regression/training_observability/test_step_straggler_sql.py +++ b/tests/regression/training_observability/test_step_straggler_sql.py @@ -2,7 +2,8 @@ import pytest -from probing.tracing import TRAIN_STEP_KIND, record_closed_span, sync_local_step +import probing +from probing.tracing import record_span from .conftest import train_step_samples_from_memtable @@ -13,12 +14,11 @@ def test_empty_without_train_step_spans(self): assert train_step_samples_from_memtable() == [] def test_train_step_duration_from_closed_spans(self, rank_env): - sync_local_step(42) + probing.step(42) rank_env(rank=3, world_size=8) - record_closed_span( - "batch", - kind=TRAIN_STEP_KIND, + record_span( + "train.step", duration_ns=int(150.0 * 1e6), source="test", ) @@ -30,15 +30,13 @@ def test_train_step_duration_from_closed_spans(self, rank_env): assert rows[0]["duration_ms"] == pytest.approx(150.0, rel=0.05) def test_multi_rank_straggler_simulation(self, rank_env): - """Single process simulates cross-rank matrix by varying RANK env.""" - sync_local_step(100) + probing.step(100) durations = {0: 120.0, 1: 118.0, 2: 350.0, 3: 125.0} for rank, duration_ms in durations.items(): rank_env(rank=rank, world_size=4) - record_closed_span( - "batch", - kind=TRAIN_STEP_KIND, + record_span( + "train.step", duration_ns=int(duration_ms * 1e6), source="test", ) @@ -49,10 +47,10 @@ def test_multi_rank_straggler_simulation(self, rank_env): assert by_rank[2] > by_rank[0] * 2 assert all(r["local_step"] == 100 for r in rows) - def test_ignores_non_train_step_kinds(self): - record_closed_span( + def test_ignores_non_train_step_names(self): + record_span( "forward", - kind="nn.forward", + phase="forward", duration_ns=int(50.0 * 1e6), ) diff --git a/tests/regression/training_observability/test_topology_context.py b/tests/regression/training_observability/test_topology_context.py index 6c7cd2ec..f946024b 100644 --- a/tests/regression/training_observability/test_topology_context.py +++ b/tests/regression/training_observability/test_topology_context.py @@ -4,8 +4,11 @@ import probing from probing.parallel import parallel_fields, parallel_topology -from probing.profiling.collective.record import CommCollective, record_comm_lite -from probing.tracing import comm_kind +from probing.profiling.collective.record import ( + CommCollective, + _comm_label, + record_comm_lite, +) from .conftest import table_rows @@ -60,8 +63,8 @@ def test_span_includes_topology(self, monkeypatch): assert attrs["dp_rank"] == 5 def test_comm_kind_labels(self): - assert comm_kind("all_reduce") == "comm.all_reduce" - assert comm_kind("comm.broadcast") == "comm.broadcast" + assert _comm_label("all_reduce") == "comm.all_reduce" + assert _comm_label("comm.broadcast") == "comm.broadcast" @pytest.mark.training_observability diff --git a/tests/regression/training_observability/test_training_iteration_e2e.py b/tests/regression/training_observability/test_training_iteration_e2e.py index 819132b3..c33234c4 100644 --- a/tests/regression/training_observability/test_training_iteration_e2e.py +++ b/tests/regression/training_observability/test_training_iteration_e2e.py @@ -3,37 +3,37 @@ import pytest import probing -from probing.profiling.collective.record import record_comm_lite -from probing.tracing import TRAIN_STEP_KIND, sync_local_step +from probing.profiling.collective.record import CommCollective, record_comm_lite +from probing.tracing.phases import BACKWARD, FORWARD, OPTIMIZER -from .conftest import train_step_samples_from_memtable, table_rows -from probing.profiling.collective.record import CommCollective +from .conftest import table_rows, train_step_samples_from_memtable @pytest.mark.training_observability class TestTrainingIterationPipeline: def test_single_iteration_step_and_comm(self, rank_env, parallel_env): - """Mimics: train.step → collective → memtable rows used by Training page.""" rank_env(rank=1, world_size=8) parallel_env(tp_rank=0, pp_rank=1, dp_rank=1) - sync_local_step(7) - - with probing.span("batch", kind=TRAIN_STEP_KIND): - with probing.span("forward", kind="nn.forward"): - pass - record_comm_lite( - op="all_reduce", - duration_ms=8.5, - group_rank=1, - group_size=8, - nbytes=4096, - ) - with probing.span("backward", kind="nn.backward"): - pass + probing.step(7) + + with probing.span("forward", phase=FORWARD): + pass + record_comm_lite( + op="all_reduce", + duration_ms=8.5, + group_rank=1, + group_size=8, + nbytes=4096, + ) + with probing.span("backward", phase=BACKWARD): + pass + with probing.span("optimizer", phase=OPTIMIZER): + pass + probing.record_span("train.step", duration_ns=int(120.0 * 1e6), source="test") step_rows = train_step_samples_from_memtable() assert len(step_rows) >= 1 - assert any(r["rank"] == 1 and r["local_step"] == 7 for r in step_rows) + assert any(r["rank"] == 1 for r in step_rows) comm_rows = table_rows(CommCollective, 5) assert len(comm_rows) == 1 @@ -41,18 +41,15 @@ def test_single_iteration_step_and_comm(self, rank_env, parallel_env): assert "pp=1" in comm_rows[0]["role"] assert comm_rows[0]["bytes"] == 4096 - def test_train_step_event_after_nested_spans(self): - """Regression path from imagenet_with_span (SpanAlreadyClosed).""" - with probing.span("batch", kind=TRAIN_STEP_KIND): - with probing.span("forward", kind="nn.forward"): - pass + def test_event_on_training_span(self): + with probing.span("forward", phase=FORWARD): probing.event("batch.stats", attributes=[{"loss": 1.25}]) - def test_torch_probe_reentrant_train_step(self): + def test_torch_probe_reentrant_optimizer(self): from probing.profiling.torch_probe import TorchProbe, TorchProbeConfig tracer = TorchProbe(config=TorchProbeConfig(enabled=True)) - with probing.span("batch", kind=TRAIN_STEP_KIND) as outer: + with probing.span("outer", phase=OPTIMIZER) as outer: tracer._begin_train_step_span() assert not outer.is_ended tracer._end_train_step_span() diff --git a/tests/unit/probing/test_web_assets.py b/tests/unit/probing/test_web_assets.py index 929418a7..fe9d811f 100644 --- a/tests/unit/probing/test_web_assets.py +++ b/tests/unit/probing/test_web_assets.py @@ -37,17 +37,42 @@ def test_dev_web_dir_when_frontend_built(): assert root is None -def test_configure_assets_root_prefers_bundled(monkeypatch, tmp_path: Path): +def test_configure_assets_root_prefers_dev_in_editable(monkeypatch, tmp_path: Path): bundled = tmp_path / "_web" bundled.mkdir() (bundled / "index.html").write_text("bundled", encoding="utf-8") + dev = tmp_path / "web" / "dist" + dev.mkdir(parents=True) + (dev / "index.html").write_text( + '
', + encoding="utf-8", + ) + + monkeypatch.setattr(web_assets, "bundled_web_dir", lambda: bundled) + monkeypatch.setattr(web_assets, "dev_web_dir", lambda: dev) + monkeypatch.setattr(web_assets, "_running_from_installed_wheel", lambda: False) + monkeypatch.delenv(web_assets._ENV, raising=False) + + assert web_assets.configure_assets_root() == dev + assert os.environ[web_assets._ENV] == str(dev) + + +def test_configure_assets_root_prefers_bundled_on_wheel(monkeypatch, tmp_path: Path): + bundled = tmp_path / "_web" + bundled.mkdir() + (bundled / "index.html").write_text( + '
', + encoding="utf-8", + ) + dev = tmp_path / "web" / "dist" dev.mkdir(parents=True) (dev / "index.html").write_text("dev", encoding="utf-8") monkeypatch.setattr(web_assets, "bundled_web_dir", lambda: bundled) monkeypatch.setattr(web_assets, "dev_web_dir", lambda: dev) + monkeypatch.setattr(web_assets, "_running_from_installed_wheel", lambda: True) monkeypatch.delenv(web_assets._ENV, raising=False) assert web_assets.configure_assets_root() == bundled diff --git a/tests/unit/probing/tracing/test_phase_transitions.py b/tests/unit/probing/tracing/test_phase_transitions.py new file mode 100644 index 00000000..376a52e7 --- /dev/null +++ b/tests/unit/probing/tracing/test_phase_transitions.py @@ -0,0 +1,374 @@ +"""Unit tests: span-driven and hook-driven training phase state transitions.""" + +from __future__ import annotations + +import dataclasses + +import pytest + +import probing +from probing.tracing import TraceEvent, bind_table, reset_backends +from probing.tracing.phases import ( + BACKWARD, + FORWARD, + IDLE, + OPTIMIZER, + hook_enter, + hook_exit, + phase, + reset_phase, +) + + +@pytest.fixture(autouse=True) +def _isolated_tracing(): + reset_phase() + probing.step(0) + probing.step(micro_batches=1) + try: + TraceEvent.drop() + except Exception: + pass + TraceEvent.init_table() + reset_backends(clear_registered=True) + bind_table(TraceEvent) + yield + reset_backends(clear_registered=True) + reset_phase() + + +def _trace_rows(n: int = 100) -> list[dict]: + fields = [f.name for f in dataclasses.fields(TraceEvent)] + return [dict(zip(fields, data)) for _ts, data in TraceEvent.take(n)] + + +def _closed_span_names(rows: list[dict]) -> list[str]: + starts = { + r["span_id"]: r["name"] for r in rows if r.get("record_type") == "span_start" + } + ends = {r["span_id"] for r in rows if r.get("record_type") == "span_end"} + return [starts[sid] for sid in ends if sid in starts] + + +# --- Span-driven transitions --- + + +class TestSpanDrivenPhaseTransitions: + def test_starts_idle(self): + assert phase() == IDLE + + def test_single_forward_idle_cycle(self): + with probing.span("forward", phase=FORWARD): + assert phase() == FORWARD + assert phase() == IDLE + + def test_full_training_step_sequence(self): + assert phase() == IDLE + with probing.span("forward", phase=FORWARD): + assert phase() == FORWARD + assert phase() == IDLE + + with probing.span("backward", phase=BACKWARD): + assert phase() == BACKWARD + assert phase() == IDLE + + with probing.span("step", phase=OPTIMIZER): + assert phase() == OPTIMIZER + assert phase() == IDLE + assert probing.step.micro_step == 1 + + def test_nested_training_phases_restore_parent(self): + with probing.span("forward", phase=FORWARD): + assert phase() == FORWARD + with probing.span("backward", phase=BACKWARD): + assert phase() == BACKWARD + assert phase() == FORWARD + assert phase() == IDLE + + def test_non_training_span_leaves_phase_idle(self): + with probing.span("data.load"): + assert phase() == IDLE + with probing.span("epoch"): + assert phase() == IDLE + + def test_inferred_phase_from_name(self): + with probing.span("forward"): + assert phase() == FORWARD + assert phase() == IDLE + + def test_train_step_name_does_not_change_phase(self): + probing.record_span("train.step", duration_ns=1000) + assert phase() == IDLE + + def test_outer_batch_inner_forward(self): + with probing.span("batch"): + assert phase() == IDLE + with probing.span("forward", phase=FORWARD): + assert phase() == FORWARD + assert phase() == IDLE + assert phase() == IDLE + + def test_optimizer_reentrant_does_not_double_step(self): + with probing.span("step", phase=OPTIMIZER): + assert phase() == OPTIMIZER + with probing.span("step", phase=OPTIMIZER): + assert phase() == OPTIMIZER + assert phase() == IDLE + assert probing.step.micro_step == 1 + + +# --- Hook-driven transitions --- + + +class TestHookDrivenPhaseTransitions: + def test_hook_forward_cycle(self): + hook_enter(FORWARD) + assert phase() == FORWARD + hook_exit(FORWARD) + assert phase() == IDLE + + def test_hook_full_iteration_sequence(self): + hook_enter(FORWARD) + assert phase() == FORWARD + hook_exit(FORWARD) + assert phase() == IDLE + + hook_enter(BACKWARD) + assert phase() == BACKWARD + hook_exit(BACKWARD) + assert phase() == IDLE + + hook_enter(OPTIMIZER) + assert phase() == OPTIMIZER + hook_exit(OPTIMIZER) + assert phase() == IDLE + + def test_hook_emits_phase_spans(self): + hook_enter(FORWARD) + hook_exit(FORWARD) + hook_enter(BACKWARD) + hook_exit(BACKWARD) + hook_enter(OPTIMIZER) + hook_exit(OPTIMIZER) + + names = _closed_span_names(_trace_rows()) + assert names.count("forward") == 1 + assert names.count("backward") == 1 + assert names.count("optimizer") == 1 + assert "train.step" in names + + def test_hook_optimizer_advances_step(self): + assert probing.step.micro_step == 0 + hook_enter(FORWARD) + hook_exit(FORWARD) + hook_enter(OPTIMIZER) + hook_exit(OPTIMIZER) + assert probing.step.micro_step == 1 + + def test_hook_forward_records_train_step_duration(self): + hook_enter(FORWARD) + hook_exit(FORWARD) + hook_enter(OPTIMIZER) + hook_exit(OPTIMIZER) + + rows = _trace_rows() + train_step = next( + ( + r + for r in rows + if r.get("record_type") == "span_start" + and r.get("name") == "train.step" + ), + None, + ) + assert train_step is not None + assert int(train_step.get("time", 0)) >= 0 + + def test_hook_without_forward_skips_train_step(self): + hook_enter(OPTIMIZER) + hook_exit(OPTIMIZER) + names = _closed_span_names(_trace_rows()) + assert "train.step" not in names + + def test_manual_span_suppresses_hook_span_but_keeps_phase(self): + with probing.span("forward", phase=FORWARD): + assert phase() == FORWARD + hook_enter(FORWARD) + assert phase() == FORWARD + rows_mid = _trace_rows() + forward_starts = [ + r + for r in rows_mid + if r.get("record_type") == "span_start" and r.get("name") == "forward" + ] + assert len(forward_starts) == 1 + hook_exit(FORWARD) + assert phase() == FORWARD + assert phase() == IDLE + + def test_hook_phase_spans_use_phase_hook_source(self): + hook_enter(BACKWARD) + rows = _trace_rows() + backward = next( + r + for r in rows + if r.get("record_type") == "span_start" and r.get("name") == "backward" + ) + import json + + attrs = json.loads(backward.get("attributes") or "{}") + assert attrs.get("source") == "phase_hook" + hook_exit(BACKWARD) + + +# --- Span + hook collaboration --- + + +class TestSpanHookCollaboration: + def test_manual_then_hook_optimizer_single_step_advance(self): + with probing.span("forward", phase=FORWARD): + pass + with probing.span("backward", phase=BACKWARD): + pass + with probing.span("step", phase=OPTIMIZER): + hook_enter(OPTIMIZER) + assert phase() == OPTIMIZER + hook_exit(OPTIMIZER) + assert phase() == IDLE + assert probing.step.micro_step == 1 + + def test_hook_cycle_then_manual_spans_reset_correctly(self): + hook_enter(FORWARD) + hook_exit(FORWARD) + assert phase() == IDLE + + with probing.span("backward", phase=BACKWARD): + assert phase() == BACKWARD + assert phase() == IDLE + + def test_phase_reads_span_stack_not_stale_after_hook_exit(self): + with probing.span("forward", phase=FORWARD): + hook_enter(FORWARD) + assert phase() == FORWARD + hook_exit(FORWARD) + assert phase() == FORWARD + assert phase() == IDLE + + +class TestGradientAccumulation: + def test_grad_acc_one_train_step_per_optimizer(self): + probing.step(micro_batches=4) + for micro in range(4): + hook_enter(FORWARD) + hook_exit(FORWARD) + hook_enter(BACKWARD) + hook_exit(BACKWARD) + if micro < 3: + assert phase() == IDLE + else: + hook_enter(OPTIMIZER) + hook_exit(OPTIMIZER) + assert phase() == IDLE + + names = _closed_span_names(_trace_rows()) + assert names.count("train.step") == 1 + assert probing.step.micro_step == 1 + assert probing.step.local_step == 0 + + hook_enter(FORWARD) + hook_exit(FORWARD) + hook_enter(BACKWARD) + hook_exit(BACKWARD) + hook_enter(OPTIMIZER) + hook_exit(OPTIMIZER) + + names = _closed_span_names(_trace_rows()) + assert names.count("train.step") == 2 + assert probing.step.micro_step == 2 + + def test_train_step_attrs_include_accum_index(self): + import json + + probing.step(micro_batches=2) + hook_enter(FORWARD) + hook_exit(FORWARD) + hook_enter(BACKWARD) + hook_exit(BACKWARD) + hook_enter(OPTIMIZER) + hook_exit(OPTIMIZER) + + row = next( + r + for r in _trace_rows() + if r.get("name") == "train.step" and r.get("record_type") == "span_start" + ) + attrs = json.loads(row["attributes"]) + assert attrs["micro_batches"] == 2 + assert "accum_index" in attrs + assert "logical_step_pending" in attrs + + +class TestSpanApi: + def test_span_phase_only_defaults_name(self): + with probing.span(phase=FORWARD) as s: + assert s.name == FORWARD + assert s.phase == FORWARD + assert phase() == FORWARD + + +class TestTorchProbeOwnership: + def test_torch_probe_skips_training_phase_when_hooks_attached(self): + import torch + import torch.nn as nn + + from probing.profiling.torch_probe import TorchProbe, TorchProbeConfig + from probing.tracing.hooks import attach_training_phases, detach_training_phases + + class M(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.w * x + + model = M() + opt = torch.optim.SGD(model.parameters(), lr=0.1) + attach_training_phases(model, opt) + try: + tracer = TorchProbe(config=TorchProbeConfig(enabled=True)) + tracer.log_module_stage("pre forward", model, force=True) + tracer.log_module_stage("post forward", model, force=True) + starts = [ + r + for r in _trace_rows() + if r.get("record_type") == "span_start" and r.get("phase") == FORWARD + ] + assert starts == [] + finally: + detach_training_phases(model, opt) + + def test_torch_probe_optimizer_span_skipped_when_hooks_attached(self): + import torch + import torch.nn as nn + + from probing.profiling.torch_probe import TorchProbe, TorchProbeConfig + from probing.tracing.hooks import attach_training_phases, detach_training_phases + + class M(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.w * x + + model = M() + opt = torch.optim.SGD(model.parameters(), lr=0.1) + attach_training_phases(model, opt) + try: + tracer = TorchProbe(config=TorchProbeConfig(enabled=True)) + tracer._begin_train_step_span(optimizer=opt) + assert tracer._train_step_cm is None + finally: + detach_training_phases(model, opt) diff --git a/tests/unit/probing/tracing/test_phases.py b/tests/unit/probing/tracing/test_phases.py new file mode 100644 index 00000000..f08e54a7 --- /dev/null +++ b/tests/unit/probing/tracing/test_phases.py @@ -0,0 +1,50 @@ +from probing.tracing import phases + + +def test_infer_training_names(): + assert phases.infer("forward") == phases.FORWARD + assert phases.infer("backward") == phases.BACKWARD + assert phases.infer("step") == phases.OPTIMIZER + + +def test_resolve_explicit_phase(): + assert phases.resolve("iter", phases.FORWARD) == phases.FORWARD + + +def test_infer_from_stage(): + assert phases.infer_from_stage("pre forward") == phases.FORWARD + assert phases.infer_from_stage("post backward") == phases.BACKWARD + assert phases.infer_from_stage("pre step") == phases.OPTIMIZER + assert phases.infer_from_stage("pre init") is None + + +def test_invalid_phase_raises(): + import pytest + + with pytest.raises(ValueError, match="invalid training phase"): + phases.resolve("x", "custom") + + +def test_resolve_span_phase_only(): + name, phase = phases.resolve_span(None, phases.FORWARD) + assert name == phases.FORWARD + assert phase == phases.FORWARD + + +def test_resolve_span_name_only(): + name, phase = phases.resolve_span("forward", None) + assert name == "forward" + assert phase == phases.FORWARD + + +def test_resolve_span_explicit_display_name(): + name, phase = phases.resolve_span("compute", phases.BACKWARD) + assert name == "compute" + assert phase == phases.BACKWARD + + +def test_resolve_span_requires_one(): + import pytest + + with pytest.raises(TypeError, match="requires name and/or phase"): + phases.resolve_span(None, None) diff --git a/tests/unit/probing/tracing/test_span_backends.py b/tests/unit/probing/tracing/test_span_backends.py new file mode 100644 index 00000000..d455cf9b --- /dev/null +++ b/tests/unit/probing/tracing/test_span_backends.py @@ -0,0 +1,153 @@ +"""Span multi-backend recorder tests.""" + +from __future__ import annotations + +import dataclasses + +import pytest + +import probing + + +@pytest.fixture(autouse=True) +def _isolate_trace_table(monkeypatch): + from probing.tracing import TraceEvent, bind_table, reset_backends + + monkeypatch.delenv("PROBING_SPAN_BACKENDS", raising=False) + try: + TraceEvent.drop() + except Exception: + pass + TraceEvent.init_table() + reset_backends(clear_registered=True) + bind_table(TraceEvent) + yield + reset_backends(clear_registered=True) + + +def _trace_rows(n: int = 50) -> list[dict]: + from probing.tracing import TraceEvent + + fields = [f.name for f in dataclasses.fields(TraceEvent)] + return [dict(zip(fields, data)) for _ts, data in TraceEvent.take(n)] + + +def test_default_backend_is_memtable(): + from probing.tracing import list_backends + + assert list_backends() == ["memtable"] + + +def test_custom_backend_receives_span_lifecycle(monkeypatch): + from probing.tracing import register_backend, reset_backends + + calls: list[tuple[str, object]] = [] + + class CaptureBackend: + name = "capture" + + def on_span_start(self, record): + calls.append(("start", record.name)) + + def on_span_end(self, record): + calls.append(("end", record.span_id)) + + def on_event(self, record): + calls.append(("event", record.name)) + + def shutdown(self): + calls.append(("shutdown", None)) + + register_backend("capture", lambda: CaptureBackend()) + monkeypatch.setenv("PROBING_SPAN_BACKENDS", "memtable,capture") + reset_backends() + + with probing.span("dual") as span: + span_id = span.span_id + probing.event("ping") + + assert ("start", "dual") in calls + assert ("event", "ping") in calls + assert any(c[0] == "end" and c[1] == span_id for c in calls) + + rows = _trace_rows() + assert any( + r.get("record_type") == "span_start" and r.get("name") == "dual" for r in rows + ) + + +def test_unknown_backend_falls_back_to_memtable_only(monkeypatch): + from probing.tracing import list_backends, reset_backends + + monkeypatch.setenv("PROBING_SPAN_BACKENDS", "unknown_backend") + reset_backends() + assert list_backends() == ["memtable"] + + with probing.span("still_works"): + pass + + rows = _trace_rows() + assert any(r.get("name") == "still_works" for r in rows) + + +def test_configure_overrides_env(monkeypatch): + from probing.tracing import configure_backends, list_backends + + monkeypatch.setenv("PROBING_SPAN_BACKENDS", "unknown_backend") + configure_backends(["memtable"]) + assert list_backends() == ["memtable"] + + +def test_otel_backend_skipped_without_sdk(monkeypatch): + from probing.tracing import list_backends, reset_backends + + monkeypatch.setenv("PROBING_SPAN_BACKENDS", "memtable,otel") + reset_backends() + assert list_backends() == ["memtable"] + + +def test_logger_backend_with_memtable(monkeypatch, capsys): + import logging + + from probing.tracing import list_backends, reset_backends + + log = logging.getLogger("probing.span") + log.handlers.clear() + log.propagate = True + + monkeypatch.setenv("PROBING_SPAN_BACKENDS", "memtable,logger") + reset_backends() + assert list_backends() == ["memtable", "logger"] + + with probing.span("hello", phase="forward"): + probing.event("ping", attributes=[{"x": 1}]) + + err = capsys.readouterr().err + assert "→ hello phase=forward" in err + assert "· ping" in err + assert "← hello" in err and "ms" in err + + rows = _trace_rows() + assert any( + r.get("record_type") == "span_start" and r.get("name") == "hello" for r in rows + ) + + +def test_logger_backend_only(monkeypatch, capsys): + import logging + + from probing.tracing import list_backends, reset_backends + + log = logging.getLogger("probing.span") + log.handlers.clear() + log.propagate = True + + monkeypatch.setenv("PROBING_SPAN_BACKENDS", "logger") + reset_backends() + assert list_backends() == ["logger"] + + with probing.span("terminal_only"): + pass + + assert "→ terminal_only" in capsys.readouterr().err + assert _trace_rows() == [] diff --git a/web/src/api/traces.rs b/web/src/api/traces.rs index d8998952..381d0d41 100644 --- a/web/src/api/traces.rs +++ b/web/src/api/traces.rs @@ -15,7 +15,7 @@ pub struct TraceEvent { pub name: String, pub timestamp: i64, pub thread_id: i64, - pub kind: Option, + pub phase: Option, pub location: Option, pub attributes: Option, pub event_attributes: Option, @@ -30,7 +30,7 @@ pub struct SpanInfo { pub start_timestamp: i64, pub end_timestamp: Option, pub thread_id: i64, - pub kind: Option, + pub phase: Option, pub location: Option, pub attributes: Option, pub children: Vec, @@ -54,6 +54,8 @@ impl ApiClient { String::new() }; + // Use logical event time (`time`, ns) — not memtable ingestion `timestamp` (µs). + // Matches training step_matrix / SPANS_SQL in probing.tracing. let query = format!( r#" SELECT @@ -62,14 +64,14 @@ impl ApiClient { span_id, COALESCE(parent_id, -1) as parent_id, name, - timestamp, + time AS timestamp, COALESCE(thread_id, 0) as thread_id, - kind, + phase, location, attributes, event_attributes FROM python.trace_event - ORDER BY timestamp DESC + ORDER BY time DESC {} "#, limit_clause @@ -96,7 +98,7 @@ impl ApiClient { let name_idx = df.names.iter().position(|c| c == "name").unwrap_or(4); let timestamp_idx = df.names.iter().position(|c| c == "timestamp").unwrap_or(5); let thread_id_idx = df.names.iter().position(|c| c == "thread_id").unwrap_or(6); - let kind_idx = df.names.iter().position(|c| c == "kind").unwrap_or(7); + let phase_idx = df.names.iter().position(|c| c == "phase").unwrap_or(7); let location_idx = df.names.iter().position(|c| c == "location").unwrap_or(8); let attributes_idx = df.names.iter().position(|c| c == "attributes").unwrap_or(9); let event_attributes_idx = df @@ -155,7 +157,7 @@ impl ApiClient { name: get_str(name_idx), timestamp: get_i64(timestamp_idx), thread_id: get_i64(thread_id_idx), - kind: get_opt_str(kind_idx), + phase: get_opt_str(phase_idx), location: get_opt_str(location_idx), attributes: get_opt_str(attributes_idx), event_attributes: get_opt_str(event_attributes_idx), @@ -184,7 +186,7 @@ impl ApiClient { start_timestamp: event.timestamp, end_timestamp: None, thread_id: event.thread_id, - kind: event.kind.clone(), + phase: event.phase.clone(), location: event.location.clone(), attributes: event.attributes.clone(), children: Vec::new(), @@ -308,7 +310,7 @@ impl ApiClient { let mut trace_events: Vec = Vec::new(); // Use (span_id, thread_id) as key to track span start time, supports multi-threaded scenarios - // Value contains: (start timestamp in microseconds, span name, kind, trace_id) + // Value contains: (start timestamp in microseconds, span name, phase, trace_id) let mut span_starts: SpanStartMap = std::collections::HashMap::new(); // First pass: collect all span_start events, build lookup table @@ -328,7 +330,7 @@ impl ApiClient { ( event.timestamp, event.name.clone(), - event.kind.clone(), + event.phase.clone(), event.trace_id, ), ); @@ -374,7 +376,7 @@ impl ApiClient { ( ts_micros, event.name.clone(), - event.kind.clone(), + event.phase.clone(), unified_pid, ), ); @@ -382,7 +384,7 @@ impl ApiClient { // Create 'B' (Begin) event let mut chrome_event = serde_json::json!({ "name": event.name, - "cat": event.kind.as_ref().unwrap_or(&"span".to_string()), + "cat": event.phase.as_ref().unwrap_or(&"span".to_string()), "ph": "B", "ts": ts_micros, "pid": pid, @@ -417,13 +419,13 @@ impl ApiClient { let key = (event.span_id, event.thread_id); // First try to find from already processed events - if let Some((start_ts, start_name, start_kind, start_pid)) = + if let Some((start_ts, start_name, start_phase, start_pid)) = span_starts.get(&key) { // Found matching span_start, create 'E' (End) event let mut chrome_event = serde_json::json!({ "name": start_name, - "cat": start_kind.as_ref().unwrap_or(&"span".to_string()), + "cat": start_phase.as_ref().unwrap_or(&"span".to_string()), "ph": "E", "ts": ts_micros, "pid": *start_pid as u32, @@ -439,7 +441,7 @@ impl ApiClient { trace_events.push(chrome_event); // Remove from span_starts to avoid duplicate matching span_starts.remove(&key); - } else if let Some((start_timestamp, start_name, start_kind, _)) = + } else if let Some((start_timestamp, start_name, start_phase, _)) = span_start_lookup.get(&key) { // Find span_start information from lookup table @@ -447,7 +449,7 @@ impl ApiClient { // Use unified pid to ensure all spans are in the same process let mut chrome_event = serde_json::json!({ "name": start_name, - "cat": start_kind.as_ref().unwrap_or(&"span".to_string()), + "cat": start_phase.as_ref().unwrap_or(&"span".to_string()), "ph": "E", "ts": ts_micros, "pid": unified_pid as u32, @@ -603,7 +605,7 @@ pub struct RayTimelineEntry { pub trace_id: i64, pub span_id: i64, pub parent_id: Option, - pub kind: Option, + pub phase: Option, pub thread_id: i64, pub attributes: Option, } diff --git a/web/src/components/sidebar/mod.rs b/web/src/components/sidebar/mod.rs index c7212a33..c8392345 100644 --- a/web/src/components/sidebar/mod.rs +++ b/web/src/components/sidebar/mod.rs @@ -84,7 +84,7 @@ pub fn Sidebar() -> Element { Link { to: Route::DashboardPage {}, class: "flex items-center gap-2", - img { src: "{crate::utils::base_path::with_base(\"/assets/logo.svg\")}", alt: "Probing", class: "w-7 h-7 flex-shrink-0" } + img { src: "{crate::utils::base_path::with_base(\"/logo.svg\")}", alt: "Probing", class: "w-7 h-7 flex-shrink-0" } span { class: "{brand}", "Probing" } } } diff --git a/web/src/hooks/mod.rs b/web/src/hooks/mod.rs index 674f0ba2..c817805a 100644 --- a/web/src/hooks/mod.rs +++ b/web/src/hooks/mod.rs @@ -69,7 +69,15 @@ where use_effect(move || { let mut loading = state.loading; let mut data = state.data; - let show_loading = !options.keep_previous_while_refreshing || data.read().is_none(); + + // Avoid stacking polls while a refresh is still in flight. + if options.keep_previous_while_refreshing && *loading.peek() { + return; + } + + // Peek so completing a fetch does not re-trigger this effect (infinite /query loop). + let show_loading = + !options.keep_previous_while_refreshing || data.with_peek(|d| d.is_none()); let result_future = fetch_fn(); spawn(async move { if show_loading { diff --git a/web/src/pages/traces.rs b/web/src/pages/traces.rs index 117449ab..517ce710 100644 --- a/web/src/pages/traces.rs +++ b/web/src/pages/traces.rs @@ -405,9 +405,9 @@ fn span_matches_text(span: &SpanInfo, query: &str) -> bool { } span.name.to_lowercase().contains(&q) || span - .kind + .phase .as_ref() - .is_some_and(|k| k.to_lowercase().contains(&q)) + .is_some_and(|p| p.to_lowercase().contains(&q)) || span .location .as_ref() @@ -612,14 +612,14 @@ fn SpanView( span { class: "w-4 shrink-0" } } span { class: "font-semibold text-gray-900 shrink-0", "{span.name}" } - if let Some(ref kind) = span.kind { + if let Some(ref phase) = span.phase { span { class: format!( "shrink-0 px-1.5 py-px rounded text-[10px] font-sans font-medium bg-{} text-{}", colors::CONTENT_ACCENT_BG, colors::CONTENT_ACCENT_TEXT, ), - "{kind}" + "{phase}" } } if let Some(ref location) = span.location { diff --git a/web/src/pages/training.rs b/web/src/pages/training.rs index 525b21c9..18f2aeeb 100644 --- a/web/src/pages/training.rs +++ b/web/src/pages/training.rs @@ -112,10 +112,10 @@ fn step_module_sql(coord_step: i64) -> String { fn step_span_sql(display_step: i64) -> String { format!( - "SELECT s.name, s.kind, round((e.time - s.time) / 1000000.0, 2) AS duration_ms \ + "SELECT s.name, s.phase, round((e.time - s.time) / 1000000.0, 2) AS duration_ms \ FROM python.trace_event s \ JOIN python.trace_event e ON s.span_id = e.span_id AND e.record_type = 'span_end' \ - WHERE s.record_type = 'span_start' AND s.kind != 'train.step' \ + WHERE s.record_type = 'span_start' AND s.name != 'train.step' \ AND s.attributes LIKE '%\"local_step\":{display_step}%' \ ORDER BY duration_ms DESC LIMIT 12" ) @@ -797,7 +797,7 @@ fn render_step_matrix_result( Card { title: "Step timings", EmptyState { - message: "No train.step spans yet. Wrap training loops with probing.span(..., kind='train.step') or enable TorchProbe.".to_string() + message: "No train.step spans yet. Enable phase hooks with probing.attach_training_phases(model, optimizer) or record train.step spans manually.".to_string() } } }, From 4593dfe7098666053e1b8e191383155c6f4a272e Mon Sep 17 00:00:00 2001 From: Reiase Date: Sun, 21 Jun 2026 12:04:18 +0800 Subject: [PATCH 2/5] Implement Drop trait for RealLib to ensure proper resource cleanup and remove shutdown function from public API. Update context ID info validation in msprof.rs for improved safety. --- probing/extensions/hccl-shim/src/forward.rs | 21 ++++++++++----------- probing/extensions/hccl-shim/src/lib.rs | 1 - probing/extensions/hccl-shim/src/msprof.rs | 2 +- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/probing/extensions/hccl-shim/src/forward.rs b/probing/extensions/hccl-shim/src/forward.rs index 98a495d8..21344186 100644 --- a/probing/extensions/hccl-shim/src/forward.rs +++ b/probing/extensions/hccl-shim/src/forward.rs @@ -37,6 +37,16 @@ struct RealLib { unsafe impl Send for RealLib {} +impl Drop for RealLib { + fn drop(&mut self) { + unsafe { + if !self.handle.is_null() { + libc::dlclose(self.handle); + } + } + } +} + static INIT: Lazy>> = Lazy::new(|| Mutex::new(None)); static INIT_FAILED: AtomicBool = AtomicBool::new(false); static LOGGED_INIT: AtomicBool = AtomicBool::new(false); @@ -237,14 +247,3 @@ pub fn forward_sys_cycle_time() -> u64 { } unsafe { stub_time() } } - -pub fn shutdown() { - let mut guard = INIT.lock(); - if let Some(real) = guard.take() { - unsafe { - if !real.handle.is_null() { - libc::dlclose(real.handle); - } - } - } -} diff --git a/probing/extensions/hccl-shim/src/lib.rs b/probing/extensions/hccl-shim/src/lib.rs index 53a4de83..6ce9c629 100644 --- a/probing/extensions/hccl-shim/src/lib.rs +++ b/probing/extensions/hccl-shim/src/lib.rs @@ -36,7 +36,6 @@ mod forward { pub fn forward_sys_cycle_time() -> u64 { 0 } - pub fn shutdown() {} } pub use tables::{ diff --git a/probing/extensions/hccl-shim/src/msprof.rs b/probing/extensions/hccl-shim/src/msprof.rs index 49ae7335..2bcfb574 100644 --- a/probing/extensions/hccl-shim/src/msprof.rs +++ b/probing/extensions/hccl-shim/src/msprof.rs @@ -155,7 +155,7 @@ pub fn read_hccl_op_info(data: *const u8, data_len: u32) -> Option Option { - if data.is_null() || (data_len as usize) < 8 { + if data.is_null() || (data_len as usize) < MSPROF_CONTEXT_ID_INFO { return None; } Some(unsafe { std::ptr::read_unaligned(data as *const MsprofContextIdInfo) }) From 083478671e15e341161f5bb1307a3b1b17d07431 Mon Sep 17 00:00:00 2001 From: Reiase Date: Sun, 21 Jun 2026 12:10:51 +0800 Subject: [PATCH 3/5] Enhance probing extensions with new libraries and update build configurations - Added `probing-hccl-profapi` and `probing-nccl-profiler-cdylib` as new packages in the workspace, including their dependencies in `Cargo.toml` and `Cargo.lock`. - Updated the Makefile to build the new libraries and adjusted targets for HCCL and NCCL profiling. - Refactored the `builtin-schema-docs` feature to include the new HCCL shim library. - Improved coverage reporting by excluding the new cdylib libraries from coverage checks in the Makefile and GitHub workflows. - Updated documentation and project structure to reflect the addition of new features and libraries. --- .github/workflows/test.yml | 4 +++- Cargo.lock | 21 ++++++++++++++++++ Cargo.toml | 4 +++- Makefile | 8 ++++--- probing/core/Cargo.toml | 4 ++-- probing/core/src/core/semantic_catalog.rs | 2 +- probing/extensions/hccl-profapi/Cargo.toml | 21 ++++++++++++++++++ probing/extensions/hccl-shim/Cargo.toml | 6 ++--- .../nccl-profiler-cdylib/Cargo.toml | 22 +++++++++++++++++++ probing/extensions/nccl-profiler/Cargo.toml | 3 ++- 10 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 probing/extensions/hccl-profapi/Cargo.toml create mode 100644 probing/extensions/nccl-profiler-cdylib/Cargo.toml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9d9d2a66..e56266be 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -122,7 +122,9 @@ jobs: cargo llvm-cov clean --workspace # Do not source show-env here — llvm-cov nextest sets instrumentation itself. # extension-module is for the Python wheel only; omit it in lib tests (PyO3 linking). - cargo llvm-cov nextest --workspace --no-default-features + # Cdylib-only plugin wrappers share sources with rlib crates; exclude from coverage. + cargo llvm-cov nextest --workspace --no-default-features \ + --exclude probing-hccl-profapi --exclude probing-nccl-profiler-cdylib cargo llvm-cov nextest -p probing-server --no-default-features --features kmsg,gpu,gpu-cuda cargo llvm-cov report --lcov --output-path coverage.lcov diff --git a/Cargo.lock b/Cargo.lock index a21e2bf1..651b58ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3396,6 +3396,16 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "probing-hccl-profapi" +version = "0.2.5" +dependencies = [ + "libc", + "once_cell", + "parking_lot 0.12.3", + "probing-memtable", +] + [[package]] name = "probing-hccl-shim" version = "0.2.5" @@ -3438,6 +3448,17 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "probing-nccl-profiler-cdylib" +version = "0.2.5" +dependencies = [ + "libc", + "once_cell", + "parking_lot 0.12.3", + "probing-memtable", + "thiserror 2.0.12", +] + [[package]] name = "probing-proto" version = "0.2.5" diff --git a/Cargo.toml b/Cargo.toml index 17f96177..33980206 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,9 @@ members = [ "probing/extensions/python", "probing/extensions/gpu", "probing/extensions/nccl-profiler", + "probing/extensions/nccl-profiler-cdylib", "probing/extensions/hccl-shim", + "probing/extensions/hccl-profapi", "probing/server", "probing/crates/store", ] @@ -114,7 +116,7 @@ pyo3-build-config = "0.29.0" [profile.dev] debug = 1 split-debuginfo = "unpacked" -codegen-units = 256 +codegen-units = 16 [profile.release] opt-level = "z" # Optimize for size. diff --git a/Makefile b/Makefile index 95ffe384..dcfb840c 100644 --- a/Makefile +++ b/Makefile @@ -169,7 +169,7 @@ else NCCL_OUT := target/release/libprobing_nccl_profiler.so endif nccl-profiler-lib: - cargo build -p probing-nccl-profiler $(CARGO_RELEASE) + cargo build -p probing-nccl-profiler-cdylib $(CARGO_RELEASE) mkdir -p python/probing/libs cp $(NCCL_OUT) python/probing/libs/ else @@ -185,7 +185,7 @@ else HCCL_SHIM_OUT := target/release/libprofapi.so endif hccl-shim-lib: - cargo build -p probing-hccl-shim $(CARGO_RELEASE) + cargo build -p probing-hccl-profapi $(CARGO_RELEASE) mkdir -p python/probing/shim/hccl cp $(HCCL_SHIM_OUT) python/probing/shim/hccl/ else @@ -258,7 +258,9 @@ clippy-fix: coverage-rust: cargo llvm-cov clean --workspace - cargo llvm-cov nextest --workspace --no-default-features --nff --lcov --output-path coverage.lcov --ignore-filename-regex '(.*/tests?/|.*/benches?/|.*/examples?/)' || true + cargo llvm-cov nextest --workspace --no-default-features --nff \ + --exclude probing-hccl-profapi --exclude probing-nccl-profiler-cdylib \ + --lcov --output-path coverage.lcov --ignore-filename-regex '(.*/tests?/|.*/benches?/|.*/examples?/)' || true coverage-python: ${PYTEST_RUN} --cov=python/probing --cov=tests --cov-report=xml:coverage.xml --cov-report=term $(PYTEST_ARGS) || true coverage: coverage-rust coverage-python diff --git a/probing/core/Cargo.toml b/probing/core/Cargo.toml index 3f5baad0..2795ac16 100644 --- a/probing/core/Cargo.toml +++ b/probing/core/Cargo.toml @@ -25,7 +25,7 @@ similar_names = "allow" [features] test-utils = [] default = ["builtin-schema-docs"] -builtin-schema-docs = ["dep:profapi", "dep:probing-nccl-profiler"] +builtin-schema-docs = ["dep:probing-hccl-shim", "dep:probing-nccl-profiler"] [lib] crate-type = ["rlib"] @@ -34,7 +34,7 @@ crate-type = ["rlib"] probing-proto = { path = "../proto" } probing-macros = { path = "../macros" } probing-memtable = { path = "../memtable" } -profapi = { path = "../extensions/hccl-shim", optional = true, package = "probing-hccl-shim" } +probing-hccl-shim = { path = "../extensions/hccl-shim", optional = true } probing-nccl-profiler = { path = "../extensions/nccl-profiler", optional = true } anyhow = { workspace = true } diff --git a/probing/core/src/core/semantic_catalog.rs b/probing/core/src/core/semantic_catalog.rs index 83676be3..a5c4e26f 100644 --- a/probing/core/src/core/semantic_catalog.rs +++ b/probing/core/src/core/semantic_catalog.rs @@ -85,7 +85,7 @@ fn column_key(table_schema: &str, table_name: &str, column_name: &str) -> (Strin pub fn register_builtin_schema_docs() { #[cfg(feature = "builtin-schema-docs")] { - profapi::register_docs(); + probing_hccl_shim::register_docs(); probing_nccl_profiler::register_docs(); } } diff --git a/probing/extensions/hccl-profapi/Cargo.toml b/probing/extensions/hccl-profapi/Cargo.toml new file mode 100644 index 00000000..6bc54c40 --- /dev/null +++ b/probing/extensions/hccl-profapi/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "probing-hccl-profapi" +description = "libprofapi.so cdylib — same sources as probing-hccl-shim, Linux HCCL dlopen name" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +name = "profapi" +crate-type = ["cdylib"] +path = "../hccl-shim/src/lib.rs" + +[dependencies] +probing-memtable = { path = "../../memtable" } + +once_cell = { workspace = true } +parking_lot = "0.12" + +[target.'cfg(target_os = "linux")'.dependencies] +libc = "0.2" diff --git a/probing/extensions/hccl-shim/Cargo.toml b/probing/extensions/hccl-shim/Cargo.toml index f5d36372..a56197d8 100644 --- a/probing/extensions/hccl-shim/Cargo.toml +++ b/probing/extensions/hccl-shim/Cargo.toml @@ -7,9 +7,9 @@ edition.workspace = true license.workspace = true [lib] -# Produces libprofapi.so on Linux (HCCL dlopen name). -name = "profapi" -crate-type = ["cdylib", "rlib"] +# rlib for workspace linking (probing-core schema docs, unit tests). +name = "probing_hccl_shim" +crate-type = ["rlib"] [dependencies] probing-memtable = { path = "../../memtable" } diff --git a/probing/extensions/nccl-profiler-cdylib/Cargo.toml b/probing/extensions/nccl-profiler-cdylib/Cargo.toml new file mode 100644 index 00000000..c3cbc4d1 --- /dev/null +++ b/probing/extensions/nccl-profiler-cdylib/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "probing-nccl-profiler-cdylib" +description = "NCCL profiler plugin cdylib — same sources as probing-nccl-profiler" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +name = "probing_nccl_profiler" +crate-type = ["cdylib"] +path = "../nccl-profiler/src/lib.rs" + +[dependencies] +probing-memtable = { path = "../../memtable" } + +once_cell = { workspace = true } +parking_lot = "0.12" +thiserror = { workspace = true } + +[target.'cfg(target_os = "linux")'.dependencies] +libc = "0.2" diff --git a/probing/extensions/nccl-profiler/Cargo.toml b/probing/extensions/nccl-profiler/Cargo.toml index 17298406..414514a1 100644 --- a/probing/extensions/nccl-profiler/Cargo.toml +++ b/probing/extensions/nccl-profiler/Cargo.toml @@ -8,7 +8,8 @@ license.workspace = true [lib] -crate-type = ["cdylib", "rlib"] +# rlib for workspace linking; cdylib plugin is probing-nccl-profiler-cdylib. +crate-type = ["rlib"] [dependencies] probing-memtable = { path = "../../memtable" } From 0e947bbad540bad5ecd925fc5874cc26fb126598 Mon Sep 17 00:00:00 2001 From: Reiase Date: Sun, 21 Jun 2026 16:45:30 +0800 Subject: [PATCH 4/5] Update SQL queries in documentation to use 'local_step' for consistency - Changed SQL queries in the quickstart and reference documentation to replace 'step' with 'local_step' for clarity and consistency across examples. - Updated related documentation in both English and Chinese to reflect these changes. - Adjusted descriptions in the concepts and SQL tables documentation to align with the new terminology. --- docs/src/design/modularity.md | 2 +- docs/src/guide/concepts.md | 2 +- docs/src/quickstart.md | 2 +- docs/src/quickstart.zh.md | 2 +- docs/src/reference/sql-tables.md | 6 +++--- docs/src/reference/sql-tables.zh.md | 6 +++--- .../core/test_table_docs_integration.py | 15 ++++----------- 7 files changed, 14 insertions(+), 21 deletions(-) diff --git a/docs/src/design/modularity.md b/docs/src/design/modularity.md index bed4bb91..66ea6c69 100644 --- a/docs/src/design/modularity.md +++ b/docs/src/design/modularity.md @@ -376,7 +376,7 @@ Track and fix incrementally: | Server → python REPL internals | ~~`PythonRepl` in server~~ | `/ws` uses `ReplSession` facade only | | Composition sprawl | All wiring in `server/engine.rs` | Optional: manifest TOML listing enabled extensions | | Skills triple loader | Rust + Python + Web embed `skills/` | Keep `skills/` SSOT; loaders versioned together in CI | -| kmsg collector | Implemented, not registered | Register in engine or delete | +| kmsg collector | Registered (Linux/kmsg feature gate) | Done | | Architecture doc | 2-layer diagram | Superseded by this doc + [Data Layer](data-layer.md) | --- diff --git a/docs/src/guide/concepts.md b/docs/src/guide/concepts.md index 12d41ce7..2ebc7489 100644 --- a/docs/src/guide/concepts.md +++ b/docs/src/guide/concepts.md @@ -72,7 +72,7 @@ single source of truth (not a separate Python counter). On data rows: -- `python.torch_trace.step` → local step +- `python.torch_trace.local_step` → per-rank step - `python.torch_trace.global_step`, `python.comm_collective.local_step` / `global_step` In-process: diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index 4d7661cf..83bcd6ea 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -68,7 +68,7 @@ probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER B probing $ENDPOINT eval "import gc, torch; gc.collect(); torch.cuda.empty_cache()" # Analyze allocation trends -probing $ENDPOINT query "SELECT step, AVG(allocated) as avg_memory FROM python.torch_trace GROUP BY step ORDER BY step" +probing $ENDPOINT query "SELECT local_step, AVG(allocated) as avg_memory FROM python.torch_trace GROUP BY local_step ORDER BY local_step" ``` ### Scenario 3: Performance Bottleneck Analysis diff --git a/docs/src/quickstart.zh.md b/docs/src/quickstart.zh.md index 597fb2c4..b0feb341 100644 --- a/docs/src/quickstart.zh.md +++ b/docs/src/quickstart.zh.md @@ -68,7 +68,7 @@ probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER B probing $ENDPOINT eval "import gc, torch; gc.collect(); torch.cuda.empty_cache()" # 分析分配趋势 -probing $ENDPOINT query "SELECT step, AVG(allocated) as avg_memory FROM python.torch_trace GROUP BY step ORDER BY step" +probing $ENDPOINT query "SELECT local_step, AVG(allocated) as avg_memory FROM python.torch_trace GROUP BY local_step ORDER BY local_step" ``` ### 场景 3:性能瓶颈分析 diff --git a/docs/src/reference/sql-tables.md b/docs/src/reference/sql-tables.md index 695f16eb..a0f89a51 100644 --- a/docs/src/reference/sql-tables.md +++ b/docs/src/reference/sql-tables.md @@ -1,7 +1,7 @@ # SQL Tables Authoritative catalog of built-in SQL tables queryable via `probing query` or in-process -`probing.query()`. Kept in sync with `python/probing/_skills/semantic/tables.yaml` (used by +`probing.query()`. Kept in sync with `skills/semantic/tables.yaml` (used by diagnostic skills and the Web Agent). Terminology: [Core Concepts](../guide/concepts.md) (endpoint, steps, `role`, federation). @@ -56,7 +56,7 @@ PyTorch module-level forward/step timings and GPU memory snapshots. | Column | Description | |--------|-------------| -| `step` | Local training step (per rank) | +| `local_step` | Local training step (per rank) | | `global_step` | Global step (`step_snapshot`) | | `rank` | `torch.distributed` rank | | `world_size` | World size | @@ -150,7 +150,7 @@ Variable snapshots when variable tracing is enabled. | Column | Description | |--------|-------------| -| `step` | Training step | +| `micro_step` | Training micro-step | | `func` | Function name | | `name` | Variable name | | `value` | String representation | diff --git a/docs/src/reference/sql-tables.zh.md b/docs/src/reference/sql-tables.zh.md index a38a6077..144f5ff0 100644 --- a/docs/src/reference/sql-tables.zh.md +++ b/docs/src/reference/sql-tables.zh.md @@ -1,7 +1,7 @@ # SQL 表目录 可通过 `probing query` 或进程内 `probing.query()` 查询的内置 SQL 表权威目录。 -与 `python/probing/_skills/semantic/tables.yaml` 保持同步(诊断 skill 与 Web Agent 使用)。 +与 `skills/semantic/tables.yaml` 保持同步(诊断 skill 与 Web Agent 使用)。 术语说明见 [核心概念](../guide/concepts.zh.md)。 @@ -55,7 +55,7 @@ PyTorch 模块级 forward/step 耗时与 GPU 显存快照。 | 列 | 说明 | |----|------| -| `step` | 本地训练步(每 rank) | +| `local_step` | 本地训练步(每 rank) | | `global_step` | 全局步(`step_snapshot`) | | `rank` | `torch.distributed` rank | | `world_size` | world size | @@ -144,7 +144,7 @@ Python + native 混合栈(**瞬时**,非历史全量)。 | 列 | 说明 | |----|------| -| `step` | 训练步 | +| `micro_step` | 训练 micro-step | | `func` | 函数名 | | `name` | 变量名 | | `value` | 字符串表示 | diff --git a/tests/regression/core/test_table_docs_integration.py b/tests/regression/core/test_table_docs_integration.py index 641ab31e..2a82e94b 100644 --- a/tests/regression/core/test_table_docs_integration.py +++ b/tests/regression/core/test_table_docs_integration.py @@ -7,22 +7,16 @@ import subprocess import sys import tempfile -from pathlib import Path import probing -def _project_root() -> Path: - return Path(__file__).resolve().parents[3] - - def _python_path_env(*, defer_engine_init: bool = False) -> dict[str, str]: env = os.environ.copy() - python_dir = str(_project_root() / "python") - env["PYTHONPATH"] = ( - f"{python_dir}:{env['PYTHONPATH']}" if env.get("PYTHONPATH") else python_dir - ) - env["PROBING"] = "1" + # Do not inherit PROBING=1: wheel probing.pth site-hook would import probing + # (and initialize the engine) before the subprocess script sets PROBING_CLI_MODE. + env.pop("PROBING", None) + env.pop("PROBING_ORIGINAL", None) if defer_engine_init: env["PROBING_CLI_MODE"] = "1" else: @@ -44,7 +38,6 @@ def _run_fresh_probing_script( import sys import tempfile -sys.path.insert(0, {repr(str(_project_root() / "python"))}) os.environ["PROBING"] = "1" {cli_mode_line} os.environ["PROBING_DATA_DIR"] = tempfile.mkdtemp(prefix="probing_doc_it_") From 2e32915b549e910cb91dd8639980b0f5a3d753bb Mon Sep 17 00:00:00 2001 From: Reiase Date: Sun, 21 Jun 2026 17:16:10 +0800 Subject: [PATCH 5/5] Update documentation for SQL analytics and add new environment variable references - Revised SQL queries in the README and quickstart documentation to enhance clarity and consistency, particularly focusing on the use of 'local_step'. - Expanded the documentation to include new sections on environment variables, detailing their usage and configuration for Probing. - Updated both English and Chinese documentation to reflect these changes, ensuring accessibility for a wider audience. - Improved the overall structure and navigation of the documentation to facilitate easier access to key information. --- README.md | 72 +++--- docs/mkdocs.yml | 5 + docs/src/design/architecture.md | 17 +- docs/src/guide/concepts.md | 305 ++++++++++++++++---------- docs/src/guide/concepts.zh.md | 184 ++++++++++++++-- docs/src/index.md | 102 +++++++-- docs/src/index.zh.md | 97 +++++++-- docs/src/quickstart.md | 185 +++++++++++----- docs/src/quickstart.zh.md | 164 +++++++++----- docs/src/reference/env-vars.md | 147 +++++++++++++ docs/src/reference/index.md | 2 + docs/src/reference/index.zh.md | 2 + docs/src/reference/skill-format.md | 326 ++++++++++++++++++++++++++++ docs/src/reference/sql-tables.md | 43 ++-- docs/src/reference/sql-tables.zh.md | 35 +-- 15 files changed, 1352 insertions(+), 334 deletions(-) create mode 100644 docs/src/reference/env-vars.md create mode 100644 docs/src/reference/skill-format.md diff --git a/README.md b/README.md index 2cbccb1a..5c8d8194 100644 --- a/README.md +++ b/README.md @@ -113,24 +113,20 @@ probing list ### SQL Analytics Interface ```bash -# Memory usage analysis -probing -t query "SELECT * FROM memory_usage WHERE timestamp > now() - interval '5 min'" - -# Performance hotspot analysis +# GPU memory trend across training steps probing -t query " - SELECT operation_name, avg(duration_ms), count(*) - FROM profiling_data - WHERE timestamp > now() - interval '5 minutes' - GROUP BY operation_name - ORDER BY avg(duration_ms) DESC + SELECT local_step, AVG(allocated) as avg_mb + FROM python.torch_trace + GROUP BY local_step ORDER BY local_step " -# Training progress tracking +# Find the slowest collectives probing -t query " - SELECT epoch, avg(loss), min(loss), count(*) as steps - FROM training_logs - GROUP BY epoch - ORDER BY epoch + SELECT op, AVG(duration_ms) as avg_ms, COUNT(*) as calls + FROM python.comm_collective + GROUP BY op + ORDER BY avg_ms DESC + LIMIT 5 " ``` @@ -160,17 +156,23 @@ The REPL provides: ### Distributed Training Analysis ```bash -# Monitor all cluster nodes -probing cluster attach - -# Inter-node communication latency -probing -t query "SELECT src_rank, dst_rank, avg(latency_ms) FROM comm_metrics" - -# Cross-node stack trace comparison -probing -t query "SELECT * FROM python.backtrace" +# See all registered cluster nodes +probing -t cluster nodes + +# Cross-rank communication analysis via federation +probing -t query " + SELECT _role, _rank, op, AVG(duration_ms) as avg_ms + FROM global.python.comm_collective + GROUP BY _role, _rank, op + ORDER BY avg_ms DESC + LIMIT 10 +" -# GPU utilization analysis -probing -t query "SELECT avg(gpu_util) FROM gpu_metrics WHERE timestamp > now() - 60" +# GPU utilization across devices +probing -t query " + SELECT ts, mem_used_pct, gpu_util_pct + FROM gpu.utilization ORDER BY ts DESC LIMIT 20 +" ``` ### Memory Analysis @@ -178,16 +180,20 @@ probing -t query "SELECT avg(gpu_util) FROM gpu_metrics WHERE timestamp > # Quick memory usage overview probing -t memory -# Memory growth trend analysis -probing -t query "SELECT hour(timestamp), avg(memory_mb) FROM memory_usage GROUP BY hour(timestamp)" - -# Memory leak detection +# Memory growth trend across steps probing -t query " - SELECT function_name, sum(allocated_bytes) as total_alloc - FROM memory_allocations - WHERE timestamp > now() - interval '1 hour' - GROUP BY function_name - ORDER BY total_alloc DESC + SELECT local_step, AVG(allocated_delta) as delta_mb + FROM python.torch_trace + GROUP BY local_step + ORDER BY local_step +" + +# Check current CPU/GPU memory via eval +probing -t eval " +import torch, gc; gc.collect() +alloc = torch.cuda.memory_allocated()/1024**2 +reserved = torch.cuda.memory_reserved()/1024**2 +print(f'GPU alloc: {alloc:.0f}MB, reserved: {reserved:.0f}MB') " ``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index a03af7a2..4ae07856 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -110,6 +110,9 @@ plugins: API Reference: API 参考 Reference: 参考手册 SQL Tables: SQL 表目录 + API Reference: API 参考 + Environment Variables: 环境变量 + Skill Format: Skill 格式规范 Versions: 版本兼容性 Contributing: 贡献指南 - mkdocstrings: @@ -156,6 +159,8 @@ nav: - reference/index.md - SQL Tables: reference/sql-tables.md - API Reference: api-reference.md + - Environment Variables: reference/env-vars.md + - Skill Format: reference/skill-format.md - Versions: versions.md - Contributing: contributing.md diff --git a/docs/src/design/architecture.md b/docs/src/design/architecture.md index c565bf33..49d5379c 100644 --- a/docs/src/design/architecture.md +++ b/docs/src/design/architecture.md @@ -160,14 +160,15 @@ probing -t host:port query "..." ## Security Considerations - **Local mode**: Unix socket permissions (process owner only) -- **Remote mode**: Optional authentication -- **Network**: Support for TLS encryption +- **Remote mode**: Optional authentication token via `PROBING_TOKEN` ## Performance Characteristics -| Aspect | Target | -|--------|--------| -| Overhead | < 5% in typical workloads | -| Memory | < 50MB additional | -| Latency | < 10ms for queries | -| Throughput | 1000+ queries/sec | +Probing is designed for minimal overhead on training workloads: + +| Aspect | Design approach | +|--------|----------------| +| Overhead | Lock-free mmap writes on hot path; sampling in background threads | +| Memory | Fixed-size ring buffers per table (MEMT); bounded by retention config | +| Latency | Queries execute against in-process DataFusion; no network round-trip for local access | +| Throughput | Columnar (Arrow) scan; `information_schema` for introspection | diff --git a/docs/src/guide/concepts.md b/docs/src/guide/concepts.md index 2ebc7489..95f40cbc 100644 --- a/docs/src/guide/concepts.md +++ b/docs/src/guide/concepts.md @@ -1,166 +1,251 @@ -# Core Concepts +# How Probing Works -One-page glossary for terms used across tutorials, guides, and design docs. -When in doubt, start here before diving into [SQL Analytics](sql-analytics.md) or -[Distributed](../design/distributed.md). +This page builds the mental model you need to use Probing effectively. It's not a +reference — it explains the architecture, the data flow, and the key design decisions +from a user's perspective. Read this before anything else. -## 1. Endpoint +## The two modes: in-process vs attach -Every **CLI** command targets a running probing server via an **endpoint**: +Probing works in two fundamentally different ways. Understanding the difference avoids +a lot of confusion. -| Form | Example | Notes | -|------|---------|-------| -| Local PID | `12345` | `probing -t 12345 query "…"` | -| Host:port | `node-a:8080` | Remote TCP; set `PROBING_PORT` at training startup | +**In-process mode** is what you get with `PROBING=1 python train.py`. A `.pth` hook +fires at Python startup, and before your training script even runs its first line, an +embedded HTTP server is listening on a Unix socket and the Rust engine has registered +all its data sources. Your training code calls `import probing` and uses +`probing.query()` directly — no network, no subprocess. ```bash -export ENDPOINT=12345 # or host:8080 -probing $ENDPOINT query "SELECT 1" +PROBING=1 python train.py ``` -**In-process** (training script): set `PROBING=1` (or inject on Linux) and `import probing` -— no endpoint string; use `probing.query()` directly. +**Attach mode** is what you get with `probing -t inject` — Linux only. The CLI +uses ptrace to load the probing shared library into an already-running process. The +process never restarted; it never imported probing. After injection, the same +in-process server starts inside the target. -There is **no** `probing.connect()` Python API. Remote access is always CLI `-t `. +```bash +probing -t $(pgrep -f train.py) inject +probing -t $(pgrep -f train.py) query "SELECT 1" +``` + +In both cases, the end result is the same: a running probing server inside the target +process. The CLI always talks to that server over HTTP (Unix socket for local PIDs, +TCP for remote `host:port`). There's no magic — `probing -t query "..."` is +literally an HTTP POST to `/query` on the embedded server. + +On macOS and Windows, only in-process mode is available. The `inject` command is +Linux-only because it depends on ptrace. + +## Where data comes from + +Probing doesn't poll. Data is pushed into tables when events happen. ---- +Every table you query is backed by one of two storage mechanisms: -## 2. Three CLI commands +**Mmap ring buffers (MEMT).** Most `python.*` tables and all extension tables use +this. A fixed-size memory-mapped file, structured as a ring of chunks. The writer +(the training process) appends rows to the current chunk. Readers (the SQL engine) +scan from any position. This is lock-free on the read path — the writer uses atomic +operations with Release/Acquire ordering, and readers never block the writer. -| Command | Usage | Data | -|---------|-------|------| -| **query** | `probing $ENDPOINT query ""` | Recorded table rows | -| **eval** | `probing $ENDPOINT eval ""` | One-off Python in the target process | -| **backtrace** | `probing $ENDPOINT backtrace` | Point-in-time stack → `python.backtrace` | +**Registered data sources (ProbeDataSource).** Some tables come from Rust code that +implements the `ProbeDataSource` trait. `python.backtrace` is a virtual table — it +captures the current stack on demand rather than storing a history. `process.envs` +reads the process environment. `gpu.devices` enumerates CUDA devices. -These are the main **CLI entry points** from outside the process — not the full product -surface (continuous profiling, federation, `global.*`, skills, etc. are covered below). -Typical flow: `backtrace` captures state → `eval` inspects live objects → `query` analyzes history. +The distinction matters because it affects what you can query: ---- +| Table | Mechanism | What you get | +|---|---|---| +| `python.torch_trace` | mmap ring buffer | History of all hook invocations since sampling started | +| `python.comm_collective` | mmap ring buffer | History of all collectives since sampling started | +| `python.backtrace` | Virtual table | **Current** stack only — no history | +| `cpu.utilization` | mmap ring buffer | Time-series of CPU samples | +| `process.envs` | Virtual table | Current environment variables | -## 3. Data tables (`python.*`) +This is why `backtrace` requires a separate command (`probing backtrace`) — it +captures the stack into the virtual table, and then you query it. The table doesn't +accumulate data on its own. -Probe data lives in **append-only SQL tables** under the `python` schema (plus built-in -extensions like `cpu.utilization`, `cluster.nodes`, `nccl.proxy_ops`). +## How tables are organized -| Table | What it records | -|-------|-----------------| -| `python.torch_trace` | Module hook timings + GPU memory | -| `python.comm_collective` | `torch.distributed` collective wall time | -| `python.trace_event` | Span start/end and custom events | -| `python.backtrace` | Latest captured stack (not a full history) | -| `python.variables` | Watched variable snapshots (when enabled) | +All tables live under a schema prefix that tells you where the data came from: -Custom plugins use the same model: `@table` dataclass + `.save()` → `python.`. -Column reference: **[SQL Tables](../reference/sql-tables.md)**. +`python.*` +: Training semantics — `torch_trace` (module hooks), `comm_collective` (distributed +ops), `trace_event` (spans), `backtrace` (stacks), `variables` (watched values). +Plus any custom tables you create with `@table`. -Tables are **not** lazy snapshots — rows are pushed when events happen (hook, collective, -span end). +`cpu.*` / `gpu.*` +: Host and device sampling. `cpu.utilization` has per-process and per-thread CPU/RSS +samples. `gpu.utilization` has GPU memory and compute utilization. ---- +`cluster.*` +: Cluster node registry. `cluster.nodes` lists every peer that has registered via +torchrun or the HTTP API. -## 4. Step coordinates +`nccl.*` +: NCCL profiler plugin output. `nccl.proxy_ops` decomposes collective wait time into +culprit (local GPU not ready) and victim (waiting on peer data) components. -Training analysis needs a **shared step index**. Probing uses Rust `step_snapshot()` as the -single source of truth (not a separate Python counter). +`global..` +: Federation. Prefix any table with `global.` to fan out the query to all registered +cluster peers. The master merges results and attaches `_host`, `_addr`, `_rank`, +`_role` tags to each row. -| Field | Meaning | -|-------|---------| -| `local_step` | Per-rank step counter (optimizer-step aligned) | -| `global_step` | Cluster-wide step (when coordinated) | +The full column reference for every built-in table is at [SQL Tables](../reference/sql-tables.md). -On data rows: +## Step coordinates: the shared time axis -- `python.torch_trace.local_step` → per-rank step -- `python.torch_trace.global_step`, `python.comm_collective.local_step` / `global_step` +Timestamps are unreliable for training analysis — steps are deterministic and align +naturally with training semantics. Probing uses a three-level step coordinate system: -In-process: +`micro_step` is the finest counter. Increments each time `probing.step()` is called. +`local_step` is the optimizer step — `micro_step // micro_batches`. With gradient +accumulation of 10, every 10 micro-steps produce one local step. +`global_step` is the cluster-wide step, equal to `local_step` when ranks are aligned. ```python -from probing.tracing import step_snapshot -s = step_snapshot() -print(s.local_step, s.global_step, s.rank) -``` +probing.step(micro_batches=10) # gradient accumulation factor +probing.step() # micro_step += 1 at each micro-batch boundary -Prefer these fields in SQL and skills — not `trainer.current_step`. +# Later, in queries: +# SELECT local_step, AVG(duration) FROM python.torch_trace GROUP BY local_step +``` ---- +Every training-related table (`torch_trace`, `comm_collective`) carries both +`local_step` and `global_step` columns. The step coordinates are managed in Rust +(not a Python counter) so they're consistent even if the training script's state +gets corrupted. -## 5. Parallel role +## Role: encoding parallel topology in one column -Distributed training places each process in a **parallel topology** (TP / PP / DP / EP / …). -Probing encodes this as one extensible string **`role`**, not one column per dimension. +Distributed training decomposes work across multiple parallelism dimensions — +tensor parallel (TP), pipeline parallel (PP), data parallel (DP), expert parallel +(EP), and combinations of them. Probing encodes the entire topology as a compact +sorted string: `dp=2,pp=1,tp=0`. -**Format:** sorted `name=value` pairs, e.g. `dp=2,pp=1,tp=0`. Empty string when unset. +This string is stamped on every `torch_trace` and `comm_collective` row. Because +it's a single column, you can GROUP BY role in SQL to compare performance across +parallelism dimensions: -| Source | How | -|--------|-----| -| Environment | Megatron-style `*_PARALLEL_RANK`, or `PROBING_ROLE_=` | -| Runtime | `probing.set_role("dp=2,pp=1,tp=0")` or `set_role(dp=2, pp=1)` | -| Read | `probing.current_role()`; `clear_role()` reverts to env | +```sql +SELECT role, AVG(duration_ms) as avg_ms +FROM python.comm_collective +GROUP BY role; +``` -`role` is stamped on **`python.torch_trace`** and **`python.comm_collective`** rows so you -can `JOIN` / `GROUP BY role` across tables on one rank. +Set it from environment variables (`PROBING_TP_RANK`, `PROBING_TP_SIZE`, etc., or +Megatron-style `TP_RANK`/`PP_RANK`/`DP_RANK`) or from Python: -Distinct from torchrun's **`role_name`** / `role_rank` on `cluster.nodes` — those are -Elastic/job launcher fields. Probing's `role` is the parallel-placement key for analytics. +```python +probing.set_role(dp=2, pp=1, tp=0) +# or: probing.set_role("dp=2,pp=1,tp=0") +probing.clear_role() # fall back to environment-derived role +``` ---- +Note: `cluster.nodes` has `role_name` and `role_rank` — those are torchrun/Elastic +launcher fields describing the launcher role, not the parallel topology. Probing's +`role` is the parallelism key for analytics; they serve different purposes. -## 6. Federation (`global.*` and tags) +## Federation: querying across the cluster -For **multi-rank** SQL, use the `global` catalog: `global.python.comm_collective` fans out -to registered peers and merges results. +When multiple training ranks are registered in a cluster, prefixing a table with +`global.` fans out the SQL query to every registered peer. The query runs +independently on each node; the master collects and concatenates the results. -Each row gets **federation tags** identifying the source probing endpoint: +Each returned row gets four federation tags added at query time: -| Tag | Meaning | -|-----|---------| -| `_host` | Source hostname | -| `_addr` | Source `host:port` | -| `_rank` | `torch.distributed` rank (from node registry) | -| `_role` | Parallel role key (from node registry / `set_role`) | +| Tag | Source | +|-----|--------| +| `_host` | Hostname of the node that produced the row | +| `_addr` | That node's probing `host:port` | +| `_rank` | `torch.distributed` rank from the cluster node registry | +| `_role` | Parallel role key from the node registry | -Example: +A query like this: ```sql -SELECT _role, _rank, avg(duration_ms) AS avg_ms +SELECT _rank, op, AVG(duration_ms) as avg_ms FROM global.python.comm_collective -WHERE global_step > 100 -GROUP BY _role, _rank +WHERE local_step > 100 +GROUP BY _rank, op ORDER BY avg_ms DESC; ``` -Register nodes via torchrun (`setup_torchrun_cluster`) or `PUT /apis/nodes`. CLI: -`probing -t cluster nodes` / `cluster query "…"`. Details: -[Distributed](../design/distributed.md). +...runs on every registered rank, then the master merges all results into one result +set with `_rank` telling you which row came from where. + +Nodes register via torchrun (`setup_torchrun_cluster`) or by POSTing to +`/apis/nodes`. Check current registration with `probing -t cluster nodes`. + +The `_role` tag uses the value from the **node registry**, which is kept in sync +with calls to `set_role()`. The `role` column on individual data rows uses the value +at **write time**. In practice these are the same, but understanding the distinction +matters when diagnosing stale role data. + +## Extension paths: three ways to add capability + +Probing has three distinct extension mechanisms. They're not interchangeable — each +serves a different purpose: + +**1. Data table plugin (`@table` in Python).** +Define a dataclass, decorate it, append rows from your code. The table appears as +`python.` and is immediately queryable. Use this when you have new metrics or +events to record. Built on mmap ring buffers. + +**2. Diagnostic skill (`steps.yaml` + `SKILL.md`).** +A YAML workflow that runs SQL queries against existing tables, applies interpretation +rules, and produces findings. Use this when you have a diagnosis recipe to codify — +like "find the slowest rank" or "check for NCCL wait imbalance." Run with +`probing skill run `. Skills don't collect new data; they analyze existing data. + +**3. Rust extension (`ProbeExtension` + `ProbeDataSource`).** +A compiled Rust crate that registers new data sources (virtual tables, mmap tables) +and/or configurable options with side effects (like starting a CPU sampler). Use +this when you need system-level access — ptrace, CUDA APIs, RDMA counters. The NCCL +profiler is a variant of this: a C ABI plugin loaded by NCCL itself, not by Probing. + +See [Extensibility](../design/extensibility.md) for the full development guide. + +## What happens when things go wrong + +Knowing what to expect when a query fails saves debugging time. + +**Invalid SQL** returns a `PyRuntimeError` with the DataFusion error message. The +error usually tells you exactly what's wrong (unknown table, unknown column, syntax +error). If you get a cryptic DataFusion error, check the server logs — set +`PROBING_LOGLEVEL=debug` for verbose output. -Row column `role` = value at **write time** on that rank. Tag `_role` = value on the -**node registry** at federation time (kept in sync via `set_role` + re-register). +**Missing tables.** If `probing $ENDPOINT tables` doesn't show the table you expect, +the data source isn't active. Common causes: the GPU extension isn't compiled in +(check `probing $ENDPOINT query "SELECT name FROM information_schema.df_settings WHERE name LIKE 'probing.gpu%'"`), or the mmap file isn't being written to (check +`$PROBING_DATA_DIR//`). ---- +**Empty results from `python.torch_trace`.** The PyTorch profiler needs +`PROBING_TORCH_PROFILING=on` at startup. Without it, hooks are never registered and +no rows are written. -## 7. Data plugin vs diagnostic skill +**Injection fails with ESRCH.** On Linux, ptrace may be restricted by YAMA LSM +(`/proc/sys/kernel/yama/ptrace_scope`). Set it to 0 or run as the same user. -| | **Table plugin** (Path 1) | **Diagnostic skill** (Path 2) | -|--|---------------------------|----------------------------------| -| You add | Dataclass table + rows | `SKILL.md` + optional `steps.yaml` | -| Output | `python.my_table` | Findings + SQL steps / agent guidance | -| Run | `SELECT …` | `probing skill run ` | -| Use when | New **metrics/events** to store | New **investigation recipe** | +See [Troubleshooting](troubleshooting.md) for more. -Optional **Path 3**: NCCL profiler cdylib → `nccl.proxy_ops` for culprit/victim wait -decomposition. See [Extensibility](../design/extensibility.md). +## Environment variables at a glance ---- +Probing has many configuration points. The most important ones: -## Where to go next +| Variable | Effect | +|----------|--------| +| `PROBING` | `0`=disabled, `1`=current process, `2`=current+children, `regex:...`=pattern match | +| `PROBING_TORCH_PROFILING` | `on` to activate PyTorch module hooks | +| `PROBING_DATA_DIR` | Where mmap ring buffers are stored | +| `PROBING_PORT` | TCP port for remote access (or `RANDOM`) | +| `PROBING_AUTH_TOKEN` | Authentication token for remote mode | +| `PROBING_CPU_SAMPLE_MS` | CPU sampling interval in milliseconds (0=off) | +| `PROBING_GPU_SAMPLE_MS` | GPU sampling interval in milliseconds | +| `PROBING_SPAN_BACKENDS` | Comma-separated: `memtable`, `logger`, `otel` | +| `PROBING_LOGLEVEL` | `trace`, `debug`, `info`, `warn`, `error` | -| Goal | Doc | -|------|-----| -| SQL patterns | [SQL Analytics](sql-analytics.md) | -| Table schemas | [SQL Tables](../reference/sql-tables.md) | -| Multi-node | [Distributed](../design/distributed.md) | -| Write a plugin | [Extensibility](../design/extensibility.md) | -| CLI / Python API | [API Reference](../api-reference.md) | +The complete reference is at [Environment Variables](../reference/env-vars.md). diff --git a/docs/src/guide/concepts.zh.md b/docs/src/guide/concepts.zh.md index c558234e..6f381b79 100644 --- a/docs/src/guide/concepts.zh.md +++ b/docs/src/guide/concepts.zh.md @@ -1,23 +1,179 @@ -## 4. Step 坐标 +# 核心概念 -训练分析使用三级 step 索引,权威来源是 Rust 坐标(通过 ``probing.step`` 访问)。 +逐步构建 Probing 的心理模型——endpoint 是什么、数据如何流入表、step 坐标如何 +组织状态、以及联邦查询如何跨节点工作。在深入 [SQL 分析](sql-analytics.zh.md)或 +[分布式](../design/distributed.zh.md)之前请先阅读本文。 -| 字段 | API | 含义 | -|------|-----|------| -| `micro_step` | `probing.step.micro_step` | 最细计数;每次 ``probing.step()`` 或 ``train.step`` span 结束 +1 | -| `local_step` | `probing.step.local_step` | 训练步(每 rank):``micro_step // micro_batches`` | -| `global_step` | `probing.step.global_step` | 与 ``local_step`` 相同(rank 对齐时即集群训练步) | -| `micro_batches` | `probing.step(micro_batches=k)` | 梯度累积倍数:每 k 个 micro_step 合成 1 个 local/global step | +## Endpoint:如何访问 probing 服务器 + +每个启用 probing 的进程都运行一个嵌入式 HTTP 服务器。从外部与其通信需要一个 +**endpoint**——可以是本地 PID 或 `host:port` 对。 + +```bash +# 本地进程——probing 从 PID 解析 Unix socket +probing -t 12345 query "SELECT 1" + +# 远程进程——TCP 连接到已知地址 +probing -t node-a:8080 query "SELECT 1" +``` + +CLI 从不直接与引擎交互。它通过 Unix socket(本地)或 TCP(远程)向目标进程 +中嵌入的服务器发送 HTTP 请求。不存在 `probing.connect()` Python API—— +远程访问始终通过 CLI 的 `-t` 参数。 + +在训练脚本内部(in-process 模式),则完全跳过 CLI,直接调用 `probing.query()`。 +引擎已经在同一进程中运行。 + +启动时设置 `PROBING=1` 通过 `.pth` 钩子激活进程内服务器——无需 import、无需 +修改代码。在 Linux 上,`probing inject` 也可以通过 ptrace 附着到已运行的进程。 + +概念上: + +``` +CLI ──(HTTP over Unix socket/TCP)──▶ probing server(目标进程内) + │ + ├── Engine(DataFusion) + ├── Config + └── Extensions(CPU、GPU、Python、NCCL...) +``` + +## 数据表:只追加、持续写入 + +Probing 将性能数据存储在 mmap 环形缓冲支持的**只追加 SQL 表**中。事件发生时 +写入——模块 hook 触发、collective 完成、span 结束。不存在轮询,不按需快照。 +数据已经在那里了。 + +最重要的表位于 `python` schema 下,因为这里是训练语义所在: + +`python.torch_trace` +: 模块级 forward/backward/step hook 耗时与 GPU 显存快照。每次 hook 触发一行。 +列包括 `local_step`、`module`、`stage`、`duration`、`allocated`。 + +`python.comm_collective` +: 每个 `torch.distributed` 集合调用(all_reduce、broadcast、all_gather 等)。 +记录 `op`、`tensor_shape`、`bytes`、`duration_ms` 以及进程组上下文。 + +`python.trace_event` +: Span 起止事件和自定义 trace 事件。在 `span_id` 上 join `span_start` / `span_end` +计算耗时。Span 可以嵌套——forward pass span 包含多个 layer span。 + +`python.backtrace` +: 最新捕获的调用栈,混合 Python 和原生帧。**瞬时数据**,不是历史全量。 +先用 `probing backtrace` 填充,再查询。用于 hang 诊断。 + +`python.variables` +: 变量快照(需显式启用变量追踪)。轻量级:值以字符串方式存储,不序列化。 + +除 `python.*` 之外,主机级数据在 `cpu.*` 和 `gpu.*`,集群元数据在 `cluster.*`, +NCCL profiler 输出在 `nccl.*`。完整的列定义见 [SQL 表目录](../reference/sql-tables.zh.md)。 + +自定义表遵循相同模型。用 `@table("my_metrics")` 定义 dataclass,在训练循环中 +追加行,表即显示为 `python.my_metrics`——与内置表一同查询。见 [扩展机制](../design/extensibility.zh.md)。 + +## Step 坐标:训练分析的共享索引 + +训练分析中,关联不同表的数据需要一个共享的时间轴。Probing 使用 **step 坐标** +而非时间戳——它们具有确定性,且与训练语义天然对齐。 + +有三个层级: + +**micro_step** —— 最细粒度计数器。每次调用 `probing.step()` 加一。 +**local_step** —— optimizer step。`micro_step // micro_batches`。 +**global_step** —— 集群范围的 step。rank 对齐时等同于 `local_step`。 ```python import probing -probing.step(micro_batches=10) # 10 个 micro-batch → 1 个 training step -probing.step() # micro_step +1 -probing.step(42) # 设置 micro_step -print(probing.step.micro_step, probing.step.local_step, probing.step.global_step) +probing.step(micro_batches=10) # 10 个 micro-batch = 1 个 optimizer step +probing.step() # micro_step += 1 +probing.step(42) # 直接设置 micro_step + +print(probing.step.micro_step) # 原始计数器 +print(probing.step.local_step) # micro_step // 10 +print(probing.step.global_step) # 集群 step ``` -SQL 表(``python.comm_collective``、``python.torch_trace``、span attributes)统一使用上述字段名。 +所有训练相关表都携带 step 列:`python.torch_trace` 有 `local_step` 和 `global_step`; +`python.comm_collective` 同样具备,外加 `group_rank` 和 `group_size`。编写查询时, +用 `local_step` 或 `global_step` 过滤——不要用 `trainer.current_step`。 + +## Role:编码并行拓扑 + +分布式训练将每个 rank 置于并行拓扑中——tensor parallel、pipeline parallel、data +parallel、expert parallel 或它们的组合。Probing 将其编码为一个紧凑字符串,而非 +每个维度一列。 + +格式为排序的 `name=value` 对:`dp=2,pp=1,tp=0`。空字符串表示 role 未设置。 + +可通过环境变量(Megatron 风格的 `*_PARALLEL_RANK` 或 `PROBING_ROLE_=`) +或 Python 设置: + +```python +import probing +probing.set_role(dp=2, pp=1, tp=0) +# 或: probing.set_role("dp=2,pp=1,tp=0") + +print(probing.current_role()) # "dp=2,pp=1,tp=0" +probing.clear_role() # 恢复为环境变量默认值 +``` + +`role` 被标记在所有 `python.torch_trace` 和 `python.comm_collective` 行上。 +这让你可以 GROUP BY role 来对比,比如跨所有 data-parallel 副本比较 TP rank 0 +和 TP rank 1。需区分 torchrun 的 `role_name` / `role_rank`(`cluster.nodes` 上 +的字段)——这些是 launcher 字段;`role` 是用于分析的 key。 + +## 联邦查询:跨节点查询 + +当多个 rank 注册到集群时,`global..
` 将查询 fan-out 到每个节点 +并合并结果。查询在每个节点上独立执行;master 收集并拼接。 + +每个返回行附带四个联邦标签: + +`_host` +: 生成行的 probing 节点主机名。 + +`_addr` +: 该节点的 `host:port` probing 地址。 + +`_rank` +: 来自集群节点注册的 `torch.distributed` rank。 + +`_role` +: 来自节点注册的并行 role key(与 `set_role` 保持同步)。 + +典型的联邦查询: + +```sql +SELECT _role, _rank, op, AVG(duration_ms) AS avg_ms +FROM global.python.comm_collective +WHERE local_step > 100 +GROUP BY _role, _rank, op +ORDER BY avg_ms DESC; +``` + +通过 torchrun(`setup_torchrun_cluster`)或 POST `/apis/nodes` 注册节点。 +用 `probing -t cluster nodes` 验证。详见 [分布式](../design/distributed.zh.md)。 + +## 表插件 vs 诊断 skill + +Probing 有两条扩展路径,选择哪条取决于你的目的: + +**表插件**添加新数据——用 `@table` 定义 dataclass,从代码追加行,数据以 +`python.` 出现。用于存储和查询新的指标或事件。 + +**诊断 skill** 添加新分析——`SKILL.md` 文件配合可选的 `steps.yaml`,描述基于 +现有表的诊断工作流。Skill 运行 SQL 查询、应用解释规则、产出诊断结论。用于 +沉淀排查方案。用 `probing skill run ` 运行。 + +NCCL profiler 是第三条路径——编译好的 cdylib,写入 `nccl.proxy_ops` 用于 +culprit/victim 等待分解。它是 Rust 扩展,而非 Python @table。见 [NCCL Profiler](../design/nccl-profiler.zh.md)。 + +## 下一步 -SQL 与 skill 请用 ``local_step`` / ``global_step`` 做训练步过滤,**不要**用 ``trainer.current_step``。 +| 目标 | 文档 | +|------|------| +| 训练分析的 SQL 模式 | [SQL 分析](sql-analytics.zh.md) | +| 每张表的列定义 | [SQL 表目录](../reference/sql-tables.zh.md) | +| 多节点配置和 torchrun | [分布式](../design/distributed.zh.md) | +| 编写自定义表或 skill | [扩展机制](../design/extensibility.zh.md) | +| CLI 命令和 Python API | [API 参考](../api-reference.zh.md) | diff --git a/docs/src/index.md b/docs/src/index.md index 5a8f2e91..422f7719 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,39 +1,99 @@ --- template: home.html title: Probing - Dynamic Performance Profiler for Distributed AI -description: Zero-intrusion profiler for distributed AI — SQL tables, live introspection, cluster federation, and diagnostic skills. +description: Attach to running Python training processes, query performance data with SQL, and diagnose distributed issues across cluster nodes. hide: toc --- # Probing -**Probing** profiles distributed AI training: continuous SQL tables, live attach, federated -`global.*` queries, and bundled diagnostic skills. +Probing lets you inspect and debug distributed AI training jobs — without modifying your code. +Attach to a running Python process, query performance data through standard SQL, and run +diagnostic workflows to find slow ranks, NCCL bottlenecks, or memory leaks. -## Capabilities - -- **Continuous profiling** — `torch_trace`, `comm_collective`, NCCL proxy, custom `@table` -- **Live introspection** — `eval`, `backtrace`, REPL against running processes -- **SQL analytics** — single-node and `global.*` federation with `_rank` / `_role` tags -- **Diagnostic skills** — `health_overview`, `slow_rank`, `nccl_culprit_victim`, … -- **Cluster** — `cluster nodes`, `cluster query`, Web UI agent -- **Zero intrusion** — `PROBING=1` at startup or Linux `inject` - -## Quick Start +## In 30 seconds ```bash pip install probing -# Recommended for training -PROBING=1 PROBING_TORCH_PROFILING=on python train.py +# Start training with probing enabled +PROBING=1 python train.py & -# Or attach on Linux -probing -t inject -probing -t query "SELECT * FROM python.torch_trace LIMIT 10" -probing -t skill run health_overview +# Attach and inspect +probing -t $(pgrep -f train.py) backtrace +probing -t $(pgrep -f train.py) query " + SELECT module, stage, AVG(duration) as sec + FROM python.torch_trace + GROUP BY module, stage + ORDER BY sec DESC LIMIT 5 +" ``` +The first command starts a training job with probing activated. The second captures +every Python and native frame on the main thread. The third finds the five slowest +module-stage pairs with a single SQL query — no logging, no instrumentation, no restart. + +## What you can do with it + +**Debug a hanging or slow training job.** +Attach to the stuck process, grab a backtrace, check GPU memory per step, and identify +the exact module or collective call that's blocking. No need to reproduce the issue. + +**Profile collective communication at scale.** +The NCCL profiler plugin decomposes proxy-op wait time into send/recv latencies so you +can tell who's waiting on whom — essential for debugging all-reduce tail latency. + +**Write custom performance tables.** +Define a dataclass with `@table("my_metrics")`, append rows from your training loop, and +query them alongside built-in tables. Your data lives in `python.my_metrics`, same +namespace, same SQL interface. + +**Query across cluster nodes.** +Prefix any table with `global.` to fan out to registered peers. Each row carries `_rank`, +`_role`, `_host` tags so results are directly comparable across the cluster. + +## How it works + +Probing ships as a Python package with a compiled Rust core (`probing._core`). When you +run `PROBING=1 python train.py`, a `.pth` hook starts an in-process HTTP server and +registers data sources for the SQL engine. Extensions — CPU sampling, GPU memory, NCCL +proxy ops, Python stack tracing — push rows into append-only columnar tables backed by +mmap ring buffers. The CLI talks to the embedded server over a Unix socket (local) or TCP +(remote). + +You don't need to know any of that to use it. `pip install probing` and you're done. + +## Start here + +**I want to debug a training issue.** +Read [Quick Start](quickstart.md), then try `probing backtrace` and `probing query`. + +**I want to understand the architecture.** +Read [Core Concepts](guide/concepts.md) and [Modularity & Boundaries](design/modularity.md). + +**I'm setting up a multi-node cluster.** +Read [Distributed](design/distributed.md) and the [SQL Tables](reference/sql-tables.md) reference. + +**I want to write a custom diagnostic skill.** +Read [Extensibility](design/extensibility.md) and browse `skills/` for examples. + +**I want to contribute.** +Read [Contributing](contributing.md), `make develop`, and pick an issue. + ## Documentation map -- [Installation](installation.md) · [Quick Start](quickstart.md) · [Core Concepts](guide/concepts.md) -- [SQL Tables](reference/sql-tables.md) · [API Reference](api-reference.md) · [Contributing](contributing.md) +| Doc | Covers | +|-----|--------| +| [Installation](installation.md) | `pip install`, `PROBING=1`, platform support | +| [Quick Start](quickstart.md) | First 5 minutes, real-world debugging scenarios | +| [Core Concepts](guide/concepts.md) | How Probing works — mental model, data flow, step coordinates, federation | +| [SQL Tables](reference/sql-tables.md) | Column reference for every built-in table | +| [API Reference](api-reference.md) | CLI commands and Python API | +| [Environment Variables](reference/env-vars.md) | Complete `PROBING_*` variable reference (30+ entries) | +| [Skill Format](reference/skill-format.md) | `steps.yaml` and `SKILL.md` specification | +| [SQL Analytics](guide/sql-analytics.md) | Query patterns, JOIN examples, time-series | +| [Diagnostic Skills](guide/skills.md) | Running and writing diagnostic workflows | +| [Extensibility](design/extensibility.md) | Data table plugins, diagnostic skills, NCCL profiler | +| [Distributed](design/distributed.md) | Multi-node federation, torchrun integration | +| [NCCL Profiler](design/nccl-profiler.md) | NCCL plugin, proxy-op wait decomposition | +| [Contributing](contributing.md) | Dev setup, pull request workflow | diff --git a/docs/src/index.zh.md b/docs/src/index.zh.md index 89feba44..846044af 100644 --- a/docs/src/index.zh.md +++ b/docs/src/index.zh.md @@ -1,38 +1,95 @@ --- template: home.html title: Probing - 分布式 AI 动态性能分析器 -description: 面向分布式 AI 的零侵入分析器 — SQL 表、现场内省、集群联邦与诊断 skill。 +description: 附着到运行中的 Python 训练进程,用 SQL 查询性能数据,跨集群节点诊断分布式问题。 hide: toc --- # Probing -**Probing** 面向分布式 AI 训练:持续写入 SQL 表、现场附着、联邦 `global.*` 查询与内置诊断 skill。 +无需修改代码即可检查和调试分布式 AI 训练任务。附着到运行中的 Python 进程,用标准 SQL +查询性能数据,运行诊断工作流定位慢 rank、NCCL 瓶颈或内存泄漏。 -## 能力概览 - -- **持续采集** — `torch_trace`、`comm_collective`、NCCL proxy、自定义 `@table` -- **现场内省** — 对运行中进程 `eval`、`backtrace`、REPL -- **SQL 分析** — 单节点与 `global.*` 联邦(`_rank` / `_role` 标签) -- **诊断 skill** — `health_overview`、`slow_rank`、`nccl_culprit_victim` 等 -- **集群** — `cluster nodes`、`cluster query`、Web UI Agent -- **零侵入** — 启动时 `PROBING=1` 或 Linux `inject` - -## 快速开始 +## 30 秒上手 ```bash pip install probing -# 训练推荐 -PROBING=1 PROBING_TORCH_PROFILING=on python train.py +# 启动带 probing 的训练 +PROBING=1 python train.py & -# 或 Linux 附着 -probing -t inject -probing -t query "SELECT * FROM python.torch_trace LIMIT 10" -probing -t skill run health_overview +# 附着并检查 +probing -t $(pgrep -f train.py) backtrace +probing -t $(pgrep -f train.py) query " + SELECT module, stage, AVG(duration) as sec + FROM python.torch_trace + GROUP BY module, stage + ORDER BY sec DESC LIMIT 5 +" ``` +第一条命令启动一个已激活 probing 的训练任务。第二条捕获主线程上所有 Python 和原生帧。 +第三条用一条 SQL 查出最慢的五个模块-阶段组合——无需加日志、无需插桩、无需重启。 + +## 能做什么 + +**调试卡住或变慢的训练任务。** +附着到卡住的进程,抓取调用栈,按 step 查看 GPU 内存变化,精确定位是哪个模块或集合通信 +调用在阻塞。不需要复现问题。 + +**诊断分布式集合通信性能。** +NCCL profiler 插件将 proxy-op 等待时间分解为 send/recv 延迟,能清晰判断谁在等谁—— +调试 all-reduce 尾部延迟的关键工具。 + +**编写自定义性能表。** +用 `@table("my_metrics")` 定义 dataclass,从训练循环中追加行,和内置表一起查询。 +数据存在 `python.my_metrics`,同一命名空间,同一 SQL 接口。 + +**跨集群节点查询。** +在任何表前加 `global.` 前缀即可 fan-out 到已注册节点。每行附带 `_rank`、`_role`、 +`_host` 标签,结果可直接跨集群对比。 + +## 工作原理 + +Probing 以 Python 包形式发布,内置编译好的 Rust 核心(`probing._core`)。 +当你运行 `PROBING=1 python train.py` 时,一个 `.pth` 钩子启动进程内 HTTP 服务器, +为 SQL 引擎注册数据源。扩展模块——CPU 采样、GPU 显存、NCCL proxy ops、Python 调用栈 +追踪——将行数据推送到 mmap 环形缓冲区支持的只追加列式表中。CLI 通过 Unix socket +(本地)或 TCP(远程)与嵌入式服务器通信。 + +作为用户,你不需要了解这些细节。`pip install probing` 就够了。 + +## 从这里开始 + +**我想调试训练问题。** +阅读[快速开始](quickstart.zh.md),然后试试 `probing backtrace` 和 `probing query`。 + +**我想理解架构。** +阅读[核心概念](guide/concepts.zh.md)和[模块化与边界](design/modularity.zh.md)。 + +**我在搭建多节点集群。** +阅读[分布式](design/distributed.zh.md)和 [SQL 表目录](reference/sql-tables.zh.md)。 + +**我想写一个自定义诊断 skill。** +阅读[扩展性](design/extensibility.zh.md),浏览 `skills/` 目录下的示例。 + +**我想贡献代码。** +阅读[贡献指南](contributing.zh.md),执行 `make develop`,挑一个 issue 开始。 + ## 文档导航 -- [安装指南](installation.zh.md) · [快速开始](quickstart.zh.md) · [核心概念](guide/concepts.zh.md) -- [SQL 表目录](reference/sql-tables.zh.md) · [API 参考](api-reference.zh.md) · [贡献指南](contributing.zh.md) +| 文档 | 内容 | +|------|------| +| [安装指南](installation.zh.md) | `pip install`、`PROBING=1`、平台支持 | +| [快速开始](quickstart.zh.md) | 5 分钟上手,真实调试场景 | +| [核心概念](guide/concepts.zh.md) | Probing 工作原理——心理模型、数据流、step 坐标、联邦查询 | +| [SQL 表目录](reference/sql-tables.zh.md) | 每张内置表的列定义 | +| [API 参考](api-reference.zh.md) | CLI 命令和 Python API | +| [环境变量](reference/env-vars.md) | 全部 30+ 个 `PROBING_*` 环境变量参考 | +| [Skill 格式规范](reference/skill-format.md) | `steps.yaml` 和 `SKILL.md` 格式规范 | +| [SQL 分析](guide/sql-analytics.zh.md) | 查询模式、JOIN 示例、时序分析 | +| [诊断 Skill](guide/skills.zh.md) | 运行和编写诊断工作流 | +| [扩展机制](design/extensibility.zh.md) | 数据表插件、诊断 skill、NCCL profiler | +| [分布式](design/distributed.zh.md) | 多节点联邦、torchrun 集成 | +| [NCCL Profiler](design/nccl-profiler.zh.md) | NCCL 插件、proxy-op 等待分解 | +| [贡献指南](contributing.zh.md) | 开发环境搭建、PR 流程 | diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index 83bcd6ea..773abfcd 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -1,94 +1,179 @@ # Quick Start -Get immediate value from Probing with this streamlined workflow. +This guide walks you through using Probing to inspect a running training process. +It assumes you've already [installed](installation.md) the package. By the end, +you'll be able to attach to a process, capture backtraces, and run SQL queries +against live performance data. -## Your First 5 Minutes +## Finding and attaching to a process -### Step 1: Set Your Target Process - -All Probing commands need a target endpoint. Set `$ENDPOINT` to either a local process ID or remote address: +First, locate the Python process you want to inspect. In your terminal: ```bash -# Local process - find and set your Python process ID -export ENDPOINT=$(pgrep -f "python.*your_script") +# Find your training process +pgrep -f "python.*train" +# → 27891 -# Or for remote processes -export ENDPOINT=remote-host:8080 +# Or with more context +ps aux | grep python | grep -v grep ``` -!!! tip "Finding Processes" - Use `ps aux | grep python` or `pgrep -f "python.*train"` to locate your target. +Set the endpoint environment variable so later commands are clean: -### Step 2: Attach and explore +```bash +export ENDPOINT=27891 +``` + +On Linux, you can attach to a running process directly: ```bash -# Connect to your process (Linux only) probing $ENDPOINT inject - -# Get basic process info -probing $ENDPOINT eval "import os, psutil; proc = psutil.Process(); print(f'PID: {os.getpid()}, Memory: {proc.memory_info().rss/1024**2:.1f}MB')" ``` -### Step 3: Try three CLI commands +On macOS or Windows, injection isn't available — you need to start the process with +probing enabled (`PROBING=1 python train.py`) instead. The `inject` command is the +only part of probing that's Linux-only; everything else (query, eval, backtrace, +skills) works on all platforms once the server is running. -Profiling tables fill automatically when hooks run. These commands read and interact with the probe: +## Your first diagnostic commands + +With the server running in the target, grab a backtrace and check what's happening +on the main thread: ```bash -probing $ENDPOINT query "SELECT name, value FROM information_schema.df_settings LIMIT 5" -probing $ENDPOINT eval "import torch; print(f'CUDA: {torch.cuda.is_available()}')" probing $ENDPOINT backtrace -probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER BY depth LIMIT 5" ``` -Details: **[Core Concepts](guide/concepts.md)** · **[API Reference](api-reference.md)** +The backtrace populates `python.backtrace` — a point-in-time view of the current +stack, mixing Python and native frames. Query it with SQL: + +```bash +probing $ENDPOINT query " + SELECT func, file, lineno, depth, frame_type + FROM python.backtrace + ORDER BY depth LIMIT 10 +" +``` + +You'll see function names, source file paths, and whether each frame is Python or +native code. Depth 0 is the innermost frame (what's executing right now). + +If you need to inspect live state beyond the stack, use `eval` to run arbitrary +Python in the target: + +```bash +probing $ENDPOINT eval "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" +probing $ENDPOINT eval " + import gc, torch + gc.collect() + alloc = torch.cuda.memory_allocated() / 1024**2 + reserved = torch.cuda.memory_reserved() / 1024**2 + print(f'GPU: {alloc:.0f}MB allocated, {reserved:.0f}MB reserved') +" +``` + +## Three common workflows + +What follows are real scenarios — the kind of things you'll actually use Probing for +in practice. -## Real-World Debugging Scenarios +### Training is hanging -### Scenario 1: Training Process Hanging +The most common use case: training suddenly stops making progress. The process is +alive but nothing is happening. -**Problem**: PyTorch training suddenly stops progressing. +Start with a backtrace to see what the main thread is doing: ```bash -# 1. See what main thread is doing probing $ENDPOINT backtrace +probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER BY depth LIMIT 5" +``` + +The innermost frame (depth 0) tells you exactly where execution is stuck. If the +stack shows `ncclAllReduce` or a `torch.distributed` call, you're looking at a +communication hang. If it shows Python code in your model, the computation itself +is the bottleneck. -# 2. Check thread states -probing $ENDPOINT eval "import threading; [(t.name, t.is_alive()) for t in threading.enumerate()]" +Next, check thread states — a collective might be hanging while other threads are +fine: -# 3. Analyze stack context -probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER BY depth LIMIT 10" +```bash +probing $ENDPOINT eval " + import threading + for t in threading.enumerate(): + print(f'{t.name}: alive={t.is_alive()}, daemon={t.daemon}') +" ``` -### Scenario 2: Memory Leak Investigation +### GPU memory is growing -**Problem**: Memory usage keeps growing during training. +Memory creeping up step after step — a leak or accumulation pattern. Query the +torch_trace table to see per-step allocation trends: ```bash -# Force cleanup and get current state -probing $ENDPOINT eval "import gc, torch; gc.collect(); torch.cuda.empty_cache()" +probing $ENDPOINT query " + SELECT local_step, AVG(allocated) as avg_mb, MAX(allocated_delta) as max_delta_mb + FROM python.torch_trace + GROUP BY local_step + ORDER BY local_step +" +``` -# Analyze allocation trends -probing $ENDPOINT query "SELECT local_step, AVG(allocated) as avg_memory FROM python.torch_trace GROUP BY local_step ORDER BY local_step" +Look for `allocated_delta` values that don't return to zero between steps — +that indicates memory not being freed between iterations. Pair with `eval` to +force a GC and check current state: + +```bash +probing $ENDPOINT eval "import gc, torch; gc.collect(); torch.cuda.empty_cache()" ``` -### Scenario 3: Performance Bottleneck Analysis +If the GC + cache clear brings memory back down, the issue is Python-side +reference cycles. If not, you're looking at a CUDA-side leak or growing +workspace. + +### Finding slow modules and operations -**Problem**: Need to identify which model components are slowest. +To identify which parts of your model are the bottleneck: ```bash -# Find most expensive operations probing $ENDPOINT query " -SELECT module, stage, AVG(duration) as avg_duration -FROM python.torch_trace -GROUP BY module, stage -ORDER BY avg_duration DESC -LIMIT 10" + SELECT module, stage, AVG(duration) as avg_duration, COUNT(*) as calls + FROM python.torch_trace + WHERE stage IN ('post forward', 'post step') + GROUP BY module, stage + ORDER BY avg_duration DESC + LIMIT 10 +" ``` -## Next Steps +Filtering to `post forward` and `post step` stages gives you the execution times +(the `pre` rows carry timing anchors; the `post` rows carry actual durations). +The results tell you exactly which module and which pass direction accounts for +the most time. + +If you're running distributed, add the federated query prefix to compare across +ranks: + +```bash +probing -t query " + SELECT _rank, module, stage, AVG(duration) as avg_duration + FROM global.python.torch_trace + WHERE stage IN ('post forward', 'post step') + GROUP BY _rank, module, stage + ORDER BY avg_duration DESC + LIMIT 20 +" +``` + +`_rank` is a federation tag added at query time — see [Core Concepts](guide/concepts.md) +for how this works. + +## What's next + +These three commands — `backtrace`, `eval`, `query` — cover the majority of +day-to-day diagnostic work. Each is documented in detail in the [API +Reference](api-reference.md). -- **[Core Concepts](guide/concepts.md)** — Endpoint, tables, step/role, federation (read this next) -- **[Diagnostic Skills](guide/skills.md)** — `probing skill run` workflows -- [SQL Analytics](guide/sql-analytics.md) - Advanced query techniques -- [Memory Analysis](guide/memory-analysis.md) - Deep dive into memory debugging -- [Debugging Guide](guide/debugging.md) - Expert debugging patterns +For deeper analysis patterns (JOINs across tables, time-series bucketing, +statistical aggregation), read [SQL Analytics](sql-analytics.md). For +multi-node debugging, start with [Distributed](../design/distributed.md). diff --git a/docs/src/quickstart.zh.md b/docs/src/quickstart.zh.md index b0feb341..30ea5370 100644 --- a/docs/src/quickstart.zh.md +++ b/docs/src/quickstart.zh.md @@ -1,94 +1,160 @@ # 快速开始 -通过这个精简的工作流程,快速获得 Probing 的价值。 +本指南带你使用 Probing 检查运行中的训练进程。假定你已完成[安装](installation.zh.md)。 +读完这篇,你将能够附着到进程、抓取调用栈、并对实时性能数据执行 SQL 查询。 -## 5 分钟上手 +## 找到并附着到进程 -### 步骤 1:设置目标进程 - -所有 Probing 命令都需要一个目标端点。将 `$ENDPOINT` 设置为本地进程 ID 或远程地址: +首先找到你想检查的 Python 进程: ```bash -# 本地进程 - 查找并设置 Python 进程 ID -export ENDPOINT=$(pgrep -f "python.*your_script") +# 查找训练进程 +pgrep -f "python.*train" +# → 27891 -# 或者远程进程 -export ENDPOINT=remote-host:8080 +# 或者带更多上下文 +ps aux | grep python | grep -v grep ``` -!!! tip "查找进程" - 使用 `ps aux | grep python` 或 `pgrep -f "python.*train"` 来定位目标进程。 +设置端点环境变量以便后续命令更简洁: -### 步骤 2:附着并探索 +```bash +export ENDPOINT=27891 +``` + +在 Linux 上,可以直接附着到运行中的进程: ```bash -# 连接到进程(仅 Linux) probing $ENDPOINT inject - -# 获取基本进程信息 -probing $ENDPOINT eval "import os, psutil; proc = psutil.Process(); print(f'PID: {os.getpid()}, 内存: {proc.memory_info().rss/1024**2:.1f}MB')" ``` -### 步骤 3:试用三种 CLI 命令 +macOS 或 Windows 不支持 injection——需要以 `PROBING=1 python train.py` 的方式启动 +进程。`inject` 是 Probing 中唯一 Linux 专属的命令,其余所有操作(query、eval、 +backtrace、skill)在服务器运行后全平台通用。 -采集表在钩子运行时自动写入。以下命令读取并与探针交互: +## 首次诊断 + +服务器已在目标进程中运行后,抓取调用栈看看主线程在做什么: ```bash -probing $ENDPOINT query "SELECT name, value FROM information_schema.df_settings LIMIT 5" -probing $ENDPOINT eval "import torch; print(f'CUDA: {torch.cuda.is_available()}')" probing $ENDPOINT backtrace -probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER BY depth LIMIT 5" ``` -详见 **[核心概念](guide/concepts.zh.md)** · **[API 参考](api-reference.zh.md)** +backtrace 会填充 `python.backtrace`——主线程当前调用栈的瞬时视图,混合 Python +和原生帧。用 SQL 查询它: + +```bash +probing $ENDPOINT query " + SELECT func, file, lineno, depth, frame_type + FROM python.backtrace + ORDER BY depth LIMIT 10 +" +``` + +你会看到函数名、源文件路径,以及每帧是 Python 还是原生代码。深度 0 是最内层帧 +(当前正在执行的位置)。 + +如需查看调用栈之外的实时状态,用 `eval` 在目标进程中执行任意 Python: + +```bash +probing $ENDPOINT eval "import torch; print(f'CUDA 可用: {torch.cuda.is_available()}')" +probing $ENDPOINT eval " + import gc, torch + gc.collect() + alloc = torch.cuda.memory_allocated() / 1024**2 + reserved = torch.cuda.memory_reserved() / 1024**2 + print(f'GPU: 已分配 {alloc:.0f}MB, 已预留 {reserved:.0f}MB') +" +``` + +## 三个常见工作流 + +以下是你会经常用到的真实场景。 -## 真实调试场景 +### 训练卡住了 -### 场景 1:训练进程卡住 +最常见的场景:训练突然停止进展。进程还活着但什么也不做。 -**问题**:PyTorch 训练突然停止进展。 +先用 backtrace 看主线程状况: ```bash -# 1. 查看主线程在做什么 probing $ENDPOINT backtrace +probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER BY depth LIMIT 5" +``` + +最内层帧(depth 0)精确显示卡在哪里。如果调用栈显示 `ncclAllReduce` 或 +`torch.distributed` 调用,这是通信 hang。如果是模型 Python 代码,则是计算瓶颈。 -# 2. 检查线程状态 -probing $ENDPOINT eval "import threading; [(t.name, t.is_alive()) for t in threading.enumerate()]" +接着检查线程状态——collective 可能卡住了但其他线程正常: -# 3. 分析堆栈上下文 -probing $ENDPOINT query "SELECT func, file, lineno FROM python.backtrace ORDER BY depth LIMIT 10" +```bash +probing $ENDPOINT eval " + import threading + for t in threading.enumerate(): + print(f'{t.name}: alive={t.is_alive()}, daemon={t.daemon}') +" ``` -### 场景 2:内存泄漏排查 +### GPU 内存在增长 -**问题**:训练过程中内存使用持续增长。 +内存随 step 不断增长——泄漏或累积。查询 torch_trace 表看每步分配趋势: ```bash -# 强制清理并获取当前状态 -probing $ENDPOINT eval "import gc, torch; gc.collect(); torch.cuda.empty_cache()" +probing $ENDPOINT query " + SELECT local_step, AVG(allocated) as avg_mb, MAX(allocated_delta) as max_delta_mb + FROM python.torch_trace + GROUP BY local_step + ORDER BY local_step +" +``` -# 分析分配趋势 -probing $ENDPOINT query "SELECT local_step, AVG(allocated) as avg_memory FROM python.torch_trace GROUP BY local_step ORDER BY local_step" +关注 `allocated_delta` 值不会在 step 间归零——说明迭代之间内存未被释放。 +配合 `eval` 强制 GC 并检查当前状态: + +```bash +probing $ENDPOINT eval "import gc, torch; gc.collect(); torch.cuda.empty_cache()" ``` -### 场景 3:性能瓶颈分析 +如果 GC 和 cache 清理后内存回落,问题是 Python 侧引用环。否则是 CUDA 侧 +泄漏或不断增长的 workspace。 + +### 找出最慢的模块和操作 -**问题**:需要找出哪些模型组件最慢。 +定位模型中哪个部分最耗时: ```bash -# 查找最耗时的操作 probing $ENDPOINT query " -SELECT module, stage, AVG(duration) as avg_duration -FROM python.torch_trace -GROUP BY module, stage -ORDER BY avg_duration DESC -LIMIT 10" + SELECT module, stage, AVG(duration) as avg_duration, COUNT(*) as calls + FROM python.torch_trace + WHERE stage IN ('post forward', 'post step') + GROUP BY module, stage + ORDER BY avg_duration DESC + LIMIT 10 +" ``` +过滤到 `post forward` 和 `post step` stage 可获得执行耗时(`pre` 行携带时间锚点, +`post` 行携带实际 duration)。结果告诉你哪个模块、哪个方向最耗时。 + +分布式训练时,加 `global.` 前缀跨 rank 对比: + +```bash +probing -t query " + SELECT _rank, module, stage, AVG(duration) as avg_duration + FROM global.python.torch_trace + WHERE stage IN ('post forward', 'post step') + GROUP BY _rank, module, stage + ORDER BY avg_duration DESC + LIMIT 20 +" +``` + +`_rank` 是查询时附加的联邦标签——工作原理见[核心概念](guide/concepts.zh.md)。 + ## 下一步 -- **[核心概念](guide/concepts.zh.md)** — 端点、表、step/role、联邦(建议优先阅读) -- **[诊断 Skill](guide/skills.zh.md)** — `probing skill run` 工作流 -- [SQL 分析](guide/sql-analytics.zh.md) - 高级查询技巧 -- [内存分析](guide/memory-analysis.zh.md) - 深入内存调试 -- [调试指南](guide/debugging.zh.md) - 专家级调试模式 +三个命令——`backtrace`、`eval`、`query`——覆盖了日常诊断的大部分场景。详细文档 +见 [API 参考](api-reference.zh.md)。 + +更深入的分析模式(跨表 JOIN、时间分桶、统计聚合)见 [SQL 分析](guide/sql-analytics.zh.md)。 +多节点调试从[分布式](../design/distributed.zh.md)开始。 diff --git a/docs/src/reference/env-vars.md b/docs/src/reference/env-vars.md new file mode 100644 index 00000000..f9576e64 --- /dev/null +++ b/docs/src/reference/env-vars.md @@ -0,0 +1,147 @@ +# Environment Variables + +Complete reference of every environment variable Probing reads. Variables are grouped +by subsystem. + +## Activation + +| Variable | Values | Default | Description | +|----------|--------|---------|-------------| +| `PROBING` | `0`, `1`/`followed`, `2`/`nested`, `regex:PATTERN`, `SCRIPT.py` | unset (disabled) | Controls whether probing activates. `1` activates the current process. `2` activates current + child processes. `regex:PATTERN` activates when the script basename matches. `SCRIPT.py` activates when the script basename equals the value exactly. | +| `PROBING_ORIGINAL` | (set automatically) | — | Backs up the original `PROBING` value before probing modifies it. Set by site_hook; don't set manually. | + +**Child-process propagation:** In `nested` mode, the original `PROBING` value is propagated to children. In `regex:` mode, non-matching children inherit `PROBING=1` so they can be inspected but won't re-trigger site hooks. + +Prefix syntax: `init:SCRIPT+` runs `exec(open(SCRIPT).read())` after activation. + +## Data storage + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_DATA_DIR` | Platform-specific | Root directory for mmap ring buffer files (MEMT tables). Each process creates a subdirectory named by its PID. | +| `PROBING_COLD` | unset | Set to `on` to enable hot-to-cold compaction of mmap tables. | +| `PROBING_COLD_TARGET_MB` | — | Target size per cold chunk after compaction. | +| `PROBING_COLD_MAX_TOTAL_MB` | — | Maximum total size of all cold storage files. | +| `PROBING_COLD_TTL_SECS` | — | Minimum age of a chunk before it's eligible for cold compaction. | +| `PROBING_COLD_POLL_MS` | — | Interval between compaction poll cycles. | +| `PROBING_COLD_MAX_AGE_SECS` | — | Maximum age of a chunk before forced compaction. | +| `PROBING_COLD_DIR` | — | Directory for cold storage files (defaults under `PROBING_DATA_DIR`). | + +## Server & networking + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_PORT` | unset | TCP port for the embedded HTTP server. Set to `RANDOM` for automatic port selection. Required for remote access. | +| `PROBING_SERVER_ADDR` | Inferred from port | Explicit bind address (e.g. `0.0.0.0:8080`). | +| `PROBING_SERVER_ADDRPATTERN` | unset | IP pattern filter for multi-homed hosts. Selects the first matching interface. | +| `PROBING_SERVER_WORKER_THREADS` | auto | Number of Tokio worker threads. | +| `PROBING_CTRL_ROOT` | `/tmp/probing/` | Directory for Unix domain sockets (local PID-based connections). | +| `PROBING_MAX_REQUEST_SIZE` | server default | Maximum HTTP request body size in bytes. | +| `PROBING_MAX_FILE_SIZE` | server default | Maximum file upload size in bytes. | +| `PROBING_ALLOWED_FILE_DIRS` | server default | Colon-separated list of directories allowed for file reads. | +| `PROBING_BASE_PATH` | unset | URL path prefix for reverse proxy deployments (e.g. `/probing`). | +| `PROBING_REMOTE_QUERY_TIMEOUT_SECS` | server default | Timeout for remote fan-out queries (federation). | +| `PROBING_ASSETS_ROOT` | built-in default | Path to the web UI static assets directory. | + +## Authentication + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_AUTH_TOKEN` | unset | Bearer token for HTTP authentication. Required for remote access when set. | +| `PROBING_AUTH_USERNAME` | unset | Username for Basic authentication. | +| `PROBING_AUTH_REALM` | unset | Authentication realm string for Basic auth. | + +## Tracing & spans + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_SPAN_BACKENDS` | `memtable` | Comma-separated list of span backends. Built-in: `memtable` (writes to `python.trace_event`), `logger` (writes to stderr), `otel` (OpenTelemetry export). Custom backends can be registered via `probing.span_backends` entry point. | +| `PROBING_SPAN_LOG_LEVEL` | `INFO` | Log level for the `logger` span backend. | +| `PROBING_SPAN_LOCATION` | unset | Enable automatic location capture via `inspect.stack()` for every span. Adds overhead; use sparingly. | + +## Step coordinates + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_MICRO_BATCHES` | `1` | Initial gradient accumulation factor. Controls `local_step = micro_step // micro_batches`. | +| `PROBING_STEP_BUCKET` | — | Step bucket size for grouped storage. | +| `PROBING_GLOBAL_STEP_BUCKET` | — | Global step bucket size (falls back to `PROBING_STEP_BUCKET`). | + +## Parallel topology (role) + +Set these to describe your training's parallelism configuration. Probing combines +them into a `role` string like `dp=2,pp=1,tp=0`. + +| Variable | Description | +|----------|-------------| +| `PROBING_TP_RANK` / `PROBING_TP_SIZE` | Tensor parallelism rank and size. | +| `PROBING_PP_RANK` / `PROBING_PP_SIZE` | Pipeline parallelism rank and size. | +| `PROBING_DP_RANK` / `PROBING_DP_SIZE` | Data parallelism rank and size. | +| `PROBING_EP_RANK` | Expert parallelism rank. | +| `PROBING_CP_RANK` | Context parallelism rank. | +| `PROBING_ROLE_` | Arbitrary named parallelism dimension (e.g. `PROBING_ROLE_SP=8`). | + +Non-PROBING-prefixed aliases are also recognized for Megatron compatibility: +`TP_RANK`, `TP_SIZE`, `PP_RANK`, `PP_SIZE`, `DP_RANK`, `DP_SIZE`, +`TENSOR_MODEL_PARALLEL_RANK`, `PIPELINE_MODEL_PARALLEL_RANK`, +`DATA_PARALLEL_RANK`, and more. + +## CPU sampling + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_CPU` | enabled | Set to `0`, `off`, `false`, or `no` to disable CPU sampling. | +| `PROBING_CPU_SAMPLE_MS` | `1000` | Sampling interval in milliseconds. Set to `0` to disable. | +| `PROBING_CPU_THREAD_TOP_N` | `8` | Maximum number of threads to sample per process per interval. | + +## GPU sampling + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_GPU` | enabled | Set to `0`, `off`, `false`, or `no` to disable GPU sampling. | +| `PROBING_GPU_SAMPLE_MS` | — | GPU sampling interval in milliseconds. | +| `PROBING_GPU_BACKEND` | `auto` | GPU backend filter: `auto`, `cuda`, `rocm`, `metal`. | + +## NCCL & HCCL + +| Variable | Description | +|----------|-------------| +| `PROBING_NCCL_MOCK` | Enable mock NCCL proxy data for testing without GPUs. | +| `PROBING_NCCL_PROFILER` | Path to the NCCL profiler shared library. | +| `PROBING_HCCL_PROFAPI_REAL` | Path to the real HCCL profapi library (Ascend NPU). | +| `PROBING_HCCL_SHIM` | Path to the HCCL shim library. | +| `PROBING_HCCL_SHIM_LOG` | Enable HCCL shim debug logging. | + +## RDMA + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_RDMA_HCA_NAME` | — | HCA device name filter for RDMA counter sampling. | +| `PROBING_RDMA_SAMPLE_RATE` | — | RDMA counter sampling rate in seconds. | + +## PyTorch integration + +| Variable | Description | +|----------|-------------| +| `PROBING_TORCH_PROFILING` | Set to `on` to activate PyTorch module hooks and write `python.torch_trace`. Required for module timing and memory data. | +| `PROBING_TORCHRUN_CLUSTER` | Enable automatic cluster registration via torchrun. | +| `PROBING_TORCHRUN_STORE_TIMEOUT` | Timeout for torchrun distributed store operations. | + +## Debugging & diagnostics + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROBING_LOGLEVEL` | `info` | Rust-side log level: `trace`, `debug`, `info`, `warn`, `error`. | +| `PROBING_CRASH_BACKTRACE` | enabled | Print a backtrace on fatal signals (SIGSEGV, SIGABRT, etc.). Set to `0` to disable. | +| `PROBING_RUST_BACKTRACE` | — | Rust error backtrace detail (similar to `RUST_BACKTRACE`). | +| `PROBING_SAFE_DEMO` | — | Safe demonstration mode that restricts dangerous operations. | + +## Skill & tool paths + +| Variable | Description | +|----------|-------------| +| `PROBING_PROJECT_SKILLS_DIR` | Per-project skill directory (overrides `$PWD/.probing/skills/`). | +| `PROBING_USER_SKILLS_DIR` | Per-user skill directory (overrides `$HOME/.probing/skills/`). | +| `PROBING_CODE_ROOT` | Root directory for embedded Python monitoring code. | +| `PROBING_CLI_MODE` | Set automatically by the CLI to prevent recursive engine initialization. | +| `PROBING_PYTHON` | Path to the Python interpreter used by the CLI. Set automatically. | diff --git a/docs/src/reference/index.md b/docs/src/reference/index.md index 5db34262..61240e0e 100644 --- a/docs/src/reference/index.md +++ b/docs/src/reference/index.md @@ -6,6 +6,8 @@ Authoritative lookup for SQL schemas, CLI commands, and runtime APIs. |------|----------| | **[SQL Tables](sql-tables.md)** | Physical columns for `python.*`, `cluster.*`, and federation tags (synced with `tables.yaml`) | | **[API Reference](../api-reference.md)** | CLI, in-process Python API, config, [unimplemented APIs](../api-reference.md#unimplemented-apis) | +| **[Environment Variables](env-vars.md)** | Complete reference of all 30+ `PROBING_*` environment variables | +| **[Skill Format](skill-format.md)** | `steps.yaml` and `SKILL.md` specification for diagnostic skill authors | | **[Versions](../versions.md)** | Release compatibility and upgrade notes | For narrative guides, see **[User Guide](../guide/index.md)** and **[Core Concepts](../guide/concepts.md)**. diff --git a/docs/src/reference/index.zh.md b/docs/src/reference/index.zh.md index 5556ae4a..e470c7a3 100644 --- a/docs/src/reference/index.zh.md +++ b/docs/src/reference/index.zh.md @@ -6,6 +6,8 @@ SQL 表结构、CLI 命令与运行时 API 的权威查阅入口。 |------|------| | **[SQL 表目录](sql-tables.zh.md)** | `python.*`、`cluster.*` 物理列与联邦标签(与 `tables.yaml` 同步) | | **[API 参考](../api-reference.zh.md)** | CLI、进程内 Python API、配置、[未实现 API](../api-reference.zh.md#unimplemented-apis) | +| **[环境变量](env-vars.md)** | 全部 30+ 个 `PROBING_*` 环境变量的完整参考 | +| **[Skill 格式规范](skill-format.md)** | 面向诊断 skill 作者的 `steps.yaml` 和 `SKILL.md` 格式规范 | | **[版本兼容性](../versions.zh.md)** | 版本兼容与升级说明 | 叙事性指南见 **[用户指南](../guide/index.zh.md)** 与 **[核心概念](../guide/concepts.zh.md)**。 diff --git a/docs/src/reference/skill-format.md b/docs/src/reference/skill-format.md new file mode 100644 index 00000000..ce5851f1 --- /dev/null +++ b/docs/src/reference/skill-format.md @@ -0,0 +1,326 @@ +# Skill Format + +Reference specification for Probing diagnostic skills. A skill is a YAML + Markdown +package that defines a diagnostic workflow: SQL queries to run, interpretation rules +to apply, and findings to produce. + +Skills live in `skills//` with two files: `steps.yaml` (the executable +workflow) and `SKILL.md` (agent-facing documentation). + +## Directory layout + +``` +skills/ +├── catalog.yaml Master index of all skills +├── semantic/ +│ ├── tables.yaml Table semantic definitions +│ ├── intents.yaml Keyword-to-skill routing +│ └── pages.yaml Page-to-skill suggestions +└── / + ├── steps.yaml Executable diagnostic workflow + └── SKILL.md Agent-facing markdown (frontmatter + body) +``` + +## catalog.yaml + +The master index. Each entry maps a skill ID to its category, priority, and path. + +```yaml +apiVersion: probing.dev/v1 +kind: Catalog +categories: + triage: + label: "Triage" + description: "Quick health checks and first-response diagnostics" +skills: + - id: health_overview + category: triage + priority: 100 + entry: true + description: "Quick system health overview — GPU memory, CPU, process status" + tables: + - python.torch_trace + pages: + - dashboard + related: + - slow_rank + path: health_overview/steps.yaml +``` + +**Fields:** +- `priority`: Higher values surface the skill earlier in listings and LLM selection. +- `entry: true`: This skill is recommended as an entry point for its category. +- `tables`: Tables the skill queries. Used for prerequisite checking. +- `pages`: Web UI pages where this skill is relevant. +- `related`: Skills commonly used before or after this one. + +## steps.yaml + +The executable workflow. A single YAML file with schema version, metadata, steps, +interpretation rules, and summary template. + +### Top-level structure + +```yaml +apiVersion: probing.dev/v1 +kind: Skill + +metadata: + id: my_skill + title: "My Diagnostic Skill" + title_en: "My Diagnostic Skill" + category: performance + tags: [gpu, memory] + triggers: + keywords: + zh: ["GPU", "显存"] + en: ["GPU", "memory"] + docs: | + Detailed description shown in CLI and web agent. + Can span multiple lines. + +spec: + parameters: [...] + requires: {...} + steps: [...] + variables: {...} + interpretation: + rules: [...] + summary_template: "..." + next_steps: [...] +``` + +### Parameters + +Declare typed parameters users can override with `--set key=value`. + +```yaml +parameters: + - name: sample_limit + type: integer + default: 100 + description: "Maximum rows to return per query" + - name: use_global + type: boolean + default: true + description: "Use global.* federation when available" + - name: step_window + type: integer + default: 20 + description: "Number of recent steps to analyze" +``` + +Types: `integer`, `boolean`, `string`. Parameter values are referenced in SQL as +`{param_name}`. + +### Requires + +Prerequisite check. The skill won't run if requirements aren't met. + +```yaml +requires: + any_tables: + - python.torch_trace + - nccl.proxy_ops +``` + +At least one of the listed tables must exist on the target endpoint. + +### Steps + +Ordered list of diagnostic operations. Each step has a type: + +**`sql`** — Run a SQL query. + +```yaml +steps: + - id: check_gpu_mem + title: "GPU Memory Per Step" + type: sql + sql: | + SELECT local_step, AVG(allocated) as avg_mb, MAX(max_allocated) as peak_mb + FROM {var_table} + WHERE local_step > (SELECT MAX(local_step) FROM {var_table}) - {step_window} + GROUP BY local_step + ORDER BY local_step + on_empty: warn + empty_message: "No GPU memory data found. Is PROBING_TORCH_PROFILING=on?" + cluster: false +``` + +Step fields: +- `id`: Unique within the skill. +- `title`: Displayed in output as `## Title`. +- `type`: `sql` (default), `api`, `ui`, or `config`. +- `sql`: The SQL query. `{param}` and `{var_name}` templates are expanded at runtime. +- `on_empty`: `skip` (default), `warn` (show empty message), or `abort` (stop). +- `empty_message`: Shown when the query returns zero rows (if `on_empty` is not `skip`). +- `cluster`: If `true`, uses federation fan-out (`POST /apis/cluster/query`). +- `when`: Optional condition. `"always"` or `"{use_global}"` (runs only when the + boolean variable is true). + +**`api`** — Call an HTTP API on the probing endpoint. + +```yaml + - id: check_nodes + title: "Cluster Peers" + type: api + path: /apis/nodes + method: GET +``` + +**`ui`** — Navigate the web UI to a view (web agent only; skipped in CLI). + +```yaml + - id: show_training + title: "Training Dashboard" + type: ui + view: training +``` + +**`config`** — Read or suggest a config change (CLI skips; web agent presents to user). + +```yaml + - id: check_sampling + title: "Sampling Rate Check" + type: config + config_key: probing.cpu.sample.interval +``` + +### Variables + +Derived template variables. The `comm_table` and `nccl_proxy_table` variables are +pre-defined and resolve based on `use_global`: + +```yaml +variables: + comm_table: "{use_global ? global.python.comm_collective : python.comm_collective}" + nccl_proxy_table: "{use_global ? global.nccl.proxy_ops : nccl.proxy_ops}" +``` + +The system derives these automatically unless overridden. `use_global` is itself a +parameter (default `true`) that auto-detects cluster availability from +`GET /apis/nodes`. + +### Interpretation rules + +Rules evaluate query results and produce severity-graded findings. They run after +all steps complete. + +```yaml +interpretation: + rules: + - id: high_gpu_mem + when: "step:check_gpu_mem | avg(allocated) > 90% * gpu_total" + severity: warning + message: "GPU memory usage ({avg_allocated:.0f}MB) exceeds 90% on rank {_rank}. Consider gradient checkpointing." + - id: mem_leak + when: "step:check_gpu_mem | slope(allocated) > 0" + severity: error + message: "GPU memory growing at {slope_rate:.1f}MB/step. Possible leak detected." +``` + +Rule fields: +- `id`: Unique within the skill. +- `when`: A predicate expression. Format: `step: | `. + Supports `avg`, `max`, `min`, `count`, `slope`, `latest` aggregations, and `>`, + `<`, `>=`, `<=`, `==`, `!=` operators. `*` multiplies by a reference value. +- `severity`: `error`, `warning`, or `info`. +- `message`: Template with `{column}` placeholders filled from the step's results. + +### Summary template + +A template expanded with step result metadata after the run: + +```yaml +summary_template: | + ## Summary + - GPU Memory: {check_gpu_mem.row_count} data points across {step_window} steps + - Slowest collective: {check_collective.top_op} at {check_collective.max_duration_ms:.1f}ms +``` + +Available fill values: `{step_id.row_count}`, `{step_id.max_}`, +`{step_id.min_}`, `{step_id.top_}`. + +### Next steps + +Suggestions shown after the skill completes: + +```yaml +next_steps: + - "If GPU memory exceeds 90%, consider running the memory_leak skill." + - "Check the NCCL culprit/victim skill if collective times are high." +``` + +## SKILL.md + +Agent-facing markdown for Cursor, Claude Code, and Codex integration. YAML +frontmatter is parsed by agent systems; the markdown body is human-readable. + +```markdown +--- +name: my_skill +description: >- + Diagnose my specific issue. +category: performance +tables: [python.torch_trace, python.comm_collective] +tags: [gpu, memory, performance] +keywords: + en: ['slow', 'bottleneck', 'throughput'] + zh: ['慢', '瓶颈', '吞吐'] +parameters: + step_window: { type: integer, default: 20 } +--- + +# My Skill + +Detailed explanation of what this skill does, when to use it, +and how to interpret the results. +``` + +## Skill resolution and overlay + +Skills are loaded from multiple roots in priority order: + +1. Embedded (compiled into the CLI binary at build time) +2. `$HOME/.probing/skills/` — user-level overrides +3. `$PWD/.probing/skills/` — project-level overrides +4. `$PROBING_PROJECT_SKILLS_DIR` — environment override +5. `$PROBING_USER_SKILLS_DIR` — environment override + +Later roots override earlier ones for the same skill ID. The catalog (`catalog.yaml`) +is also merged across roots — entries in higher-priority roots replace embedded +entries with the same ID. + +This means you can override a built-in skill by placing a modified copy in your +project's `.probing/skills/` directory. + +## Installing skills for AI agents + +```bash +# Install all skills for Cursor, Claude Code, and Codex +probing skill install + +# Install specific agents only +probing skill install --agent cursor --agent claude + +# Install to user-level agent directories only +probing skill install --user +``` + +This copies each skill's `SKILL.md` into `~/.cursor/skills/`, +`~/.claude/skills/`, or `~/.agents/skills/` so those agent tools can discover +and execute the skills during conversations. + +## Validation + +```bash +# Validate a single skill +python -m probing.skills validate my_skill + +# Validate all skills +python -m probing.skills validate --all +``` + +The validator checks: missing steps, duplicate step IDs, read-only SQL compliance +(all statements must start with SELECT/WITH/SHOW/DESCRIBE), and missing SKILL.md. diff --git a/docs/src/reference/sql-tables.md b/docs/src/reference/sql-tables.md index a0f89a51..7611b6eb 100644 --- a/docs/src/reference/sql-tables.md +++ b/docs/src/reference/sql-tables.md @@ -1,29 +1,42 @@ # SQL Tables -Authoritative catalog of built-in SQL tables queryable via `probing query` or in-process -`probing.query()`. Kept in sync with `skills/semantic/tables.yaml` (used by -diagnostic skills and the Web Agent). +This page catalogs every built-in SQL table you can query through Probing. It's a +reference — if you're looking for query patterns and how-to, start with [SQL +Analytics](../guide/sql-analytics.md). -Terminology: [Core Concepts](../guide/concepts.md) (endpoint, steps, `role`, federation). - -## Schemas +Each table is backed by an mmap ring buffer (MEMT) or registered dynamically by an +extension crate. Tables live under schema prefixes that reflect their data source: +`python.*` for training and Python runtime data, `cpu.*` / `gpu.*` for host and +device sampling, `cluster.*` for node registry, `nccl.*` for the NCCL profiler +plugin, and `global..
` for federated cross-rank queries. -| Prefix | Meaning | -|--------|---------| -| `python.*` | Python / training probe tables (memtable) | -| `cpu.*`, `gpu.*`, `process.*` | Host / device sampling (extensions) | -| `cluster.*` | Cluster registry | -| `nccl.*` | NCCL profiler plugin (optional) | -| `global..
` | Federated fan-out across registered peers | -| `information_schema.*` | Engine metadata | +The authoritative schema definitions live in `skills/semantic/tables.yaml` (used by +diagnostic skills and the Web Agent). The tables on this page are kept in sync with +that file. -List tables on a live endpoint: +To see what tables are actually available on a live endpoint: ```bash probing $ENDPOINT tables probing $ENDPOINT tables --all ``` +Terminology: [Core Concepts](../guide/concepts.md) (endpoint, steps, `role`, federation). + +## Schema prefixes + +Each schema represents a category of data source. The tables listed below are organized +by these prefixes so you know where to look: + +| Prefix | Data source | +|--------|-------------| +| `python.*` | Training and Python runtime (memtable-backed) | +| `cpu.*`, `gpu.*`, `process.*` | Host and device sampling (extension crates) | +| `cluster.*` | Cluster node registry | +| `nccl.*` | NCCL profiler plugin (optional, cdylib) | +| `global..
` | Federated fan-out across registered peers | +| `information_schema.*` | Engine metadata and configuration | + ## Federation Tables with a **`global_name`** can be queried as `global.` (e.g. diff --git a/docs/src/reference/sql-tables.zh.md b/docs/src/reference/sql-tables.zh.md index 144f5ff0..e012b2c6 100644 --- a/docs/src/reference/sql-tables.zh.md +++ b/docs/src/reference/sql-tables.zh.md @@ -1,28 +1,35 @@ # SQL 表目录 -可通过 `probing query` 或进程内 `probing.query()` 查询的内置 SQL 表权威目录。 -与 `skills/semantic/tables.yaml` 保持同步(诊断 skill 与 Web Agent 使用)。 +本页列出 Probing 中可通过 `probing query` 或 `probing.query()` 查询的所有内置 +SQL 表。这是一份参考手册——如果你需要查询模式和操作指南,从 [SQL 分析](../guide/sql-analytics.zh.md) +开始。 -术语说明见 [核心概念](../guide/concepts.zh.md)。 +每张表由 mmap 环形缓冲(MEMT)承载或由扩展 crate 动态注册。表按 schema 前缀 +组织,反映数据来源:`python.*` 是训练和 Python 运行时数据,`cpu.*` / `gpu.*` +是主机和设备采样,`cluster.*` 是节点注册信息,`nccl.*` 是 NCCL profiler 插件 +输出,`global..
` 是跨 rank 的联邦查询。 -## Schema 前缀 - -| 前缀 | 含义 | -|------|------| -| `python.*` | Python / 训练探针表(memtable) | -| `cpu.*`、`gpu.*`、`process.*` | 主机 / 设备采样(扩展) | -| `cluster.*` | 集群注册表 | -| `nccl.*` | NCCL profiler 插件(可选) | -| `global..
` | 跨已注册节点联邦 fan-out | -| `information_schema.*` | 引擎元数据 | +Schema 定义以 `skills/semantic/tables.yaml` 为权威来源(诊断 skill 与 Web Agent +使用该文件定义),本页与其保持同步。 -列出当前端点上的表: +在真实端点上查看当前可用表: ```bash probing $ENDPOINT tables probing $ENDPOINT tables --all ``` +## Schema 前缀 + +| 前缀 | 数据来源 | +|------|----------| +| `python.*` | 训练和 Python 运行时(memtable) | +| `cpu.*`、`gpu.*`、`process.*` | 主机和设备采样(扩展 crate) | +| `cluster.*` | 集群节点注册表 | +| `nccl.*` | NCCL profiler 插件(可选,cdylib) | +| `global..
` | 跨已注册节点联邦 fan-out | +| `information_schema.*` | 引擎元数据和配置 | + ## 联邦查询 带 **`global_name`** 的表可用 `global.<路径>` 查询(如 `global.python.comm_collective`)。