Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python.version }}
- name: Install the project
run: uv sync -p ${{ matrix.python.version }} -U --no-dev
- name: Install the project (include dev deps for tests)
run: uv sync -p ${{ matrix.python.version }} -U
- name: Run pytest
run: >
uv run --with pytest --with pytest-cov
pytest --cov -m "not system"
run: uv run pytest -vv --cov -m "not system"
2 changes: 1 addition & 1 deletion src/axtreme/eval/qoi_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def plot_distribution(
for samples_i in samples_list:
_ = ax.hist(samples_i, density=True, alpha=0.3, bins=len(samples_i) // 5 + 1)
_ = ax.set_xlabel("QOI value")
_ = ax.set_title("QOi estimator distibutions")
_ = ax.set_title("QOI estimator distributions")

if brute_force:
_ = ax.axvline(brute_force, c="black", label=f"Brute force ({brute_force:.2f})")
Expand Down
2 changes: 1 addition & 1 deletion src/axtreme/plotting/doe.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def plot_qoi_estimates_from_experiment(
x,
qoi_means - 1.96 * qoi_sems,
qoi_means + 1.96 * qoi_sems,
label=f"90% Confidence Bound {name}",
label=f"95% Confidence Bound {name}",
alpha=0.3,
Comment thread
am-kaiser marked this conversation as resolved.
**kwargs,
)
Expand Down
8 changes: 4 additions & 4 deletions src/axtreme/plotting/gp_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def plot_surface_over_2d_search_space(
colors: A list of colors to use for each function. If None, will use default Plotly colors.
num_points: The number of points in each dimension to evaluate the functions at.
"""
# Extract the parameter names and ranges from the search space
# Extract the two parameters from the search space; only their ranges are used below.
assert len(search_space.parameters) == 2, "Only 2D search spaces are supported for now." # noqa: PLR2004

(x1_name, x1_param), (x2_name, x2_param) = list(search_space.parameters.items())
(_, x1_param), (_, x2_param) = list(search_space.parameters.items())

if not (isinstance(x1_param, RangeParameter) and isinstance(x2_param, RangeParameter)):
msg = f"""Expect search_space.parameters to all be of type RangeParameter.
Expand Down Expand Up @@ -371,11 +371,11 @@ def plot_1d_model(model: SingleTaskGP, X: torch.Tensor | None = None, ax: None |
mean = posterior.mean[:, target_idx]
var = posterior.variance[:, target_idx]
_ = ax.fill_between(X.flatten(), mean - 1.95 * var**0.5, mean + 1.95 * var**0.5, alpha=0.3, color=c)
_ = ax.plot(X, mean, color=c, label=f"gp target {target_idx}")
_ = ax.plot(X, mean, color=c, label=f"GP target {target_idx}")
_ = ax.scatter(train_x.flatten(), train_y, color=c)
_ = ax.errorbar(train_x.flatten(), train_y, 1.95 * train_var**0.5, fmt="o", color=c)

_ = ax.set_title("Gp prediction")
_ = ax.set_title("GP prediction")
return ax


Expand Down
Loading