feat(profiler): drive torch.profiler around the training loop#1750
feat(profiler): drive torch.profiler around the training loop#1750dyurk-lila wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a unified, backend-agnostic TorchProfilerConfig and Profiler class to replace the previous Megatron-specific profiler configuration, enabling profiling across both FSDP and Megatron backends. The profiler is now driven by the trainer via start, step, and stop RPCs dispatched to workers, supported by comprehensive validation and unit tests. The review feedback recommends several robustness enhancements: tracking the profiler's running state to prevent PyTorch RuntimeErrors during stop operations, using getattr with fallbacks for backward compatibility with older PyTorch versions, adding null-checks to configuration validation to prevent TypeErrors from YAML inputs, and adding a unit test to verify safe stop behavior.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| self._last_pairs: list = [] | ||
| self._window_count: int = 0 | ||
| if not config.enable: | ||
| return |
There was a problem hiding this comment.
Initialize a _running state variable to track whether the profiler has been successfully started. This ensures we can safely guard calls to stop() and prevent PyTorch from raising a RuntimeError if stop() is called on a non-running profiler.
| self._last_pairs: list = [] | |
| self._window_count: int = 0 | |
| if not config.enable: | |
| return | |
| self._last_pairs: list = [] | |
| self._window_count: int = 0 | |
| self._running = False | |
| if not config.enable: | |
| return |
| try: | ||
| # ``self_device_time_total`` is torch 2.11's field (the older | ||
| # ``self_cuda_time_total`` was removed). Microseconds, self time. | ||
| self._last_pairs = [(str(e.key), float(e.self_device_time_total)) for e in prof.key_averages()] |
There was a problem hiding this comment.
Using e.self_device_time_total directly can raise an AttributeError on older PyTorch versions (where it was named self_cuda_time_total or self_cpu_time_total). Using getattr with fallbacks makes the kernel summary extraction robust across different PyTorch versions and profiling activities.
self._last_pairs = [
(
str(e.key),
float(
getattr(
e,
"self_device_time_total",
getattr(e, "self_cuda_time_total", getattr(e, "self_cpu_time_total", 0.0)),
)
),
)
for e in prof.key_averages()
]| def start(self) -> None: | ||
| if self.check(): | ||
| logger.info(f"[Profiler] stopped for rank {self.rank}") | ||
| self.prof.stop() | ||
|
|
||
| def save(self): | ||
| if self.prof is not None and not self.saved: | ||
| if not os.path.exists(self.save_path): | ||
| os.makedirs(self.save_path) | ||
| save_file_name = f"/prof_rank_{self.rank}.json" | ||
| logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}") | ||
| self.prof.export_chrome_trace(self.save_path + save_file_name) | ||
| self.enable = False | ||
| self.saved = True | ||
| try: | ||
| logger.info(f"[Profiler] started for rank {self.rank}") | ||
| self.prof.start() | ||
| except Exception as e: | ||
| self._disable("start", e) |
There was a problem hiding this comment.
Set self._running = True upon successful start of the profiler.
| def start(self) -> None: | |
| if self.check(): | |
| logger.info(f"[Profiler] stopped for rank {self.rank}") | |
| self.prof.stop() | |
| def save(self): | |
| if self.prof is not None and not self.saved: | |
| if not os.path.exists(self.save_path): | |
| os.makedirs(self.save_path) | |
| save_file_name = f"/prof_rank_{self.rank}.json" | |
| logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}") | |
| self.prof.export_chrome_trace(self.save_path + save_file_name) | |
| self.enable = False | |
| self.saved = True | |
| try: | |
| logger.info(f"[Profiler] started for rank {self.rank}") | |
| self.prof.start() | |
| except Exception as e: | |
| self._disable("start", e) | |
| def start(self) -> None: | |
| if self.check(): | |
| try: | |
| logger.info(f"[Profiler] started for rank {self.rank}") | |
| self.prof.start() | |
| self._running = True | |
| except Exception as e: | |
| self._disable("start", e) |
| def stop(self) -> None: | ||
| if self.check(): | ||
| logger.info(f"[Profiler] Trace stopped for rank {self.rank}") | ||
| self.enable = False | ||
| try: | ||
| logger.info(f"[Profiler] stopped for rank {self.rank}") | ||
| self.prof.stop() | ||
| except Exception as e: | ||
| self._disable("stop", e) |
There was a problem hiding this comment.
Only call self.prof.stop() if the profiler is currently running (self._running is True), and reset the running state in a finally block. This prevents PyTorch from raising a RuntimeError if stop() is called on a non-running profiler (e.g., if start() failed or was never called).
| def stop(self) -> None: | |
| if self.check(): | |
| logger.info(f"[Profiler] Trace stopped for rank {self.rank}") | |
| self.enable = False | |
| try: | |
| logger.info(f"[Profiler] stopped for rank {self.rank}") | |
| self.prof.stop() | |
| except Exception as e: | |
| self._disable("stop", e) | |
| def stop(self) -> None: | |
| if self.check() and getattr(self, "_running", False): | |
| try: | |
| logger.info(f"[Profiler] stopped for rank {self.rank}") | |
| self.prof.stop() | |
| except Exception as e: | |
| self._disable("stop", e) | |
| finally: | |
| self._running = False |
| bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES] | ||
| if bad_activities: | ||
| raise ValueError( | ||
| f"invalid `torch_profiler_config.activities` entries {bad_activities}. " | ||
| f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}." | ||
| ) |
There was a problem hiding this comment.
Add a check to ensure self.activities is not None before iterating over it. If a user sets activities: null in their YAML configuration, this prevents a TypeError during validation.
if self.activities is None:
raise ValueError("`torch_profiler_config.activities` cannot be None.")
bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES]
if bad_activities:
raise ValueError(
f"invalid `torch_profiler_config.activities` entries {bad_activities}. "
f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}."
)| for name in ("skip_first", "wait", "warmup", "repeat"): | ||
| value = getattr(self, name) | ||
| if value < 0: | ||
| raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.") | ||
| if self.active < 1: | ||
| raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.") |
There was a problem hiding this comment.
Add checks to ensure that the numeric schedule fields are not None before performing comparison operations. This prevents a TypeError if any of these fields are set to null in the YAML configuration.
| for name in ("skip_first", "wait", "warmup", "repeat"): | |
| value = getattr(self, name) | |
| if value < 0: | |
| raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.") | |
| if self.active < 1: | |
| raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.") | |
| for name in ("skip_first", "wait", "warmup", "repeat"): | |
| value = getattr(self, name) | |
| if value is None: | |
| raise ValueError(f"`torch_profiler_config.{name}` cannot be None.") | |
| if value < 0: | |
| raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.") | |
| if self.active is None: | |
| raise ValueError("`torch_profiler_config.active` cannot be None.") | |
| if self.active < 1: | |
| raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.") |
| def test_dispatch_to_policy_when_enabled(self): | ||
| trainer, calls = self._trainer(enable=True) | ||
| trainer._profiler_start() | ||
| trainer._profiler_step() | ||
| trainer._profiler_stop() | ||
| assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")] |
There was a problem hiding this comment.
Add a unit test to verify that calling stop() on a profiler that was never started is a safe no-op and does not raise any exceptions.
| def test_dispatch_to_policy_when_enabled(self): | |
| trainer, calls = self._trainer(enable=True) | |
| trainer._profiler_start() | |
| trainer._profiler_step() | |
| trainer._profiler_stop() | |
| assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")] | |
| def test_dispatch_to_policy_when_enabled(self): | |
| trainer, calls = self._trainer(enable=True) | |
| trainer._profiler_start() | |
| trainer._profiler_step() | |
| trainer._profiler_stop() | |
| assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")] | |
| def test_profiler_stop_without_start_is_noop(tmp_path): | |
| prof = Profiler(_ProfCfg(), default_save_path=str(tmp_path)) | |
| # Calling stop() before start() should not raise any error and should be a safe no-op. | |
| prof.stop() | |
| assert getattr(prof, "_running", False) is False |
SumanthRH
left a comment
There was a problem hiding this comment.
Can you also make the changes to the full context trainer?
|
Done — bracketed |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a backend-agnostic, highly configurable torch.profiler integration for both FSDP and Megatron backends across RL and SFT training loops. It replaces the previous Megatron-specific profiler config with a unified TorchProfilerConfig and manages the profiler lifecycle within the trainers. Feedback on these changes highlights several robustness and compatibility improvements: wrapping trace exports in try...except blocks to prevent I/O or PyTorch version compatibility crashes, handling cloud storage URIs gracefully by falling back to local paths, and adding stricter validation checks to prevent TypeError exceptions when configuration values are parsed as None or empty.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| save_path: Optional[str] = None | ||
| """Trace output dir. Defaults to ``{ckpt_path}/profiler_traces`` when None.""" |
There was a problem hiding this comment.
Can you remove this default ? The fallback in the case where the ckpt path is a cloud path is non-obvious.
I would rather have this be an explicit parameter that users set
There was a problem hiding this comment.
Done, the config now requires an explicit save path when profiling is enabled
| f"[Profiler] cloud save_path {self.save_path!r} is not writable by torch.profiler; " | ||
| f"falling back to local './profiler_traces'." | ||
| ) | ||
| self.save_path = "./profiler_traces" |
There was a problem hiding this comment.
The directory from which ray workers execute is the working directory in /tmp/ray/session_latest/runtime_resources/working_dir_files/...
using this relative path would mean that the profiler traces get saved in the working directory in /tmp/ray , which is pretty bad! I would recommend making save path for the torch profiler traces explicit
There was a problem hiding this comment.
Done, the config now requires an explicit save path when profiling is enabled
|
@dyurk-lila did you test this with colocated training for profiling more than one global step? I'm noticing an issue with the profiler with FSDP backend where it breaks due to offload to CPU. ray.exceptions.RayTaskError(RuntimeError): �[36mray::skyrl_entrypoint()�[39m (pid=864994, ip=10.1.151.135)
File "/home/ray/default/dyurk-lila/examples/train_scripts/full_context/main_full_ctx.py", line 53, in skyrl_entrypoint
exp.run()
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/train/entrypoints/main_base.py", line 403, in run
asyncio.run(trainer.train())
File "/home/ray/anaconda3/lib/python3.12/asyncio/runners.py", line 195, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.12/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/examples/train_scripts/full_context/trainer_full_ctx.py", line 70, in train
training_input = self.fwd_logprobs_values_reward(training_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/.cache/uv/builds-v0/.tmpy8tZf1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/train/trainer.py", line 1170, in fwd_logprobs_values_reward
ref_output = self.dispatch.forward("ref", data_fwd_pass)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker_dispatch.py", line 227, in forward
self._ensure_on_gpu(model, need_optimizer=False, need_model=True)
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker_dispatch.py", line 153, in _ensure_on_gpu
self._actor_groups[other].offload_to_cpu()
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker.py", line 711, in offload_to_cpu
return ray.get(refs)
^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(RuntimeError): �[36mray::FSDPPolicyWorkerBase.offload_to_cpu()�[39m (pid=873281, ip=10.1.151.135, actor_id=a1ff557caaa2233dd4b5a86498000000, repr=<skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker.FSDPPolicyWorkerBase object at 0x7e79b2286fc0>)
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/utils/__init__.py", line 107, in swap_tensors
torch._C._swap_tensor_impl(t1, t2)
RuntimeError: Expected no weakrefs to t1's Tensor object but got 8
The above exception was the direct cause of the following exception:
�[36mray::FSDPPolicyWorkerBase.offload_to_cpu()�[39m (pid=873281, ip=10.1.151.135, actor_id=a1ff557caaa2233dd4b5a86498000000, repr=<skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker.FSDPPolicyWorkerBase object at 0x7e79b2286fc0>)
File "/home/ray/anaconda3/lib/python3.12/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/workers/worker.py", line 332, in offload_to_cpu
self.strategy.offload_to_cpu(
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/distributed/fsdp_strategy.py", line 123, in offload_to_cpu
offload_fsdp2_model_to_cpu(model, empty_cache=True)
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2026-06-06_05-43-45_413839_979/runtime_resources/working_dir_files/_ray_pkg_796f8bbc5193f5b4/skyrl/backends/skyrl_train/distributed/fsdp_utils.py", line 76, in offload_fsdp2_model_to_cpu
model.to("cpu", non_blocking=True)
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/transformers/modeling_utils.py", line 3650, in to
return super().to(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1384, in to
return self._apply(convert)
^^^^^^^^^^^^^^^^^^^^
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py", line 626, in _apply
ret = super()._apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 934, in _apply
module._apply(fn)
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 934, in _apply
module._apply(fn)
File "/home/ray/.cache/uv/builds-v0/.tmpaPoZgZ/lib/python3.12/site-packages/torch/nn/modules/module.py", line 991, in _apply
raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Embedding.weight |
SkyRL constructed a Profiler object on the Megatron policy worker but never
drove it -- start()/step()/stop() were called nowhere, so torch_profiler_config
was dead code. This wires torch.profiler end to end for both Megatron and FSDP,
both RL and SFT, with the full torch.profiler surface exposed as config.
- start_profile / profile_step / stop_profile RPCs on the shared Worker base,
dispatched via pass_through thin wrappers in WorkerDispatch.
- Trainers bracket the loop: start before, one profile_step per global step,
stop after -- all gated on torch_profiler_config.enable so non-profiling runs
dispatch zero extra RPCs.
- TorchProfilerConfig hoisted to PolicyConfig (also wired through the SFT config
bridge), exposing schedule (skip_first/wait/warmup/active/repeat) and capture
(activities/record_shapes/profile_memory/with_stack/with_flops/with_modules/
export_type) knobs. Defaults reproduce prior effective behavior; enable=false
by default. Traces written by tensorboard_trace_handler as HTA/Kineto-friendly
*.pt.trace.json under {ckpt_path}/profiler_traces by default.
- Profiles only the policy model's training step (fwd/bwd + optimizer), not the
critic/ref models and not generation/inference.
- All profiler paths are exception-isolated: a fault disables profiling for the
rest of the run rather than crashing it.
- Removes the redundant Megatron-only torch_profiler_config.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Bracket FullCtxTrainer.train() with _profiler_start / _profiler_step / _profiler_stop, matching the wiring in RayPPOTrainer.train(). One profile_step per dummy global step; stop runs in a finally so the open kineto trace window isn't leaked if a step raises. No-op unless torch_profiler_config.enable is set, so non-profiling runs pay nothing. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ect empty activities
Two robustness fixes from PR review:
- Profiler.__init__ detects a cloud save_path (s3://, gs://, gcs:// via the
existing is_cloud_path helper) and falls back to ./profiler_traces with a
warning. save_path commonly defaults to {ckpt_path}/profiler_traces, and
ckpt_path can be a cloud URI -- which torch.profiler can't write to, so the
trace would otherwise be silently lost.
- TorchProfilerConfig.validate rejects an empty `activities` list. It passed the
membership check vacuously, but torch.profiler.profile(activities=[]) records
nothing -- now it fails fast at startup.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
0e73e00 to
cd63131
Compare
…derived default
Per review: remove the non-obvious {ckpt_path}/profiler_traces default and the
silent cloud-URI -> ./profiler_traces fallback. The relative fallback would land
traces in the Ray runtime working dir under /tmp/ray, and the ckpt-derived
default was surprising. save_path is now a required, explicit local path,
validated fail-fast at startup (non-empty + not a cloud URI). Updated docs,
config defaults comment, and tests accordingly.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
I had not tried this but was able to reproduce the issue, seems to be specifically localized to this path (torch profiling + FSDP + colocated + manual CPU offload). There doesn't seem to be a great workaround for this, so I'll just add a config verifier to fail loudly and descriptively in this case. |
… CPU offload
The FSDP2 manual CPU-offload path moves models with model.to("cpu") ->
nn.Module._apply -> torch.utils.swap_tensors, which raises
"RuntimeError: _apply(): Couldn't swap <param>" while torch.profiler holds
weakrefs to those params during an active window (reproduced on CPU with
torch 2.12). That offload only fires mid-loop under colocation.
Map of the crash precondition (all required):
- strategy == "fsdp" (Megatron offloads via flat-buffer/.data reassignment,
never swap_tensors -> immune)
- fsdp_config.cpu_offload == False (the default "manual" path; cpu_offload=True
uses FSDP2-native offload, no manual swap)
- colocate_all OR colocate_policy_ref (otherwise no in-loop offload happens)
SFT is single-model and hardcodes colocate_all=False -> never at risk.
Enforce via TorchProfilerConfig.validate(): the RL validator (validate_cfg) now
passes strategy/colocation/cpu_offload context so an incompatible config fails
fast at startup with an actionable message (set cpu_offload=true, disable
colocation, or use Megatron). SFT calls validate() with no context so the
cross-field check is skipped. Added unit + end-to-end tests and documented the
restriction in config.mdx.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
What
SkyRL constructs a
Profilerobject on the Megatron policy worker but never drives it —.start()/.step()/.stop()are called nowhere in the repo, sotorch_profiler_configwas effectively dead code. This PR wirestorch.profilerup end to end for both Megatron and FSDP, both RL and SFT, with the fulltorch.profilersurface exposed as config (no hardcodedactive=1/ single step).After this, profiling reduces to setting a couple of config flags and reading the SkyRL-written traces — no worker subclass, no trainer overrides.
How it's driven
start_profile/profile_step/stop_profileRPCs on the sharedWorkerbase (worker.py) → on the Ray actor method table of bothPolicyWorkers automatically (same pattern asoptim_step,set_lr,save_memory_snapshot). No subclass, noray.remotere-wrap. Dispatched viapass_throughthin wrappers inWorkerDispatch.start_profilebefore, oneprofile_stepper global step,stop_profileafter (in afinally, so an open trace window is never leaked) — all gated ontorch_profiler_config.enableso non-profiling runs dispatch zero extra RPCs.sft_trainer.pytrain()/ dummy-train loop + oneprofile_stepintrain_step.trainer.pytrain()loop (and the async / fully-async trainers) + oneprofile_stepper global step, so a torchactivewindow spans the whole step (not a single minibatch).Config — full
torch.profilersurface, sane defaultsTorchProfilerConfig(hoisted toPolicyConfig, also wired through the SFT config bridge):skip_first, wait, warmup, active, repeat→torch.profiler.schedule. This is the "profile N steps, at an interval, repeating M times" knob (repeat=0= whole run).activities,record_shapes,profile_memory,with_stack,with_flops,with_modules,export_type.enable=false(default) = unchanged from before.tensorboard_trace_handleras HTA/Kineto-friendly*.pt.trace.json(one per active window per rank).save_pathdefaults to{ckpt_path}/profiler_traces.TorchProfilerConfig.validate()rejects unusable settings up front (called from both the RL and SFT entrypoint validators) so an enabled run fails fast instead of silently degrading.Scope
Profiles only the policy model's training step (forward/backward + optimizer). In an RL run it does not profile the critic or reference models, and it does not profile generation/inference — only policy training compute on the configured
ranks.Per-kernel summary data path
For downstream attribution tooling that wants per-kernel self-time without re-parsing the trace:
on_trace_readyalso stashes a pickle-safe per-kernel self-device-time summary for the just-closed window (exact — no cross-stream overlap double-counting).Profiler.get_kernel_summary()→Worker.dump_profiler_summary()RPC →WorkerDispatch.dump_profiler_summary(model)(returns a per-rank list). SkyRL's own trainers do not call this; the trace files remain the primary deliverable.Safety
All profiler paths are exception-isolated: a profiler fault disables profiling for the rest of the run rather than crashing it.
Cleanup
Removes the now-redundant Megatron-only
torch_profiler_config(it was dead; replaced by the backend-agnostic policy-level one).Testing
tests/backends/skyrl_train/utils/test_profiler.py): schedule-driven trace-file counts (single window,repeat,skip_firstdeferral), disabled/rank-not-selected no-ops,save_pathresolution, activities threading, exception isolation, the kernel-summary path, and the Worker / WorkerDispatch / trainer RPC plumbing.tests/train/test_config.pyandtests/train/test_sft_config.py:TorchProfilerConfig.validate()rejects bad configs on both the RL and SFT paths, andtorch_profiler_configbridges throughbuild_skyrl_config_for_sft.🤖 Generated with Claude Code