From 78aa7432b232b61e67b45212da147a0606e4ae98 Mon Sep 17 00:00:00 2001 From: GHX5T-SOL <200635707+GHX5T-SOL@users.noreply.github.com> Date: Tue, 19 May 2026 17:42:56 +0200 Subject: [PATCH] fix: clarify output plot unit labels --- ogcore/output_plots.py | 28 +++++++++++++++++------ tests/test_output_plots.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/ogcore/output_plots.py b/ogcore/output_plots.py index d3bc2836b..7877d80d6 100644 --- a/ogcore/output_plots.py +++ b/ogcore/output_plots.py @@ -12,6 +12,9 @@ from ogcore.utils import Inequality +INTEREST_RATE_VARS = {"r_gov", "r", "r_p"} + + def plot_aggregates( base_tpi, base_params, @@ -41,7 +44,8 @@ def plot_aggregates( var_list (list): names of variable to plot plot_type (string): type of plot, can be: 'pct_diff': plots percentage difference between baseline - and reform ((reform-base)/base) + and reform ((reform-base)/base). For interest rates, + percentage point differences are plotted. 'diff': plots difference between baseline and reform (reform-base) 'levels': plot variables in model units @@ -73,6 +77,8 @@ def plot_aggregates( # Check that reform included if doing pct_diff or diff plot if plot_type == "pct_diff" or plot_type == "diff": assert reform_tpi is not None + only_interest_rates = all(v in INTEREST_RATE_VARS for v in var_list) + has_interest_rates = any(v in INTEREST_RATE_VARS for v in var_list) fig1, ax1 = plt.subplots() if not stationarized: for v in var_list: @@ -91,7 +97,12 @@ def plot_aggregates( plot_var = reform_tpi[v] - base_tpi[v] else: plot_var = (reform_tpi[v] - base_tpi[v]) / base_tpi[v] - ylabel = r"Pct. change" + if only_interest_rates: + ylabel = r"Percentage point change" + elif has_interest_rates: + ylabel = r"Pct. or percentage point change" + else: + ylabel = r"Pct. change" plt.plot( year_vec, plot_var[start_index : start_index + num_years_to_plot], @@ -119,7 +130,7 @@ def plot_aggregates( ], label="Reform " + VAR_LABELS[v], ) - ylabel = r"Model Units" + ylabel = r"Rate" if only_interest_rates else r"Model Units" elif plot_type == "forecast": # Need reform and baseline to ensure plot makes sense assert reform_tpi is not None @@ -437,8 +448,8 @@ def plot_gdp_ratio( p (OG-Core Specifications class): parameters object var_list (list): names of variable to plot plot_type (string): type of plot, can be: - 'diff': plots difference between baseline and reform - (reform-base) + 'diff' or 'pct_diff': plots percentage point difference + between baseline and reform ratios to GDP (reform-base) 'levels': plot variables in model units num_years_to_plot (integer): number of years to include in plot start_year (integer): year to start plot @@ -453,7 +464,7 @@ def plot_gdp_ratio( assert isinstance(start_year, (int, np.integer)) assert isinstance(num_years_to_plot, int) assert num_years_to_plot <= base_params.T - if plot_type == "diff": + if plot_type != "levels": assert reform_tpi is not None # Make sure both runs cover same time period if reform_tpi: @@ -510,7 +521,10 @@ def plot_gdp_ratio( plot_var[start_index : start_index + num_years_to_plot], label=ToGDP_LABELS[v], ) - ylabel = r"Percent of GDP" + if plot_type == "levels": + ylabel = r"Percent of GDP" + else: + ylabel = r"Percentage points of GDP" # vertical markers at certain years if vertical_line_years: for yr in vertical_line_years: diff --git a/tests/test_output_plots.py b/tests/test_output_plots.py index 35c09e0b6..4be822ada 100644 --- a/tests/test_output_plots.py +++ b/tests/test_output_plots.py @@ -249,6 +249,38 @@ def test_plot_aggregates_not_a_type(tmpdir): ) +def test_plot_aggregates_interest_rate_pct_diff_label(): + fig = output_plots.plot_aggregates( + base_tpi, + base_params, + reform_tpi=reform_tpi, + reform_params=reform_params, + var_list=["r"], + plot_type="pct_diff", + num_years_to_plot=20, + start_year=int(base_params.start_year), + ) + + assert fig.axes[0].get_ylabel() == "Percentage point change" + plt.close() + + +def test_plot_aggregates_interest_rate_levels_label(): + fig = output_plots.plot_aggregates( + base_tpi, + base_params, + reform_tpi=reform_tpi, + reform_params=reform_params, + var_list=["r"], + plot_type="levels", + num_years_to_plot=20, + start_year=int(base_params.start_year), + ) + + assert fig.axes[0].get_ylabel() == "Rate" + plt.close() + + test_data = [ (base_tpi, base_params, None, None, None, None, "levels"), (base_tpi, base_params, reform_tpi, reform_params, None, None, "levels"), @@ -316,6 +348,20 @@ def test_plot_gdp_ratio_save_fig(tmpdir): assert isinstance(img, np.ndarray) +def test_plot_gdp_ratio_pct_diff_label(): + fig = output_plots.plot_gdp_ratio( + base_tpi, + base_params, + reform_tpi=reform_tpi, + reform_params=reform_params, + start_year=int(base_params.start_year), + plot_type="pct_diff", + ) + + assert fig.axes[0].get_ylabel() == "Percentage points of GDP" + plt.close() + + def test_ability_bar(): fig = output_plots.ability_bar( base_tpi,