Skip to content

feat(profiler): drive torch.profiler around the training loop#1750

Open
dyurk-lila wants to merge 5 commits into
NovaSky-AI:mainfrom
dyurk-lila:upstream-profiler-driving
Open

feat(profiler): drive torch.profiler around the training loop#1750
dyurk-lila wants to merge 5 commits into
NovaSky-AI:mainfrom
dyurk-lila:upstream-profiler-driving

Conversation

@dyurk-lila

Copy link
Copy Markdown

What

SkyRL constructs a Profiler object on the Megatron policy worker but never drives it.start()/.step()/.stop() are called nowhere in the repo, so torch_profiler_config was effectively dead code. This PR wires torch.profiler up end to end for both Megatron and FSDP, both RL and SFT, with the full torch.profiler surface exposed as config (no hardcoded active=1 / single step).

After this, profiling reduces to setting a couple of config flags and reading the SkyRL-written traces — no worker subclass, no trainer overrides.

How it's driven

  • start_profile / profile_step / stop_profile RPCs on the shared Worker base (worker.py) → on the Ray actor method table of both PolicyWorkers automatically (same pattern as optim_step, set_lr, save_memory_snapshot). No subclass, no ray.remote re-wrap. Dispatched via pass_through thin wrappers in WorkerDispatch.
  • Trainers bracket the loop: start_profile before, one profile_step per global step, stop_profile after (in a finally, so an open trace window is never leaked) — all gated on torch_profiler_config.enable so non-profiling runs dispatch zero extra RPCs.
    • SFT: sft_trainer.py train() / dummy-train loop + one profile_step in train_step.
    • RL: trainer.py train() loop (and the async / fully-async trainers) + one profile_step per global step, so a torch active window spans the whole step (not a single minibatch).

Config — full torch.profiler surface, sane defaults

TorchProfilerConfig (hoisted to PolicyConfig, also wired through the SFT config bridge):

  • Schedule: skip_first, wait, warmup, active, repeattorch.profiler.schedule. This is the "profile N steps, at an interval, repeating M times" knob (repeat=0 = whole run).
  • Capture: activities, record_shapes, profile_memory, with_stack, with_flops, with_modules, export_type.
  • Defaults reproduce the prior effective behavior (CPU+CUDA, shapes+stack). enable=false (default) = unchanged from before.
  • Traces written by tensorboard_trace_handler as HTA/Kineto-friendly *.pt.trace.json (one per active window per rank). save_path defaults to {ckpt_path}/profiler_traces.
  • TorchProfilerConfig.validate() rejects unusable settings up front (called from both the RL and SFT entrypoint validators) so an enabled run fails fast instead of silently degrading.

Scope

Profiles only the policy model's training step (forward/backward + optimizer). In an RL run it does not profile the critic or reference models, and it does not profile generation/inference — only policy training compute on the configured ranks.

Per-kernel summary data path

For downstream attribution tooling that wants per-kernel self-time without re-parsing the trace:

  • on_trace_ready also stashes a pickle-safe per-kernel self-device-time summary for the just-closed window (exact — no cross-stream overlap double-counting).
  • Exposed via Profiler.get_kernel_summary()Worker.dump_profiler_summary() RPC → WorkerDispatch.dump_profiler_summary(model) (returns a per-rank list). SkyRL's own trainers do not call this; the trace files remain the primary deliverable.

Safety

All profiler paths are exception-isolated: a profiler fault disables profiling for the rest of the run rather than crashing it.

Cleanup

Removes the now-redundant Megatron-only torch_profiler_config (it was dead; replaced by the backend-agnostic policy-level one).

Testing

  • New CPU unit tests (tests/backends/skyrl_train/utils/test_profiler.py): schedule-driven trace-file counts (single window, repeat, skip_first deferral), disabled/rank-not-selected no-ops, save_path resolution, activities threading, exception isolation, the kernel-summary path, and the Worker / WorkerDispatch / trainer RPC plumbing.
  • tests/train/test_config.py and tests/train/test_sft_config.py: TorchProfilerConfig.validate() rejects bad configs on both the RL and SFT paths, and torch_profiler_config bridges through build_skyrl_config_for_sft.
  • Locally: ruff/black clean on all touched files; the new profiler tests pass on CPU. GPU/e2e paths left to CI.

🤖 Generated with Claude Code

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a unified, backend-agnostic TorchProfilerConfig and Profiler class to replace the previous Megatron-specific profiler configuration, enabling profiling across both FSDP and Megatron backends. The profiler is now driven by the trainer via start, step, and stop RPCs dispatched to workers, supported by comprehensive validation and unit tests. The review feedback recommends several robustness enhancements: tracking the profiler's running state to prevent PyTorch RuntimeErrors during stop operations, using getattr with fallbacks for backward compatibility with older PyTorch versions, adding null-checks to configuration validation to prevent TypeErrors from YAML inputs, and adding a unit test to verify safe stop behavior.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +58 to 61
self._last_pairs: list = []
self._window_count: int = 0
if not config.enable:
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Initialize a _running state variable to track whether the profiler has been successfully started. This ensures we can safely guard calls to stop() and prevent PyTorch from raising a RuntimeError if stop() is called on a non-running profiler.

Suggested change
self._last_pairs: list = []
self._window_count: int = 0
if not config.enable:
return
self._last_pairs: list = []
self._window_count: int = 0
self._running = False
if not config.enable:
return

try:
# ``self_device_time_total`` is torch 2.11's field (the older
# ``self_cuda_time_total`` was removed). Microseconds, self time.
self._last_pairs = [(str(e.key), float(e.self_device_time_total)) for e in prof.key_averages()]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using e.self_device_time_total directly can raise an AttributeError on older PyTorch versions (where it was named self_cuda_time_total or self_cpu_time_total). Using getattr with fallbacks makes the kernel summary extraction robust across different PyTorch versions and profiling activities.

            self._last_pairs = [
                (
                    str(e.key),
                    float(
                        getattr(
                            e,
                            "self_device_time_total",
                            getattr(e, "self_cuda_time_total", getattr(e, "self_cpu_time_total", 0.0)),
                        )
                    ),
                )
                for e in prof.key_averages()
            ]

Comment on lines +155 to +161
def start(self) -> None:
if self.check():
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()

def save(self):
if self.prof is not None and not self.saved:
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
save_file_name = f"/prof_rank_{self.rank}.json"
logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}")
self.prof.export_chrome_trace(self.save_path + save_file_name)
self.enable = False
self.saved = True
try:
logger.info(f"[Profiler] started for rank {self.rank}")
self.prof.start()
except Exception as e:
self._disable("start", e)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Set self._running = True upon successful start of the profiler.

Suggested change
def start(self) -> None:
if self.check():
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
def save(self):
if self.prof is not None and not self.saved:
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
save_file_name = f"/prof_rank_{self.rank}.json"
logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}")
self.prof.export_chrome_trace(self.save_path + save_file_name)
self.enable = False
self.saved = True
try:
logger.info(f"[Profiler] started for rank {self.rank}")
self.prof.start()
except Exception as e:
self._disable("start", e)
def start(self) -> None:
if self.check():
try:
logger.info(f"[Profiler] started for rank {self.rank}")
self.prof.start()
self._running = True
except Exception as e:
self._disable("start", e)

Comment on lines +170 to +176
def stop(self) -> None:
if self.check():
logger.info(f"[Profiler] Trace stopped for rank {self.rank}")
self.enable = False
try:
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
except Exception as e:
self._disable("stop", e)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Only call self.prof.stop() if the profiler is currently running (self._running is True), and reset the running state in a finally block. This prevents PyTorch from raising a RuntimeError if stop() is called on a non-running profiler (e.g., if start() failed or was never called).

Suggested change
def stop(self) -> None:
if self.check():
logger.info(f"[Profiler] Trace stopped for rank {self.rank}")
self.enable = False
try:
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
except Exception as e:
self._disable("stop", e)
def stop(self) -> None:
if self.check() and getattr(self, "_running", False):
try:
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
except Exception as e:
self._disable("stop", e)
finally:
self._running = False

Comment on lines +191 to +196
bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES]
if bad_activities:
raise ValueError(
f"invalid `torch_profiler_config.activities` entries {bad_activities}. "
f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a check to ensure self.activities is not None before iterating over it. If a user sets activities: null in their YAML configuration, this prevents a TypeError during validation.

        if self.activities is None:
            raise ValueError("`torch_profiler_config.activities` cannot be None.")
        bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES]
        if bad_activities:
            raise ValueError(
                f"invalid `torch_profiler_config.activities` entries {bad_activities}. "
                f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}."
            )

Comment on lines +207 to +212
for name in ("skip_first", "wait", "warmup", "repeat"):
value = getattr(self, name)
if value < 0:
raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.")
if self.active < 1:
raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add checks to ensure that the numeric schedule fields are not None before performing comparison operations. This prevents a TypeError if any of these fields are set to null in the YAML configuration.

Suggested change
for name in ("skip_first", "wait", "warmup", "repeat"):
value = getattr(self, name)
if value < 0:
raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.")
if self.active < 1:
raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.")
for name in ("skip_first", "wait", "warmup", "repeat"):
value = getattr(self, name)
if value is None:
raise ValueError(f"`torch_profiler_config.{name}` cannot be None.")
if value < 0:
raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.")
if self.active is None:
raise ValueError("`torch_profiler_config.active` cannot be None.")
if self.active < 1:
raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.")

Comment on lines +389 to +394
def test_dispatch_to_policy_when_enabled(self):
trainer, calls = self._trainer(enable=True)
trainer._profiler_start()
trainer._profiler_step()
trainer._profiler_stop()
assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a unit test to verify that calling stop() on a profiler that was never started is a safe no-op and does not raise any exceptions.

Suggested change
def test_dispatch_to_policy_when_enabled(self):
trainer, calls = self._trainer(enable=True)
trainer._profiler_start()
trainer._profiler_step()
trainer._profiler_stop()
assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")]
def test_dispatch_to_policy_when_enabled(self):
trainer, calls = self._trainer(enable=True)
trainer._profiler_start()
trainer._profiler_step()
trainer._profiler_stop()
assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")]
def test_profiler_stop_without_start_is_noop(tmp_path):
prof = Profiler(_ProfCfg(), default_save_path=str(tmp_path))
# Calling stop() before start() should not raise any error and should be a safe no-op.
prof.stop()
assert getattr(prof, "_running", False) is False

@SumanthRH SumanthRH self-assigned this Jun 4, 2026

@SumanthRH SumanthRH left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dyurk-lila

Copy link
Copy Markdown
Author

Done — bracketed FullCtxTrainer.train() with _profiler_start / _profiler_step / _profiler_stop, matching the wiring in RayPPOTrainer.train(). One profile_step per dummy global step, and stop runs in a finally so the kineto trace window isn't leaked if a step raises. No-op unless torch_profiler_config.enable is set. (e662baf)

@dyurk-lila dyurk-lila requested a review from SumanthRH June 8, 2026 19:50
@SumanthRH

Copy link
Copy Markdown
Member

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a backend-agnostic, highly configurable torch.profiler integration for both FSDP and Megatron backends across RL and SFT training loops. It replaces the previous Megatron-specific profiler config with a unified TorchProfilerConfig and manages the profiler lifecycle within the trainers. Feedback on these changes highlights several robustness and compatibility improvements: wrapping trace exports in try...except blocks to prevent I/O or PyTorch version compatibility crashes, handling cloud storage URIs gracefully by falling back to local paths, and adding stricter validation checks to prevent TypeError exceptions when configuration values are parsed as None or empty.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread skyrl/backends/skyrl_train/utils/profiler.py
Comment thread skyrl/backends/skyrl_train/utils/profiler.py Outdated
Comment thread skyrl/train/config/config.py
Comment thread skyrl/train/config/config.py Outdated
Comment on lines +154 to +155
save_path: Optional[str] = None
"""Trace output dir. Defaults to ``{ckpt_path}/profiler_traces`` when None."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this default ? The fallback in the case where the ckpt path is a cloud path is non-obvious.

I would rather have this be an explicit parameter that users set

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, the config now requires an explicit save path when profiling is enabled

f"[Profiler] cloud save_path {self.save_path!r} is not writable by torch.profiler; "
f"falling back to local './profiler_traces'."
)
self.save_path = "./profiler_traces"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The directory from which ray workers execute is the working directory in /tmp/ray/session_latest/runtime_resources/working_dir_files/...

using this relative path would mean that the profiler traces get saved in the working directory in /tmp/ray , which is pretty bad! I would recommend making save path for the torch profiler traces explicit

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, the config now requires an explicit save path when profiling is enabled

@SumanthRH

Copy link
Copy Markdown
Member

@dyurk-lila did you test this with colocated training for profiling more than one global step? I'm noticing an issue with the profiler with FSDP backend where it breaks due to offload to CPU.

ray.exceptions.RayTaskError(RuntimeError): �[36mray::skyrl_entrypoint()�[39m (pid=864994, ip=10.1.151.135)
  File "/home/ray/default/dyurk-lila/examples/train_scripts/full_context/main_full_ctx.py", line 53, in skyrl_entrypoint
    exp.run()
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/train/entrypoints/main_base.py", line 403, in run
    asyncio.run(trainer.train())
  File "/home/ray/anaconda3/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/examples/train_scripts/full_context/trainer_full_ctx.py", line 70, in train
    training_input = self.fwd_logprobs_values_reward(training_input)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpy8tZf1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/train/trainer.py", line 1170, in fwd_logprobs_values_reward
    ref_output = self.dispatch.forward("ref", data_fwd_pass)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker_dispatch.py", line 227, in forward
    self._ensure_on_gpu(model, need_optimizer=False, need_model=True)
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker_dispatch.py", line 153, in _ensure_on_gpu
    self._actor_groups[other].offload_to_cpu()
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker.py", line 711, in offload_to_cpu
    return ray.get(refs)
           ^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^
                                  ^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(RuntimeError): �[36mray::FSDPPolicyWorkerBase.offload_to_cpu()�[39m (pid=873281, ip=10.1.151.135, actor_id=a1ff557caaa2233dd4b5a86498000000, repr=<skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker.FSDPPolicyWorkerBase object at 0x7e79b2286fc0>)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/utils/__init__.py", line 107, in swap_tensors
    torch._C._swap_tensor_impl(t1, t2)
RuntimeError: Expected no weakrefs to t1's Tensor object but got  8

The above exception was the direct cause of the following exception:

�[36mray::FSDPPolicyWorkerBase.offload_to_cpu()�[39m (pid=873281, ip=10.1.151.135, actor_id=a1ff557caaa2233dd4b5a86498000000, repr=<skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker.FSDPPolicyWorkerBase object at 0x7e79b2286fc0>)
  File "/home/ray/anaconda3/lib/python3.12/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
           ^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker.py", line 332, in offload_to_cpu
    self.strategy.offload_to_cpu(
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/distributed/fsdp_strategy.py", line 123, in offload_to_cpu
    offload_fsdp2_model_to_cpu(model, empty_cache=True)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/distributed/fsdp_utils.py", line 76, in offload_fsdp2_model_to_cpu
    model.to("cpu", non_blocking=True)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/transformers/modeling_utils.py", line 3650, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1384, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py", line 626, in _apply
    ret = super()._apply(*args, **kwargs)  # type: ignore[misc]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 934, in _apply
    module._apply(fn)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 934, in _apply
    module._apply(fn)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 991, in _apply
    raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Embedding.weight

dyurk-lila and others added 3 commits June 15, 2026 12:28
SkyRL constructed a Profiler object on the Megatron policy worker but never
drove it -- start()/step()/stop() were called nowhere, so torch_profiler_config
was dead code. This wires torch.profiler end to end for both Megatron and FSDP,
both RL and SFT, with the full torch.profiler surface exposed as config.

- start_profile / profile_step / stop_profile RPCs on the shared Worker base,
  dispatched via pass_through thin wrappers in WorkerDispatch.
- Trainers bracket the loop: start before, one profile_step per global step,
  stop after -- all gated on torch_profiler_config.enable so non-profiling runs
  dispatch zero extra RPCs.
- TorchProfilerConfig hoisted to PolicyConfig (also wired through the SFT config
  bridge), exposing schedule (skip_first/wait/warmup/active/repeat) and capture
  (activities/record_shapes/profile_memory/with_stack/with_flops/with_modules/
  export_type) knobs. Defaults reproduce prior effective behavior; enable=false
  by default. Traces written by tensorboard_trace_handler as HTA/Kineto-friendly
  *.pt.trace.json under {ckpt_path}/profiler_traces by default.
- Profiles only the policy model's training step (fwd/bwd + optimizer), not the
  critic/ref models and not generation/inference.
- All profiler paths are exception-isolated: a fault disables profiling for the
  rest of the run rather than crashing it.
- Removes the redundant Megatron-only torch_profiler_config.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Bracket FullCtxTrainer.train() with _profiler_start / _profiler_step /
_profiler_stop, matching the wiring in RayPPOTrainer.train(). One
profile_step per dummy global step; stop runs in a finally so the open
kineto trace window isn't leaked if a step raises. No-op unless
torch_profiler_config.enable is set, so non-profiling runs pay nothing.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ect empty activities

Two robustness fixes from PR review:

- Profiler.__init__ detects a cloud save_path (s3://, gs://, gcs:// via the
  existing is_cloud_path helper) and falls back to ./profiler_traces with a
  warning. save_path commonly defaults to {ckpt_path}/profiler_traces, and
  ckpt_path can be a cloud URI -- which torch.profiler can't write to, so the
  trace would otherwise be silently lost.
- TorchProfilerConfig.validate rejects an empty `activities` list. It passed the
  membership check vacuously, but torch.profiler.profile(activities=[]) records
  nothing -- now it fails fast at startup.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@dyurk-lila dyurk-lila force-pushed the upstream-profiler-driving branch from 0e73e00 to cd63131 Compare June 15, 2026 17:28
…derived default

Per review: remove the non-obvious {ckpt_path}/profiler_traces default and the
silent cloud-URI -> ./profiler_traces fallback. The relative fallback would land
traces in the Ray runtime working dir under /tmp/ray, and the ckpt-derived
default was surprising. save_path is now a required, explicit local path,
validated fail-fast at startup (non-empty + not a cloud URI). Updated docs,
config defaults comment, and tests accordingly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@dyurk-lila

Copy link
Copy Markdown
Author

@dyurk-lila did you test this with colocated training for profiling more than one global step? I'm noticing an issue with the profiler with FSDP backend where it breaks due to offload to CPU.

ray.exceptions.RayTaskError(RuntimeError): �[36mray::skyrl_entrypoint()�[39m (pid=864994, ip=10.1.151.135)
  File "/home/ray/default/dyurk-lila/examples/train_scripts/full_context/main_full_ctx.py", line 53, in skyrl_entrypoint
    exp.run()
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/train/entrypoints/main_base.py", line 403, in run
    asyncio.run(trainer.train())
  File "/home/ray/anaconda3/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/examples/train_scripts/full_context/trainer_full_ctx.py", line 70, in train
    training_input = self.fwd_logprobs_values_reward(training_input)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpy8tZf1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/train/trainer.py", line 1170, in fwd_logprobs_values_reward
    ref_output = self.dispatch.forward("ref", data_fwd_pass)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker_dispatch.py", line 227, in forward
    self._ensure_on_gpu(model, need_optimizer=False, need_model=True)
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker_dispatch.py", line 153, in _ensure_on_gpu
    self._actor_groups[other].offload_to_cpu()
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker.py", line 711, in offload_to_cpu
    return ray.get(refs)
           ^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^
                                  ^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(RuntimeError): �[36mray::FSDPPolicyWorkerBase.offload_to_cpu()�[39m (pid=873281, ip=10.1.151.135, actor_id=a1ff557caaa2233dd4b5a86498000000, repr=<skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker.FSDPPolicyWorkerBase object at 0x7e79b2286fc0>)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/utils/__init__.py", line 107, in swap_tensors
    torch._C._swap_tensor_impl(t1, t2)
RuntimeError: Expected no weakrefs to t1's Tensor object but got  8

The above exception was the direct cause of the following exception:

�[36mray::FSDPPolicyWorkerBase.offload_to_cpu()�[39m (pid=873281, ip=10.1.151.135, actor_id=a1ff557caaa2233dd4b5a86498000000, repr=<skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker.FSDPPolicyWorkerBase object at 0x7e79b2286fc0>)
  File "/home/ray/anaconda3/lib/python3.12/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
           ^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker.py", line 332, in offload_to_cpu
    self.strategy.offload_to_cpu(
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/distributed/fsdp_strategy.py", line 123, in offload_to_cpu
    offload_fsdp2_model_to_cpu(model, empty_cache=True)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/distributed/fsdp_utils.py", line 76, in offload_fsdp2_model_to_cpu
    model.to("cpu", non_blocking=True)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/transformers/modeling_utils.py", line 3650, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1384, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py", line 626, in _apply
    ret = super()._apply(*args, **kwargs)  # type: ignore[misc]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 934, in _apply
    module._apply(fn)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 934, in _apply
    module._apply(fn)
  File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 991, in _apply
    raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Embedding.weight

I had not tried this but was able to reproduce the issue, seems to be specifically localized to this path (torch profiling + FSDP + colocated + manual CPU offload). There doesn't seem to be a great workaround for this, so I'll just add a config verifier to fail loudly and descriptively in this case.

… CPU offload

The FSDP2 manual CPU-offload path moves models with model.to("cpu") ->
nn.Module._apply -> torch.utils.swap_tensors, which raises
"RuntimeError: _apply(): Couldn't swap <param>" while torch.profiler holds
weakrefs to those params during an active window (reproduced on CPU with
torch 2.12). That offload only fires mid-loop under colocation.

Map of the crash precondition (all required):
  - strategy == "fsdp"  (Megatron offloads via flat-buffer/.data reassignment,
    never swap_tensors -> immune)
  - fsdp_config.cpu_offload == False  (the default "manual" path; cpu_offload=True
    uses FSDP2-native offload, no manual swap)
  - colocate_all OR colocate_policy_ref  (otherwise no in-loop offload happens)
SFT is single-model and hardcodes colocate_all=False -> never at risk.

Enforce via TorchProfilerConfig.validate(): the RL validator (validate_cfg) now
passes strategy/colocation/cpu_offload context so an incompatible config fails
fast at startup with an actionable message (set cpu_offload=true, disable
colocation, or use Megatron). SFT calls validate() with no context so the
cross-field check is skipped. Added unit + end-to-end tests and documented the
restriction in config.mdx.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@dyurk-lila dyurk-lila requested a review from SumanthRH June 15, 2026 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants