Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions ogcore/output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from ogcore.utils import Inequality


INTEREST_RATE_VARS = {"r_gov", "r", "r_p"}


def plot_aggregates(
base_tpi,
base_params,
Expand Down Expand Up @@ -41,7 +44,8 @@ def plot_aggregates(
var_list (list): names of variable to plot
plot_type (string): type of plot, can be:
'pct_diff': plots percentage difference between baseline
and reform ((reform-base)/base)
and reform ((reform-base)/base). For interest rates,
percentage point differences are plotted.
'diff': plots difference between baseline and reform
(reform-base)
'levels': plot variables in model units
Expand Down Expand Up @@ -73,6 +77,8 @@ def plot_aggregates(
# Check that reform included if doing pct_diff or diff plot
if plot_type == "pct_diff" or plot_type == "diff":
assert reform_tpi is not None
only_interest_rates = all(v in INTEREST_RATE_VARS for v in var_list)
has_interest_rates = any(v in INTEREST_RATE_VARS for v in var_list)
fig1, ax1 = plt.subplots()
if not stationarized:
for v in var_list:
Expand All @@ -91,7 +97,12 @@ def plot_aggregates(
plot_var = reform_tpi[v] - base_tpi[v]
else:
plot_var = (reform_tpi[v] - base_tpi[v]) / base_tpi[v]
ylabel = r"Pct. change"
if only_interest_rates:
ylabel = r"Percentage point change"
elif has_interest_rates:
ylabel = r"Pct. or percentage point change"
else:
ylabel = r"Pct. change"
plt.plot(
year_vec,
plot_var[start_index : start_index + num_years_to_plot],
Expand Down Expand Up @@ -119,7 +130,7 @@ def plot_aggregates(
],
label="Reform " + VAR_LABELS[v],
)
ylabel = r"Model Units"
ylabel = r"Rate" if only_interest_rates else r"Model Units"
elif plot_type == "forecast":
# Need reform and baseline to ensure plot makes sense
assert reform_tpi is not None
Expand Down Expand Up @@ -437,8 +448,8 @@ def plot_gdp_ratio(
p (OG-Core Specifications class): parameters object
var_list (list): names of variable to plot
plot_type (string): type of plot, can be:
'diff': plots difference between baseline and reform
(reform-base)
'diff' or 'pct_diff': plots percentage point difference
between baseline and reform ratios to GDP (reform-base)
'levels': plot variables in model units
num_years_to_plot (integer): number of years to include in plot
start_year (integer): year to start plot
Expand All @@ -453,7 +464,7 @@ def plot_gdp_ratio(
assert isinstance(start_year, (int, np.integer))
assert isinstance(num_years_to_plot, int)
assert num_years_to_plot <= base_params.T
if plot_type == "diff":
if plot_type != "levels":
assert reform_tpi is not None
# Make sure both runs cover same time period
if reform_tpi:
Expand Down Expand Up @@ -510,7 +521,10 @@ def plot_gdp_ratio(
plot_var[start_index : start_index + num_years_to_plot],
label=ToGDP_LABELS[v],
)
ylabel = r"Percent of GDP"
if plot_type == "levels":
ylabel = r"Percent of GDP"
else:
ylabel = r"Percentage points of GDP"
# vertical markers at certain years
if vertical_line_years:
for yr in vertical_line_years:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,38 @@ def test_plot_aggregates_not_a_type(tmpdir):
)


def test_plot_aggregates_interest_rate_pct_diff_label():
fig = output_plots.plot_aggregates(
base_tpi,
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
var_list=["r"],
plot_type="pct_diff",
num_years_to_plot=20,
start_year=int(base_params.start_year),
)

assert fig.axes[0].get_ylabel() == "Percentage point change"
plt.close()


def test_plot_aggregates_interest_rate_levels_label():
fig = output_plots.plot_aggregates(
base_tpi,
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
var_list=["r"],
plot_type="levels",
num_years_to_plot=20,
start_year=int(base_params.start_year),
)

assert fig.axes[0].get_ylabel() == "Rate"
plt.close()


test_data = [
(base_tpi, base_params, None, None, None, None, "levels"),
(base_tpi, base_params, reform_tpi, reform_params, None, None, "levels"),
Expand Down Expand Up @@ -316,6 +348,20 @@ def test_plot_gdp_ratio_save_fig(tmpdir):
assert isinstance(img, np.ndarray)


def test_plot_gdp_ratio_pct_diff_label():
fig = output_plots.plot_gdp_ratio(
base_tpi,
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
start_year=int(base_params.start_year),
plot_type="pct_diff",
)

assert fig.axes[0].get_ylabel() == "Percentage points of GDP"
plt.close()


def test_ability_bar():
fig = output_plots.ability_bar(
base_tpi,
Expand Down