Skip to content

oadamharoon/stcrlrta

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RL-STC + RTA — Supplementary Code

Reinforcement learning of self-triggered control (STC) policies with run-time assurance (RTA) for safety-critical control. This folder contains every script used to produce the experimental results in the paper.

Requirements

A Conda environment specification is provided in requirements.yml:

conda env create -f requirements.yml
conda activate stcrlrta

Top-level dependencies (see requirements.yml for pinned versions):

  • Python 3.11
  • PyTorch
  • Stable-Baselines3 (DQN, PPO, SAC)
  • Gymnasium
  • MLflow (training metrics + run tracking)
  • NumPy, SciPy, Matplotlib, tqdm

GPU is supported (device="auto") but not required.

Code organisation

The codebase is organised into six groups. Paper section references in parentheses indicate where each script's output appears.

Environments — discrete action (DQN, PPO, Pref-DQN, Lagrangian-DQN)

File Environment State / Action sizes
stc_pendulum_gym_env.py Inverted pendulum 2 / 168
stc_cartpole_gym_env.py CartPole-v1 4 / 200
stc_quadrotor_gym_env.py Planar (2-D) quadrotor hover 6 / 792
stc_quadrotor3d_gym_env.py 6-DOF (3-D) quadrotor hover 12 / 5000

Each environment implements the STC + RTA reward described in the paper (Sec. Method): per-step stability term +1 − V/V_scale, communication reward −w_c · 1/τ, and the binary RTA shield that overrides the agent's action with the LQR backup u = −Kx whenever the linearised one-step prediction violates the safety set.

Environments — continuous action (SAC, Pref-SAC, Lagrangian-SAC)

File Environment Action shape
stc_pendulum_gym_env_cont.py Pendulum (cont) (1,)
stc_cartpole_gym_env_cont.py CartPole (cont) (1,)
stc_quadrotor_gym_env_cont.py Planar quadrotor (2,)
stc_quadrotor3d_gym_env_cont.py 6-DOF quadrotor (4,)

Continuous variants subclass the discrete envs and replace the discrete action grid with a Box. Physics, reward, RTA, and observation are identical.

Training scripts

File Algo Paper section
train_dqn.py DQN Main results + Ablations
train_ppo.py PPO Algorithm comparison (discrete)
train_sac.py SAC Algorithm comparison (continuous) + Quadrotor3D main results
train_dqn_pref.py Preference-conditioned DQN Preference-conditioned policies
train_sac_pref.py Preference-conditioned SAC Preference-conditioned policies
train_dqn_lagrangian.py Lagrangian-DQN (CMDP soft penalty) Lagrangian baseline
train_sac_lagrangian.py Lagrangian-SAC (CMDP soft penalty) Lagrangian baseline

All training scripts log per-episode metrics (ep_reward, ep_msi, ep_rta_pct, …) to ./mlruns/. View with mlflow ui then open http://localhost:5000.

Evaluation scripts

File Purpose Paper section
eval.py Single-policy diagnostic plots (states, τ, reward) + optional MP4 Qualitative trajectories
eval_baselines.py Fixed-rate LQR (τ_min, τ_match) and classical Lyapunov-STC Baselines table
eval_pref.py Sweep a preference-conditioned model over the w_c grid Pareto front
eval_lagrangian.py Lagrangian model evaluation (constraint + hard violation rates) Lagrangian baseline
eval_disturbance.py Robustness to constant / periodic / impulse disturbances Disturbance robustness table

Helper / orchestration scripts

File Purpose
multi_seed_runs.py End-to-end orchestrator: trains 3 seeds × 11 w_c × 3 envs + ablations + Lagrangian, then aggregates mean ± std (ddof=1) for the paper tables
compute_theory_vals.py Computes r*, verifies Assumption 1 (RTA backup safety) for each env
compute_wc_bounds.py Computes the principled lower bound on w_c from terminal-state Lyapunov values
utils.py Trajectory norms (P1–P4) and τ_k reconstruction from MSI series
preference_wrapper.py Gym wrapper appending a normalised w_c to the observation
lagrangian_wrapper.py Gym wrapper + SB3 callback implementing the CMDP primal-dual update

Shell sweep scripts

File Purpose
run_sweep.sh Launch one training session per w_c in parallel screen sessions
run_sweep_lagrangian.sh Same as run_sweep.sh but for train_dqn_lagrangian.py
eval_sweep.sh After a sweep finishes, evaluate every checkpoint (final + best) and emit report.md + raw_output.txt
eval_ppo_sweep.sh Targeted PPO Pendulum sweep evaluator (timestamp-indexed checkpoints)

Reproducing the paper

Every training and eval script defaults to --seed 0 (Python, NumPy, PyTorch, the env, and SB3 are all reseeded). All single-seed commands below use this default; multi-seed runs are produced by passing --seed 0, --seed 1, and --seed 2 (or, for the DQN tables, by multi_seed_runs.py).

The single command that reproduces all DQN-based tables (main results, no-RTA ablation, Lagrangian baseline, multi-seed aggregation across 3 seeds — seed 0, seed 1, and seed 2) is:

python multi_seed_runs.py

This launches 3 seeds × 11 w_c values × {Pendulum, CartPole, Planar Quadrotor} for the DQN main sweep, plus the No-RTA and Lagrangian-DQN ablations at the canonical w_c, then evaluates each checkpoint over 100 episodes and prints aggregated tables. It does not cover PPO, SAC, Quadrotor3D, the Fixed-τ ablation, domain randomisation, preference-conditioned, or SAC-Lagrangian — multi-seed runs for those experiments were produced by re-running the single-seed commands below with --seed 0, --seed 1, and --seed 2. To preview without executing: python multi_seed_runs.py --dry_run.

Below are the commands for each individual experiment.

Main results — discrete DQN (Pendulum, CartPole, planar Quadrotor)

# Train across the 11-point w_c grid in parallel screens
./run_sweep.sh --algo dqn --env pendulum_gym
./run_sweep.sh --algo dqn --env cartpole_gym
./run_sweep.sh --algo dqn --env quadrotor_gym

# Evaluate the sweep (final + best for each w_c, 100 eval episodes each)
./eval_sweep.sh --algo dqn --env pendulum_gym
./eval_sweep.sh --algo dqn --env cartpole_gym
./eval_sweep.sh --algo dqn --env quadrotor_gym

Algorithm comparison — PPO and SAC

PPO (Pendulum only):

./run_sweep.sh --algo ppo --env pendulum_gym

./eval_sweep.sh --algo ppo --env pendulum_gym

SAC (Pendulum, CartPole, planar Quadrotor):

./run_sweep.sh --algo sac --env pendulum_gym
./run_sweep.sh --algo sac --env cartpole_gym
./run_sweep.sh --algo sac --env quadrotor_gym

./eval_sweep.sh --algo sac --env pendulum_gym
./eval_sweep.sh --algo sac --env cartpole_gym
./eval_sweep.sh --algo sac --env quadrotor_gym

Quadrotor3D — SAC main results (2 M timesteps, 256-128-128 net)

The 6-DOF quadrotor needs more samples and capacity than the smaller envs. train_sac.py automatically uses a [256, 128, 128] actor/critic when --env quadrotor3d_gym is selected; the default budget is bumped to 2 M:

./run_sweep.sh --algo sac --env quadrotor3d_gym \
    --extra_args "--total_timesteps 2000000"

./eval_sweep.sh --algo sac --env quadrotor3d_gym

Ablations

Canonical w_c values per environment (where each best-model checkpoint achieves its highest MSI with lowest RTA activation): pendulum_gym = 8, cartpole_gym = 16, quadrotor_gym = 16, quadrotor3d_gym = 48.

τ_match values per environment (mean MSI of the seed-0 best-model checkpoint at the canonical w_c): pendulum_gym = 0.396 s, cartpole_gym = 0.317 s, quadrotor_gym = 0.259 s, quadrotor3d_gym = 0.302 s.

No-RTA ablation (RTA shield disabled at the canonical w_c per env):

python train_dqn.py --env pendulum_gym  --comm_weight 8.0  --no_rta
python train_dqn.py --env cartpole_gym  --comm_weight 16.0 --no_rta
python train_dqn.py --env quadrotor_gym --comm_weight 16.0 --no_rta
python train_sac.py --env quadrotor3d_gym --comm_weight 48.0 --no_rta \
    --total_timesteps 2000000

python eval.py --env pendulum_gym    --model <ckpt> --no_rta            --n_eval 100
python eval.py --env cartpole_gym    --model <ckpt> --no_rta            --n_eval 100
python eval.py --env quadrotor_gym   --model <ckpt> --no_rta            --n_eval 100
python eval.py --env quadrotor3d_gym --model <ckpt> --algo sac --no_rta --n_eval 100

Fixed-τ ablation (control-only, τ pinned to the τ_match value above):

python train_dqn.py --env pendulum_gym  --comm_weight 8.0  --tau_fixed 0.396
python train_dqn.py --env cartpole_gym  --comm_weight 16.0 --tau_fixed 0.317
python train_dqn.py --env quadrotor_gym --comm_weight 16.0 --tau_fixed 0.259
python train_sac.py --env quadrotor3d_gym --comm_weight 48.0 --tau_fixed 0.302 \
    --total_timesteps 2000000

python eval.py --env pendulum_gym    --model <ckpt> --tau_fixed 0.396           --n_eval 100
python eval.py --env cartpole_gym    --model <ckpt> --tau_fixed 0.317           --n_eval 100
python eval.py --env quadrotor_gym   --model <ckpt> --tau_fixed 0.259           --n_eval 100
python eval.py --env quadrotor3d_gym --model <ckpt> --algo sac --tau_fixed 0.302 --n_eval 100

Domain-randomisation / dynamics-mismatch study (mass ~ U[0.6, 1.4] × nominal; K, P, RTA thresholds remain at nominal):

python train_dqn.py --env pendulum_gym  --comm_weight 8.0  --randomize_mass
python train_dqn.py --env cartpole_gym  --comm_weight 16.0 --randomize_mass
python train_dqn.py --env quadrotor_gym --comm_weight 16.0 --randomize_mass
python train_sac.py --env quadrotor3d_gym --comm_weight 48.0 --randomize_mass \
    --total_timesteps 2000000

# Deploy at off-nominal (0.7×, 1.3×) and nominal (1.0×) mass scales
for scale in 0.7 1.0 1.3; do
  python eval.py --env pendulum_gym    --model <ckpt> --dynamics_scale $scale --n_eval 100
  python eval.py --env cartpole_gym    --model <ckpt> --dynamics_scale $scale --n_eval 100
  python eval.py --env quadrotor_gym   --model <ckpt> --dynamics_scale $scale --n_eval 100
  python eval.py --env quadrotor3d_gym --model <ckpt> --algo sac --dynamics_scale $scale --n_eval 100
done

Baselines table — fixed-rate LQR + classical Lyapunov-STC

τ_match is the mean MSI of the seed-0 best-model checkpoint at the canonical w_c.

python eval_baselines.py --env pendulum_gym    --n_eval 100 --tau_match 0.396
python eval_baselines.py --env cartpole_gym    --n_eval 100 --tau_match 0.317
python eval_baselines.py --env quadrotor_gym   --n_eval 100 --tau_match 0.259
python eval_baselines.py --env quadrotor3d_gym --n_eval 100 --tau_match 0.302

Lagrangian baseline (CMDP soft penalty)

DQN Lagrangian on the discrete envs and SAC Lagrangian for Quadrotor3D at the canonical w_c values:

python train_dqn_lagrangian.py --env pendulum_gym  --comm_weight 8.0
python train_dqn_lagrangian.py --env cartpole_gym  --comm_weight 16.0
python train_dqn_lagrangian.py --env quadrotor_gym --comm_weight 16.0
python train_sac_lagrangian.py --env quadrotor3d_gym --comm_weight 48.0 \
    --total_timesteps 2000000

# Evaluate (constraint + hard violation rates)
python eval_lagrangian.py --env pendulum_gym \
    --model checkpoints/pendulum_gym_lagrangian_wc8.0_seed0__<ts>/best_model    --n_eval 100
python eval_lagrangian.py --env cartpole_gym \
    --model checkpoints/cartpole_gym_lagrangian_wc16.0_seed0__<ts>/best_model   --n_eval 100
python eval_lagrangian.py --env quadrotor_gym \
    --model checkpoints/quadrotor_gym_lagrangian_wc16.0_seed0__<ts>/best_model  --n_eval 100
python eval_lagrangian.py --env quadrotor3d_gym \
    --model checkpoints/quadrotor3d_gym_lagrangian_wc48.0_seed0__<ts>/best_model --n_eval 100

Preference-conditioned policies

Lower-dim envs (Pendulum, CartPole, planar Quadrotor) — 2 M timesteps over the standard 11-point w_c grid [0.25, 0.5, 1, 2, 4, 6, 8, 10, 12, 14, 16]:

python train_dqn_pref.py --env pendulum_gym  --wc_sampling discrete \
    --total_timesteps 2000000
python train_dqn_pref.py --env cartpole_gym  --wc_sampling discrete \
    --total_timesteps 2000000
python train_dqn_pref.py --env quadrotor_gym --wc_sampling discrete \
    --total_timesteps 2000000

python eval_pref.py --env pendulum_gym  \
    --model dqn_pref_pendulum_gym__<ts>  \
    --best_model checkpoints/pref_pendulum_gym__<ts>/best_model  \
    --sweep --save_report --n_eval 100
python eval_pref.py --env cartpole_gym  \
    --model dqn_pref_cartpole_gym__<ts>  \
    --best_model checkpoints/pref_cartpole_gym__<ts>/best_model  \
    --sweep --save_report --n_eval 100
python eval_pref.py --env quadrotor_gym \
    --model dqn_pref_quadrotor_gym__<ts> \
    --best_model checkpoints/pref_quadrotor_gym__<ts>/best_model \
    --sweep --save_report --n_eval 100

6-DOF Quadrotor3D — 4 M timesteps over the high-dim grid [0.25, 1, 4, 8, 16, 24, 32, 40, 48, 56, 64] (eval_pref.py auto-detects quadrotor3d_gym, loads the SAC model, and bumps the sweep grid to wc_max=64):

python train_sac_pref.py --env quadrotor3d_gym --wc_sampling discrete \
    --wc_max 64.0 --total_timesteps 4000000

python eval_pref.py --env quadrotor3d_gym \
    --model sac_pref_quadrotor3d_gym__<ts> \
    --best_model checkpoints/pref_quadrotor3d_gym__<ts>/best_model \
    --sweep --save_report --n_eval 100

Disturbance robustness

Each --*_ckpt is the seed-0 best-model checkpoint at the canonical w_c.

python eval_disturbance.py \
    --pendulum_ckpt    checkpoints/pendulum_gym__<ts>/best_model.zip \
    --cartpole_ckpt    checkpoints/cartpole_gym__<ts>/best_model.zip \
    --quadrotor_ckpt   checkpoints/quadrotor_gym__<ts>/best_model.zip \
    --quadrotor3d_ckpt checkpoints/quadrotor3d_gym__<ts>/best_model.zip \
    --n_eval 100

Theory verification

python compute_theory_vals.py    # r*, Assumption 1 for each env
python compute_wc_bounds.py      # principled w_c lower bound from V_term

Outputs

Training writes:

  • ./mlruns/ — MLflow tracking (run with mlflow ui to inspect).
  • ./checkpoints/{run_tag}__{timestamp}/ — periodic checkpoints + best_model.zip.
  • ./{algo}_stc_{env}__{timestamp}.zip — final model.

eval_sweep.sh writes report.md + raw_output.txt next to each sweep's checkpoint directory; eval.py --render writes per-episode MP4s under ./eval_output/.

About

We use RL to learn the control law of a system along with the optimal update time for the system using self-triggered control (STC) with run-time assurance (RTA).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors