diff --git a/.gitignore b/.gitignore
index 1263f60e0..fa6c98c39 100644
--- a/.gitignore
+++ b/.gitignore
@@ -188,3 +188,6 @@ job_queue.db-wal
db.sqlite
trained_models/
+
+
+test_forecasting_models.ipynb
diff --git a/DashAI/back/api/api_v1/endpoints/datasets.py b/DashAI/back/api/api_v1/endpoints/datasets.py
index 562bea733..60677d462 100644
--- a/DashAI/back/api/api_v1/endpoints/datasets.py
+++ b/DashAI/back/api/api_v1/endpoints/datasets.py
@@ -630,6 +630,189 @@ async def get_info(
return info
+@router.get("/{dataset_id}/temporal-info")
+@inject
+async def get_temporal_info(
+ dataset_id: int,
+ timestamp_column: str,
+ session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]),
+):
+ """Get temporal information about a dataset for forecasting tasks.
+
+ This endpoint analyzes a timestamp column to detect frequency, date range,
+ and other temporal characteristics useful for time series forecasting.
+
+ Parameters
+ ----------
+ dataset_id : int
+ ID of the dataset to analyze.
+ timestamp_column : str
+ Name of the column containing timestamps.
+
+ Returns
+ -------
+ dict
+ Dictionary with temporal information including:
+ - frequency_code: Short code (D, H, M, W, A, T)
+ - frequency_label: Human-readable label
+ - frequency_description: Detailed description
+ - start_date: First timestamp in the series
+ - end_date: Last timestamp in the series
+ - total_periods: Number of data points
+ - detected_gaps: Number of missing periods detected
+ """
+ import os
+
+ import pandas as pd
+ import pyarrow.ipc as ipc
+
+ from DashAI.back.core.enums.status import DatasetStatus
+
+ with session_factory() as db:
+ try:
+ dataset = db.get(Dataset, dataset_id)
+ if not dataset:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Dataset not found",
+ )
+
+ if dataset.status != DatasetStatus.FINISHED:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail="Dataset is not in finished state",
+ )
+
+ # Load the dataset
+ dataset_path = f"{dataset.file_path}/dataset"
+ data_filepath = os.path.join(dataset_path, "data.arrow")
+
+ with pa.OSFile(data_filepath, "rb") as source:
+ reader = ipc.open_file(source)
+ table = reader.read_all()
+
+ data_frame = table.to_pandas()
+
+ if timestamp_column not in data_frame.columns:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Column '{timestamp_column}' not found in dataset",
+ )
+
+ # Convert to datetime
+ try:
+ timestamps = pd.to_datetime(data_frame[timestamp_column])
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Cannot parse '{timestamp_column}' as datetime: {str(e)}",
+ ) from e
+
+ # Sort and analyze
+ sorted_ts = timestamps.sort_values()
+ diffs = sorted_ts.diff().dropna()
+
+ if len(diffs) == 0:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Not enough data points to detect frequency",
+ )
+
+ # Get most common difference (mode)
+ mode_diff = (
+ diffs.mode().iloc[0] if len(diffs.mode()) > 0 else diffs.median()
+ )
+
+ # Frequency mapping with detailed info
+ frequency_map = {
+ "T": {
+ "code": "T",
+ "label": "Minutely",
+ "description": "Each row represents one minute",
+ "example": "e.g., 10:00, 10:01, 10:02...",
+ },
+ "H": {
+ "code": "H",
+ "label": "Hourly",
+ "description": "Each row represents one hour",
+ "example": "e.g., 10:00, 11:00, 12:00...",
+ },
+ "D": {
+ "code": "D",
+ "label": "Daily",
+ "description": "Each row represents one day",
+ "example": "e.g., Jan 1, Jan 2, Jan 3...",
+ },
+ "W": {
+ "code": "W",
+ "label": "Weekly",
+ "description": "Each row represents one week",
+ "example": "e.g., Week 1, Week 2, Week 3...",
+ },
+ "M": {
+ "code": "M",
+ "label": "Monthly",
+ "description": "Each row represents one month",
+ "example": "e.g., Jan, Feb, Mar...",
+ },
+ "A": {
+ "code": "A",
+ "label": "Yearly",
+ "description": "Each row represents one year",
+ "example": "e.g., 2022, 2023, 2024...",
+ },
+ }
+
+ # Detect frequency
+ if mode_diff >= pd.Timedelta(days=365):
+ freq_code = "A"
+ elif mode_diff >= pd.Timedelta(days=28):
+ freq_code = "M"
+ elif mode_diff >= pd.Timedelta(days=7):
+ freq_code = "W"
+ elif mode_diff >= pd.Timedelta(days=1):
+ freq_code = "D"
+ elif mode_diff >= pd.Timedelta(hours=1):
+ freq_code = "H"
+ else:
+ freq_code = "T"
+
+ freq_info = frequency_map[freq_code]
+
+ # Calculate average difference in human-readable format
+ avg_diff = diffs.mean()
+ if avg_diff >= pd.Timedelta(days=1):
+ avg_diff_str = f"{avg_diff.days} days"
+ elif avg_diff >= pd.Timedelta(hours=1):
+ avg_diff_str = f"{avg_diff.seconds // 3600} hours"
+ else:
+ avg_diff_str = f"{avg_diff.seconds // 60} minutes"
+
+ # Detect gaps (periods where diff is significantly larger than mode)
+ gap_threshold = mode_diff * 1.5
+ gaps = (diffs > gap_threshold).sum()
+
+ return {
+ "frequency_code": freq_info["code"],
+ "frequency_label": freq_info["label"],
+ "frequency_description": freq_info["description"],
+ "frequency_example": freq_info["example"],
+ "average_interval": avg_diff_str,
+ "start_date": sorted_ts.min().isoformat(),
+ "end_date": sorted_ts.max().isoformat(),
+ "total_periods": len(data_frame),
+ "detected_gaps": int(gaps),
+ "timestamp_column": timestamp_column,
+ }
+
+ except exc.SQLAlchemyError as e:
+ logger.exception(e)
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Internal database error",
+ ) from e
+
+
@router.get("/{dataset_id}/model-sessions-exist")
@inject
async def get_model_sessions_exist(
@@ -1264,6 +1447,9 @@ async def get_dataset_file(
col: sliced_batch[col][j].as_py()
for col in sliced_batch.schema.names
}
+ # Use jsonable_encoder to handle Timestamp and other
+ # non-JSON-serializable types
+ row = jsonable_encoder(row)
rows.append(row)
rows_collected += 1
if rows_collected >= page_size:
diff --git a/DashAI/back/api/api_v1/endpoints/explainers.py b/DashAI/back/api/api_v1/endpoints/explainers.py
index a430ef45d..5ace40601 100755
--- a/DashAI/back/api/api_v1/endpoints/explainers.py
+++ b/DashAI/back/api/api_v1/endpoints/explainers.py
@@ -241,8 +241,24 @@ async def upload_global_explainer(
status_code=status.HTTP_404_NOT_FOUND, detail="Run not found"
)
+ # Check if name exists and append suffix if needed
+ base_name = params.name
+ counter = 1
+ new_name = base_name
+
+ while True:
+ existing = db.scalars(
+ select(GlobalExplainer).where(GlobalExplainer.name == new_name)
+ ).first()
+
+ if not existing:
+ break
+
+ counter += 1
+ new_name = f"{base_name}_{counter}"
+
explainer = GlobalExplainer(
- name=params.name,
+ name=new_name,
run_id=params.run_id,
explainer_name=params.explainer_name,
parameters=params.parameters,
diff --git a/DashAI/back/api/api_v1/endpoints/jobs.py b/DashAI/back/api/api_v1/endpoints/jobs.py
index 715a24783..e39676a1b 100644
--- a/DashAI/back/api/api_v1/endpoints/jobs.py
+++ b/DashAI/back/api/api_v1/endpoints/jobs.py
@@ -164,6 +164,8 @@ async def get_job_details(
@router.post("/", status_code=status.HTTP_201_CREATED)
+@router.post("/start", status_code=status.HTTP_201_CREATED)
+@router.post("/start/", status_code=status.HTTP_201_CREATED)
@inject
async def enqueue_job(
request: Request,
diff --git a/DashAI/back/api/api_v1/endpoints/model_sessions.py b/DashAI/back/api/api_v1/endpoints/model_sessions.py
index ae48f8fdb..6bc3b7d80 100644
--- a/DashAI/back/api/api_v1/endpoints/model_sessions.py
+++ b/DashAI/back/api/api_v1/endpoints/model_sessions.py
@@ -157,6 +157,15 @@ async def validate_columns(
validation_response = {}
try:
+ # For ForecastingTask, validate BEFORE prepare_for_task
+ # (column names change after prepare, so validate with originals)
+ if params.task_name == "ForecastingTask":
+ task.validate_dataset_for_task(
+ dataset=minimal_dataset,
+ dataset_name=dataset.name,
+ input_columns=inputs_names,
+ output_columns=outputs_names,
+ )
task.prepare_for_task(
dataset=minimal_dataset,
input_columns=inputs_names,
diff --git a/DashAI/back/api/api_v1/endpoints/predict.py b/DashAI/back/api/api_v1/endpoints/predict.py
index 188a14b02..b7aafc473 100644
--- a/DashAI/back/api/api_v1/endpoints/predict.py
+++ b/DashAI/back/api/api_v1/endpoints/predict.py
@@ -1,5 +1,6 @@
import json
import logging
+import math
from typing import TYPE_CHECKING, Dict, List
from fastapi import APIRouter, Depends, Query, Request, status
@@ -21,6 +22,20 @@
from DashAI.back.dependencies.registry.component_registry import ComponentRegistry
+
+def sanitize_for_json(value):
+ """Convert NaN/Inf float values to None for JSON serialization."""
+ if isinstance(value, dict):
+ return {k: sanitize_for_json(v) for k, v in value.items()}
+ elif isinstance(value, list):
+ return [sanitize_for_json(item) for item in value]
+ elif isinstance(value, float):
+ if math.isnan(value) or math.isinf(value):
+ return None
+ return value
+ return value
+
+
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
diff --git a/DashAI/back/api/api_v1/endpoints/runs.py b/DashAI/back/api/api_v1/endpoints/runs.py
index 58ecc42e4..24202e1b4 100644
--- a/DashAI/back/api/api_v1/endpoints/runs.py
+++ b/DashAI/back/api/api_v1/endpoints/runs.py
@@ -327,22 +327,62 @@ async def delete_run(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Run not found"
)
+
+ # Delete orphan-prone operations that are not linked by FK constraints.
+ global_explainers = (
+ db.query(GlobalExplainer).filter(GlobalExplainer.run_id == run_id).all()
+ )
+ for explainer in global_explainers:
+ if explainer.plot_path and os.path.exists(explainer.plot_path):
+ remove_path(explainer.plot_path)
+ if explainer.explanation_path and os.path.exists(
+ explainer.explanation_path
+ ):
+ remove_path(explainer.explanation_path)
+ db.delete(explainer)
+
+ local_explainers = (
+ db.query(LocalExplainer).filter(LocalExplainer.run_id == run_id).all()
+ )
+ for explainer in local_explainers:
+ if explainer.plots_path and os.path.exists(explainer.plots_path):
+ remove_path(explainer.plots_path)
+ if explainer.explanation_path and os.path.exists(
+ explainer.explanation_path
+ ):
+ remove_path(explainer.explanation_path)
+ db.delete(explainer)
+
+ for prediction in run.predictions:
+ if prediction.results_path and os.path.exists(prediction.results_path):
+ remove_path(prediction.results_path)
+
+ for path in [
+ run.run_path,
+ run.plot_history_path,
+ run.plot_slice_path,
+ run.plot_contour_path,
+ run.plot_importance_path,
+ ]:
+ if path and os.path.exists(path):
+ remove_path(path)
+
db.delete(run)
- if run.status == RunStatus.FINISHED:
- os.remove(run.run_path)
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
+ except HTTPException:
+ raise
except exc.SQLAlchemyError as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Internal database error",
+ detail=f"Internal database error: {e}",
) from e
- except OSError as e:
+ except (OSError, ValueError) as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to delete directory",
+ detail=f"Failed to delete run resources: {e}",
) from e
diff --git a/DashAI/back/api/api_v1/schemas/job_params.py b/DashAI/back/api/api_v1/schemas/job_params.py
index 869df132e..a9944922f 100644
--- a/DashAI/back/api/api_v1/schemas/job_params.py
+++ b/DashAI/back/api/api_v1/schemas/job_params.py
@@ -8,6 +8,7 @@ class JobParams(BaseModel):
job_type: Literal[
"ModelJob",
+ "ForecastingJob",
"ExplainerJob",
"PredictJob",
"DatasetJob",
diff --git a/DashAI/back/api/api_v1/schemas/predict_params.py b/DashAI/back/api/api_v1/schemas/predict_params.py
index 92c1a78ac..4c8bc9ad9 100644
--- a/DashAI/back/api/api_v1/schemas/predict_params.py
+++ b/DashAI/back/api/api_v1/schemas/predict_params.py
@@ -1,8 +1,18 @@
-from pydantic import BaseModel
+from typing import Optional
+
+from pydantic import BaseModel, Field
class PredictParams(BaseModel):
run_id: int
+ forecast_periods: Optional[int] = Field(
+ default=None,
+ description="Number of future periods to forecast (ForecastingTask only). "
+ "If provided, timestamps will be generated automatically from the last "
+ "training date.",
+ gt=0,
+ le=1000,
+ )
class RenameRequest(BaseModel):
diff --git a/DashAI/back/converters/simple_converters/extend_time_series_converter.py b/DashAI/back/converters/simple_converters/extend_time_series_converter.py
new file mode 100644
index 000000000..25c5b8956
--- /dev/null
+++ b/DashAI/back/converters/simple_converters/extend_time_series_converter.py
@@ -0,0 +1,466 @@
+"""
+Extend Time Series Converter for DashAI.
+
+This converter extends a time series dataset by adding n future timestamps
+with the same period as the original dataset. This is useful for preparing
+datasets for forecasting predictions.
+"""
+
+from typing import Union
+
+import pandas as pd
+
+from DashAI.back.converters.base_converter import BaseConverter
+from DashAI.back.core.schema_fields import (
+ int_field,
+ schema_field,
+ string_field,
+)
+from DashAI.back.core.schema_fields.base_schema import BaseSchema
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ to_dashai_dataset,
+)
+
+
+class ExtendTimeSeriesConverterSchema(BaseSchema):
+ """Schema for ExtendTimeSeriesConverter parameters."""
+
+ n_steps: schema_field(
+ int_field(ge=1, le=100000),
+ 1,
+ "Number of future time steps to add to the dataset (max: 100,000).",
+ ) # type: ignore
+
+ time_column: schema_field(
+ string_field(),
+ "",
+ (
+ "Name of the timestamp column to extend. "
+ "If empty, the converter will auto-detect datetime columns."
+ ),
+ ) # type: ignore
+
+
+class ExtendTimeSeriesConverter(BaseConverter):
+ """
+ Converter that extends a time series dataset with future timestamps.
+
+ This converter adds n new rows to the dataset with timestamps that continue
+ the sequence from the last timestamp in the dataset. The frequency/period
+ is automatically inferred from the existing timestamps.
+
+ All columns except the timestamp column will be filled with NaN values
+ in the new rows, as these are future values to be predicted.
+
+ Example:
+ --------
+ Original dataset:
+ date | y | exog1
+ 2024-01-01 | 10.5 | 100
+ 2024-01-02 | 11.2 | 105
+ 2024-01-03 | 12.1 | 110
+
+ After extending with n_steps=2:
+ date | y | exog1
+ 2024-01-01 | 10.5 | 100
+ 2024-01-02 | 11.2 | 105
+ 2024-01-03 | 12.1 | 110
+ 2024-01-04 | NaN | NaN
+ 2024-01-05 | NaN | NaN
+ """
+
+ SCHEMA = ExtendTimeSeriesConverterSchema
+ DESCRIPTION = (
+ "Extends a time series dataset by adding n future timestamps with the same "
+ "period as the original data. Other columns are filled with NaN values. "
+ "This is useful for preparing datasets for forecasting predictions."
+ )
+ SHORT_DESCRIPTION = "Extends time series with n future timestamps for forecasting."
+ DISPLAY_NAME = "Extend Time Series Converter"
+
+ # Maximum allowed n_steps to prevent memory issues
+ MAX_N_STEPS = 100000
+
+ def __init__(self, n_steps: int = 1, time_column: str = ""):
+ """Initialize the converter with schema parameters."""
+ super().__init__()
+ self.n_steps = n_steps
+ self.time_column = time_column
+
+ # Internal state
+ self._fitted = False
+ self._time_column_validated = ""
+ self._inferred_freq = None
+
+ def _detect_datetime_columns(self, df: pd.DataFrame) -> list[str]:
+ """
+ Detect columns with datetime or timestamp data types.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ DataFrame to analyze
+
+ Returns
+ -------
+ list[str]
+ List of column names with datetime/timestamp types
+ """
+ datetime_columns = []
+
+ for col in df.columns:
+ # Check if column dtype is datetime
+ if pd.api.types.is_datetime64_any_dtype(df[col]):
+ datetime_columns.append(col)
+ # Try to parse as datetime if it's object/string type
+ elif df[col].dtype == "object":
+ try:
+ # Try to convert a sample to datetime
+ pd.to_datetime(df[col].dropna().head(10), errors="raise")
+ datetime_columns.append(col)
+ except (ValueError, TypeError):
+ # Not a datetime column
+ pass
+
+ return datetime_columns
+
+ def _infer_frequency(self, time_series: pd.Series) -> pd.DateOffset:
+ """
+ Infer the frequency/period of a datetime series.
+
+ Parameters
+ ----------
+ time_series : pd.Series
+ Series with datetime values
+
+ Returns
+ -------
+ pd.DateOffset
+ The inferred frequency
+
+ Raises
+ ------
+ ValueError
+ If frequency cannot be inferred
+ """
+ # Ensure the series is sorted
+ time_series = time_series.sort_values().reset_index(drop=True)
+
+ # Convert to datetime if not already
+ if not pd.api.types.is_datetime64_any_dtype(time_series):
+ time_series = pd.to_datetime(time_series)
+
+ # Remove NaT values
+ time_series = time_series.dropna()
+
+ # Need at least 2 points to infer frequency
+ if len(time_series) < 2:
+ raise ValueError(
+ "Need at least 2 timestamps to infer frequency. "
+ f"Found {len(time_series)} timestamps."
+ )
+
+ # Check for duplicate timestamps
+ duplicates = time_series.duplicated()
+ if duplicates.any():
+ n_duplicates = duplicates.sum()
+ # Warning: we'll still try to infer, but user should know
+ import warnings
+
+ warnings.warn(
+ f"Found {n_duplicates} duplicate timestamp(s) in the time series. "
+ "This may affect frequency inference.",
+ UserWarning,
+ stacklevel=2,
+ )
+ # Remove duplicates for frequency inference
+ time_series = time_series.drop_duplicates()
+
+ # Try using pandas infer_freq on unique sorted values
+ freq = pd.infer_freq(time_series)
+ if freq is not None:
+ return pd.tseries.frequencies.to_offset(freq)
+
+ # If infer_freq fails, calculate the most common difference
+ diffs = time_series.diff().dropna()
+
+ if len(diffs) == 0:
+ raise ValueError("Cannot infer frequency: no time differences found")
+
+ # Filter out zero differences (duplicates that weren't caught)
+ diffs = diffs[diffs != pd.Timedelta(0)]
+
+ if len(diffs) == 0:
+ raise ValueError(
+ "Cannot infer frequency: all timestamps are identical "
+ "after removing duplicates"
+ )
+
+ # Get the most common difference
+ most_common_diff = diffs.mode()
+
+ if len(most_common_diff) == 0:
+ raise ValueError("Cannot infer frequency: no consistent time difference")
+
+ # Check if the frequency is reasonably consistent
+ # If there's high variance, warn the user
+ diff_std = diffs.std()
+ diff_mean = diffs.mean()
+
+ if diff_std / diff_mean > 0.5: # More than 50% coefficient of variation
+ import warnings
+
+ warnings.warn(
+ f"Timestamps have irregular intervals "
+ f"(std/mean = {diff_std / diff_mean:.2f}). "
+ f"Using most common difference: {most_common_diff.iloc[0]}. "
+ "Results may not be accurate for irregular time series.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ # Return the most common difference as a Timedelta
+ return most_common_diff.iloc[0]
+
+ def fit(
+ self, x: DashAIDataset, y: Union[DashAIDataset, None] = None
+ ) -> "ExtendTimeSeriesConverter":
+ """
+ Fit the converter by validating parameters and detecting time column.
+
+ Parameters
+ ----------
+ x : DashAIDataset
+ Input dataset containing the time series data
+ y : DashAIDataset, optional
+ Not used in this converter
+
+ Returns
+ -------
+ ExtendTimeSeriesConverter
+ The fitted converter instance
+
+ Raises
+ ------
+ ValueError
+ If validation fails (missing time column, invalid parameters, etc.)
+ """
+ # Validate parameters
+ if self.n_steps < 1:
+ raise ValueError("n_steps must be a positive integer")
+
+ if self.n_steps > self.MAX_N_STEPS:
+ raise ValueError(
+ f"n_steps cannot exceed {self.MAX_N_STEPS} to prevent memory issues. "
+ f"Requested: {self.n_steps}"
+ )
+
+ # Convert to pandas for analysis
+ data_frame: pd.DataFrame = x.to_pandas() # type: ignore
+
+ # Validate dataset is not empty
+ if len(data_frame) == 0:
+ raise ValueError(
+ "Cannot extend an empty dataset. "
+ "Please provide a dataset with at least 2 rows."
+ )
+
+ # Detect datetime columns
+ datetime_columns = self._detect_datetime_columns(data_frame)
+
+ if len(datetime_columns) == 0:
+ raise ValueError(
+ "No datetime columns found in the dataset. "
+ "Please ensure your dataset has at least one timestamp column."
+ )
+
+ # Determine which time column to use
+ if self.time_column:
+ # User specified a time column
+ if self.time_column not in data_frame.columns:
+ raise ValueError(
+ f"Specified time column '{self.time_column}' not found in dataset. "
+ f"Available columns: {list(data_frame.columns)}"
+ )
+
+ if self.time_column not in datetime_columns:
+ # Try to convert it to datetime
+ try:
+ data_frame[self.time_column] = pd.to_datetime(
+ data_frame[self.time_column]
+ )
+ self._time_column_validated = self.time_column
+ except (ValueError, TypeError) as e:
+ raise ValueError(
+ f"Column '{self.time_column}' cannot be converted "
+ f"to datetime: {e}"
+ ) from e
+ else:
+ self._time_column_validated = self.time_column
+ else:
+ # Auto-detect time column
+ if len(datetime_columns) > 1:
+ raise ValueError(
+ f"Multiple datetime columns found: {datetime_columns}. "
+ "Please specify which one to use with the 'time_column' parameter."
+ )
+ self._time_column_validated = datetime_columns[0]
+
+ # Infer the frequency
+ time_series = data_frame[self._time_column_validated]
+
+ # Convert to datetime if needed
+ if not pd.api.types.is_datetime64_any_dtype(time_series):
+ time_series = pd.to_datetime(time_series)
+
+ try:
+ self._inferred_freq = self._infer_frequency(time_series)
+ except ValueError as e:
+ raise ValueError(
+ f"Failed to infer frequency for time column "
+ f"'{self._time_column_validated}': {e}"
+ ) from e
+
+ self._fitted = True
+ return self
+
+ def transform(
+ self, x: DashAIDataset, y: Union[DashAIDataset, None] = None
+ ) -> DashAIDataset:
+ """
+ Transform the dataset by adding n future timestamps.
+
+ Parameters
+ ----------
+ x : DashAIDataset
+ Input dataset to transform
+ y : DashAIDataset, optional
+ Not used in this converter
+
+ Returns
+ -------
+ DashAIDataset
+ Extended dataset with n additional rows containing future timestamps
+
+ Raises
+ ------
+ ValueError
+ If converter is not fitted or transformation fails
+ """
+ if not self._fitted:
+ raise ValueError("Converter must be fitted before transform")
+
+ # Convert to pandas
+ data_frame: pd.DataFrame = x.to_pandas() # type: ignore
+
+ # Verify time column still exists
+ if self._time_column_validated not in data_frame.columns:
+ raise ValueError(
+ f"Time column '{self._time_column_validated}' not found "
+ f"in transform dataset"
+ )
+
+ # Convert time column to datetime if needed
+ if not pd.api.types.is_datetime64_any_dtype(
+ data_frame[self._time_column_validated]
+ ):
+ data_frame[self._time_column_validated] = pd.to_datetime(
+ data_frame[self._time_column_validated]
+ )
+
+ # Get the last timestamp
+ last_timestamp = data_frame[self._time_column_validated].max()
+
+ # Validate last_timestamp is not NaT
+ if pd.isna(last_timestamp):
+ raise ValueError(
+ f"Cannot extend time series: all timestamps in column "
+ f"'{self._time_column_validated}' are NaT (Not a Time)"
+ )
+
+ # Generate future timestamps
+ future_timestamps = []
+ current_timestamp = last_timestamp
+
+ try:
+ for _i in range(self.n_steps):
+ current_timestamp = current_timestamp + self._inferred_freq
+ future_timestamps.append(current_timestamp)
+ except (OverflowError, ValueError) as e:
+ raise ValueError(
+ f"Error generating future timestamp at step "
+ f"{_i + 1}/{self.n_steps}: {e}. "
+ "This might be due to timestamp overflow or invalid frequency."
+ ) from e
+
+ # Create new rows with future timestamps
+ future_rows = []
+ for future_ts in future_timestamps:
+ # Create a row with NaN for all columns except timestamp
+ new_row = dict.fromkeys(data_frame.columns)
+ new_row[self._time_column_validated] = future_ts
+ future_rows.append(new_row)
+
+ # Create DataFrame from future rows
+ future_df = pd.DataFrame(future_rows)
+
+ # Ensure the timestamp column has the same dtype
+ future_df[self._time_column_validated] = pd.to_datetime(
+ future_df[self._time_column_validated]
+ )
+
+ # Align column order with original dataframe
+ future_df = future_df[data_frame.columns]
+
+ # Preserve original data types as much as possible
+ # for non-timestamp columns
+ for col in data_frame.columns:
+ if col != self._time_column_validated:
+ # Try to maintain the original dtype
+ # (will be nullable version due to NaN)
+ try:
+ original_dtype = data_frame[col].dtype
+ # For numeric types, pandas will handle
+ # the NaN conversion automatically
+ if pd.api.types.is_numeric_dtype(original_dtype):
+ # Let pandas handle it naturally
+ # (int -> float for NaN compatibility)
+ pass
+ elif pd.api.types.is_datetime64_any_dtype(original_dtype):
+ future_df[col] = pd.to_datetime(future_df[col])
+ except Exception:
+ # If conversion fails, keep as is
+ # (likely already None/NaN)
+ pass
+
+ # Concatenate original data with future data
+ try:
+ extended_df = pd.concat([data_frame, future_df], ignore_index=True)
+ except Exception as e:
+ raise ValueError(
+ f"Error concatenating original and extended data: {e}. "
+ "This might be due to incompatible data types."
+ ) from e
+
+ # Validate the extended dataframe
+ if len(extended_df) != len(data_frame) + self.n_steps:
+ raise ValueError(
+ f"Extended dataset has unexpected number of rows. "
+ f"Expected: {len(data_frame) + self.n_steps}, "
+ f"Got: {len(extended_df)}"
+ )
+
+ # Convert back to DashAIDataset
+ return to_dashai_dataset(extended_df)
+
+ def changes_row_count(self) -> bool:
+ """
+ Indicates that this converter changes the number of rows.
+
+ Returns
+ -------
+ bool
+ True, as new rows with future timestamps are added
+ """
+ return True
diff --git a/DashAI/back/converters/simple_converters/time_series_window_converter.py b/DashAI/back/converters/simple_converters/time_series_window_converter.py
new file mode 100644
index 000000000..3a1dfee40
--- /dev/null
+++ b/DashAI/back/converters/simple_converters/time_series_window_converter.py
@@ -0,0 +1,233 @@
+"""
+Time Series Window Converter for DashAI.
+
+This converter transforms time series data into a tabular regression format
+by creating lag features and target columns with fixed horizons.
+"""
+
+from typing import Union
+
+import pandas as pd
+
+from DashAI.back.converters.base_converter import BaseConverter
+from DashAI.back.core.schema_fields import (
+ int_field,
+ schema_field,
+ string_field,
+)
+from DashAI.back.core.schema_fields.base_schema import BaseSchema
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ to_dashai_dataset,
+)
+
+
+class TimeSeriesWindowConverterSchema(BaseSchema):
+ """Schema for TimeSeriesWindowConverter parameters."""
+
+ window_size: schema_field(
+ int_field(ge=1),
+ 7,
+ "Number of past time steps to use as lag features (window size).",
+ ) # type: ignore
+
+ horizon: schema_field(
+ int_field(ge=1),
+ 1,
+ "Number of time steps into the future to predict (forecasting horizon).",
+ ) # type: ignore
+
+ target_column: schema_field(
+ string_field(),
+ "",
+ "Name of the target column containing the time series values to forecast.",
+ ) # type: ignore
+
+
+class TimeSeriesWindowConverter(BaseConverter):
+ """
+ Converter that transforms time series data into a regression problem.
+
+ This converter creates lag features (lag_1, lag_2, ..., lag_w) from a time series
+ and a target column shifted h steps into the future (y_target_h), where:
+ - w is the window_size parameter
+ - h is the horizon parameter
+
+ The resulting dataset can be used with standard regression models to perform
+ forecasting as a supervised learning problem.
+
+ Example:
+ --------
+ Original time series: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ With window_size=3 and horizon=1:
+
+ lag_3 lag_2 lag_1 y_target_1
+ 1 2 3 4
+ 2 3 4 5
+ 3 4 5 6
+ 4 5 6 7
+ 5 6 7 8
+ 6 7 8 9
+ 7 8 9 10
+ """
+
+ SCHEMA = TimeSeriesWindowConverterSchema
+ DESCRIPTION = (
+ "Transforms time series data into a tabular regression format by creating "
+ "lag features from past values and a target column shifted into the future. "
+ "This enables forecasting using standard regression models."
+ )
+ SHORT_DESCRIPTION = (
+ "Converts time series to regression with lag features and future targets."
+ )
+ DISPLAY_NAME = "Time Series Window Converter"
+
+ def __init__(self, window_size: int = 7, horizon: int = 1, target_column: str = ""):
+ """Initialize the converter with schema parameters."""
+ super().__init__()
+ self.window_size = window_size
+ self.horizon = horizon
+ self.target_column = target_column
+
+ # Internal state
+ self._fitted = False
+ self._target_column_validated = ""
+
+ def fit(
+ self, x: DashAIDataset, y: Union[DashAIDataset, None] = None
+ ) -> "TimeSeriesWindowConverter":
+ """
+ Fit the converter by validating parameters and target column.
+
+ Parameters
+ ----------
+ x : DashAIDataset
+ Input dataset containing the time series data
+ y : DashAIDataset, optional
+ Not used in this converter
+
+ Returns
+ -------
+ TimeSeriesWindowConverter
+ The fitted converter instance
+
+ Raises
+ ------
+ ValueError
+ If validation fails (missing target column, invalid parameters, etc.)
+ """
+ # Validate parameters
+ if self.window_size < 1:
+ raise ValueError("window_size must be a positive integer")
+
+ if self.horizon < 1:
+ raise ValueError("horizon must be a positive integer")
+
+ if not self.target_column:
+ raise ValueError("target_column must be a non-empty string")
+
+ # Check if target column exists in dataset
+ if self.target_column not in x.column_names:
+ raise ValueError(
+ f"Target column '{self.target_column}' not found in dataset. "
+ f"Available columns: {x.column_names}"
+ )
+
+ # Validate that we have enough data points
+ min_required_rows = self.window_size + self.horizon
+ if len(x) < min_required_rows:
+ raise ValueError(
+ f"Dataset has {len(x)} rows but needs at least "
+ f"{min_required_rows} rows (window_size={self.window_size} + "
+ f"horizon={self.horizon})"
+ )
+
+ # Store validated target column name
+ self._target_column_validated = self.target_column
+ self._fitted = True
+
+ return self
+
+ def transform(
+ self, x: DashAIDataset, y: Union[DashAIDataset, None] = None
+ ) -> DashAIDataset:
+ """
+ Transform the dataset by creating lag features and target column.
+
+ Parameters
+ ----------
+ x : DashAIDataset
+ Input dataset to transform
+ y : DashAIDataset, optional
+ Not used in this converter
+
+ Returns
+ -------
+ DashAIDataset
+ Transformed dataset with lag features and target column
+
+ Raises
+ ------
+ ValueError
+ If converter is not fitted or transformation fails
+ """
+ if not self._fitted:
+ raise ValueError("Converter must be fitted before transform")
+
+ # Convert to pandas for easier manipulation
+ data_frame = x.to_pandas()
+
+ # Verify target column still exists
+ if self._target_column_validated not in data_frame.columns:
+ raise ValueError(
+ f"Target column '{self._target_column_validated}' not found "
+ f"in transform dataset"
+ )
+
+ # Create a copy to avoid modifying the original
+ result_df = pd.DataFrame()
+
+ # Create lag features (lag_1, lag_2, ..., lag_w)
+ target_series = data_frame[self._target_column_validated]
+
+ for lag in range(1, self.window_size + 1):
+ lag_column_name = f"lag_{lag}"
+ result_df[lag_column_name] = target_series.shift(lag)
+
+ # Create multiple target columns (y_target_1 to y_target_horizon)
+ for h in range(1, self.horizon + 1):
+ target_column_name = f"y_target_{h}"
+ result_df[target_column_name] = target_series.shift(-h)
+
+ # Include any other columns that are not the target column
+ # This preserves potential date columns or other features
+ other_columns = [
+ col for col in data_frame.columns if col != self._target_column_validated
+ ]
+ for col in other_columns:
+ result_df[col] = data_frame[col]
+
+ # Remove rows with NaN values (caused by shifting)
+ # These occur at the beginning (due to lag) and end (due to future target)
+ result_df = result_df.dropna()
+
+ # Validate that we still have data after removing NaN rows
+ if len(result_df) == 0:
+ raise ValueError(
+ "No valid rows remain after creating lag features and target column. "
+ "Try reducing window_size or horizon, or use a larger dataset."
+ )
+
+ # Convert back to DashAIDataset
+ return to_dashai_dataset(result_df)
+
+ def changes_row_count(self) -> bool:
+ """
+ Indicates that this converter changes the number of rows.
+
+ Returns
+ -------
+ bool
+ True, as rows with NaN values are removed
+ """
+ return True
diff --git a/DashAI/back/dataloaders/classes/dashai_dataset.py b/DashAI/back/dataloaders/classes/dashai_dataset.py
index 320e952e6..4052feb27 100644
--- a/DashAI/back/dataloaders/classes/dashai_dataset.py
+++ b/DashAI/back/dataloaders/classes/dashai_dataset.py
@@ -3,9 +3,12 @@
import logging
import os
+import numpy as np
+import pandas as pd
+import pyarrow as pa
from beartype import beartype
from beartype.typing import Dict, List, Literal, Optional, Tuple, Union
-from datasets import Dataset
+from datasets import Dataset, DatasetDict
from DashAI.back.types.categorical import Categorical
from DashAI.back.types.dashai_data_type import DashAIDataType
@@ -523,7 +526,6 @@ def sample(
Dict
A dictionary with selected samples.
"""
- import numpy as np
if n > len(self):
raise ValueError(
@@ -697,7 +699,6 @@ def transform_dataset_with_schema(
DashAIDataset
- The updated dataset with new type information
"""
- import pyarrow as pa # local import
table = get_arrow_table(dataset)
dai_table = {}
@@ -789,8 +790,6 @@ def save_dataset(
import json
- import pyarrow as pa # local import
-
table = get_arrow_table(dataset)
data_filepath = os.path.join(path, "data.arrow")
with pa.OSFile(data_filepath, "wb") as sink:
@@ -829,8 +828,6 @@ def load_dataset(dataset_path: Union[str, os.PathLike]) -> DashAIDataset:
import json
- import pyarrow as pa # local import
-
data_filepath = os.path.join(dataset_path, "data.arrow")
with pa.OSFile(data_filepath, "rb") as source:
reader = pa.ipc.open_file(source)
@@ -960,7 +957,6 @@ def split_indexes(
# Generate shuffled indexes
if seed is None:
seed = 42
- import numpy as np
from sklearn.model_selection import train_test_split
indexes = np.arange(total_rows)
@@ -1045,7 +1041,6 @@ def split_dataset(
ValueError
Must provide all indexes or none.
"""
- import numpy as np
if all(idx is None for idx in [train_indexes, test_indexes, val_indexes]):
from datasets import DatasetDict
@@ -1279,8 +1274,6 @@ def get_columns_spec(dataset_path: str) -> Dict[str, Dict]:
"""
data_filepath = os.path.join(dataset_path, "data.arrow")
- import pyarrow as pa # local import
-
with pa.OSFile(data_filepath, "rb") as source:
reader = pa.ipc.open_file(source)
schema = reader.schema
@@ -1520,7 +1513,6 @@ def modify_table(
DashAIDataset
The modified dataset with the updated column type.
"""
- import pyarrow as pa
original_table = dataset.arrow_table
updated_columns = {}
@@ -1554,3 +1546,236 @@ def modify_table(
new_types = types if types else dataset.types
return DashAIDataset(new_table, splits=dataset.splits, types=new_types)
+
+
+@beartype
+def split_dataset_temporal(
+ dataset: DashAIDataset,
+ train_size: Union[int, float] = 0.7,
+ val_size: Union[int, float] = 0.15,
+ test_size: Union[int, float] = 0.15,
+ gap: int = 0,
+ timestamp_col: str = "ds",
+ min_train_size: int = 50,
+ min_val_size: int = 10,
+ min_test_size: int = 10,
+) -> DatasetDict:
+ """Time-aware data splitting for forecasting tasks.
+
+ Unlike random splitting, this maintains temporal order:
+ - Training data comes first chronologically
+ - Validation data follows training data
+ - Test data comes last
+ - No data leakage from future to past
+
+ Parameters
+ ----------
+ dataset : DashAIDataset
+ Dataset to split (must be sorted by timestamp)
+ train_size : Union[int, float]
+ Size of training set. If float, interpreted as proportion.
+ val_size : Union[int, float]
+ Size of validation set. If float, interpreted as proportion.
+ test_size : Union[int, float]
+ Size of test set. If float, interpreted as proportion.
+ gap : int
+ Number of periods to skip between splits to avoid data leakage.
+ timestamp_col : str
+ Name of timestamp column for ordering
+ min_train_size : int
+ Minimum number of training samples required
+ min_val_size : int
+ Minimum number of validation samples required
+ min_test_size : int
+ Minimum number of test samples required
+
+ Returns
+ -------
+ DatasetDict
+ Dictionary with 'train', 'validation', 'test' splits
+
+ Raises
+ ------
+ ValueError
+ If insufficient data for splits or validation fails
+ """
+ n_samples = dataset.num_rows
+
+ # Calculate actual split sizes from proportions or absolute values
+ if isinstance(train_size, float):
+ train_size = int(n_samples * train_size)
+ if isinstance(val_size, float):
+ val_size = int(n_samples * val_size)
+ if isinstance(test_size, float):
+ test_size = int(n_samples * test_size)
+
+ # Adjust for gaps
+ total_with_gaps = train_size + val_size + test_size + (2 * gap)
+
+ if total_with_gaps > n_samples:
+ # Proportionally reduce sizes to fit
+ available = n_samples - (2 * gap)
+ scale_factor = available / (train_size + val_size + test_size)
+
+ train_size = max(min_train_size, int(train_size * scale_factor))
+ val_size = max(min_val_size, int(val_size * scale_factor))
+ test_size = max(min_test_size, int(test_size * scale_factor))
+
+ # Validate minimum sizes
+ if train_size < min_train_size:
+ raise ValueError(
+ f"Training set too small: {train_size} < {min_train_size}. "
+ f"Need more data or smaller validation/test sets."
+ )
+
+ if val_size < min_val_size:
+ raise ValueError(
+ f"Validation set too small: {val_size} < {min_val_size}. "
+ f"Need more data or smaller test set."
+ )
+
+ if test_size < min_test_size:
+ raise ValueError(
+ f"Test set too small: {test_size} < {min_test_size}. Need more data."
+ )
+
+ # Ensure dataset is sorted by timestamp
+ df_raw = dataset.to_pandas()
+ if isinstance(df_raw, pd.DataFrame):
+ dataset_df = df_raw
+ else:
+ # Handle iterator case
+ dataset_df = pd.concat(df_raw, ignore_index=True)
+
+ if timestamp_col in dataset_df.columns:
+ dataset_df = dataset_df.sort_values(timestamp_col).reset_index(drop=True)
+
+ # Calculate split indices with gaps
+ train_end = train_size
+ val_start = train_end + gap
+ val_end = val_start + val_size
+ test_start = val_end + gap
+ test_end = test_start + test_size
+
+ if test_end > n_samples:
+ raise ValueError(
+ f"Not enough data for splits with gaps. Need {test_end} samples, "
+ f"have {n_samples}. Try reducing gap or split sizes."
+ )
+
+ # Create splits
+ # Split the dataset
+ train_df = dataset_df.iloc[:train_end]
+ val_df = dataset_df.iloc[val_start:val_end]
+ test_df = dataset_df.iloc[test_start:test_end]
+
+ # Convert back to DashAI datasets
+ train_dataset = to_dashai_dataset(Dataset.from_pandas(train_df))
+ val_dataset = to_dashai_dataset(Dataset.from_pandas(val_df))
+ test_dataset = to_dashai_dataset(Dataset.from_pandas(test_df))
+
+ # Log split information
+ if timestamp_col in dataset_df.columns:
+ log.info("✅ Temporal split completed:")
+ log.info(
+ f" Train: {len(train_df)} samples "
+ f"({dataset_df[timestamp_col].iloc[0]} to "
+ f"{dataset_df[timestamp_col].iloc[train_end - 1]})"
+ )
+ log.info(
+ f" Validation: {len(val_df)} samples "
+ f"({dataset_df[timestamp_col].iloc[val_start]} to "
+ f"{dataset_df[timestamp_col].iloc[val_end - 1]})"
+ )
+ log.info(
+ f" Test: {len(test_df)} samples "
+ f"({dataset_df[timestamp_col].iloc[test_start]} to "
+ f"{dataset_df[timestamp_col].iloc[test_end - 1]})"
+ )
+ if gap > 0:
+ log.info(f" Gap: {gap} periods between splits")
+ else:
+ log.info(
+ f"✅ Temporal split completed: "
+ f"Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}"
+ )
+
+ return DatasetDict(
+ {"train": train_dataset, "validation": val_dataset, "test": test_dataset}
+ )
+
+
+@beartype
+def prepare_for_forecasting_experiment(
+ dataset: DashAIDataset,
+ splits: dict,
+ timestamp_col: str = "ds",
+ output_columns: List[str] = None,
+) -> Tuple[DatasetDict, Dict]:
+ """Prepare dataset for forecasting experiment with temporal splits.
+
+ Parameters
+ ----------
+ dataset : DashAIDataset
+ Dataset to prepare for forecasting
+ splits : dict
+ Split configuration from frontend
+ timestamp_col : str
+ Name of timestamp column
+ output_columns : List[str]
+ Output columns (for compatibility)
+
+ Returns
+ -------
+ Tuple[DatasetDict, Dict]
+ Prepared dataset and split indices
+ """
+ splitType = splits.get("splitType")
+
+ if splitType in {"manual", "predefined"}:
+ # Preserve explicit user-defined splits for compatibility.
+ prepared_dataset, split_indices = prepare_for_model_session(
+ dataset, splits, output_columns or []
+ )
+ train_indexes = split_indices["train_indexes"]
+ test_indexes = split_indices["test_indexes"]
+ val_indexes = split_indices["val_indexes"]
+ else:
+ if splitType != "temporal":
+ log.warning(
+ "ForecastingTask received splitType=%r. "
+ "Falling back to temporal split to avoid data leakage.",
+ splitType,
+ )
+
+ train_size = splits.get("train", 0.7)
+ val_size = splits.get("validation", 0.15)
+ test_size = splits.get("test", 0.15)
+ gap = splits.get("gap", 0)
+
+ prepared_dataset = split_dataset_temporal(
+ dataset,
+ train_size=train_size,
+ val_size=val_size,
+ test_size=test_size,
+ gap=gap,
+ timestamp_col=timestamp_col,
+ )
+
+ # Keep compatibility with the rest of the system by exposing indices
+ # relative to the temporally ordered full dataset.
+ train_len = len(prepared_dataset["train"])
+ val_len = len(prepared_dataset["validation"])
+ test_len = len(prepared_dataset["test"])
+
+ train_indexes = list(range(train_len))
+ val_start = train_len + gap
+ val_indexes = list(range(val_start, val_start + val_len))
+ test_start = val_start + val_len + gap
+ test_indexes = list(range(test_start, test_start + test_len))
+
+ return prepared_dataset, {
+ "train_indexes": train_indexes,
+ "test_indexes": test_indexes,
+ "val_indexes": val_indexes,
+ }
diff --git a/DashAI/back/explainability/explainers/__init__.py b/DashAI/back/explainability/explainers/__init__.py
index e69de29bb..eca68d2ff 100644
--- a/DashAI/back/explainability/explainers/__init__.py
+++ b/DashAI/back/explainability/explainers/__init__.py
@@ -0,0 +1,24 @@
+"""Explainer implementations.
+
+This module contains all explainer implementations organized by task type.
+
+Forecasting explainers are in the `forecasting_explainers` submodule.
+"""
+
+# Import forecasting explainers
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_feature_importance import (
+ ForecastFeatureImportance,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_decomposition import (
+ ForecastDecomposition,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_uncertainty import (
+ ForecastUncertainty,
+)
+
+__all__ = [
+ # Forecasting explainers
+ "ForecastFeatureImportance",
+ "ForecastDecomposition",
+ "ForecastUncertainty",
+]
diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/__init__.py b/DashAI/back/explainability/explainers/forecasting_explainers/__init__.py
new file mode 100644
index 000000000..5fb14f082
--- /dev/null
+++ b/DashAI/back/explainability/explainers/forecasting_explainers/__init__.py
@@ -0,0 +1,35 @@
+"""Forecasting explainability module.
+
+Provides specialized explainers for time series forecasting models:
+- Base classes with time series utilities
+- Feature importance for exogenous variables
+- Forecast decomposition
+- Uncertainty analysis
+
+All forecasting explainers inherit from ForecastingGlobalExplainer or
+ForecastingLocalExplainer to leverage common time series functionality.
+"""
+
+from DashAI.back.explainability.explainers.forecasting_explainers.forecasting_global_explainer import (
+ ForecastingGlobalExplainer,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecasting_local_explainer import (
+ ForecastingLocalExplainer,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_feature_importance import (
+ ForecastFeatureImportance,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_decomposition import (
+ ForecastDecomposition,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_uncertainty import (
+ ForecastUncertainty,
+)
+
+__all__ = [
+ "ForecastingGlobalExplainer",
+ "ForecastingLocalExplainer",
+ "ForecastFeatureImportance",
+ "ForecastDecomposition",
+ "ForecastUncertainty",
+]
diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecast_decomposition.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_decomposition.py
new file mode 100644
index 000000000..b25b1b7bb
--- /dev/null
+++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_decomposition.py
@@ -0,0 +1,383 @@
+"""Forecast Decomposition Explainer for time series models.
+
+This explainer decomposes forecasts into interpretable components (trend,
+seasonality, residual) for any forecasting model that implements
+``get_forecast_components()``.
+
+Works with:
+- Prophet → trend, weekly, yearly (native structural decomposition)
+- ARIMA → trend, weekly/yearly, residual (STL on fitted + forecast)
+- SARIMAX → trend, weekly/yearly, residual (STL with explicit period s)
+- SklearnMultiStep → trend, weekly/yearly, residual (STL on history + forecast)
+
+Any future model that implements ``get_forecast_components(horizon)`` and
+returns a DataFrame with at least a ``ds`` column and one component column
+will be automatically supported.
+"""
+
+from typing import List, Tuple
+
+import numpy as np
+import pandas as pd
+import plotly
+import plotly.graph_objects as go
+from datasets import DatasetDict
+from plotly.subplots import make_subplots
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ bool_field,
+ int_field,
+ schema_field,
+)
+from DashAI.back.explainability.global_explainer import BaseGlobalExplainer
+from DashAI.back.models.base_model import BaseModel
+
+
+class ForecastDecompositionSchema(BaseSchema):
+ """Forecast Decomposition breaks down predictions into interpretable components.
+
+ This helps understand what drives the forecast:
+ - Trend: Long-term direction
+ - Seasonality: Repeating patterns (weekly, yearly, etc.)
+ - External factors: Effect of exogenous variables
+ - Residuals: Unexplained variation
+ """
+
+ horizon: schema_field(
+ int_field(ge=1, le=365),
+ placeholder=30,
+ description="Number of future periods to forecast and decompose. "
+ "Longer horizons show how components evolve over time.",
+ ) # type: ignore
+
+ include_historical: schema_field(
+ bool_field(),
+ placeholder=False,
+ description="If True, includes historical component decomposition "
+ "to show how the model understood past data.",
+ ) # type: ignore
+
+
+class ForecastDecomposition(BaseGlobalExplainer):
+ """Universal forecast decomposition explainer.
+
+ Decomposes time series forecasts into interpretable components,
+ adapting to different model types automatically.
+ """
+
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+ SCHEMA = ForecastDecompositionSchema
+
+ def __init__(
+ self,
+ model: BaseModel,
+ horizon: int = 30,
+ include_historical: bool = False,
+ ):
+ """Initialize ForecastDecomposition explainer.
+
+ Parameters
+ ----------
+ model : BaseModel
+ Trained forecasting model to explain
+ horizon : int
+ Number of periods to forecast (default: 30)
+ include_historical : bool
+ Whether to include historical decomposition (default: False)
+ """
+ super().__init__(model)
+ self.horizon = horizon
+ self.include_historical = include_historical
+
+ def _get_native_components(self) -> pd.DataFrame:
+ """Extract components from any model that implements get_forecast_components().
+
+ This covers Prophet, ARIMA, SARIMAX, and SklearnMultiStepForecaster.
+ """
+ if not hasattr(self.model, "get_forecast_components"):
+ raise AttributeError(
+ f"{type(self.model).__name__} must implement "
+ "get_forecast_components(horizon) to use this path."
+ )
+ return self.model.get_forecast_components(self.horizon)
+
+ def _get_generic_components(self, dataset: DatasetDict) -> pd.DataFrame:
+ """Fallback for models without native decomposition.
+
+ Uses simple predictions as "trend" component.
+ """
+ x, y = dataset
+
+ # Construct history dataframe from dataset (x and y)
+ # This allows the model to predict continuing from this dataset
+ try:
+ # Convert to pandas with error handling
+ try:
+ x_df = x.to_pandas() if hasattr(x, "to_pandas") else pd.DataFrame(x)
+ except Exception as e:
+ print(f"Warning: Failed to convert x to DataFrame: {e}")
+ x_df = None
+
+ try:
+ y_df = y.to_pandas() if hasattr(y, "to_pandas") else pd.DataFrame(y)
+ except Exception as e:
+ print(f"Warning: Failed to convert y to DataFrame: {e}")
+ y_df = None
+
+ # Combine if possible
+ if x_df is not None and y_df is not None and len(x_df) == len(y_df):
+ history_df = x_df.copy()
+ for col in y_df.columns:
+ history_df[col] = y_df[col].to_numpy()
+
+ # Get predictions using history context
+ predictions = self.model.predict(
+ x_pred=history_df, periods=self.horizon
+ )
+ else:
+ if x_df is not None and y_df is not None:
+ print(f"Warning: lengths differ (x={len(x_df)}, y={len(y_df)}).")
+ history_df = x_df.copy()
+ predictions = self.model.predict(
+ x_pred=history_df, periods=self.horizon
+ )
+ elif x_df is not None:
+ print("Warning: Only x dataset available. Using x as history.")
+ history_df = x_df.copy()
+ predictions = self.model.predict(
+ x_pred=history_df, periods=self.horizon
+ )
+ else:
+ print("Warning: Could not create history. Using standard predict.")
+ predictions = self.model.predict(periods=self.horizon)
+
+ except Exception as e:
+ print(f"Warning: Could not use dataset as history context: {e}")
+ # Fallback to standard prediction
+ predictions = self.model.predict(periods=self.horizon)
+
+ # Handle case where model returns fewer predictions than requested
+ # (e.g. SklearnMultiStepForecaster with direct strategy)
+ actual_horizon = len(predictions)
+
+ # Determine start date
+ start_date = pd.Timestamp.now()
+ if (
+ hasattr(self.model, "last_timestamp")
+ and self.model.last_timestamp is not None
+ ):
+ start_date = self.model.last_timestamp
+ elif hasattr(self.model, "last_ds") and self.model.last_ds is not None:
+ start_date = self.model.last_ds
+
+ # Determine frequency
+ freq = "D"
+ if hasattr(self.model, "frequency") and self.model.frequency:
+ freq = self.model.frequency
+
+ # Generate dates (start from next period after last timestamp)
+ dates = pd.date_range(start=start_date, periods=actual_horizon + 1, freq=freq)[
+ 1:
+ ]
+
+ # Create simple dataframe with predictions as "trend"
+ components_df = pd.DataFrame(
+ {
+ "ds": dates,
+ "trend": predictions
+ if isinstance(predictions, np.ndarray)
+ else predictions.to_numpy(),
+ "seasonal": np.zeros(actual_horizon),
+ "residual": np.zeros(actual_horizon),
+ }
+ )
+
+ return components_df
+
+ def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict:
+ """Generate component decomposition explanation.
+
+ Parameters
+ ----------
+ dataset : Tuple[DatasetDict, DatasetDict]
+ Tuple with (input_samples, targets) used for context
+
+ Returns
+ -------
+ dict
+ Dictionary with:
+ - ds: Timestamps
+ - trend: Trend component
+ - seasonal: Seasonal component (if applicable)
+ - weekly/yearly: Specific seasonality (if applicable)
+ - exog_*: External regressor effects (if applicable)
+ - model_type: Type of model decomposed
+ """
+ # Detect model type and extract components
+ model_name = type(self.model).__name__
+
+ # Friendly display names for known model classes
+ _display_names = {
+ "ProphetModel": "Prophet",
+ "StatsmodelsARIMAModel": "ARIMA",
+ "StatsmodelsSARIMAXModel": "SARIMAX",
+ "SklearnMultiStepForecaster": "Sklearn MultiStep",
+ }
+
+ try:
+ if hasattr(self.model, "get_forecast_components"):
+ # Prophet, ARIMA, SARIMAX, SklearnMultiStepForecaster
+ components_df = self._get_native_components()
+ model_type = _display_names.get(model_name, model_name)
+
+ else:
+ # Generic fallback for unknown model types
+ components_df = self._get_generic_components(dataset)
+ model_type = "Generic"
+
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to extract components from {model_name}: {str(e)}"
+ ) from e
+
+ # Convert to serializable format
+ explanation = {
+ "model_type": model_type,
+ "horizon": self.horizon,
+ "ds": components_df["ds"].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
+ if "ds" in components_df.columns
+ else list(range(len(components_df))),
+ }
+
+ # Add all available components
+ for col in components_df.columns:
+ if col != "ds":
+ explanation[col] = np.round(components_df[col].values, 3).tolist()
+
+ return explanation
+
+ def _create_decomposition_plot(self, explanation: dict) -> go.Figure:
+ """Create multi-panel decomposition plot."""
+
+ # Identify available components
+ component_cols = [
+ k for k in explanation if k not in ["ds", "model_type", "horizon"]
+ ]
+
+ # Prioritize component order for better visualization
+ priority_order = ["trend", "seasonal", "yearly", "weekly", "daily"]
+ ordered_components = []
+
+ for comp in priority_order:
+ if comp in component_cols:
+ ordered_components.append(comp)
+
+ # Add remaining components (e.g., exog_*)
+ for comp in component_cols:
+ if comp not in ordered_components:
+ ordered_components.append(comp)
+
+ n_components = len(ordered_components)
+
+ # Create subplots
+ fig = make_subplots(
+ rows=n_components,
+ cols=1,
+ subplot_titles=[
+ comp.replace("_", " ").title() for comp in ordered_components
+ ],
+ vertical_spacing=0.05,
+ )
+
+ # Add trace for each component
+ for i, component in enumerate(ordered_components, 1):
+ fig.add_trace(
+ go.Scatter(
+ x=explanation["ds"],
+ y=explanation[component],
+ name=component.replace("_", " ").title(),
+ line={"width": 2},
+ mode="lines",
+ ),
+ row=i,
+ col=1,
+ )
+
+ # Update layout
+ fig.update_layout(
+ height=250 * n_components,
+ title_text=f"Forecast Decomposition ({explanation['model_type']} Model)",
+ showlegend=False,
+ hovermode="x unified",
+ )
+
+ fig.update_xaxes(title_text="Date", row=n_components, col=1)
+
+ return fig
+
+ def _create_stacked_plot(self, explanation: dict) -> go.Figure:
+ """Create stacked area plot showing component contributions."""
+
+ explanation_df = pd.DataFrame(explanation)
+
+ # Components to stack (exclude residuals/noise)
+ stack_components = [
+ col
+ for col in explanation_df.columns
+ if col not in ["ds", "model_type", "horizon", "residual", "noise"]
+ and not col.startswith("yhat")
+ ]
+
+ fig = go.Figure()
+
+ for component in stack_components:
+ fig.add_trace(
+ go.Scatter(
+ x=explanation_df["ds"],
+ y=explanation_df[component],
+ name=component.replace("_", " ").title(),
+ mode="lines",
+ stackgroup="one",
+ fillcolor="rgba(0,0,0,0.1)",
+ )
+ )
+
+ fig.update_layout(
+ title="Component Contribution Over Time",
+ xaxis_title="Date",
+ yaxis_title="Contribution",
+ hovermode="x unified",
+ )
+
+ return fig
+
+ def plot(self, explanation: dict) -> List[dict]:
+ """Create visualization plots.
+
+ Parameters
+ ----------
+ explanation : dict
+ Explanation dictionary from explain()
+
+ Returns
+ -------
+ List[dict]
+ List of plotly JSON figures
+ """
+ plots = []
+
+ # Main decomposition plot
+ decomp_fig = self._create_decomposition_plot(explanation)
+ plots.append(plotly.io.to_json(decomp_fig))
+
+ # Stacked contribution plot (if multiple components)
+ component_cols = [
+ k for k in explanation if k not in ["ds", "model_type", "horizon"]
+ ]
+
+ if len(component_cols) > 1:
+ stacked_fig = self._create_stacked_plot(explanation)
+ plots.append(plotly.io.to_json(stacked_fig))
+
+ return plots
diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecast_feature_importance.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_feature_importance.py
new file mode 100644
index 000000000..cf8e02378
--- /dev/null
+++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_feature_importance.py
@@ -0,0 +1,446 @@
+"""Feature Importance Explainer for Forecasting Models.
+
+Evaluates the importance of exogenous variables (regressors) in forecasting models
+by measuring how model performance degrades when each feature is permuted.
+
+This is the forecasting adaptation of Permutation Feature Importance, using
+time series specific metrics (MAE, RMSE, MAPE) instead of classification metrics.
+
+Works with any forecasting model that implements ForecastingModel interface,
+which provides get_exogenous_columns() to list external features in their original
+format (model-agnostic).
+
+Compatible models:
+- Prophet with add_regressor()
+- ARIMA/SARIMAX with exog
+- Any model inheriting from ForecastingModel
+"""
+
+from typing import List, Tuple
+
+import numpy as np
+import pandas as pd
+import plotly
+import plotly.express as px
+from datasets import DatasetDict
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ enum_field,
+ int_field,
+ schema_field,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecasting_global_explainer import ( # noqa: E501
+ ForecastingGlobalExplainer,
+)
+from DashAI.back.models.base_model import BaseModel
+
+
+class ForecastFeatureImportanceSchema(BaseSchema):
+ """Feature Importance for forecasting models with exogenous variables.
+
+ Measures how much each external variable (weather, holidays, promotions, etc.)
+ contributes to forecast accuracy by randomly shuffling each feature and
+ measuring performance degradation.
+ """
+
+ scoring: schema_field(
+ enum_field(enum=["mae", "rmse", "mape"]),
+ placeholder="mae",
+ description="Metric to evaluate performance degradation. "
+ "MAE (Mean Absolute Error) is most interpretable, "
+ "RMSE (Root Mean Squared Error) penalizes large errors, "
+ "MAPE (Mean Absolute Percentage Error) shows relative error.",
+ ) # type: ignore
+
+ n_repeats: schema_field(
+ int_field(ge=1, le=50),
+ placeholder=10,
+ description="Number of times to permute each feature. "
+ "More repeats give more stable importance estimates but take longer.",
+ ) # type: ignore
+
+ random_state: schema_field(
+ int_field(ge=0),
+ placeholder=42,
+ description="Seed for random number generator to ensure reproducible results.",
+ ) # type: ignore
+
+
+class ForecastFeatureImportance(ForecastingGlobalExplainer):
+ """Feature importance explainer for forecasting models with exogenous variables.
+
+ Identifies which external variables (regressors) are most important for
+ accurate forecasts by measuring performance degradation when each is permuted.
+ """
+
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+ SCHEMA = ForecastFeatureImportanceSchema
+
+ def __init__(
+ self,
+ model: BaseModel,
+ scoring: str = "mae",
+ n_repeats: int = 10,
+ random_state: int = 42,
+ ):
+ """Initialize ForecastFeatureImportance explainer.
+
+ Parameters
+ ----------
+ model : BaseModel
+ Trained forecasting model to explain
+ scoring : str
+ Metric to use: 'mae', 'rmse', or 'mape' (default: 'mae')
+ n_repeats : int
+ Number of permutation repeats (default: 10)
+ random_state : int
+ Random seed for reproducibility (default: 42)
+ """
+ super().__init__(model)
+
+ # Define scoring functions
+ self.scoring_functions = {
+ "mae": self._mean_absolute_error,
+ "rmse": self._root_mean_squared_error,
+ "mape": self._mean_absolute_percentage_error,
+ }
+
+ if scoring not in self.scoring_functions:
+ raise ValueError(
+ f"Unknown scoring metric: {scoring}. "
+ f"Choose from: {list(self.scoring_functions.keys())}"
+ )
+
+ self.scoring = scoring
+ self.score_func = self.scoring_functions[scoring]
+ self.n_repeats = n_repeats
+ self.random_state = random_state
+
+ def _mean_absolute_error(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
+ """Calculate Mean Absolute Error."""
+ return np.mean(np.abs(y_true - y_pred))
+
+ def _root_mean_squared_error(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
+ """Calculate Root Mean Squared Error."""
+ return np.sqrt(np.mean((y_true - y_pred) ** 2))
+
+ def _mean_absolute_percentage_error(
+ self, y_true: np.ndarray, y_pred: np.ndarray
+ ) -> float:
+ """Calculate Mean Absolute Percentage Error."""
+ # Avoid division by zero
+ mask = y_true != 0
+ if not np.any(mask):
+ return np.inf
+ return np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
+
+ def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict:
+ """Calculate feature importance using permutation.
+
+ Parameters
+ ----------
+ dataset : Tuple[DatasetDict, DatasetDict]
+ Tuple with (full_prepared_dataset, targets)
+ For forecasting: first element contains timestamp + exog vars + target
+
+ Returns
+ -------
+ dict
+ Dictionary with:
+ - features: List of feature names
+ - importances_mean: Average importance for each feature
+ - importances_std: Standard deviation of importance
+ - baseline_score: Model performance without permutation
+ - scoring_metric: Metric used (mae, rmse, mape)
+ """
+ x, y = dataset
+
+ # Use test set for evaluation
+ x_test = x["test"]
+ y_test = y["test"]
+
+ # Get exogenous features from model (using base class method)
+ exog_features = self._get_exogenous_columns()
+
+ if len(exog_features) == 0:
+ return {
+ "error": "No exogenous features found",
+ "message": (
+ "This model does not use exogenous variables. "
+ "Feature importance is only available for models with "
+ "external regressors."
+ ),
+ "features": [],
+ "importances_mean": [],
+ "importances_std": [],
+ }
+
+ print(
+ "[ForecastFeatureImportance] Evaluating "
+ f"{len(exog_features)} exogenous variables"
+ )
+
+ # Convert to pandas - x_test now has ALL columns including timestamp
+ x_df = x_test.to_pandas()
+ y_true = y_test.to_pandas().to_numpy().ravel()
+
+ timestamp_col = self._get_timestamp_column()
+ target_col = self._get_target_column()
+
+ print(
+ f"[ForecastFeatureImportance] Using timestamp: {timestamp_col}, "
+ f"target: {target_col}"
+ )
+ print(f"[ForecastFeatureImportance] Test set size: {len(x_df)} rows")
+ print(f"[ForecastFeatureImportance] Columns in x_df: {x_df.columns.tolist()}")
+
+ # Calculate baseline score (without permutation)
+ try:
+ # Use ForecastingModel interface: predict(x_pred=DataFrame)
+ # The model's predict() method expects a DataFrame with:
+ # - Timestamp column (original name)
+ # - Exogenous variables (original names)
+ # This is model-agnostic - works for Prophet, ARIMA, LSTM, etc.
+
+ # Prepare input DataFrame for model's predict method
+ input_df = x_df.copy()
+
+ # Call the model's predict method using ForecastingModel interface
+ predictions = self.model.predict(x_pred=input_df) # type: ignore
+
+ # Extract predictions from result
+ # ForecastingModel.predict() returns DataFrame with target column
+ if isinstance(predictions, pd.DataFrame):
+ # Try to get the target column
+ if target_col and target_col in predictions.columns:
+ y_pred_baseline = predictions[target_col].to_numpy()
+ elif "yhat" in predictions.columns:
+ # Fallback to 'yhat' (Prophet style)
+ y_pred_baseline = predictions["yhat"].to_numpy()
+ else:
+ # Take first numeric column
+ numeric_cols = predictions.select_dtypes(
+ include=[np.number]
+ ).columns
+ if len(numeric_cols) > 0:
+ y_pred_baseline = predictions[numeric_cols[0]].to_numpy()
+ else:
+ raise ValueError("No numeric columns found in predictions")
+ elif isinstance(predictions, np.ndarray):
+ y_pred_baseline = predictions
+ else:
+ y_pred_baseline = np.array(predictions)
+
+ baseline_score = self.score_func(y_true, y_pred_baseline)
+ print(
+ "[ForecastFeatureImportance] Baseline score "
+ f"({self.scoring}): {baseline_score:.4f}"
+ )
+ except Exception as e:
+ print("[ForecastFeatureImportance] ERROR in baseline prediction:")
+ print(f" - x_df shape: {x_df.shape}")
+ print(f" - x_df columns: {x_df.columns.tolist()}")
+ print(f" - Error: {str(e)}")
+ raise RuntimeError(f"Failed to get baseline predictions: {str(e)}") from e
+
+ # Calculate importance for each feature
+ importances = {feature: [] for feature in exog_features}
+
+ rng = np.random.RandomState(self.random_state)
+
+ for feature in exog_features:
+ print(f"[ForecastFeatureImportance] Permuting feature: {feature}")
+ for repeat in range(self.n_repeats):
+ # Copy dataframe and permute the feature
+ x_permuted = x_df.copy()
+ x_permuted[feature] = rng.permutation(x_permuted[feature].to_numpy())
+
+ # Get predictions with permuted feature using ForecastingModel interface
+ try:
+ # Use same interface as baseline
+ predictions_perm = self.model.predict(x_pred=x_permuted) # type: ignore
+
+ # Extract predictions (same logic as baseline)
+ if isinstance(predictions_perm, pd.DataFrame):
+ if target_col and target_col in predictions_perm.columns:
+ y_pred_permuted = predictions_perm[target_col].to_numpy()
+ elif "yhat" in predictions_perm.columns:
+ y_pred_permuted = predictions_perm["yhat"].to_numpy()
+ else:
+ numeric_cols = predictions_perm.select_dtypes(
+ include=[np.number]
+ ).columns
+ if len(numeric_cols) > 0:
+ y_pred_permuted = predictions_perm[
+ numeric_cols[0]
+ ].to_numpy()
+ else:
+ raise ValueError(
+ "No numeric columns in permuted predictions"
+ )
+ elif isinstance(predictions_perm, np.ndarray):
+ y_pred_permuted = predictions_perm
+ else:
+ y_pred_permuted = np.array(predictions_perm)
+
+ permuted_score = self.score_func(y_true, y_pred_permuted)
+
+ # For error metrics (lower is better), importance is positive
+ # when permutation increases error
+ importance = permuted_score - baseline_score
+ importances[feature].append(importance)
+
+ except Exception as e:
+ print(
+ f" Warning: Failed repeat {repeat + 1} for {feature}: {str(e)}"
+ )
+ importances[feature].append(0.0)
+
+ # Calculate statistics
+ features = list(importances.keys())
+ importances_mean = [np.mean(importances[f]) for f in features]
+ importances_std = [np.std(importances[f]) for f in features]
+
+ return {
+ "features": features,
+ "importances_mean": np.round(importances_mean, 4).tolist(),
+ "importances_std": np.round(importances_std, 4).tolist(),
+ "baseline_score": round(baseline_score, 4),
+ "scoring_metric": self.scoring,
+ }
+
+ def _create_plot(
+ self, data: pd.DataFrame, explanation: dict
+ ) -> plotly.graph_objs.Figure:
+ """Create horizontal bar plot showing feature importances.
+
+ Parameters
+ ----------
+ data : pd.DataFrame
+ Dataframe with features and importances
+ explanation : dict
+ Full explanation dictionary
+
+ Returns
+ -------
+ plotly.graph_objs.Figure
+ Interactive bar chart
+ """
+ # Sort by importance
+ data = data.sort_values(by="importances_mean", ascending=True)
+
+ fig = px.bar(
+ data,
+ x="importances_mean",
+ y="features",
+ error_x="importances_std",
+ orientation="h",
+ title=f"Feature Importance ({explanation['scoring_metric'].upper()})",
+ labels={
+ "importances_mean": (
+ f"Importance (Δ{explanation['scoring_metric'].upper()})"
+ ),
+ "features": "Feature",
+ },
+ )
+
+ # Add baseline info
+ baseline_text = (
+ f"Baseline {explanation['scoring_metric'].upper()}: "
+ f"{explanation['baseline_score']:.4f}"
+ )
+
+ fig.add_annotation(
+ text=baseline_text,
+ xref="paper",
+ yref="paper",
+ x=0.98,
+ y=0.98,
+ showarrow=False,
+ bgcolor="rgba(255,255,255,0.8)",
+ bordercolor="black",
+ borderwidth=1,
+ )
+
+ # Add explanation note
+ note_text = (
+ f"Higher values = more important feature
"
+ f"Measured as increase in {explanation['scoring_metric'].upper()} "
+ f"when feature is randomly shuffled"
+ )
+
+ fig.add_annotation(
+ text=note_text,
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=-0.15,
+ showarrow=False,
+ font={"size": 10},
+ xanchor="center",
+ )
+
+ fig.update_layout(
+ height=max(400, len(data) * 40),
+ margin={"b": 100},
+ )
+
+ return fig
+
+ def plot(self, explanation: dict) -> List[dict]:
+ """Create visualization of feature importances.
+
+ Parameters
+ ----------
+ explanation : dict
+ Explanation dictionary from explain()
+
+ Returns
+ -------
+ List[dict]
+ List with single plotly JSON figure
+ """
+ # Check for errors
+ if "error" in explanation:
+ # Return empty plot with error message
+ import plotly.graph_objects as go
+
+ fig = go.Figure()
+ fig.add_annotation(
+ text=explanation["message"],
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ showarrow=False,
+ font={"size": 14},
+ )
+ fig.update_layout(
+ title="Feature Importance - No Exogenous Variables",
+ xaxis={"visible": False},
+ yaxis={"visible": False},
+ )
+ return [plotly.io.to_json(fig)]
+
+ # Create dataframe
+ data = pd.DataFrame(
+ {
+ "features": explanation["features"],
+ "importances_mean": explanation["importances_mean"],
+ "importances_std": explanation["importances_std"],
+ }
+ )
+
+ # Clean feature names for display
+ # Remove 'exog_' prefix if present, then format for readability
+ data["features"] = (
+ data["features"]
+ .str.replace("exog_", "", regex=False)
+ .str.replace("_", " ")
+ .str.title()
+ )
+
+ fig = self._create_plot(data, explanation)
+
+ return [plotly.io.to_json(fig)]
diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecast_uncertainty.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_uncertainty.py
new file mode 100644
index 000000000..1be43245f
--- /dev/null
+++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_uncertainty.py
@@ -0,0 +1,487 @@
+"""Forecast Uncertainty Analysis Explainer.
+
+Analyzes and visualizes prediction uncertainty (confidence/prediction intervals)
+across the forecast horizon. Essential for risk management and decision-making.
+
+Shows how confidence in predictions degrades over time and helps users understand
+the reliability of forecasts at different time horizons.
+
+Works with all forecasting models via ``get_forecast_uncertainty()``:
+- Prophet → native intervals from Prophet's uncertainty sampling
+- ARIMA → parametric intervals from statsmodels (analytical CIs)
+- SARIMAX → parametric intervals from statsmodels (analytical CIs)
+- SklearnMultiStep → empirical intervals: residual std × sqrt(horizon step)
+- Unknown models → fallback placeholder (±10% of forecast value)
+"""
+
+from typing import List, Tuple
+
+import numpy as np
+import pandas as pd
+import plotly
+import plotly.graph_objects as go
+from datasets import DatasetDict
+from plotly.subplots import make_subplots
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ float_field,
+ int_field,
+ schema_field,
+)
+from DashAI.back.explainability.global_explainer import BaseGlobalExplainer
+from DashAI.back.models.base_model import BaseModel
+
+
+class ForecastUncertaintySchema(BaseSchema):
+ """Forecast Uncertainty Analysis shows prediction confidence intervals.
+
+ Helps answer:
+ - How confident is the model in its predictions?
+ - How does uncertainty grow with forecast horizon?
+ - What's the best/worst case scenario?
+
+ Critical for inventory planning, capacity planning, and risk management.
+ """
+
+ horizon: schema_field(
+ int_field(ge=1, le=365),
+ placeholder=30,
+ description="Number of future periods to forecast. "
+ "Longer horizons typically show increasing uncertainty.",
+ ) # type: ignore
+
+ confidence_level: schema_field(
+ float_field(ge=0.5, le=0.99),
+ placeholder=0.80,
+ description="Confidence level for prediction intervals (e.g., 0.80 = 80%). "
+ "Higher values give wider intervals.",
+ ) # type: ignore
+
+
+class ForecastUncertainty(BaseGlobalExplainer):
+ """Analyzes forecast uncertainty and prediction intervals.
+
+ Visualizes how prediction confidence changes across the forecast horizon,
+ helping users understand forecast reliability and plan for uncertainty.
+ """
+
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+ SCHEMA = ForecastUncertaintySchema
+
+ def __init__(
+ self,
+ model: BaseModel,
+ horizon: int = 30,
+ confidence_level: float = 0.80,
+ ):
+ """Initialize ForecastUncertainty explainer.
+
+ Parameters
+ ----------
+ model : BaseModel
+ Trained forecasting model
+ horizon : int
+ Number of periods to forecast (default: 30)
+ confidence_level : float
+ Confidence level for intervals (default: 0.80 = 80%)
+ """
+ super().__init__(model)
+ self.horizon = horizon
+ self.confidence_level = confidence_level
+
+ def _get_native_uncertainty(self) -> pd.DataFrame:
+ """Get uncertainty from any model that implements get_forecast_uncertainty().
+
+ This covers Prophet, ARIMA, SARIMAX, and SklearnMultiStepForecaster.
+ """
+ if not hasattr(self.model, "get_forecast_uncertainty"):
+ raise AttributeError(
+ f"{type(self.model).__name__} must implement "
+ "get_forecast_uncertainty(horizon, confidence_level) to use this path."
+ )
+ return self.model.get_forecast_uncertainty(self.horizon, self.confidence_level)
+
+ def _get_generic_uncertainty(
+ self, dataset: Tuple[DatasetDict, DatasetDict]
+ ) -> pd.DataFrame:
+ """Fallback for models without native uncertainty quantification.
+
+ Returns point predictions with placeholder intervals.
+ """
+ x, y = dataset
+
+ # Construct history dataframe from dataset (x and y)
+ try:
+ # Convert to pandas
+ x_df = x.to_pandas() if hasattr(x, "to_pandas") else pd.DataFrame(x)
+
+ y_df = y.to_pandas() if hasattr(y, "to_pandas") else pd.DataFrame(y)
+
+ # Combine
+ if len(x_df) == len(y_df):
+ history_df = x_df.copy()
+ for col in y_df.columns:
+ history_df[col] = y_df[col].to_numpy()
+ else:
+ print(f"Warning: lengths differ (x={len(x_df)}, y={len(y_df)}).")
+ history_df = x_df.copy()
+
+ # Get point predictions using history context
+ predictions = self.model.predict(x_pred=history_df, periods=self.horizon)
+
+ except Exception as e:
+ print(f"Warning: Could not use dataset as history context: {e}")
+ # Fallback
+ predictions = self.model.predict(periods=self.horizon)
+
+ if hasattr(predictions, "to_numpy"):
+ y_pred = predictions.to_numpy()
+ elif isinstance(predictions, np.ndarray):
+ y_pred = predictions
+ else:
+ y_pred = np.array(predictions)
+
+ # Handle case where model returns fewer predictions than requested
+ actual_horizon = len(y_pred)
+
+ # Create placeholder intervals (±10% of prediction)
+ uncertainty_pct = 0.10
+
+ # Determine start date
+ start_date = pd.Timestamp.now()
+ if (
+ hasattr(self.model, "last_timestamp")
+ and self.model.last_timestamp is not None
+ ):
+ start_date = self.model.last_timestamp
+ elif hasattr(self.model, "last_ds") and self.model.last_ds is not None:
+ start_date = self.model.last_ds
+
+ # Determine frequency
+ freq = "D"
+ if hasattr(self.model, "frequency") and self.model.frequency:
+ freq = self.model.frequency
+
+ # Generate dates (start from next period after last timestamp)
+ dates = pd.date_range(start=start_date, periods=actual_horizon + 1, freq=freq)[
+ 1:
+ ]
+
+ uncertainty_df = pd.DataFrame(
+ {
+ "ds": dates,
+ "yhat": y_pred,
+ "yhat_lower": y_pred * (1 - uncertainty_pct),
+ "yhat_upper": y_pred * (1 + uncertainty_pct),
+ "estimated_intervals": True, # Flag that these are not native
+ }
+ )
+
+ return uncertainty_df
+
+ def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict:
+ """Generate uncertainty analysis explanation.
+
+ Parameters
+ ----------
+ dataset : Tuple[DatasetDict, DatasetDict]
+ Tuple with (input_features, targets) for context
+
+ Returns
+ -------
+ dict
+ Dictionary with:
+ - ds: Timestamps
+ - yhat: Point predictions
+ - yhat_lower: Lower bound of prediction interval
+ - yhat_upper: Upper bound of prediction interval
+ - uncertainty: Interval width (yhat_upper - yhat_lower)
+ - uncertainty_pct: Uncertainty as % of prediction
+ - confidence_level: Configured confidence level
+ - model_type: Model that generated intervals
+ """
+ model_name = type(self.model).__name__
+
+ # Friendly display names for known model classes
+ _display_names = {
+ "ProphetModel": "Prophet",
+ "StatsmodelsARIMAModel": "ARIMA",
+ "StatsmodelsSARIMAXModel": "SARIMAX",
+ "SklearnMultiStepForecaster": "Sklearn MultiStep",
+ }
+
+ # Models with true parametric / native intervals
+ _parametric_models = {
+ "ProphetModel",
+ "StatsmodelsARIMAModel",
+ "StatsmodelsSARIMAXModel",
+ }
+
+ # Human-readable description of the interval source
+ _interval_sources = {
+ "ProphetModel": "Native (Prophet uncertainty sampling)",
+ "StatsmodelsARIMAModel": "Parametric (ARIMA analytical CI)",
+ "StatsmodelsSARIMAXModel": "Parametric (SARIMAX analytical CI)",
+ "SklearnMultiStepForecaster": "Empirical (residual std × √horizon)",
+ }
+
+ try:
+ if hasattr(self.model, "get_forecast_uncertainty"):
+ # Prophet, ARIMA, SARIMAX, SklearnMultiStepForecaster
+ forecast_df = self._get_native_uncertainty()
+ model_type = _display_names.get(model_name, model_name)
+ has_native_intervals = model_name in _parametric_models
+ interval_source = _interval_sources.get(
+ model_name, "Native (model-specific)"
+ )
+
+ else:
+ # Generic fallback for unknown model types
+ forecast_df = self._get_generic_uncertainty(dataset)
+ model_type = "Generic"
+ has_native_intervals = False
+ interval_source = "Placeholder (±10% of forecast)"
+
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to get uncertainty estimates from {model_name}: {str(e)}"
+ ) from e
+
+ # Calculate uncertainty metrics
+ forecast_df["uncertainty"] = (
+ forecast_df["yhat_upper"] - forecast_df["yhat_lower"]
+ )
+
+ # Avoid division by zero
+ safe_yhat = np.where(forecast_df["yhat"] == 0, 1e-10, forecast_df["yhat"])
+ forecast_df["uncertainty_pct"] = (
+ forecast_df["uncertainty"] / np.abs(safe_yhat) * 100
+ )
+
+ # Build explanation
+ explanation = {
+ "model_type": model_type,
+ "confidence_level": self.confidence_level,
+ "horizon": self.horizon,
+ "has_native_intervals": has_native_intervals,
+ "interval_source": interval_source,
+ "ds": forecast_df["ds"].dt.strftime("%Y-%m-%d %H:%M:%S").tolist(),
+ "yhat": np.round(forecast_df["yhat"].to_numpy(), 3).tolist(),
+ "yhat_lower": np.round(forecast_df["yhat_lower"].to_numpy(), 3).tolist(),
+ "yhat_upper": np.round(forecast_df["yhat_upper"].to_numpy(), 3).tolist(),
+ "uncertainty": np.round(forecast_df["uncertainty"].to_numpy(), 3).tolist(),
+ "uncertainty_pct": np.round(
+ forecast_df["uncertainty_pct"].to_numpy(), 2
+ ).tolist(),
+ }
+
+ # Add summary statistics
+ explanation["summary"] = {
+ "mean_uncertainty": round(forecast_df["uncertainty"].mean(), 3),
+ "max_uncertainty": round(forecast_df["uncertainty"].max(), 3),
+ "mean_uncertainty_pct": round(forecast_df["uncertainty_pct"].mean(), 2),
+ "uncertainty_growth": round(
+ forecast_df["uncertainty"].iloc[-1] / forecast_df["uncertainty"].iloc[0]
+ if forecast_df["uncertainty"].iloc[0] != 0
+ else 0,
+ 2,
+ ),
+ }
+
+ return explanation
+
+ def _create_forecast_plot(self, explanation: dict) -> go.Figure:
+ """Create main forecast plot with confidence intervals."""
+
+ forecast_plot_df = pd.DataFrame(
+ {
+ "ds": pd.to_datetime(explanation["ds"]),
+ "yhat": explanation["yhat"],
+ "yhat_lower": explanation["yhat_lower"],
+ "yhat_upper": explanation["yhat_upper"],
+ }
+ )
+
+ fig = go.Figure()
+
+ # Add confidence interval band
+ fig.add_trace(
+ go.Scatter(
+ x=forecast_plot_df["ds"],
+ y=forecast_plot_df["yhat_upper"],
+ mode="lines",
+ line={"width": 0},
+ showlegend=False,
+ hoverinfo="skip",
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=forecast_plot_df["ds"],
+ y=forecast_plot_df["yhat_lower"],
+ mode="lines",
+ line={"width": 0},
+ fillcolor="rgba(68, 68, 68, 0.2)",
+ fill="tonexty",
+ name=(
+ f"{int(explanation['confidence_level'] * 100)}% Confidence Interval"
+ ),
+ )
+ )
+
+ # Add point forecast
+ fig.add_trace(
+ go.Scatter(
+ x=forecast_plot_df["ds"],
+ y=forecast_plot_df["yhat"],
+ mode="lines",
+ name="Forecast",
+ line={"color": "blue", "width": 2},
+ )
+ )
+
+ # Title
+ title = (
+ f"Forecast with {int(explanation['confidence_level'] * 100)}% "
+ "Confidence Interval"
+ )
+ interval_source = explanation.get("interval_source", "")
+ if interval_source:
+ title += f"
{interval_source}"
+
+ fig.update_layout(
+ title=title,
+ xaxis_title="Date",
+ yaxis_title="Predicted Value",
+ hovermode="x unified",
+ height=500,
+ )
+
+ return fig
+
+ def _create_uncertainty_growth_plot(self, explanation: dict) -> go.Figure:
+ """Create plot showing how uncertainty grows over horizon."""
+
+ growth_plot_df = pd.DataFrame(
+ {
+ "ds": pd.to_datetime(explanation["ds"]),
+ "uncertainty": explanation["uncertainty"],
+ "uncertainty_pct": explanation["uncertainty_pct"],
+ }
+ )
+
+ # Create subplot with absolute and relative uncertainty
+ fig = make_subplots(
+ rows=2,
+ cols=1,
+ subplot_titles=(
+ "Absolute Uncertainty (Interval Width)",
+ "Relative Uncertainty (% of Forecast)",
+ ),
+ vertical_spacing=0.12,
+ )
+
+ # Absolute uncertainty
+ fig.add_trace(
+ go.Scatter(
+ x=growth_plot_df["ds"],
+ y=growth_plot_df["uncertainty"],
+ mode="lines+markers",
+ name="Uncertainty",
+ line={"color": "red", "width": 2},
+ marker={"size": 4},
+ ),
+ row=1,
+ col=1,
+ )
+
+ # Relative uncertainty
+ fig.add_trace(
+ go.Scatter(
+ x=growth_plot_df["ds"],
+ y=growth_plot_df["uncertainty_pct"],
+ mode="lines+markers",
+ name="Uncertainty %",
+ line={"color": "orange", "width": 2},
+ marker={"size": 4},
+ ),
+ row=2,
+ col=1,
+ )
+
+ fig.update_xaxes(title_text="Date", row=2, col=1)
+ fig.update_yaxes(title_text="Interval Width", row=1, col=1)
+ fig.update_yaxes(title_text="Uncertainty (%)", row=2, col=1)
+
+ fig.update_layout(
+ title="Uncertainty Growth Over Forecast Horizon",
+ height=600,
+ showlegend=False,
+ )
+
+ return fig
+
+ def plot(self, explanation: dict) -> List[dict]:
+ """Create visualization plots.
+
+ Parameters
+ ----------
+ explanation : dict
+ Explanation dictionary from explain()
+
+ Returns
+ -------
+ List[dict]
+ List of plotly JSON figures
+ """
+ plots = []
+
+ # Main forecast with intervals
+ forecast_fig = self._create_forecast_plot(explanation)
+ plots.append(plotly.io.to_json(forecast_fig))
+
+ # Uncertainty growth analysis
+ uncertainty_fig = self._create_uncertainty_growth_plot(explanation)
+ plots.append(plotly.io.to_json(uncertainty_fig))
+
+ # Add summary statistics as annotation figure
+ summary = explanation["summary"]
+
+ import plotly.graph_objects as go
+
+ summary_fig = go.Figure()
+
+ summary_text = (
+ f"Uncertainty Summary
"
+ f"Mean Uncertainty: {summary['mean_uncertainty']}
"
+ f"Max Uncertainty: {summary['max_uncertainty']}
"
+ f"Mean Uncertainty %: {summary['mean_uncertainty_pct']:.1f}%
"
+ f"Uncertainty Growth: {summary['uncertainty_growth']:.2f}x"
+ )
+
+ summary_fig.add_annotation(
+ text=summary_text,
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ showarrow=False,
+ font={"size": 14},
+ align="left",
+ bgcolor="rgba(255,255,255,0.9)",
+ bordercolor="black",
+ borderwidth=2,
+ )
+
+ summary_fig.update_layout(
+ title="Summary Statistics",
+ xaxis={"visible": False},
+ yaxis={"visible": False},
+ height=300,
+ )
+
+ plots.append(plotly.io.to_json(summary_fig))
+
+ return plots
diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_global_explainer.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_global_explainer.py
new file mode 100644
index 000000000..42c96390f
--- /dev/null
+++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_global_explainer.py
@@ -0,0 +1,294 @@
+"""Base class for global explainers specialized for forecasting tasks.
+
+Provides common functionality for explaining forecasting models:
+- Timestamp column detection and handling
+- Frequency inference and validation
+- Exogenous variable management
+- Time series data preparation
+
+All global explainers for forecasting tasks should inherit from this class.
+"""
+
+from abc import abstractmethod
+from typing import List, Optional, Tuple
+
+import pandas as pd
+from datasets import DatasetDict
+
+from DashAI.back.explainability.global_explainer import BaseGlobalExplainer
+from DashAI.back.models.base_model import BaseModel
+
+
+class ForecastingGlobalExplainer(BaseGlobalExplainer):
+ """Base class for global explainers specialized for forecasting.
+
+ Provides common utilities for handling time series data:
+ - Timestamp column detection
+ - Frequency inference
+ - Exogenous variable extraction
+ - Data validation for forecasting
+
+ Subclasses must implement:
+ - explain(): Generate the explanation
+ - plot(): Create visualizations
+ """
+
+ # All forecasting explainers are compatible with ForecastingTask
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+
+ def __init__(self, model: BaseModel, **kwargs):
+ """Initialize forecasting global explainer.
+
+ Parameters
+ ----------
+ model : BaseModel
+ Trained forecasting model to explain
+ **kwargs : dict
+ Additional parameters passed to parent class
+ """
+ super().__init__(model, **kwargs)
+
+ # Cache for model metadata
+ self._timestamp_col: Optional[str] = None
+ self._target_col: Optional[str] = None
+ self._exog_cols: Optional[List[str]] = None
+ self._frequency: Optional[str] = None
+
+ def _get_timestamp_column(self) -> Optional[str]:
+ """Get timestamp column name from model.
+
+ Returns
+ -------
+ str or None
+ Name of timestamp column, or None if not available
+ """
+ if self._timestamp_col is not None:
+ return self._timestamp_col
+
+ # Try to get from model
+ if hasattr(self.model, "timestamp_col"):
+ self._timestamp_col = getattr(self.model, "timestamp_col", None)
+ elif hasattr(self.model, "get_column_names"):
+ try:
+ col_names = self.model.get_column_names() # type: ignore
+ self._timestamp_col = col_names.get("timestamp")
+ except Exception:
+ pass
+
+ return self._timestamp_col
+
+ def _get_target_column(self) -> Optional[str]:
+ """Get target column name from model.
+
+ Returns
+ -------
+ str or None
+ Name of target column, or None if not available
+ """
+ if self._target_col is not None:
+ return self._target_col
+
+ # Try to get from model
+ if hasattr(self.model, "target_col"):
+ self._target_col = getattr(self.model, "target_col", None)
+ elif hasattr(self.model, "get_column_names"):
+ try:
+ col_names = self.model.get_column_names() # type: ignore
+ self._target_col = col_names.get("target")
+ except Exception:
+ pass
+
+ return self._target_col
+
+ def _get_exogenous_columns(self) -> List[str]:
+ """Get exogenous variable names from model.
+
+ Uses model's interface to get exogenous columns in original format.
+
+ Returns
+ -------
+ List[str]
+ List of exogenous variable names
+ """
+ if self._exog_cols is not None:
+ return self._exog_cols
+
+ # Try to get from model using ForecastingModel interface
+ if hasattr(self.model, "get_exogenous_columns"):
+ try:
+ self._exog_cols = self.model.get_exogenous_columns() # type: ignore
+ return self._exog_cols or []
+ except Exception:
+ pass
+
+ # Fallback: check exog_cols attribute
+ if hasattr(self.model, "exog_cols"):
+ self._exog_cols = getattr(self.model, "exog_cols", [])
+ return self._exog_cols or []
+
+ return []
+
+ def _get_frequency(self) -> Optional[str]:
+ """Get time series frequency from model.
+
+ Returns
+ -------
+ str or None
+ Frequency string (e.g., 'D', 'H', 'M'), or None if not available
+ """
+ if self._frequency is not None:
+ return self._frequency
+
+ # Try to get from model
+ if hasattr(self.model, "frequency"):
+ self._frequency = getattr(self.model, "frequency", None)
+
+ return self._frequency
+
+ def _extract_timestamps(
+ self, dataset: DatasetDict, split: str = "test"
+ ) -> pd.Series:
+ """Extract timestamp column from dataset.
+
+ Parameters
+ ----------
+ dataset : DatasetDict
+ Dataset containing time series data
+ split : str
+ Which split to extract from (default: "test")
+
+ Returns
+ -------
+ pd.Series
+ Series with timestamps as datetime
+
+ Raises
+ ------
+ ValueError
+ If timestamp column not found or cannot be converted
+ """
+ timestamp_col = self._get_timestamp_column()
+
+ if timestamp_col is None:
+ raise ValueError(
+ "Cannot determine timestamp column. "
+ "Model must store timestamp_col or implement get_column_names()"
+ )
+
+ if split not in dataset:
+ raise ValueError(f"Split '{split}' not found in dataset")
+
+ ds = dataset[split]
+
+ if timestamp_col not in ds.column_names:
+ raise ValueError(
+ f"Timestamp column '{timestamp_col}' not found in dataset. "
+ f"Available columns: {ds.column_names}"
+ )
+
+ # Convert to pandas Series with datetime
+ timestamps = pd.to_datetime(ds.to_pandas()[timestamp_col])
+
+ return timestamps
+
+ def _prepare_dataset_with_timestamps(
+ self, dataset: DatasetDict, split: str = "test"
+ ) -> pd.DataFrame:
+ """Prepare dataset as DataFrame with all required columns.
+
+ Includes timestamp column, exogenous variables, and target (if available).
+ This is useful when explainers need the full context for predictions.
+
+ Parameters
+ ----------
+ dataset : DatasetDict
+ Dataset to prepare
+ split : str
+ Which split to use (default: "test")
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame with timestamps, exogenous variables, and target
+ """
+ if split not in dataset:
+ raise ValueError(f"Split '{split}' not found in dataset")
+
+ split_df = dataset[split].to_pandas()
+
+ # Ensure timestamp column is datetime
+ timestamp_col = self._get_timestamp_column()
+ if timestamp_col and timestamp_col in split_df.columns:
+ split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col])
+
+ return split_df
+
+ def _validate_has_exogenous_variables(self) -> bool:
+ """Check if model uses exogenous variables.
+
+ Returns
+ -------
+ bool
+ True if model has exogenous variables
+ """
+ if hasattr(self.model, "has_exogenous_variables"):
+ try:
+ return self.model.has_exogenous_variables() # type: ignore
+ except Exception:
+ pass
+
+ # Fallback: check if exog_cols is non-empty
+ exog_cols = self._get_exogenous_columns()
+ return len(exog_cols) > 0
+
+ def _infer_frequency(self, timestamps: pd.Series) -> Optional[str]:
+ """Infer frequency from timestamp series.
+
+ Parameters
+ ----------
+ timestamps : pd.Series
+ Series of datetime values
+
+ Returns
+ -------
+ str or None
+ Inferred frequency (e.g., 'D', 'H'), or None if cannot infer
+ """
+ try:
+ # Use pandas infer_freq
+ freq = pd.infer_freq(timestamps)
+ return freq
+ except Exception:
+ # Try getting from model
+ return self._get_frequency()
+
+ @abstractmethod
+ def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict:
+ """Generate explanation for the forecasting model.
+
+ Parameters
+ ----------
+ dataset : Tuple[DatasetDict, DatasetDict]
+ Tuple with (input_features, targets)
+ Note: For forecasting, input_features may need timestamp column
+
+ Returns
+ -------
+ dict
+ Explanation results
+ """
+
+ @abstractmethod
+ def plot(self, explanation: dict) -> List[dict]:
+ """Create visualizations for the explanation.
+
+ Parameters
+ ----------
+ explanation : dict
+ Explanation dictionary from explain()
+
+ Returns
+ -------
+ List[dict]
+ List of plotly JSON figures
+ """
diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_local_explainer.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_local_explainer.py
new file mode 100644
index 000000000..217e50c6c
--- /dev/null
+++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_local_explainer.py
@@ -0,0 +1,336 @@
+"""Base class for local explainers specialized for forecasting tasks.
+
+Provides common functionality for explaining individual forecasts:
+- Instance selection from time series
+- Window extraction for point-in-time explanations
+- Temporal context management
+- Per-forecast explanation generation
+
+All local explainers for forecasting tasks should inherit from this class.
+"""
+
+from abc import abstractmethod
+from typing import List, Optional, Tuple
+
+import pandas as pd
+from datasets import DatasetDict
+
+from DashAI.back.explainability.local_explainer import BaseLocalExplainer
+from DashAI.back.models.base_model import BaseModel
+
+
+class ForecastingLocalExplainer(BaseLocalExplainer):
+ """Base class for local explainers specialized for forecasting.
+
+ Provides common utilities for explaining individual forecasts:
+ - Timestamp handling for specific forecast points
+ - Window extraction (e.g., last N days before forecast)
+ - Exogenous variable context
+ - Per-instance explanation generation
+
+ Subclasses must implement:
+ - fit(): Prepare explainer with training data
+ - explain_instance(): Generate explanation for specific forecast
+ - plot(): Create visualizations
+ """
+
+ # All forecasting local explainers are compatible with ForecastingTask
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+
+ def __init__(self, model: BaseModel, **kwargs):
+ """Initialize forecasting local explainer.
+
+ Parameters
+ ----------
+ model : BaseModel
+ Trained forecasting model to explain
+ **kwargs : dict
+ Additional parameters passed to parent class
+ """
+ super().__init__(model, **kwargs)
+
+ # Cache for model metadata
+ self._timestamp_col: Optional[str] = None
+ self._target_col: Optional[str] = None
+ self._exog_cols: Optional[List[str]] = None
+ self._frequency: Optional[str] = None
+
+ def _get_timestamp_column(self) -> Optional[str]:
+ """Get timestamp column name from model.
+
+ Returns
+ -------
+ str or None
+ Name of timestamp column, or None if not available
+ """
+ if self._timestamp_col is not None:
+ return self._timestamp_col
+
+ # Try to get from model
+ if hasattr(self.model, "timestamp_col"):
+ self._timestamp_col = getattr(self.model, "timestamp_col", None)
+ elif hasattr(self.model, "get_column_names"):
+ try:
+ col_names = self.model.get_column_names() # type: ignore
+ self._timestamp_col = col_names.get("timestamp")
+ except Exception:
+ pass
+
+ return self._timestamp_col
+
+ def _get_target_column(self) -> Optional[str]:
+ """Get target column name from model.
+
+ Returns
+ -------
+ str or None
+ Name of target column, or None if not available
+ """
+ if self._target_col is not None:
+ return self._target_col
+
+ # Try to get from model
+ if hasattr(self.model, "target_col"):
+ self._target_col = getattr(self.model, "target_col", None)
+ elif hasattr(self.model, "get_column_names"):
+ try:
+ col_names = self.model.get_column_names() # type: ignore
+ self._target_col = col_names.get("target")
+ except Exception:
+ pass
+
+ return self._target_col
+
+ def _get_exogenous_columns(self) -> List[str]:
+ """Get exogenous variable names from model.
+
+ Uses model's interface to get exogenous columns in original format.
+
+ Returns
+ -------
+ List[str]
+ List of exogenous variable names
+ """
+ if self._exog_cols is not None:
+ return self._exog_cols
+
+ # Try to get from model using ForecastingModel interface
+ if hasattr(self.model, "get_exogenous_columns"):
+ try:
+ self._exog_cols = self.model.get_exogenous_columns() # type: ignore
+ return self._exog_cols or []
+ except Exception:
+ pass
+
+ # Fallback: check exog_cols attribute
+ if hasattr(self.model, "exog_cols"):
+ self._exog_cols = getattr(self.model, "exog_cols", [])
+ return self._exog_cols or []
+
+ return []
+
+ def _get_frequency(self) -> Optional[str]:
+ """Get time series frequency from model.
+
+ Returns
+ -------
+ str or None
+ Frequency string (e.g., 'D', 'H', 'M'), or None if not available
+ """
+ if self._frequency is not None:
+ return self._frequency
+
+ # Try to get from model
+ if hasattr(self.model, "frequency"):
+ self._frequency = getattr(self.model, "frequency", None)
+
+ return self._frequency
+
+ def _extract_window(
+ self,
+ dataset: DatasetDict,
+ split: str = "test",
+ window_size: Optional[int] = None,
+ end_index: Optional[int] = None,
+ ) -> pd.DataFrame:
+ """Extract a window of data for local explanation.
+
+ Useful for explaining a specific forecast by showing the context
+ (e.g., last 30 days before the forecast point).
+
+ Parameters
+ ----------
+ dataset : DatasetDict
+ Dataset containing time series data
+ split : str
+ Which split to use (default: "test")
+ window_size : int, optional
+ Number of time points to include in window
+ If None, returns all data up to end_index
+ end_index : int, optional
+ Last index to include (exclusive)
+ If None, uses all available data
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame with windowed data
+ """
+ if split not in dataset:
+ raise ValueError(f"Split '{split}' not found in dataset")
+
+ split_df = dataset[split].to_pandas()
+
+ # Apply end index
+ if end_index is not None:
+ split_df = split_df.iloc[:end_index]
+
+ # Apply window size
+ if window_size is not None and len(split_df) > window_size:
+ split_df = split_df.iloc[-window_size:]
+
+ # Ensure timestamp column is datetime
+ timestamp_col = self._get_timestamp_column()
+ if timestamp_col and timestamp_col in split_df.columns:
+ split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col])
+
+ return split_df
+
+ def _select_instance_by_timestamp(
+ self, dataset: DatasetDict, timestamp: pd.Timestamp, split: str = "test"
+ ) -> pd.Series:
+ """Select a specific instance by timestamp.
+
+ Parameters
+ ----------
+ dataset : DatasetDict
+ Dataset containing time series data
+ timestamp : pd.Timestamp
+ Timestamp of instance to select
+ split : str
+ Which split to use (default: "test")
+
+ Returns
+ -------
+ pd.Series
+ Single row as Series
+
+ Raises
+ ------
+ ValueError
+ If timestamp not found in dataset
+ """
+ timestamp_col = self._get_timestamp_column()
+
+ if timestamp_col is None:
+ raise ValueError(
+ "Cannot select by timestamp: timestamp column not available"
+ )
+
+ split_df = dataset[split].to_pandas()
+ split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col])
+
+ mask = split_df[timestamp_col] == timestamp
+
+ if not mask.any():
+ raise ValueError(f"Timestamp {timestamp} not found in {split} split")
+
+ return split_df[mask].iloc[0]
+
+ def _prepare_dataset_with_timestamps(
+ self, dataset: DatasetDict, split: str = "test"
+ ) -> pd.DataFrame:
+ """Prepare dataset as DataFrame with all required columns.
+
+ Includes timestamp column, exogenous variables, and target (if available).
+
+ Parameters
+ ----------
+ dataset : DatasetDict
+ Dataset to prepare
+ split : str
+ Which split to use (default: "test")
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame with timestamps, exogenous variables, and target
+ """
+ if split not in dataset:
+ raise ValueError(f"Split '{split}' not found in dataset")
+
+ split_df = dataset[split].to_pandas()
+
+ # Ensure timestamp column is datetime
+ timestamp_col = self._get_timestamp_column()
+ if timestamp_col and timestamp_col in split_df.columns:
+ split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col])
+
+ return split_df
+
+ def _validate_has_exogenous_variables(self) -> bool:
+ """Check if model uses exogenous variables.
+
+ Returns
+ -------
+ bool
+ True if model has exogenous variables
+ """
+ if hasattr(self.model, "has_exogenous_variables"):
+ try:
+ return self.model.has_exogenous_variables() # type: ignore
+ except Exception:
+ pass
+
+ # Fallback: check if exog_cols is non-empty
+ exog_cols = self._get_exogenous_columns()
+ return len(exog_cols) > 0
+
+ @abstractmethod
+ def fit(
+ self, dataset: Tuple[DatasetDict, DatasetDict], **fit_params
+ ) -> "ForecastingLocalExplainer":
+ """Fit the explainer on training data.
+
+ Parameters
+ ----------
+ dataset : Tuple[DatasetDict, DatasetDict]
+ Tuple with (input_features, targets)
+ **fit_params : dict
+ Additional fitting parameters
+
+ Returns
+ -------
+ ForecastingLocalExplainer
+ Self
+ """
+
+ @abstractmethod
+ def explain_instance(self, instance: DatasetDict) -> dict:
+ """Generate explanation for a specific forecast instance.
+
+ Parameters
+ ----------
+ instance : DatasetDict
+ Single instance or small window to explain
+
+ Returns
+ -------
+ dict
+ Explanation for this specific instance
+ """
+
+ @abstractmethod
+ def plot(self, explanation: dict) -> List[dict]:
+ """Create visualizations for the local explanation.
+
+ Parameters
+ ----------
+ explanation : dict
+ Explanation dictionary from explain_instance()
+
+ Returns
+ -------
+ List[dict]
+ List of plotly JSON figures
+ """
diff --git a/DashAI/back/initial_components.py b/DashAI/back/initial_components.py
index f8b9f006b..1e2d27091 100644
--- a/DashAI/back/initial_components.py
+++ b/DashAI/back/initial_components.py
@@ -60,13 +60,32 @@
CharacterReplacer,
)
from DashAI.back.converters.simple_converters.column_remover import ColumnRemover
+
+# Forecasting converters
+from DashAI.back.converters.simple_converters.extend_time_series_converter import (
+ ExtendTimeSeriesConverter,
+)
from DashAI.back.converters.simple_converters.nan_remover import NanRemover
+from DashAI.back.converters.simple_converters.time_series_window_converter import (
+ TimeSeriesWindowConverter,
+)
# DataLoaders
from DashAI.back.dataloaders.classes.csv_dataloader import CSVDataLoader
from DashAI.back.dataloaders.classes.excel_dataloader import ExcelDataLoader
from DashAI.back.dataloaders.classes.json_dataloader import JSONDataLoader
+# Forecasting explainers
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_decomposition import ( # noqa: E501
+ ForecastDecomposition,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_feature_importance import ( # noqa: E501
+ ForecastFeatureImportance,
+)
+from DashAI.back.explainability.explainers.forecasting_explainers.forecast_uncertainty import ( # noqa: E501
+ ForecastUncertainty,
+)
+
# Explainers
from DashAI.back.explainability.explainers.kernel_shap import KernelShap
from DashAI.back.explainability.explainers.partial_dependence import PartialDependence
@@ -99,6 +118,9 @@
from DashAI.back.job.dataset_job import DatasetJob
from DashAI.back.job.explainer_job import ExplainerJob
from DashAI.back.job.explorer_job import ExplorerJob
+
+# Forecasting job
+from DashAI.back.job.forecasting_job import ForecastingJob
from DashAI.back.job.generative_job import GenerativeJob
from DashAI.back.job.model_job import ModelJob
from DashAI.back.job.pipeline_job import PipelineJob
@@ -113,6 +135,10 @@
from DashAI.back.metrics.classification.precision import Precision
from DashAI.back.metrics.classification.recall import Recall
from DashAI.back.metrics.classification.roc_auc import ROCAUC
+
+# Forecasting metrics
+from DashAI.back.metrics.forecasting.mape import MAPE
+from DashAI.back.metrics.forecasting.smape import SMAPE
from DashAI.back.metrics.regression.explained_variance import ExplainedVariance
from DashAI.back.metrics.regression.mae import MAE
from DashAI.back.metrics.regression.median_absolute_error import MedianAbsoluteError
@@ -123,6 +149,18 @@
from DashAI.back.metrics.translation.chrf import Chrf
from DashAI.back.metrics.translation.ter import Ter
+# Forecasting models
+from DashAI.back.models.forecasting.prophet_model import ProphetModel
+from DashAI.back.models.forecasting.sklearn_multistep_forecaster import (
+ SklearnMultiStepForecaster,
+)
+from DashAI.back.models.forecasting.statsmodels_arima_model import (
+ StatsmodelsARIMAModel,
+)
+from DashAI.back.models.forecasting.statsmodels_sarimax_model import (
+ StatsmodelsSARIMAXModel,
+)
+
# Models
from DashAI.back.models.hugging_face.distilbert_transformer import DistilBertTransformer
from DashAI.back.models.hugging_face.llama_model import LlamaModel
@@ -178,6 +216,9 @@
from DashAI.back.models.scikit_learn.linearSVR import LinearSVR
from DashAI.back.models.scikit_learn.logistic_regression import LogisticRegression
from DashAI.back.models.scikit_learn.mlp_regression import MLPRegression
+from DashAI.back.models.scikit_learn.multi_output_regression import (
+ MultiOutputRegression,
+)
from DashAI.back.models.scikit_learn.random_forest_classifier import (
RandomForestClassifier,
)
@@ -203,6 +244,9 @@
# Tasks
from DashAI.back.tasks.controlnet_task import ControlNetTask
+
+# Forecasting task
+from DashAI.back.tasks.forecasting_task import ForecastingTask
from DashAI.back.tasks.regression_task import RegressionTask
from DashAI.back.tasks.tabular_classification_task import TabularClassificationTask
from DashAI.back.tasks.text_classification_task import TextClassificationTask
@@ -231,6 +275,7 @@ def get_initial_components():
TextClassificationTask,
TranslationTask,
RegressionTask,
+ ForecastingTask,
TextToImageGenerationTask,
TextToTextGenerationTask,
ControlNetTask,
@@ -259,6 +304,11 @@ def get_initial_components():
SDXLCannyControlNetModel,
LogisticRegression,
MLPRegression,
+ MultiOutputRegression,
+ ProphetModel,
+ SklearnMultiStepForecaster,
+ StatsmodelsARIMAModel,
+ StatsmodelsSARIMAXModel,
RandomForestClassifier,
RandomForestRegression,
DistilBertTransformer,
@@ -281,6 +331,8 @@ def get_initial_components():
Chrf,
MSE,
RMSE,
+ MAPE,
+ SMAPE,
MAE,
R2,
MedianAbsoluteError,
@@ -296,6 +348,7 @@ def get_initial_components():
ExplainerJob,
ModelJob,
ExplorerJob,
+ ForecastingJob,
PredictJob,
ConverterListJob,
DatasetJob,
@@ -305,6 +358,9 @@ def get_initial_components():
KernelShap,
PartialDependence,
PermutationFeatureImportance,
+ ForecastDecomposition,
+ ForecastFeatureImportance,
+ ForecastUncertainty,
# Explorers
DescribeExplorer,
ScatterPlotExplorer,
@@ -364,6 +420,8 @@ def get_initial_components():
SMOTEConverter,
SMOTEENNConverter,
RandomUnderSamplerConverter,
+ TimeSeriesWindowConverter,
+ ExtendTimeSeriesConverter,
]
# Obtener plugins instalados
diff --git a/DashAI/back/job/explainer_job.py b/DashAI/back/job/explainer_job.py
index 4863a59f0..e93b1f9c9 100644
--- a/DashAI/back/job/explainer_job.py
+++ b/DashAI/back/job/explainer_job.py
@@ -416,43 +416,66 @@ def run(
) from e
try:
splits = json.loads(run.split_indexes)
- loaded_dataset = split_dataset(
- loaded_dataset,
- train_indexes=splits["train_indexes"],
- test_indexes=splits["test_indexes"],
- val_indexes=splits["val_indexes"],
- )
- prepared_dataset = task.prepare_for_task(
- dataset=loaded_dataset,
- input_columns=self.input_columns,
- output_columns=self.output_columns,
- )
- data = select_columns(
- prepared_dataset,
- self.input_columns,
- self.output_columns,
- )
+ if model_session.task_name == "ForecastingTask":
+ # For forecasting: prepare full dataset BEFORE splitting
+ # (preserves all data points for temporal integrity)
+ prepared_dataset = task.prepare_for_task(
+ dataset=loaded_dataset,
+ input_columns=self.input_columns,
+ output_columns=self.output_columns,
+ )
- data_x = split_dataset(
- data[0],
- train_indexes=splits["train_indexes"],
- test_indexes=splits["test_indexes"],
- val_indexes=splits["val_indexes"],
- )
- data_y = split_dataset(
- data[1],
- train_indexes=splits["train_indexes"],
- test_indexes=splits["test_indexes"],
- val_indexes=splits["val_indexes"],
- )
- for split_name in data_x:
- data_x[split_name] = trained_model.prepare_dataset(
- data_x[split_name], is_fit=False
+ data = select_columns(
+ prepared_dataset,
+ self.input_columns,
+ self.output_columns,
+ )
+
+ data_x = split_dataset(
+ data[0],
+ train_indexes=splits["train_indexes"],
+ test_indexes=splits["test_indexes"],
+ val_indexes=splits["val_indexes"],
+ )
+ data_y = split_dataset(
+ data[1],
+ train_indexes=splits["train_indexes"],
+ test_indexes=splits["test_indexes"],
+ val_indexes=splits["val_indexes"],
+ )
+ else:
+ # For other tasks: standard flow
+ prepared_dataset = task.prepare_for_task(
+ dataset=loaded_dataset,
+ input_columns=self.input_columns,
+ output_columns=self.output_columns,
+ )
+ data = select_columns(
+ prepared_dataset,
+ self.input_columns,
+ self.output_columns,
+ )
+
+ data_x = split_dataset(
+ data[0],
+ train_indexes=splits["train_indexes"],
+ test_indexes=splits["test_indexes"],
+ val_indexes=splits["val_indexes"],
)
- data_y[split_name] = trained_model.prepare_output(
- data_y[split_name], is_fit=False
+ data_y = split_dataset(
+ data[1],
+ train_indexes=splits["train_indexes"],
+ test_indexes=splits["test_indexes"],
+ val_indexes=splits["val_indexes"],
)
+ for split_name in data_x:
+ data_x[split_name] = trained_model.prepare_dataset(
+ data_x[split_name], is_fit=False
+ )
+ data_y[split_name] = trained_model.prepare_output(
+ data_y[split_name], is_fit=False
+ )
except Exception as e:
log.exception(e)
diff --git a/DashAI/back/job/forecasting_job.py b/DashAI/back/job/forecasting_job.py
new file mode 100644
index 000000000..7359c64bc
--- /dev/null
+++ b/DashAI/back/job/forecasting_job.py
@@ -0,0 +1,447 @@
+"""Forecasting-specific job for time series model training."""
+
+import gc
+import json
+import logging
+import os
+import pickle
+from typing import List
+
+from kink import inject
+from sqlalchemy import exc
+from sqlalchemy.orm import sessionmaker
+
+from DashAI.back.core.enums.metrics import LevelEnum, SplitEnum
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ load_dataset,
+ prepare_for_forecasting_experiment,
+ select_columns,
+ split_dataset,
+)
+from DashAI.back.dependencies.database.models import Dataset, ModelSession, Run
+from DashAI.back.job.base_job import BaseJob, JobError
+from DashAI.back.metrics.base_metric import BaseMetric
+from DashAI.back.models.base_model import BaseModel
+from DashAI.back.models.model_factory import ModelFactory
+from DashAI.back.optimizers.base_optimizer import BaseOptimizer
+from DashAI.back.tasks.base_task import BaseTask
+
+logging.basicConfig(level=logging.DEBUG)
+log = logging.getLogger(__name__)
+
+
+class ForecastingJob(BaseJob):
+ """ForecastingJob class for time series model training with temporal splitting."""
+
+ @inject
+ def set_status_as_delivered(
+ self, session_factory: sessionmaker = lambda di: di["session_factory"]
+ ) -> None:
+ """Set the status of the job as delivered."""
+ run_id: int = self.kwargs["run_id"]
+
+ with session_factory() as db:
+ run: Run = db.get(Run, run_id)
+ if not run:
+ raise JobError(f"Run {run_id} does not exist in DB.")
+ try:
+ run.set_status_as_delivered()
+ db.commit()
+ except exc.SQLAlchemyError as e:
+ log.exception(e)
+ raise JobError(
+ "Internal database error",
+ ) from e
+
+ @inject
+ def set_status_as_error(
+ self, session_factory: sessionmaker = lambda di: di["session_factory"]
+ ) -> None:
+ """Set the status of the job as error."""
+ run_id: int = self.kwargs.get("run_id")
+ if run_id is None:
+ return
+
+ with session_factory() as db:
+ run: Run = db.get(Run, run_id)
+ if not run:
+ return
+ try:
+ run.set_status_as_error()
+ db.commit()
+ except exc.SQLAlchemyError as e:
+ log.exception(e)
+
+ @inject
+ def get_job_name(self) -> str:
+ """Get a descriptive name for the job."""
+ run_id = self.kwargs.get("run_id")
+ if not run_id:
+ return "Forecasting Training"
+
+ from kink import di
+
+ session_factory = di["session_factory"]
+
+ try:
+ with session_factory() as db:
+ run: Run = db.get(Run, run_id)
+ if run and run.name:
+ return f"Forecast: {run.name}"
+ except Exception:
+ pass
+
+ return f"Forecasting Training ({run_id})"
+
+ @inject
+ def run(self) -> None:
+ from kink import di
+
+ component_registry = di["component_registry"]
+ session_factory = di["session_factory"]
+ config = di["config"]
+
+ # Get the necessary parameters
+ run_id: int = self.kwargs["run_id"]
+
+ with session_factory() as db:
+ run: Run = db.get(Run, run_id)
+ run.huey_id = self.kwargs.get("huey_id", None)
+ db.commit()
+ try:
+ # Get the model session, dataset, task, metrics and splits
+ model_session: ModelSession = db.get(ModelSession, run.model_session_id)
+ if not model_session:
+ raise JobError(
+ f"Model session {run.model_session_id} does not exist in DB."
+ )
+ dataset: Dataset = db.get(Dataset, model_session.dataset_id)
+ if not dataset:
+ raise JobError(
+ f"Dataset {model_session.dataset_id} does not exist in DB."
+ )
+
+ try:
+ loaded_dataset: DashAIDataset = load_dataset(
+ f"{dataset.file_path}/dataset"
+ )
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ f"Can not load dataset from path {dataset.file_path}",
+ ) from e
+
+ try:
+ task: BaseTask = component_registry[model_session.task_name][
+ "class"
+ ]()
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ (
+ f"Unable to find Task with name {model_session.task_name} "
+ "in registry"
+ ),
+ ) from e
+
+ # Validate this is a forecasting task
+ if model_session.task_name != "ForecastingTask":
+ raise JobError(
+ f"ForecastingJob can only be used with ForecastingTask, "
+ f"got {model_session.task_name}"
+ )
+
+ try:
+ # Get metrics selected in the model session
+ train_metrics: List[BaseMetric] = [
+ component_registry[m]["class"]
+ for m in model_session.train_metrics
+ ]
+ validation_metrics: List[BaseMetric] = [
+ component_registry[m]["class"]
+ for m in model_session.validation_metrics
+ ]
+ test_metrics: List[BaseMetric] = [
+ component_registry[m]["class"]
+ for m in model_session.test_metrics
+ ]
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ "Unable to find metrics associated with"
+ f"Task {model_session.task_name} in registry",
+ ) from e
+
+ try:
+ # Prepare dataset for forecasting task with auto-detection
+ prepared_dataset = task.prepare_for_task(
+ loaded_dataset,
+ outputs_columns=model_session.output_columns,
+ inputs_columns=model_session.input_columns,
+ # Optional: Override auto-detection if specified
+ timestamp_column=getattr(
+ model_session, "timestamp_column", None
+ ),
+ frequency=getattr(model_session, "frequency", "auto"),
+ )
+
+ # Get temporal metadata for logging
+ temporal_metadata = task.get_temporal_metadata()
+ log.info(f"Temporal metadata: {temporal_metadata}")
+
+ splits = json.loads(model_session.splits)
+
+ # Use forecasting-specific preparation with temporal splitting
+ prepared_dataset, splits = prepare_for_forecasting_experiment(
+ dataset=prepared_dataset,
+ splits=splits,
+ timestamp_col=temporal_metadata.get("timestamp_col", "ds"),
+ output_columns=model_session.output_columns,
+ )
+
+ run.split_indexes = json.dumps(
+ {
+ "train_indexes": splits["train_indexes"],
+ "test_indexes": splits["test_indexes"],
+ "val_indexes": splits["val_indexes"],
+ }
+ )
+
+ x, y = select_columns(
+ prepared_dataset,
+ model_session.input_columns,
+ model_session.output_columns,
+ )
+
+ x = split_dataset(x)
+ y = split_dataset(y)
+
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ f"""Can not prepare Dataset {dataset.id}
+ for ForecastingTask {model_session.task_name}""",
+ ) from e
+
+ try:
+ run_model_class = component_registry[run.model_name]["class"]
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ f"Unable to find Model with name {run.model_name} in registry.",
+ ) from e
+
+ # Validate model is compatible with forecasting.
+ compatible_tasks = getattr(run_model_class, "_compatible_tasks", None)
+ if compatible_tasks is None:
+ compatible_tasks = getattr(
+ run_model_class, "COMPATIBLE_COMPONENTS", None
+ )
+
+ if compatible_tasks is None:
+ log.warning(
+ f"Model {run.model_name} does not specify task compatibility"
+ )
+ elif "ForecastingTask" not in compatible_tasks:
+ raise JobError(
+ f"Model {run.model_name} is not compatible with ForecastingTask"
+ )
+
+ try:
+ factory = ModelFactory(
+ run_model_class,
+ run.parameters,
+ run_id,
+ x,
+ y,
+ train_metrics,
+ validation_metrics,
+ test_metrics,
+ # No n_labels for forecasting tasks
+ n_labels=None,
+ )
+ model: BaseModel = factory.model
+ run_optimizable_parameters = factory.optimizable_parameters
+
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ f"Unable to instantiate forecasting model using run {run_id}",
+ ) from e
+
+ # Handle hyperparameter optimization for forecasting
+ if run_optimizable_parameters:
+ try:
+ # Optimizer configuration
+ run_optimizer_class = component_registry[run.optimizer_name][
+ "class"
+ ]
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ f"Unable to find Optimizer with name "
+ f"{run.optimizer_name} in registry.",
+ ) from e
+
+ if run.goal_metric != "":
+ try:
+ goal_metric = component_registry[run.goal_metric]
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ "Metric is not compatible with the ForecastingTask",
+ ) from e
+ try:
+ optimizer: BaseOptimizer = run_optimizer_class(
+ **run.optimizer_parameters
+ )
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ (
+ "Optimizer parameters not compatible "
+ "with the optimizer"
+ ),
+ ) from e
+
+ try:
+ run.set_status_as_started()
+ db.commit()
+ except exc.SQLAlchemyError as e:
+ log.exception(e)
+ raise JobError(
+ "Connection with the database failed",
+ ) from e
+
+ try:
+ # Forecasting model training
+ plot_paths = []
+ if not run_optimizable_parameters:
+ # Simple fit with forecasting-specific parameters
+ # Pass temporal metadata to model for column information
+ if hasattr(model, "fit") and hasattr(model, "_task_type"):
+ model.fit(
+ x["train"],
+ y["train"],
+ temporal_metadata=temporal_metadata,
+ )
+ else:
+ model.fit(x["train"], y["train"])
+ else:
+ # Hyperparameter optimization for forecasting
+ optimizer.optimize(
+ model,
+ x,
+ y,
+ run_optimizable_parameters,
+ goal_metric,
+ model_session.task_name,
+ )
+ model = optimizer.get_model()
+ # Generate hyperparameter plot
+ trials = optimizer.get_trials_values()
+ plot_filenames, plots = optimizer.create_plots(
+ trials, run_id, n_params=len(run_optimizable_parameters)
+ )
+ for filename, plot in zip(plot_filenames, plots, strict=True):
+ plot_path = os.path.join(config["RUNS_PATH"], filename)
+ with open(plot_path, "wb") as file:
+ pickle.dump(plot, file)
+ plot_paths.append(plot_path)
+
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ "Forecasting model training failed",
+ ) from e
+
+ try:
+ paths = plot_paths + [None] * (4 - len(plot_paths))
+ (
+ run.plot_history_path,
+ run.plot_slice_path,
+ run.plot_contour_path,
+ run.plot_importance_path,
+ ) = paths[:4]
+ db.commit()
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ "Hyperparameter plot path saving failed",
+ ) from e
+
+ try:
+ run.set_status_as_finished()
+ db.commit()
+ except exc.SQLAlchemyError as e:
+ log.exception(e)
+ raise JobError(
+ "Connection with the database failed",
+ ) from e
+
+ try:
+ if train_metrics:
+ model.calculate_metrics(
+ split=SplitEnum.TRAIN,
+ level=LevelEnum.LAST,
+ )
+ if validation_metrics:
+ model.calculate_metrics(
+ split=SplitEnum.VALIDATION,
+ level=LevelEnum.LAST,
+ )
+ if test_metrics:
+ model.calculate_metrics(
+ split=SplitEnum.TEST,
+ level=LevelEnum.LAST,
+ )
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ "Forecasting metrics calculation failed",
+ ) from e
+
+ try:
+ run_path = os.path.join(config["RUNS_PATH"], str(run.id))
+ model.save(run_path)
+
+ # Save forecasting-specific artifacts
+ if hasattr(model, "get_forecast_components"):
+ try:
+ # Save forecast components for interpretation
+ components = model.get_forecast_components(horizon=30)
+ components_path = os.path.join(
+ config["RUNS_PATH"],
+ f"{run.id}_forecast_components.csv",
+ )
+ components.to_csv(components_path, index=False)
+ log.info(f"Saved forecast components to {components_path}")
+ except Exception as e:
+ log.warning(f"Could not save forecast components: {e}")
+
+ except Exception as e:
+ log.exception(e)
+ raise JobError(
+ "Forecasting model saving failed",
+ ) from e
+
+ try:
+ run.run_path = run_path
+ db.commit()
+ log.info(
+ f"✅ ForecastingJob completed successfully for run {run_id}"
+ )
+ except exc.SQLAlchemyError as e:
+ log.exception(e)
+ run.set_status_as_error()
+ db.commit()
+ raise JobError(
+ "Connection with the database failed",
+ ) from e
+ except Exception as e:
+ run.set_status_as_error()
+ db.commit()
+ raise e
+ finally:
+ gc.collect()
diff --git a/DashAI/back/job/model_job.py b/DashAI/back/job/model_job.py
index 3aaff04bc..7cd0b9d1f 100644
--- a/DashAI/back/job/model_job.py
+++ b/DashAI/back/job/model_job.py
@@ -183,6 +183,11 @@ def run(
prepared_dataset, model_session.output_columns[0]
)
+ # Get temporal metadata for forecasting tasks
+ temporal_metadata = None
+ if model_session.task_name == "ForecastingTask":
+ temporal_metadata = task.get_temporal_metadata()
+
splits = json.loads(model_session.splits)
prepared_dataset, splits = prepare_for_model_session(
dataset=prepared_dataset,
@@ -275,9 +280,20 @@ def run(
# Hyperparameter Tunning
plot_paths = []
if not run_optimizable_parameters:
- model.train(
- x["train"], y["train"], x["validation"], y["validation"]
- )
+ # Forecasting models use fit() with temporal_metadata
+ if model_session.task_name == "ForecastingTask":
+ model.fit(
+ x["train"],
+ y["train"],
+ temporal_metadata=temporal_metadata,
+ )
+ else:
+ model.train(
+ x["train"],
+ y["train"],
+ x["validation"],
+ y["validation"],
+ )
else:
optimizer.optimize(
model,
diff --git a/DashAI/back/job/predict_job.py b/DashAI/back/job/predict_job.py
index 54ca2ac31..2a41511dd 100644
--- a/DashAI/back/job/predict_job.py
+++ b/DashAI/back/job/predict_job.py
@@ -1,7 +1,11 @@
import logging
+import math
+import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
+import numpy as np
+import pandas as pd
from fastapi import status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import HTTPException
@@ -18,6 +22,35 @@
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
+
+def sanitize_for_json(value):
+ """Convert NaN/Inf float values to None for JSON serialization.
+
+ Parameters
+ ----------
+ value : Any
+ Value to sanitize (can be list, dict, float, etc.)
+
+ Returns
+ -------
+ Any
+ Sanitized value with NaN/Inf replaced by None
+ """
+ if isinstance(value, dict):
+ return {k: sanitize_for_json(v) for k, v in value.items()}
+ elif isinstance(value, list):
+ return [sanitize_for_json(item) for item in value]
+ elif isinstance(value, float):
+ if math.isnan(value) or math.isinf(value):
+ return None
+ return value
+ elif isinstance(value, np.floating):
+ if np.isnan(value) or np.isinf(value):
+ return None
+ return float(value)
+ return value
+
+
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)
@@ -255,14 +288,185 @@ def get_job_name(self) -> str:
return f"Prediction (Prediction:{prediction_id}, Dataset:{dataset_id})"
+ def _validate_forecasting_dataset(
+ self,
+ dataset: "DashAIDataset",
+ model_session,
+ trained_model: Any,
+ train_dataset: "DashAIDataset" = None,
+ ) -> str:
+ """Validate dataset for forecasting prediction.
+
+ Parameters
+ ----------
+ dataset : DashAIDataset
+ The prediction dataset to validate.
+ model_session : ModelSession
+ The model session associated with the prediction.
+ trained_model : Any
+ The loaded trained model instance.
+ train_dataset : DashAIDataset, optional
+ The training dataset (used for backcasting validation).
+
+ Returns
+ -------
+ str
+ The name of the detected timestamp column
+
+ Raises
+ ------
+ HTTPException
+ If dataset is invalid for forecasting
+ """
+ pred_df = dataset.to_pandas()
+
+ # Auto-detect timestamp column (try 'ds' first for compatibility)
+ timestamp_col = None
+ if "ds" in pred_df.columns:
+ timestamp_col = "ds"
+ else:
+ for col in pred_df.columns:
+ try:
+ pd.to_datetime(pred_df[col])
+ timestamp_col = col
+ log.info(f"Auto-detected timestamp column: '{col}'")
+ break
+ except Exception:
+ continue
+
+ if timestamp_col is None:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail="Forecasting prediction requires a timestamp column "
+ f"(datetime). Available columns: {list(pred_df.columns)}",
+ )
+
+ # Parse and validate timestamps
+ try:
+ ds_series = pd.to_datetime(pred_df[timestamp_col])
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail=(f"Cannot parse '{timestamp_col}' column as datetime: {str(e)}"),
+ ) from e
+
+ # Check for duplicates
+ if ds_series.duplicated().any():
+ duplicates = ds_series[ds_series.duplicated()].unique()[:5].tolist()
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail=(
+ f"Duplicate timestamps found in '{timestamp_col}' column: "
+ f"{duplicates}"
+ ),
+ )
+
+ # Check monotonicity (strictly increasing)
+ if not ds_series.is_monotonic_increasing:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail=(
+ f"Timestamps in '{timestamp_col}' column must be strictly "
+ "increasing (sorted)."
+ ),
+ )
+
+ # Get training metadata from model
+ train_frequency = getattr(trained_model, "frequency", None)
+ train_last_ds = getattr(trained_model, "last_ds", None)
+ exog_cols = getattr(trained_model, "exog_cols", [])
+
+ log.info(
+ f"Training metadata - frequency: {train_frequency}, "
+ f"last_ds: {train_last_ds}, exog_cols: {exog_cols}"
+ )
+
+ # Validate frequency consistency (if available)
+ if train_frequency and len(ds_series) >= 2:
+ try:
+ inferred_freq = pd.infer_freq(ds_series)
+ if inferred_freq and inferred_freq != train_frequency:
+ log.warning(
+ f"Frequency mismatch: training={train_frequency}, "
+ f"prediction={inferred_freq}"
+ )
+ except Exception:
+ log.warning("Could not infer frequency from prediction dataset")
+
+ # Check for backcasting (dates before training start)
+ if train_last_ds and train_dataset is not None:
+ try:
+ train_df = train_dataset.to_pandas()
+
+ # Auto-detect timestamp in training data
+ train_timestamp_col = None
+ if "ds" in train_df.columns:
+ train_timestamp_col = "ds"
+ else:
+ for col in train_df.columns:
+ try:
+ pd.to_datetime(train_df[col])
+ train_timestamp_col = col
+ break
+ except Exception:
+ continue
+
+ if train_timestamp_col:
+ train_ds_series = pd.to_datetime(train_df[train_timestamp_col])
+ train_start = train_ds_series.min()
+
+ min_pred_ds = ds_series.min()
+ if min_pred_ds < train_start:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail=(
+ f"Requested timestamps precede the training "
+ f"window start (train_start = {train_start}). "
+ f"Retrain the model including those dates or "
+ f"submit only in-sample/future dates."
+ ),
+ )
+ except HTTPException:
+ raise
+ except Exception as e:
+ log.warning(f"Could not validate backcasting: {e}")
+
+ # Validate exogenous regressors
+ if exog_cols:
+ missing_exog = [col for col in exog_cols if col not in pred_df.columns]
+ if missing_exog:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail=(
+ f"Missing required exogenous columns for prediction: "
+ f"{missing_exog}. The model was trained with these "
+ f"regressors and requires values for all prediction "
+ f"timestamps."
+ ),
+ )
+
+ # Check for NaN values in exogenous columns
+ for col in exog_cols:
+ if pred_df[col].isna().any():
+ nan_count = pred_df[col].isna().sum()
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail=(
+ f"Exogenous column '{col}' contains {nan_count} "
+ f"missing values. All exogenous regressors must have "
+ f"values for every timestamp."
+ ),
+ )
+
+ log.info(f"Forecasting validation passed for {len(ds_series)} timestamps")
+ return timestamp_col
+
@inject
def run(
self,
- ) -> List[Any]:
- import uuid
- from pathlib import Path
-
+ ) -> None:
from DashAI.back.dataloaders.classes.dashai_dataset import (
+ get_arrow_table,
load_dataset,
save_dataset,
to_dashai_dataset,
@@ -273,7 +477,8 @@ def run(
config = di["config"]
prediction_id: int = self.kwargs["prediction_id"]
- manual_input_data: List[dict] = self.kwargs.get("manual_input_data", [])
+ manual_input_data: list = self.kwargs.get("manual_input_data", [])
+ forecast_periods = self.kwargs.get("forecast_periods")
with session_factory() as db:
try:
@@ -292,12 +497,17 @@ def run(
dataset_id = prediction.dataset_id
- # Validate input data
- if not manual_input_data and not dataset_id:
+ # Validate input data (forecast_periods also valid for forecasting)
+ if (
+ not manual_input_data
+ and not dataset_id
+ and forecast_periods is None
+ ):
prediction.set_status_as_error()
db.commit()
raise JobError(
- "Either dataset_id or manual_input_data must be provided."
+ "Either dataset_id, manual_input_data, or "
+ "forecast_periods must be provided."
)
# Retrieve Model Session
@@ -370,7 +580,7 @@ def run(
) from e
try:
- trained_model: BaseModel = model.load(prediction.run.run_path)
+ trained_model: BaseModel = model().load(prediction.run.run_path)
except Exception as e:
prediction.set_status_as_error()
db.commit()
@@ -393,70 +603,318 @@ def run(
f"{dataset_trained.file_path}/dataset/"
) from e
- try:
- # Load or create prediction dataset
- if dataset_id:
- loaded_dataset: "DashAIDataset" = load_dataset(
- str(Path(f"{dataset.file_path}/dataset/"))
+ # Determine if this is a forecasting task
+ is_forecasting = model_session.task_name == "ForecastingTask"
+ timestamp_col = "ds" # default; overwritten inside forecasting block
+
+ if is_forecasting:
+ # ============ FORECASTING PREDICTION ============
+ try:
+ if forecast_periods is not None:
+ # --- Auto-generate future timestamps ---
+ log.info(
+ f"Auto-generating {forecast_periods} future timestamps"
+ )
+
+ timestamp_col = "ds"
+ frequency = getattr(trained_model, "frequency", "D")
+ if frequency is None:
+ frequency = "D"
+
+ # Get last training date from the FULL dataset.
+ # trained_model.last_ds only stores the end of the
+ # training split, not the end of the full dataset.
+ last_ds = None
+ try:
+ train_df = get_arrow_table(train_dataset).to_pandas()
+ if "ds" in train_df.columns:
+ last_ds = pd.to_datetime(train_df["ds"]).max()
+ timestamp_col = "ds"
+ else:
+ for col in train_df.columns:
+ try:
+ ds_series = pd.to_datetime(train_df[col])
+ last_ds = ds_series.max()
+ timestamp_col = col
+ break
+ except Exception:
+ continue
+ except Exception as e:
+ log.warning(f"Could not read training dataset: {e}")
+
+ # Fall back to model attribute if dataset read fails
+ if last_ds is None:
+ last_ds = getattr(trained_model, "last_ds", None)
+ if last_ds is None:
+ last_ds = getattr(trained_model, "last_timestamp", None)
+
+ if last_ds is None:
+ raise JobError(
+ "Cannot auto-generate timestamps: Unable to "
+ "determine the last training date. "
+ "Please use a dataset instead."
+ )
+
+ last_training_date = pd.to_datetime(last_ds)
+
+ # Check exogenous regressors
+ exog_cols = getattr(trained_model, "exog_cols", [])
+ if exog_cols:
+ raise HTTPException(
+ status_code=(status.HTTP_422_UNPROCESSABLE_ENTITY),
+ detail=(
+ f"Cannot auto-generate predictions for "
+ f"models with exogenous variables "
+ f"({exog_cols}). Please upload a dataset "
+ f"with timestamps and exogenous values."
+ ),
+ )
+
+ # Generate future timestamps using DateOffset
+ freq_offset_map = {
+ "D": pd.DateOffset(days=1),
+ "H": pd.DateOffset(hours=1),
+ "W": pd.DateOffset(weeks=1),
+ "M": pd.DateOffset(months=1),
+ "MS": pd.DateOffset(months=1),
+ "ME": pd.DateOffset(months=1),
+ "Y": pd.DateOffset(years=1),
+ "YS": pd.DateOffset(years=1),
+ "YE": pd.DateOffset(years=1),
+ "A": pd.DateOffset(years=1),
+ "AS": pd.DateOffset(years=1),
+ "Q": pd.DateOffset(months=3),
+ "QS": pd.DateOffset(months=3),
+ "QE": pd.DateOffset(months=3),
+ }
+
+ first_offset = freq_offset_map.get(frequency)
+ if first_offset is None:
+ try:
+ first_offset = pd.Timedelta(1, unit=frequency[0])
+ except ValueError:
+ log.warning(
+ "Unknown frequency '%s', defaulting to 1 day",
+ frequency,
+ )
+ first_offset = pd.DateOffset(days=1)
+
+ start_date = last_training_date + first_offset
+
+ freq_alias_map = {
+ "M": "MS",
+ "Y": "YS",
+ "A": "YS",
+ "Q": "QS",
+ "ME": "MS",
+ "YE": "YS",
+ "QE": "QS",
+ }
+ safe_freq = freq_alias_map.get(frequency, frequency)
+
+ future_dates = pd.date_range(
+ start=start_date,
+ periods=forecast_periods,
+ freq=safe_freq,
+ )
+ future_df = pd.DataFrame({timestamp_col: future_dates})
+
+ log.info(
+ f"Generated timestamps from {future_dates[0]} "
+ f"to {future_dates[-1]}"
+ )
+
+ else:
+ # --- Use uploaded dataset for forecasting ---
+ if dataset_id:
+ loaded_dataset = load_dataset(
+ str(Path(f"{dataset.file_path}/dataset/"))
+ )
+ elif manual_input_data:
+ dataset_trained_path = str(
+ Path(f"{dataset_trained.file_path}/dataset/")
+ )
+ loaded_dataset = task.process_manual_input(
+ manual_input_data, dataset_trained_path
+ )
+ else:
+ raise JobError(
+ "Either dataset_id, manual_input_data, or "
+ "forecast_periods must be provided for "
+ "forecasting."
+ )
+
+ # Validate forecasting dataset
+ timestamp_col = self._validate_forecasting_dataset(
+ loaded_dataset,
+ model_session,
+ trained_model,
+ train_dataset,
+ )
+
+ pred_df = loaded_dataset.to_pandas()
+ exog_cols = getattr(trained_model, "exog_cols", [])
+ future_cols = [timestamp_col] + exog_cols
+ available_cols = [
+ col for col in future_cols if col in pred_df.columns
+ ]
+
+ if timestamp_col not in available_cols:
+ raise JobError(
+ f"Forecasting prediction requires "
+ f"'{timestamp_col}' column in dataset"
+ )
+
+ future_df = pred_df[available_cols].copy()
+ future_df[timestamp_col] = pd.to_datetime(
+ future_df[timestamp_col]
+ )
+
+ # Call model.predict with the future_df
+ log.info(f"Predicting on {len(future_df)} timestamps")
+ predictions = trained_model.predict(future_df)
+
+ # Handle different prediction formats
+ if hasattr(predictions, "yhat"):
+ y_pred = predictions["yhat"].to_numpy()
+ elif isinstance(predictions, np.ndarray):
+ y_pred = predictions
+ else:
+ y_pred = np.array(predictions)
+
+ # Build result dataset: timestamp + prediction
+ output_col = (
+ model_session.output_columns[0]
+ if model_session.output_columns
+ else "prediction"
)
- else:
- dataset_trained_path = str(
- Path(f"{dataset_trained.file_path}/dataset/")
+ result_df = future_df[[timestamp_col]].copy()
+ result_df[output_col] = y_pred
+
+ # Add confidence intervals if available
+ if hasattr(predictions, "columns"):
+ if "yhat_lower" in predictions.columns:
+ result_df["yhat_lower"] = predictions[
+ "yhat_lower"
+ ].to_numpy()
+ if "yhat_upper" in predictions.columns:
+ result_df["yhat_upper"] = predictions[
+ "yhat_upper"
+ ].to_numpy()
+
+ # Convert timestamp column to string so DashAI stores it
+ # consistently (Arrow datetime types lack DashAI metadata).
+ if pd.api.types.is_datetime64_any_dtype(result_df[timestamp_col]):
+ result_df[timestamp_col] = result_df[timestamp_col].dt.strftime(
+ "%Y-%m-%d"
+ )
+
+ dataset_with_prediction = to_dashai_dataset(result_df)
+
+ except HTTPException:
+ prediction.set_status_as_error()
+ db.commit()
+ raise
+ except ValueError as ve:
+ prediction.set_status_as_error()
+ db.commit()
+ log.error(f"Validation Error: {ve}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid input data: {str(ve)}",
+ ) from ve
+ except Exception as e:
+ prediction.set_status_as_error()
+ db.commit()
+ log.error(e)
+ raise JobError(
+ "Forecasting prediction failed",
+ ) from e
+
+ else:
+ # ============ STANDARD PREDICTION ============
+ try:
+ # Load or create prediction dataset
+ if dataset_id:
+ loaded_dataset: DashAIDataset = load_dataset(
+ str(Path(f"{dataset.file_path}/dataset/"))
+ )
+ else:
+ dataset_trained_path = str(
+ Path(f"{dataset_trained.file_path}/dataset/")
+ )
+ loaded_dataset = task.process_manual_input(
+ manual_input_data, dataset_trained_path
+ )
+
+ # Select input columns and make prediction
+ prepared_dataset = loaded_dataset.select_columns(
+ model_session.input_columns
)
- loaded_dataset = task.process_manual_input(
- manual_input_data, dataset_trained_path
+ y_pred_proba = np.array(trained_model.predict(prepared_dataset))
+
+ # Process predictions (convert to labels for classification)
+ y_pred = task.process_predictions(
+ train_dataset,
+ y_pred_proba,
+ model_session.output_columns[0],
)
- prepared_dataset, y_pred = _run_prediction_pipeline(
- task=task,
- trained_model=trained_model,
- train_dataset=train_dataset,
- loaded_dataset=loaded_dataset,
- model_session=model_session,
- )
+ # Build dataset with predictions
+ dataset_with_prediction = to_dashai_dataset(
+ prepared_dataset.add_column(
+ model_session.output_columns[0], y_pred
+ )
+ )
- except ValueError as ve:
- prediction.set_status_as_error()
- db.commit()
- log.error(f"Validation Error: {ve}")
- raise HTTPException(
- status_code=400,
- detail=f"Invalid input data: {str(ve)}",
- ) from ve
- except TypeError as te:
- log.error(f"Type Error: {te}")
- raise HTTPException(
- status_code=400,
- detail=f"Type validation failed: {str(te)}",
- ) from te
- except Exception as e:
- prediction.set_status_as_error()
- db.commit()
- log.error(e)
- raise JobError(
- "Model prediction failed",
- ) from e
+ except ValueError as ve:
+ prediction.set_status_as_error()
+ db.commit()
+ log.error(f"Validation Error: {ve}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid input data: {str(ve)}",
+ ) from ve
+ except TypeError as te:
+ log.error(f"Type Error: {te}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Type validation failed: {str(te)}",
+ ) from te
+ except Exception as e:
+ prediction.set_status_as_error()
+ db.commit()
+ log.error(e)
+ raise JobError(
+ "Model prediction failed",
+ ) from e
- # Save Predictions to Arrow file
+ # ============ SAVE PREDICTIONS TO ARROW ============
try:
- # Create unique folder for predictions
path = str(Path(f"{config['DATASETS_PATH']}/predictions/"))
folder_name = str(uuid.uuid4())
full_path = Path(path) / folder_name
full_path.mkdir(parents=True, exist_ok=True)
- # Add predictions to loaded dataset
- dataset_with_prediction = to_dashai_dataset(
- prepared_dataset.add_column(model_session.output_columns[0], y_pred)
- )
-
- # Filter schema from trained dataset
- trained_schema = train_dataset.types
- filtered_schema = {
- key: value.to_string()
- for key, value in trained_schema.items()
- if key in model_session.input_columns + model_session.output_columns
- }
+ # Build schema for the saved dataset
+ if is_forecasting:
+ # Build a proper schema so the Arrow file has DashAI type
+ # metadata. Empty schema {} causes transform_dataset_with_schema
+ # to produce an empty table with no metadata → 404 on read.
+ filtered_schema = {
+ col: {"dtype": "string"}
+ if col == timestamp_col
+ else {"dtype": "float64"}
+ for col in dataset_with_prediction.column_names
+ }
+ else:
+ trained_schema = train_dataset.types
+ filtered_schema = {
+ key: value.to_string()
+ for key, value in trained_schema.items()
+ if key
+ in model_session.input_columns + model_session.output_columns
+ }
# Store num of rows, columns, and column names
dataset_with_prediction.compute_base_metadata()
@@ -472,10 +930,11 @@ def run(
prediction.results_path = str(full_path)
prediction.set_status_as_finished()
db.commit()
+
except Exception as e:
prediction.set_status_as_error()
db.commit()
log.exception(e)
raise JobError(
- "Can not save prediction to json file",
+ "Cannot save prediction results",
) from e
diff --git a/DashAI/back/metrics/base_metric.py b/DashAI/back/metrics/base_metric.py
index f43290a4f..a99836de5 100644
--- a/DashAI/back/metrics/base_metric.py
+++ b/DashAI/back/metrics/base_metric.py
@@ -9,7 +9,18 @@
class BaseMetric:
- """Abstract class of all metrics."""
+ """Abstract class of all metrics.
+
+ Attributes
+ ----------
+ HIGHER_IS_BETTER : bool
+ Indicates the optimization direction for this metric.
+ - True: Higher values are better (e.g., Accuracy, F1)
+ - False: Lower values are better (e.g., MAE, RMSE, SMAPE)
+
+ This attribute is used by hyperparameter optimizers to determine
+ whether to maximize or minimize the metric during optimization.
+ """
TYPE: Final[str] = "Metric"
MAXIMIZE: Final[bool] = False
diff --git a/DashAI/back/metrics/classification/accuracy.py b/DashAI/back/metrics/classification/accuracy.py
index 05dcf78eb..b669a3862 100644
--- a/DashAI/back/metrics/classification/accuracy.py
+++ b/DashAI/back/metrics/classification/accuracy.py
@@ -12,7 +12,12 @@
class Accuracy(ClassificationMetric):
- """Accuracy metric to classification tasks."""
+ """Accuracy metric to classification tasks.
+
+ Higher accuracy values are better (range: 0.0 to 1.0).
+ """
+
+ HIGHER_IS_BETTER = True
DESCRIPTION: str = (
"Proportion of correct predictions over all samples, "
diff --git a/DashAI/back/metrics/classification/f1.py b/DashAI/back/metrics/classification/f1.py
index 9cf4102f2..6935ca601 100644
--- a/DashAI/back/metrics/classification/f1.py
+++ b/DashAI/back/metrics/classification/f1.py
@@ -12,7 +12,12 @@
class F1(ClassificationMetric):
- """F1 score to classification tasks."""
+ """F1 score to classification tasks.
+
+ Higher F1 values are better (range: 0.0 to 1.0).
+ """
+
+ HIGHER_IS_BETTER = True
DESCRIPTION: str = (
"Harmonic mean of precision and recall, "
diff --git a/DashAI/back/metrics/classification/precision.py b/DashAI/back/metrics/classification/precision.py
index e7f5be067..2d70f7fd4 100644
--- a/DashAI/back/metrics/classification/precision.py
+++ b/DashAI/back/metrics/classification/precision.py
@@ -12,7 +12,12 @@
class Precision(ClassificationMetric):
- """Precision metric to classification tasks."""
+ """Precision metric to classification tasks.
+
+ Higher precision values are better (range: 0.0 to 1.0).
+ """
+
+ HIGHER_IS_BETTER = True
DESCRIPTION: str = (
"Fraction of predicted positives that are correct, "
diff --git a/DashAI/back/metrics/classification/recall.py b/DashAI/back/metrics/classification/recall.py
index a472909ba..0081eb267 100644
--- a/DashAI/back/metrics/classification/recall.py
+++ b/DashAI/back/metrics/classification/recall.py
@@ -12,7 +12,12 @@
class Recall(ClassificationMetric):
- """Recall metric to classification tasks."""
+ """Recall metric to classification tasks.
+
+ Higher recall values are better (range: 0.0 to 1.0).
+ """
+
+ HIGHER_IS_BETTER = True
DESCRIPTION: str = (
"Fraction of actual positives correctly identified, "
diff --git a/DashAI/back/metrics/forecasting/__init__.py b/DashAI/back/metrics/forecasting/__init__.py
new file mode 100644
index 000000000..d6dbff035
--- /dev/null
+++ b/DashAI/back/metrics/forecasting/__init__.py
@@ -0,0 +1,9 @@
+"""Forecasting metrics for time series evaluation."""
+
+from .mape import MAPE
+from .smape import SMAPE
+
+__all__ = [
+ "MAPE",
+ "SMAPE",
+]
diff --git a/DashAI/back/metrics/forecasting/mape.py b/DashAI/back/metrics/forecasting/mape.py
new file mode 100644
index 000000000..062af18ee
--- /dev/null
+++ b/DashAI/back/metrics/forecasting/mape.py
@@ -0,0 +1,55 @@
+"""Mean Absolute Percentage Error (MAPE) metric for forecasting."""
+
+import numpy as np
+
+from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
+from DashAI.back.metrics.regression_metric import RegressionMetric, prepare_to_metric
+
+
+class MAPE(RegressionMetric):
+ """Mean Absolute Percentage Error metric for forecasting tasks.
+
+ MAPE measures the average absolute percentage difference between
+ predicted and actual values. It's scale-independent and easy to interpret.
+
+ MAPE = (1/n) * Σ|((y_true - y_pred) / y_true)| * 100
+
+ Note: MAPE can be problematic when true values are close to zero.
+ """
+
+ COMPATIBLE_COMPONENTS = [
+ "RegressionTask",
+ "ForecastingTask",
+ ]
+
+ @staticmethod
+ def score(true_values: DashAIDataset, predicted_values: np.ndarray) -> float:
+ """Calculate MAPE between true values and predicted values.
+
+ Parameters
+ ----------
+ true_values : DashAIDataset
+ A DashAI dataset with true values.
+ predicted_values : np.ndarray
+ Array with the predicted values for each instance.
+
+ Returns
+ -------
+ float
+ MAPE score as percentage (0-100, lower is better)
+ """
+ true_values, pred_values = prepare_to_metric(true_values, predicted_values)
+
+ # Handle zero values in denominator
+ mask = np.abs(true_values) > 1e-8 # Avoid division by very small numbers
+
+ if not np.any(mask):
+ # All true values are essentially zero
+ return 0.0 if np.allclose(pred_values, 0) else 100.0
+
+ # Calculate MAPE only for non-zero true values
+ mape_values = np.abs(
+ (true_values[mask] - pred_values[mask]) / true_values[mask]
+ )
+
+ return float(np.mean(mape_values) * 100)
diff --git a/DashAI/back/metrics/forecasting/smape.py b/DashAI/back/metrics/forecasting/smape.py
new file mode 100644
index 000000000..414b76a10
--- /dev/null
+++ b/DashAI/back/metrics/forecasting/smape.py
@@ -0,0 +1,56 @@
+"""Symmetric Mean Absolute Percentage Error (sMAPE) metric for forecasting."""
+
+import numpy as np
+
+from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
+from DashAI.back.metrics.regression_metric import RegressionMetric, prepare_to_metric
+
+
+class SMAPE(RegressionMetric):
+ """Symmetric Mean Absolute Percentage Error metric for forecasting tasks.
+
+ sMAPE is a more stable version of MAPE that handles zero values better
+ by using the average of actual and predicted values in the denominator.
+
+ sMAPE = (2/n) * Σ|(y_true - y_pred)| / (|y_true| + |y_pred|) * 100
+
+ sMAPE is bounded between 0% and 200%, making it more stable than MAPE.
+ """
+
+ COMPATIBLE_COMPONENTS = [
+ "RegressionTask",
+ "ForecastingTask",
+ ]
+
+ @staticmethod
+ def score(true_values: DashAIDataset, predicted_values: np.ndarray) -> float:
+ """Calculate sMAPE between true values and predicted values.
+
+ Parameters
+ ----------
+ true_values : DashAIDataset
+ A DashAI dataset with true values.
+ predicted_values : np.ndarray
+ Array with the predicted values for each instance.
+
+ Returns
+ -------
+ float
+ sMAPE score as percentage (0-200, lower is better)
+ """
+ true_values, pred_values = prepare_to_metric(true_values, predicted_values)
+
+ # Calculate symmetric denominator
+ denominator = np.abs(true_values) + np.abs(pred_values)
+
+ # Handle zero denominator (both actual and predicted are zero)
+ mask = denominator > 1e-8
+
+ if not np.any(mask):
+ # All values are essentially zero
+ return 0.0
+
+ # Calculate sMAPE
+ smape_values = np.abs(true_values[mask] - pred_values[mask]) / denominator[mask]
+
+ return float(np.mean(smape_values) * 200)
diff --git a/DashAI/back/metrics/regression_metric.py b/DashAI/back/metrics/regression_metric.py
index 08353d6a5..0eb872125 100644
--- a/DashAI/back/metrics/regression_metric.py
+++ b/DashAI/back/metrics/regression_metric.py
@@ -39,4 +39,14 @@ def prepare_to_metric(
true_values = np.array(y.to_pandas().to_numpy().flatten())
pred_values = np.array(y_pred).flatten()
+ # Filter out NaN values (common in forecasting with lag features)
+ valid_mask = ~(np.isnan(true_values) | np.isnan(pred_values))
+ n_nan = np.sum(~valid_mask)
+ if n_nan > 0:
+ true_values = true_values[valid_mask]
+ pred_values = pred_values[valid_mask]
+
+ if len(true_values) == 0:
+ raise ValueError("All values are NaN after filtering. Cannot compute metrics.")
+
return true_values, pred_values
diff --git a/DashAI/back/models/forecasting/__init__.py b/DashAI/back/models/forecasting/__init__.py
new file mode 100644
index 000000000..a1e826587
--- /dev/null
+++ b/DashAI/back/models/forecasting/__init__.py
@@ -0,0 +1,15 @@
+"""Forecasting models for time series prediction."""
+
+from .base_forecasting_model import ForecastingModel
+from .prophet_model import ProphetModel
+from .sklearn_multistep_forecaster import SklearnMultiStepForecaster
+from .statsmodels_arima_model import StatsmodelsARIMAModel
+from .statsmodels_sarimax_model import StatsmodelsSARIMAXModel
+
+__all__ = [
+ "ForecastingModel",
+ "ProphetModel",
+ "SklearnMultiStepForecaster",
+ "StatsmodelsARIMAModel",
+ "StatsmodelsSARIMAXModel",
+]
diff --git a/DashAI/back/models/forecasting/base_forecasting_model.py b/DashAI/back/models/forecasting/base_forecasting_model.py
new file mode 100644
index 000000000..bfa5d5af9
--- /dev/null
+++ b/DashAI/back/models/forecasting/base_forecasting_model.py
@@ -0,0 +1,329 @@
+"""Forecasting Model abstract class.
+
+This module defines the common interface for all forecasting models in DashAI.
+It ensures model-agnostic handling of time series data and exogenous variables.
+"""
+
+import warnings
+from abc import abstractmethod
+from typing import List, Optional
+
+import pandas as pd
+
+from DashAI.back.models.base_model import BaseModel
+
+
+class ForecastingModel(BaseModel):
+ """Abstract class for all forecasting models.
+
+ This class defines the common interface that all forecasting models must implement.
+ It handles:
+ - Exogenous variables (external regressors) in a model-agnostic way
+ - Timestamp and target column management
+ - Prediction interface for both in-sample and out-of-sample forecasting
+
+ Key Attributes
+ --------------
+ exog_cols : List[str]
+ List of exogenous variable column names used during training.
+ These are stored in their ORIGINAL names from the dataset,
+ not in any model-specific format.
+
+ timestamp_col : Optional[str]
+ Name of the timestamp/datetime column in the original dataset.
+
+ target_col : Optional[str]
+ Name of the target variable column in the original dataset.
+
+ Philosophy
+ ----------
+ This class maintains column names in their ORIGINAL format from the user's
+ dataset. Each specific model implementation (Prophet, ARIMA, etc.) is
+ responsible for:
+ 1. Internally converting column names to its required format
+ (e.g., Prophet needs 'ds'/'y')
+ 2. Converting predictions back to match original column names
+ 3. Handling model-specific requirements transparently
+
+ This ensures the system is agnostic to each model's internal conventions.
+
+ Note
+ ----
+ This class inherits TYPE = "Model" from BaseModel. The name "ForecastingModel"
+ (without "Base" prefix) avoids conflicts with the component registry system
+ which looks for classes with "Base" in their name.
+ """
+
+ _compatible_tasks = ["ForecastingTask"]
+
+ def __init__(self, **kwargs):
+ """Initialize forecasting model.
+
+ Sets up common attributes that all forecasting models should maintain.
+
+ Parameters
+ ----------
+ **kwargs
+ Additional arguments passed to BaseModel.__init__
+ """
+ super().__init__(**kwargs)
+
+ # Store exogenous variable names in ORIGINAL format
+ self.exog_cols: List[str] = []
+
+ # Store column names for reference
+ self.timestamp_col: Optional[str] = None
+ self.target_col: Optional[str] = None
+
+ @abstractmethod
+ def fit(self, x: pd.DataFrame, y: pd.DataFrame, **kwargs) -> "ForecastingModel":
+ """Train the forecasting model.
+
+ Parameters
+ ----------
+ x : pd.DataFrame
+ Training features including:
+ - Timestamp column (datetime)
+ - Exogenous variables (optional)
+ May also include the target column (will be used from there if present)
+
+ y : pd.DataFrame
+ Target variable values.
+ Single column with the variable to forecast.
+
+ **kwargs
+ Additional model-specific parameters.
+
+ Returns
+ -------
+ self : ForecastingModel
+ Returns self for method chaining.
+
+ Notes
+ -----
+ Implementations should:
+ 1. Auto-detect timestamp column (try pd.to_datetime on columns)
+ 2. Filter exogenous variables (numeric only, exclude timestamp/target)
+ 3. Store original column names in self.exog_cols, self.timestamp_col,
+ self.target_col
+ 4. Internally convert to model-specific format if needed
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def predict(
+ self,
+ x_pred: Optional[pd.DataFrame] = None,
+ periods: Optional[int] = None,
+ exog_future: Optional[pd.DataFrame] = None,
+ **kwargs,
+ ) -> pd.DataFrame:
+ """Generate forecasts.
+
+ Supports two prediction modes:
+ 1. In-sample: Provide x_pred with timestamps and exogenous values
+ 2. Out-of-sample: Provide periods and exog_future for future forecasting
+
+ Parameters
+ ----------
+ x_pred : pd.DataFrame, optional
+ Input data for in-sample predictions containing:
+ - Timestamp column
+ - Exogenous variables (if model uses them)
+
+ periods : int, optional
+ Number of future periods to forecast (out-of-sample mode).
+
+ exog_future : pd.DataFrame, optional
+ Future values of exogenous variables for out-of-sample forecasting.
+ Must contain all columns in self.exog_cols.
+
+ **kwargs
+ Additional model-specific parameters.
+
+ Returns
+ -------
+ pd.DataFrame or np.ndarray
+ Predictions with columns using ORIGINAL names:
+ - Timestamp column (same name as training data)
+ - Target column (same name as training data)
+ - Optionally: prediction intervals, components, etc.
+
+ Notes
+ -----
+ Implementations MUST support both prediction modes:
+ 1. In-sample predictions (x_pred provided): For calculating metrics on
+ train/validation/test splits
+ 2. Out-of-sample predictions (periods provided): For future forecasting
+
+ Implementations should:
+ 1. Auto-detect timestamp column in x_pred (handle both original name and 'ds')
+ 2. Validate exogenous variables are present if model requires them
+ 3. Return predictions with ORIGINAL column names (not model-specific names)
+
+ IMPORTANT: Do NOT raise NotImplementedError for in-sample predictions.
+ Model evaluation (metrics calculation) requires in-sample predictions.
+ """
+ raise NotImplementedError
+
+ def train(
+ self,
+ x_train,
+ y_train,
+ x_validation=None,
+ y_validation=None,
+ **kwargs,
+ ) -> "ForecastingModel":
+ """Compatibility wrapper for the generic DashAI model contract.
+
+ Forecasting jobs train models via ``fit()`` so they can pass
+ ``temporal_metadata`` and other forecasting-specific arguments. This
+ wrapper keeps forecasting models instantiable through ``ModelFactory``,
+ which still expects every model to provide a concrete ``train()`` method.
+ """
+ if x_validation is not None or y_validation is not None:
+ warnings.warn(
+ "ForecastingModel.train() ignores validation datasets. "
+ "Forecasting models should be trained via fit() with the "
+ "appropriate temporal metadata.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ return self.fit(x_train, y_train, **kwargs)
+
+ def get_exogenous_columns(self) -> List[str]:
+ """Get list of exogenous variable names in original format.
+
+ Returns
+ -------
+ List[str]
+ List of exogenous variable column names as they appear in the
+ original dataset (not in model-specific format).
+
+ Examples
+ --------
+ >>> model.fit(x_train, y_train)
+ >>> model.get_exogenous_columns()
+ ['temperature', 'humidity', 'wind_speed']
+ # NOT ['exog_temperature', 'exog_humidity', 'exog_wind_speed']
+ # NOT ['extra_regressor_1', 'extra_regressor_2', 'extra_regressor_3']
+ """
+ return self.exog_cols.copy()
+
+ def has_exogenous_variables(self) -> bool:
+ """Check if model uses exogenous variables.
+
+ Returns
+ -------
+ bool
+ True if model was trained with exogenous variables, False otherwise.
+ """
+ return len(self.exog_cols) > 0
+
+ def get_column_names(self) -> dict:
+ """Get all relevant column names in original format.
+
+ Returns
+ -------
+ dict
+ Dictionary with keys:
+ - 'timestamp': Timestamp column name
+ - 'target': Target column name
+ - 'exogenous': List of exogenous variable names
+ """
+ return {
+ "timestamp": self.timestamp_col,
+ "target": self.target_col,
+ "exogenous": self.exog_cols.copy(),
+ }
+
+ def _get_seasonal_period(self) -> int:
+ """Infer seasonal period from the model's stored frequency.
+
+ Returns the number of observations per seasonal cycle, used to
+ configure STL decomposition in ``get_forecast_components()``.
+
+ Returns
+ -------
+ int
+ Seasonal period (e.g., 7 for daily data → weekly cycle).
+ """
+ if not self.frequency:
+ return 7 # default: weekly cycle for daily data
+
+ freq = self.frequency.upper().strip()
+
+ if freq.startswith(("T", "MIN")):
+ return 60 # minutely → hourly seasonality
+ if freq.startswith("H"):
+ return 24 # hourly → daily seasonality
+ if freq.startswith("D"):
+ return 7 # daily → weekly seasonality
+ if freq.startswith("W"):
+ return 52 # weekly → yearly seasonality
+ if freq.startswith(("M", "ME", "MS")):
+ return 12 # monthly → yearly seasonality
+ if freq.startswith(("Q", "QE", "QS")):
+ return 4 # quarterly → yearly seasonality
+ if freq.startswith(("A", "Y", "AE", "AS", "YS", "YE")):
+ return 1 # yearly → no sub-annual seasonality
+
+ return 7 # fallback
+
+ def _period_to_seasonality_name(self, period: int) -> str:
+ """Map a seasonal period integer to a human-readable component name.
+
+ Parameters
+ ----------
+ period : int
+ Number of observations per seasonal cycle.
+
+ Returns
+ -------
+ str
+ Name for the seasonal component column (e.g., 'weekly', 'yearly').
+ """
+ mapping = {
+ 60: "hourly",
+ 24: "daily",
+ 7: "weekly",
+ 52: "yearly",
+ 12: "yearly",
+ 4: "yearly",
+ 365: "yearly",
+ }
+ return mapping.get(period, "seasonal")
+
+ def _validate_predict_implementation(self) -> None:
+ """Validate that subclass implements predict() correctly.
+
+ This method can be called in tests to ensure implementations support
+ both in-sample and out-of-sample predictions.
+
+ Raises
+ ------
+ NotImplementedError
+ If predict() raises NotImplementedError for in-sample predictions
+ ValueError
+ If predict() doesn't handle both prediction modes
+
+ Notes
+ -----
+ This is a helper for testing - not called automatically during runtime.
+ Developers should call this in unit tests for their forecasting models.
+
+ Example
+ -------
+ >>> # In test_my_model.py
+ >>> model = MyForecastingModel()
+ >>> model.fit(x_train, y_train)
+ >>> model._validate_predict_implementation() # Ensures correct implementation
+ """
+ warnings.warn(
+ "ForecastingModel.predict() must support both in-sample (x_pred) "
+ "and out-of-sample (periods) prediction modes. "
+ "In-sample predictions are required for metrics calculation.",
+ UserWarning,
+ stacklevel=2,
+ )
diff --git a/DashAI/back/models/forecasting/prophet_model.py b/DashAI/back/models/forecasting/prophet_model.py
new file mode 100644
index 000000000..94b9a7c9b
--- /dev/null
+++ b/DashAI/back/models/forecasting/prophet_model.py
@@ -0,0 +1,1029 @@
+"""Prophet model wrapper for DashAI forecasting.
+
+This model wraps Facebook Prophet for native time series forecasting
+with automatic seasonality detection and holiday effects.
+"""
+
+import os
+import pickle
+from typing import Any, Optional, Union
+
+import numpy as np
+import pandas as pd
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ enum_field,
+ float_field,
+ int_field,
+ schema_field,
+)
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ to_dashai_dataset,
+)
+from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel
+
+
+def _patch_prophet_regressor_column_matrix():
+ """Patch Prophet for compatibility with newer pandas versions."""
+ from prophet import Prophet
+
+ if getattr(Prophet, "_dashai_pandas_compat_patch", False):
+ return Prophet
+
+ @staticmethod
+ def _dashai_fourier_series(dates, period, series_order):
+ """Prophet expects nanosecond timestamps, but newer pandas can
+ hand back ``datetime64[us]`` arrays. Force nanosecond precision so
+ weekly/yearly seasonal features keep the correct period.
+ """
+ if not (series_order >= 1):
+ raise ValueError("series_order must be >= 1")
+
+ ns_dates = (
+ pd.to_datetime(dates).to_numpy(dtype="datetime64[ns]").astype(np.int64)
+ )
+ t = ns_dates // 1_000_000_000 / (3600 * 24.0)
+
+ x_T = t * np.pi * 2
+ fourier_components = np.empty((dates.shape[0], 2 * series_order))
+ for i in range(series_order):
+ c = x_T * (i + 1) / period
+ fourier_components[:, 2 * i] = np.sin(c)
+ fourier_components[:, (2 * i) + 1] = np.cos(c)
+ return fourier_components
+
+ def _dashai_regressor_column_matrix(self, seasonal_features, modes):
+ components = pd.DataFrame(
+ {
+ "col": np.arange(seasonal_features.shape[1]),
+ "component": [x.split("_delim_")[0] for x in seasonal_features.columns],
+ }
+ )
+
+ if self.train_holiday_names is not None:
+ components = self.add_group_component(
+ components, "holidays", self.train_holiday_names.unique()
+ )
+
+ for mode in ["additive", "multiplicative"]:
+ components = self.add_group_component(
+ components, mode + "_terms", modes[mode]
+ )
+ regressors_by_mode = [
+ r for r, props in self.extra_regressors.items() if props["mode"] == mode
+ ]
+ components = self.add_group_component(
+ components,
+ "extra_regressors_" + mode,
+ regressors_by_mode,
+ )
+ modes[mode].append(mode + "_terms")
+ modes[mode].append("extra_regressors_" + mode)
+
+ modes[self.holidays_mode].append("holidays")
+
+ clean_components = components.reset_index(drop=True)
+ component_cols = pd.crosstab(
+ pd.Series(clean_components["col"].to_numpy(), name="col"),
+ pd.Series(clean_components["component"].to_numpy(), name="component"),
+ ).sort_index(level="col")
+
+ for name in ["additive_terms", "multiplicative_terms"]:
+ if name not in component_cols:
+ component_cols[name] = 0
+
+ component_cols = component_cols.drop("zeros", axis=1, errors="ignore")
+
+ if (
+ max(
+ component_cols["additive_terms"]
+ + component_cols["multiplicative_terms"]
+ )
+ > 1
+ ):
+ raise Exception("A bug occurred in seasonal components.")
+
+ if self.train_component_cols is not None:
+ component_cols = component_cols[self.train_component_cols.columns]
+ if not component_cols.equals(self.train_component_cols):
+ raise Exception("A bug occurred in constructing regressors.")
+
+ return component_cols, modes
+
+ Prophet.fourier_series = _dashai_fourier_series
+ Prophet.regressor_column_matrix = _dashai_regressor_column_matrix
+ Prophet._dashai_pandas_compat_patch = True
+ return Prophet
+
+
+class ProphetModelSchema(BaseSchema):
+ """Schema for Prophet model configuration.
+
+ Prophet is a forecasting procedure designed for business time series data.
+ It works best with time series that have strong seasonal effects and several
+ seasons of historical data. Prophet is robust to missing data and shifts in
+ the trend, and typically handles outliers well.
+ """
+
+ seasonality_mode: schema_field(
+ enum_field(enum=["additive", "multiplicative"]),
+ placeholder="additive",
+ description="Type of seasonality. 'additive' assumes seasonal effects are "
+ "added to the trend. 'multiplicative' assumes seasonal effects are "
+ "multiplied by the trend.",
+ ) = "additive" # type: ignore
+
+ yearly_seasonality: schema_field(
+ enum_field(enum=["auto", "true", "false"]),
+ placeholder="auto",
+ description="Yearly seasonality. 'auto' detects automatically, "
+ "'true' forces yearly seasonality, 'false' disables it.",
+ ) = "auto" # type: ignore
+
+ weekly_seasonality: schema_field(
+ enum_field(enum=["auto", "true", "false"]),
+ placeholder="auto",
+ description="Weekly seasonality. 'auto' detects automatically, "
+ "'true' forces weekly seasonality, 'false' disables it.",
+ ) = "auto" # type: ignore
+
+ daily_seasonality: schema_field(
+ enum_field(enum=["auto", "true", "false"]),
+ placeholder="auto",
+ description="Daily seasonality. 'auto' detects automatically, "
+ "'true' forces daily seasonality, 'false' disables it.",
+ ) = "auto" # type: ignore
+
+ growth: schema_field(
+ enum_field(enum=["linear", "logistic", "flat"]),
+ placeholder="linear",
+ description="Growth model. 'linear' for unlimited growth, "
+ "'logistic' for growth that saturates at a carrying capacity "
+ "(requires cap_multiplier), 'flat' for no trend.",
+ ) = "linear" # type: ignore
+
+ cap_multiplier: schema_field(
+ float_field(ge=1.0, le=10.0),
+ placeholder=1.5,
+ description="For logistic growth: multiplier applied to max(y) to set "
+ "the carrying capacity. E.g., 1.5 means cap = 1.5 * max(y). "
+ "Only used when growth='logistic'.",
+ ) = 1.5 # type: ignore
+
+ floor_ratio: schema_field(
+ float_field(ge=0.0, le=1.0),
+ placeholder=0.0,
+ description="For logistic growth: floor as ratio of min(y). "
+ "E.g., 0.5 means floor = 0.5 * min(y). "
+ "Only used when growth='logistic'.",
+ ) = 0.0 # type: ignore
+
+ changepoint_prior_scale: schema_field(
+ float_field(ge=0.001, le=1.0),
+ placeholder=0.05,
+ description="Controls flexibility of automatic changepoint selection. "
+ "Higher values allow more changepoints (more flexible trend). "
+ "Lower values result in fewer changepoints (more conservative trend).",
+ ) = 0.05 # type: ignore
+
+ seasonality_prior_scale: schema_field(
+ float_field(ge=0.01, le=100.0),
+ placeholder=10.0,
+ description="Controls flexibility of seasonality. Higher values allow "
+ "more seasonal variation. Lower values result in smoother seasonality.",
+ ) = 10.0 # type: ignore
+
+ holidays_prior_scale: schema_field(
+ float_field(ge=0.01, le=100.0),
+ placeholder=10.0,
+ description="Controls flexibility of holiday effects. Higher values "
+ "allow larger holiday effects.",
+ ) = 10.0 # type: ignore
+
+ interval_width: schema_field(
+ float_field(ge=0.5, le=0.99),
+ placeholder=0.8,
+ description="Width of prediction intervals. 0.8 means 80% confidence "
+ "intervals. Prophet will generate yhat_lower and yhat_upper bounds.",
+ ) = 0.8 # type: ignore
+
+ uncertainty_samples: schema_field(
+ int_field(ge=100, le=10000),
+ placeholder=1000,
+ description="Number of samples to draw for uncertainty estimation. "
+ "More samples give smoother intervals but slower prediction.",
+ ) = 1000 # type: ignore
+
+
+class ProphetModel(ForecastingModel):
+ """Prophet forecasting model wrapper for DashAI.
+
+ This model implements the ForecastingModel interface, handling all
+ column name conversions internally. It maintains exogenous variables in
+ their original format and converts to Prophet's 'ds'/'y' convention only
+ during internal operations.
+ """
+
+ SCHEMA = ProphetModelSchema
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+ _task_type = "ForecastingTask"
+
+ def __init__(
+ self,
+ seasonality_mode: str = "additive",
+ yearly_seasonality: str = "auto",
+ weekly_seasonality: str = "auto",
+ daily_seasonality: str = "auto",
+ growth: str = "linear",
+ cap_multiplier: float = 1.5,
+ floor_ratio: float = 0.0,
+ changepoint_prior_scale: float = 0.05,
+ seasonality_prior_scale: float = 10.0,
+ holidays_prior_scale: float = 10.0,
+ interval_width: float = 0.8,
+ uncertainty_samples: int = 1000,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs) # Pass kwargs to ForecastingModel
+
+ self.seasonality_mode = seasonality_mode
+ self.yearly_seasonality = self._parse_bool_setting(yearly_seasonality)
+ self.weekly_seasonality = self._parse_bool_setting(weekly_seasonality)
+ self.daily_seasonality = self._parse_bool_setting(daily_seasonality)
+ self.growth = growth
+ self.cap_multiplier = cap_multiplier
+ self.floor_ratio = floor_ratio
+ self.changepoint_prior_scale = changepoint_prior_scale
+ self.seasonality_prior_scale = seasonality_prior_scale
+ self.holidays_prior_scale = holidays_prior_scale
+ self.interval_width = interval_width
+ self.uncertainty_samples = uncertainty_samples
+
+ # Store cap/floor for predictions when using logistic growth
+ self._cap_value: Optional[float] = None
+ self._floor_value: Optional[float] = None
+
+ self.model = None
+ # exog_cols, timestamp_col, target_col are inherited from ForecastingModel
+ self.last_ds: Optional[pd.Timestamp] = None
+ self.frequency: Optional[str] = None
+
+ def _parse_bool_setting(self, setting: str) -> Union[bool, str]:
+ if setting.lower() == "true":
+ return True
+ if setting.lower() == "false":
+ return False
+ return "auto"
+
+ def _validate_forecasting_data(self, x: DashAIDataset, y: DashAIDataset) -> None:
+ """Validate that data is suitable for Prophet.
+
+ Parameters
+ ----------
+ X : DashAIDataset
+ Input features (must contain a timestamp column)
+ y : DashAIDataset
+ Target values (must contain a numeric column)
+
+ Raises
+ ------
+ ValueError
+ If data is not suitable for Prophet
+ """
+ x_cols = set(x.column_names)
+ y_cols = set(y.column_names)
+
+ if len(x_cols) == 0:
+ raise ValueError(
+ "Prophet requires at least one input column (timestamp). "
+ "Received empty dataset."
+ )
+
+ if len(y_cols) != 1:
+ raise ValueError(
+ f"Prophet requires exactly one target column. "
+ f"Received {len(y_cols)} columns: {list(y_cols)}"
+ )
+
+ def fit(
+ self,
+ x_train: DashAIDataset,
+ y: DashAIDataset,
+ temporal_metadata: dict = None,
+ **fit_params,
+ ) -> "ProphetModel":
+ """Train Prophet forecasting model.
+
+ Implements ForecastingModel.fit() interface. Handles all column name
+ conversions internally - stores original names in base class attributes,
+ converts to Prophet's 'ds'/'y' convention only for internal use.
+
+ Parameters
+ ----------
+ x_train : DashAIDataset
+ Input features containing timestamp and optional exogenous variables
+ y : DashAIDataset
+ Target time series (single column)
+ temporal_metadata : dict, optional
+ Metadata from ForecastingTask containing:
+ - timestamp_col: name of timestamp column
+ - target_col: name of target column
+ - exog_cols: list of exogenous variable column names
+ - frequency: time series frequency
+ If not provided, will attempt auto-detection (legacy behavior)
+ **fit_params
+ Additional fitting parameters
+
+ Returns
+ -------
+ ProphetModel
+ Fitted model instance
+ """
+ try:
+ Prophet = _patch_prophet_regressor_column_matrix()
+ except ImportError as e:
+ raise ImportError(
+ "Prophet is required for ProphetModel. "
+ "Install with: pip install prophet"
+ ) from e
+
+ # Validate data format
+ self._validate_forecasting_data(x_train, y)
+
+ # Convert to pandas DataFrames
+ x_df = x_train.to_pandas()
+ y_df = y.to_pandas()
+
+ # Get column information from metadata (task-agnostic approach)
+ if temporal_metadata:
+ timestamp_col = temporal_metadata.get("timestamp_col")
+ target_col = temporal_metadata.get("target_col")
+ exog_cols_from_task = temporal_metadata.get("exog_cols", [])
+ frequency = temporal_metadata.get("frequency", "D")
+
+ print("[ProphetModel] Using temporal metadata from task:")
+ print(f" - Timestamp: '{timestamp_col}'")
+ print(f" - Target: '{target_col}'")
+ print(f" - Frequency: {frequency}")
+ if exog_cols_from_task:
+ print(f" - Exogenous variables: {exog_cols_from_task}")
+ else:
+ # Legacy: auto-detection if no metadata provided
+ print(
+ "[ProphetModel] ⚠️ No temporal_metadata provided, using auto-detection"
+ )
+
+ # Get target column name (should be single column)
+ target_col = y_df.columns[0]
+
+ # Auto-detect timestamp column in x_df
+ timestamp_col = None
+ for col in x_df.columns:
+ try:
+ pd.to_datetime(x_df[col])
+ timestamp_col = col
+ print(f"[ProphetModel] Detected timestamp column: '{col}'")
+ break
+ except Exception:
+ continue
+
+ if timestamp_col is None:
+ raise ValueError(
+ f"No timestamp column found in input data. "
+ f"Available columns: {list(x_df.columns)}"
+ )
+
+ exog_cols_from_task = []
+ frequency = fit_params.get("frequency", "D")
+
+ # Store original column names in base class attributes
+ self.timestamp_col = timestamp_col
+ self.target_col = target_col
+ self.frequency = frequency
+
+ # Build Prophet dataframe (internal conversion to 'ds'/'y')
+ prophet_df = pd.DataFrame(
+ {
+ "ds": pd.to_datetime(x_df[timestamp_col]).to_numpy(),
+ }
+ )
+
+ # Check if target column is in x_train (user might have included it by mistake)
+ target_in_inputs = target_col in x_df.columns
+
+ if target_in_inputs:
+ # Target is in inputs - use it from there for consistency
+ print(
+ "[ProphetModel] ℹ️ Target '{}' found in inputs - using it "
+ "from there".format(target_col)
+ )
+ prophet_df["y"] = pd.to_numeric(
+ x_df[target_col], errors="coerce"
+ ).to_numpy()
+ else:
+ # Target is only in y - normal case
+ prophet_df["y"] = pd.to_numeric(
+ y_df[target_col], errors="coerce"
+ ).to_numpy()
+
+ # Add exogenous variables (columns that are not timestamp and are numeric)
+ # Exclude timestamp and target columns, and only include numeric columns
+ # Store in ORIGINAL format (as per BaseForecastingModel contract)
+ self.exog_cols = []
+ if temporal_metadata:
+ candidate_exog_cols = [col for col in exog_cols_from_task if col in x_df]
+ missing_exog_cols = [col for col in exog_cols_from_task if col not in x_df]
+ if missing_exog_cols:
+ print(
+ "[ProphetModel] ⚠️ Ignoring missing exogenous columns from task: "
+ f"{missing_exog_cols}"
+ )
+ else:
+ candidate_exog_cols = [
+ col for col in x_df.columns if col not in {timestamp_col, target_col}
+ ]
+
+ for col in candidate_exog_cols:
+ if col == target_col and target_in_inputs:
+ print(
+ "[ProphetModel] ℹ️ Excluding target '{}' from exogenous "
+ "variables".format(col)
+ )
+ continue
+
+ # Only add numeric columns
+ if pd.api.types.is_numeric_dtype(x_df[col]):
+ self.exog_cols.append(col) # Store ORIGINAL name
+ prophet_df[col] = x_df[col].to_numpy()
+ else:
+ print(
+ "[ProphetModel] ⚠️ Skipping non-numeric column: '{}' "
+ "(type: {})".format(col, x_df[col].dtype)
+ )
+
+ # Handle logistic growth - requires 'cap' (and optionally 'floor') columns
+ if self.growth == "logistic":
+ y_max = prophet_df["y"].max()
+ y_min = prophet_df["y"].min()
+
+ # Calculate cap and floor based on multipliers
+ self._cap_value = y_max * self.cap_multiplier
+ self._floor_value = y_min * self.floor_ratio
+
+ # Add cap column (required for logistic growth)
+ prophet_df["cap"] = self._cap_value
+
+ # Add floor column if floor_ratio > 0
+ if self.floor_ratio > 0:
+ prophet_df["floor"] = self._floor_value
+
+ print(
+ f"[ProphetModel] Logistic growth: cap={self._cap_value:.2f} "
+ f"(max*{self.cap_multiplier}), floor={self._floor_value:.2f}"
+ )
+
+ # Store additional metadata
+ self.last_ds = prophet_df["ds"].max()
+
+ print(f"[ProphetModel] Training with {len(prophet_df)} data points")
+ print(
+ f"[ProphetModel] Date range: {prophet_df['ds'].min()} to "
+ f"{prophet_df['ds'].max()}"
+ )
+ if self.exog_cols:
+ print(f"[ProphetModel] Exogenous variables: {self.exog_cols}")
+
+ # Initialize Prophet model
+ self.model = Prophet(
+ seasonality_mode=self.seasonality_mode,
+ yearly_seasonality=self.yearly_seasonality,
+ weekly_seasonality=self.weekly_seasonality,
+ daily_seasonality=self.daily_seasonality,
+ growth=self.growth,
+ changepoint_prior_scale=self.changepoint_prior_scale,
+ seasonality_prior_scale=self.seasonality_prior_scale,
+ holidays_prior_scale=self.holidays_prior_scale,
+ interval_width=self.interval_width,
+ uncertainty_samples=self.uncertainty_samples,
+ )
+
+ # Add exogenous regressors to Prophet (using original names)
+ for col in self.exog_cols:
+ self.model.add_regressor(col)
+
+ self.model.fit(prophet_df)
+
+ print("✅ Prophet model training completed")
+ return self
+
+ def _add_cap_floor_columns(self, dataframe: pd.DataFrame) -> pd.DataFrame:
+ """Add cap and floor columns for logistic growth predictions.
+
+ Parameters
+ ----------
+ dataframe : pd.DataFrame
+ DataFrame to add cap/floor columns to
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame with cap (and optionally floor) columns added
+ """
+ if self.growth != "logistic":
+ return dataframe
+
+ result_df = dataframe.copy()
+
+ if self._cap_value is not None:
+ result_df["cap"] = self._cap_value
+
+ if self._floor_value is not None and self.floor_ratio > 0:
+ result_df["floor"] = self._floor_value
+
+ return result_df
+
+ def predict(
+ self,
+ x_pred: Optional[Any] = None,
+ periods: Optional[int] = None,
+ horizon: Optional[int] = None,
+ exog_future: Optional[pd.DataFrame] = None,
+ return_components: bool = False,
+ **kwargs,
+ ) -> Union[np.ndarray, pd.DataFrame]:
+ if self.model is None:
+ raise ValueError("Prophet model is not fitted yet. Call fit() first.")
+
+ def _extract_predictions(
+ forecast_df: pd.DataFrame, requested_ds: pd.Series
+ ) -> Union[np.ndarray, pd.DataFrame]:
+ """Extract predictions for requested timestamps.
+
+ For timestamps that don't exist in Prophet's forecast (gaps in data),
+ returns NaN. These will be filtered out by prepare_to_metric().
+ """
+ # Normalize both forecast and requested timestamps to ensure matching
+ # Prophet internally normalizes dates, so we need to do the same
+ forecast_df = forecast_df.copy()
+ forecast_df["ds"] = pd.to_datetime(forecast_df["ds"]).dt.normalize()
+ requested_ds_normalized = pd.to_datetime(requested_ds).dt.normalize()
+
+ # Debug: Show sample of dates
+ print(
+ f"[ProphetModel] _extract_predictions: "
+ f"forecast has {len(forecast_df)} rows, "
+ f"requested {len(requested_ds_normalized)} timestamps"
+ )
+ print(
+ f"[ProphetModel] Forecast dates range: "
+ f"{forecast_df['ds'].min()} to {forecast_df['ds'].max()}"
+ )
+ print(
+ f"[ProphetModel] Requested dates range: "
+ f"{requested_ds_normalized.min()} to {requested_ds_normalized.max()}"
+ )
+
+ aligned = forecast_df.set_index("ds").reindex(requested_ds_normalized)
+
+ # Check for missing predictions
+ missing_mask = aligned["yhat"].isna()
+ if missing_mask.any():
+ missing_count = missing_mask.sum()
+ total_count = len(requested_ds)
+ print(
+ f"[ProphetModel] ⚠️ {missing_count}/{total_count} timestamps "
+ f"have no predictions (gaps in data). These will be excluded "
+ f"from metrics calculation."
+ )
+ # Debug: Show which dates are missing
+ if missing_count <= 10:
+ missing_dates = requested_ds_normalized[missing_mask.to_numpy()]
+ print(f"[ProphetModel] Missing dates: {list(missing_dates)}")
+ else:
+ missing_dates = requested_ds_normalized[missing_mask.to_numpy()]
+ print(
+ f"[ProphetModel] First 5 missing dates: "
+ f"{list(missing_dates[:5])}"
+ )
+
+ if return_components:
+ return aligned.reset_index()
+ return aligned["yhat"].to_numpy()
+
+ if x_pred is not None:
+ if isinstance(x_pred, (int, np.integer)):
+ periods = int(x_pred)
+ else:
+ if isinstance(x_pred, pd.DataFrame):
+ input_df = x_pred.copy()
+ else:
+ input_df = to_dashai_dataset(x_pred).to_pandas()
+
+ # Auto-detect timestamp column (try 'ds' first for compatibility)
+ timestamp_col = None
+ if "ds" in input_df.columns:
+ timestamp_col = "ds"
+ else:
+ # Try to find timestamp column
+ for col in input_df.columns:
+ try:
+ pd.to_datetime(input_df[col])
+ timestamp_col = col
+ break
+ except Exception:
+ continue
+
+ if timestamp_col is None:
+ raise ValueError(
+ "Prophet predict requires a timestamp column. "
+ f"Available columns: {list(input_df.columns)}"
+ )
+
+ input_df = input_df.copy()
+
+ # Rename to 'ds' for Prophet
+ if timestamp_col != "ds":
+ input_df = input_df.rename(columns={timestamp_col: "ds"})
+
+ # Normalize timestamps to ensure consistent comparison
+ input_df["ds"] = pd.to_datetime(input_df["ds"]).dt.normalize()
+ input_df = input_df.sort_values("ds").reset_index(drop=True)
+
+ # Check if we need in-sample predictions (for explainability)
+ # If any requested date is <= last training date, we need to include
+ # historical dates in the prediction
+ # Use Prophet's internal history dataframe to get training date range
+ if not hasattr(self.model, "history_dates"):
+ raise ValueError(
+ "Prophet model has no training history. "
+ "Ensure the model was fitted before prediction."
+ )
+ # Normalize history dates for consistent comparison
+ history_dates = pd.Series(self.model.history_dates)
+ history_dates_normalized = history_dates.dt.normalize()
+ last_train_date = history_dates_normalized.max()
+ has_historical = (input_df["ds"] <= last_train_date).any()
+
+ if has_historical:
+ # For in-sample predictions (explainability use case):
+ # Include both historical and future dates
+ # Create a complete dataframe from first training date to last
+ # requested date. This ensures Prophet generates predictions for
+ # all dates including historical ones
+ max_requested_date = input_df["ds"].max()
+
+ # Use make_future_dataframe but include historical dates
+ future_df = self.model.make_future_dataframe(
+ periods=0, # Don't extend beyond training
+ freq=self.frequency or "D",
+ include_history=True, # Include training dates
+ )
+
+ # Add any future dates beyond training if needed
+ if max_requested_date > last_train_date:
+ additional_periods = pd.date_range(
+ start=last_train_date + pd.Timedelta(days=1),
+ end=max_requested_date,
+ freq=self.frequency or "D",
+ )
+ additional_df = pd.DataFrame({"ds": additional_periods})
+ future_df = pd.concat(
+ [future_df, additional_df], ignore_index=True
+ )
+
+ # Add exogenous variables if present
+ if self.exog_cols:
+ missing_cols = [
+ col for col in self.exog_cols if col not in input_df.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ "Missing exogenous columns for prediction: "
+ f"{missing_cols}."
+ )
+
+ # Merge exogenous data from input_df with future_df
+ # For historical dates, use the provided values
+ future_df = future_df.merge(
+ input_df[["ds"] + self.exog_cols], on="ds", how="left"
+ )
+
+ # Check if there are missing exogenous values
+ if future_df[self.exog_cols].isna().any().any():
+ raise ValueError(
+ "Missing exogenous values for some dates. "
+ "All dates in prediction range must have "
+ "exogenous data."
+ )
+ else:
+ # Normal future forecasting (original behavior)
+ future_df = input_df[["ds"]].copy()
+
+ if self.exog_cols:
+ missing_cols = [
+ col for col in self.exog_cols if col not in input_df.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ "Missing exogenous columns for prediction: "
+ f"{missing_cols}."
+ )
+ future_df = pd.concat(
+ [
+ future_df,
+ input_df[self.exog_cols].reset_index(drop=True),
+ ],
+ axis=1,
+ )
+
+ # Add cap/floor for logistic growth
+ future_df = self._add_cap_floor_columns(future_df)
+
+ # Debug: Log what we're predicting
+ print(
+ f"[ProphetModel] Predicting for {len(future_df)} dates: "
+ f"{future_df['ds'].min()} to {future_df['ds'].max()}"
+ )
+ print(f"[ProphetModel] has_historical={has_historical}")
+
+ forecast = self.model.predict(future_df)
+
+ # Debug: Log what Prophet returned
+ print(
+ f"[ProphetModel] Prophet returned {len(forecast)} predictions: "
+ f"{forecast['ds'].min()} to {forecast['ds'].max()}"
+ )
+
+ return _extract_predictions(forecast, input_df["ds"])
+
+ # Handle periods/horizon compatibility
+ if periods is None and horizon is not None:
+ periods = horizon
+
+ if periods is None:
+ raise ValueError(
+ "Prophet predict requires either 'x_pred' data or a 'periods' value."
+ )
+ if periods <= 0:
+ raise ValueError("Prediction horizon must be a positive integer.")
+
+ frequency = self.frequency or "D"
+
+ # If x_pred is provided with periods, use it to determine start date
+ start_date = None
+ if x_pred is not None:
+ if isinstance(x_pred, pd.DataFrame):
+ input_df = x_pred.copy()
+ else:
+ input_df = to_dashai_dataset(x_pred).to_pandas()
+
+ # Find timestamp column
+ ts_col = None
+ if "ds" in input_df.columns:
+ ts_col = "ds"
+ elif self.timestamp_col in input_df.columns:
+ ts_col = self.timestamp_col
+
+ if ts_col:
+ start_date = pd.to_datetime(input_df[ts_col]).max()
+ print(f"[ProphetModel] Using input as start date: {start_date}")
+
+ # Also update last_ds for explainers
+ self.last_ds = start_date
+
+ if start_date:
+ # Generate future dataframe starting after start_date
+ future_dates = pd.date_range(
+ start=start_date, periods=periods + 1, freq=frequency
+ )[1:]
+ future_df = pd.DataFrame({"ds": future_dates})
+ else:
+ # Standard behavior (continue from training)
+ future_df = self.model.make_future_dataframe(
+ periods=periods, freq=frequency
+ )
+
+ if self.exog_cols and exog_future is not None:
+ missing_cols = [
+ col for col in self.exog_cols if col not in exog_future.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ f"Missing exogenous columns for future prediction: {missing_cols}."
+ )
+ if len(exog_future) != periods:
+ raise ValueError(
+ "Missing exogenous values must match the prediction horizon length."
+ )
+ for col in self.exog_cols:
+ future_df[col] = exog_future[col].to_numpy()
+ elif self.exog_cols:
+ raise ValueError(
+ f"Future exogenous values required for columns: {self.exog_cols}."
+ )
+
+ # Add cap/floor for logistic growth
+ future_df = self._add_cap_floor_columns(future_df)
+ forecast = self.model.predict(future_df)
+ print(f"[ProphetModel] Generated forecast for {periods} periods")
+ print(
+ "[ProphetModel] Forecast range: "
+ f"{forecast['ds'].iloc[-periods:].min()} to "
+ f"{forecast['ds'].iloc[-periods:].max()}"
+ )
+
+ if return_components:
+ return forecast.tail(periods)
+ return forecast["yhat"].tail(periods).to_numpy()
+
+ def get_forecast_uncertainty(
+ self, horizon: int, confidence_level: float = 0.80
+ ) -> pd.DataFrame:
+ """Get forecast with native Prophet prediction intervals.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of future periods to forecast.
+ confidence_level : float
+ Desired confidence level. Note: Prophet's intervals are controlled
+ by ``interval_width`` set at model creation. This parameter is
+ accepted for interface uniformity but may not match exactly if the
+ model was initialized with a different ``interval_width``.
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``.
+ Intervals come from Prophet's own uncertainty sampling
+ (``uncertainty_samples`` parameter).
+
+ Raises
+ ------
+ ValueError
+ If the model was trained with exogenous variables.
+ """
+ if self.model is None:
+ raise ValueError("Model must be fitted before getting uncertainty.")
+
+ if self.exog_cols:
+ raise ValueError(
+ f"Cannot generate forecast uncertainty: model was trained with "
+ f"exogenous variables {self.exog_cols}. Future exogenous values "
+ f"are required but not available. "
+ f"Use ForecastFeatureImportance instead."
+ )
+
+ freq = self.frequency or "D"
+ future_df = self.model.make_future_dataframe(periods=horizon, freq=freq)
+ future_df = self._add_cap_floor_columns(future_df)
+ forecast = self.model.predict(future_df)
+
+ fc = forecast.tail(horizon)
+ return pd.DataFrame(
+ {
+ "ds": fc["ds"].to_numpy(),
+ "yhat": fc["yhat"].to_numpy(),
+ "yhat_lower": fc["yhat_lower"].to_numpy(),
+ "yhat_upper": fc["yhat_upper"].to_numpy(),
+ }
+ )
+
+ def get_forecast_components(self, horizon: int) -> pd.DataFrame:
+ """Get forecast decomposition (trend, seasonality, etc.).
+
+ Note: This method requires making future predictions. If the model
+ was trained with exogenous variables, this will fail unless future
+ values for those variables are provided.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of periods to forecast
+
+ Returns
+ -------
+ pd.DataFrame
+ Forecast components (trend, seasonal, etc.)
+
+ Raises
+ ------
+ ValueError
+ If model was trained with exogenous variables (cannot forecast
+ without future exogenous values)
+ """
+ if self.model is None:
+ raise ValueError("Model must be fitted before getting components")
+
+ if self.exog_cols:
+ # Model uses exogenous variables - cannot make valid forecast
+ raise ValueError(
+ f"Cannot generate forecast components: model was trained with "
+ f"exogenous variables {self.exog_cols}.\n"
+ f"Future forecasting requires known future values for these variables, "
+ f"which are not available.\n"
+ f"Recommendation: For models with exogenous variables, use "
+ f"ForecastFeatureImportance explainer instead."
+ )
+
+ # No exogenous variables - can make simple forecast
+ future_df = self.model.make_future_dataframe(
+ periods=horizon, freq=self.frequency or "D"
+ )
+ # Add cap/floor for logistic growth
+ future_df = self._add_cap_floor_columns(future_df)
+ forecast = self.model.predict(future_df)
+
+ # Return components for the forecast period
+ component_cols = ["ds", "trend", "seasonal", "weekly", "yearly"]
+ available_cols = [col for col in component_cols if col in forecast.columns]
+ return forecast[available_cols].iloc[-horizon:]
+
+ def save(self, filename: str) -> None:
+ """Save Prophet model to file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to save the model
+ """
+ model_state = {
+ "model": self.model,
+ # Base class attributes (original column names)
+ "exog_cols": self.exog_cols,
+ "timestamp_col": self.timestamp_col,
+ "target_col": self.target_col,
+ # Prophet-specific metadata
+ "last_ds": self.last_ds,
+ "frequency": self.frequency,
+ # Logistic growth parameters
+ "_cap_value": self._cap_value,
+ "_floor_value": self._floor_value,
+ "cap_multiplier": self.cap_multiplier,
+ "floor_ratio": self.floor_ratio,
+ "config": {
+ "seasonality_mode": self.seasonality_mode,
+ "yearly_seasonality": self.yearly_seasonality,
+ "weekly_seasonality": self.weekly_seasonality,
+ "daily_seasonality": self.daily_seasonality,
+ "growth": self.growth,
+ "changepoint_prior_scale": self.changepoint_prior_scale,
+ "seasonality_prior_scale": self.seasonality_prior_scale,
+ "holidays_prior_scale": self.holidays_prior_scale,
+ "interval_width": self.interval_width,
+ "uncertainty_samples": self.uncertainty_samples,
+ },
+ }
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ with open(filename, "wb") as f:
+ pickle.dump(model_state, f)
+
+ print(f"✅ Prophet model saved to {filename}")
+
+ def load(self, filename: str) -> "ProphetModel":
+ """Load Prophet model from file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to load the model from
+
+ Returns
+ -------
+ ProphetModel
+ Loaded model instance
+ """
+ _patch_prophet_regressor_column_matrix()
+
+ with open(filename, "rb") as f:
+ model_state = pickle.load(f)
+
+ self.model = model_state["model"]
+
+ # Restore base class attributes (original column names)
+ self.exog_cols = model_state["exog_cols"]
+ self.timestamp_col = model_state.get(
+ "timestamp_col"
+ ) # May not exist in old models
+ self.target_col = model_state.get("target_col") # May not exist in old models
+
+ # Restore Prophet-specific metadata
+ self.last_ds = model_state["last_ds"]
+ self.frequency = model_state["frequency"]
+
+ # Restore logistic growth parameters (may not exist in old models)
+ self._cap_value = model_state.get("_cap_value")
+ self._floor_value = model_state.get("_floor_value")
+ self.cap_multiplier = model_state.get("cap_multiplier", 1.5)
+ self.floor_ratio = model_state.get("floor_ratio", 0.0)
+
+ # Restore configuration
+ config = model_state["config"]
+ for key, value in config.items():
+ setattr(self, key, value)
+
+ print(f"✅ Prophet model loaded from {filename}")
+ return self
diff --git a/DashAI/back/models/forecasting/sklearn_multistep_forecaster.py b/DashAI/back/models/forecasting/sklearn_multistep_forecaster.py
new file mode 100644
index 000000000..8f1928905
--- /dev/null
+++ b/DashAI/back/models/forecasting/sklearn_multistep_forecaster.py
@@ -0,0 +1,928 @@
+"""Sklearn-based multi-step forecasting model for DashAI.
+
+This model uses sklearn regressors with a sliding window approach to perform
+multi-step-ahead forecasting. It internally creates lag features and can
+predict multiple steps into the future.
+"""
+
+import os
+import pickle
+from typing import Any, List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from sklearn.ensemble import RandomForestRegressor as _RandomForestRegressor
+from sklearn.linear_model import LinearRegression as _LinearRegression
+from sklearn.linear_model import Ridge as _Ridge
+from sklearn.multioutput import MultiOutputRegressor
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ enum_field,
+ int_field,
+ schema_field,
+)
+from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
+from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel
+
+
+class SklearnMultiStepForecasterSchema(BaseSchema):
+ """Schema for SklearnMultiStepForecaster configuration."""
+
+ base_estimator: schema_field(
+ enum_field(enum=["linear", "ridge", "random_forest"]),
+ placeholder="linear",
+ description=(
+ "Base estimator for forecasting. "
+ "'linear': Fast linear regression (best for linear trends). "
+ "'ridge': Linear regression with L2 regularization "
+ "(prevents overfitting). "
+ "'random_forest': Tree-based ensemble (handles non-linear patterns)."
+ ),
+ ) = "linear" # type: ignore
+
+ window_size: schema_field(
+ int_field(ge=1, le=365),
+ placeholder=3,
+ description=(
+ "Number of past time steps (lags) to use as features. "
+ "Smaller values work better for small datasets. "
+ "Will be auto-adjusted if dataset is too small."
+ ),
+ ) = 3 # type: ignore
+
+ forecast_strategy: schema_field(
+ enum_field(enum=["direct", "recursive"]),
+ placeholder="direct",
+ description=(
+ "Multi-step forecasting strategy. "
+ "'direct': Train separate model for each horizon "
+ "(more accurate, slower). "
+ "'recursive': Use predictions as inputs for next step "
+ "(faster, error compounds)."
+ ),
+ ) = "direct" # type: ignore
+
+
+class SklearnMultiStepForecaster(ForecastingModel):
+ """Sklearn-based multi-step forecasting model.
+
+ This model transforms time series forecasting into a supervised learning problem
+ by creating lag features automatically. It supports:
+ - Multiple sklearn base estimators (linear, ridge, random_forest)
+ - Direct multi-step strategy (separate model per horizon)
+ - Recursive strategy (iterative predictions)
+ - Exogenous variables
+
+ Example usage in ForecastingTask:
+ 1. User uploads time series with columns: [timestamp, value, exog1, exog2]
+ 2. Task identifies timestamp, target, exogenous variables
+ 3. Model internally creates lags and trains
+ 4. Prediction works exactly like Prophet/ARIMA
+
+ The key advantage is that users can leverage sklearn's powerful regression
+ models for forecasting without manually creating lag features.
+ """
+
+ SCHEMA = SklearnMultiStepForecasterSchema
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+ _task_type = "ForecastingTask"
+
+ def __init__(
+ self,
+ base_estimator: str = "linear",
+ window_size: int = 3,
+ forecast_strategy: str = "direct",
+ **kwargs,
+ ) -> None:
+ """Initialize SklearnMultiStepForecaster.
+
+ Parameters
+ ----------
+ base_estimator : str
+ Base sklearn estimator to use ('linear', 'ridge', 'random_forest')
+ window_size : int
+ Number of past time steps to use as lag features
+ forecast_strategy : str
+ Strategy for multi-step forecasting ('direct' or 'recursive')
+ **kwargs
+ Additional arguments passed to ForecastingModel
+ """
+ super().__init__(**kwargs)
+
+ self.base_estimator = base_estimator
+ self.window_size = window_size
+ self.forecast_strategy = forecast_strategy
+
+ # Internal state
+ self.models: List[Any] = []
+ self.training_history: Optional[pd.Series] = None
+ self.training_exog_history: Optional[pd.DataFrame] = None
+ self.training_full_series: Optional[pd.Series] = None
+ self.training_full_exog: Optional[pd.DataFrame] = None
+ self.training_full_exog: Optional[pd.DataFrame] = None
+ self.max_horizon: int = 1
+ self.last_timestamp: Optional[pd.Timestamp] = None
+
+ def _get_base_estimator(self):
+ """Get instance of base estimator."""
+ estimators = {
+ "linear": _LinearRegression,
+ "ridge": _Ridge,
+ "random_forest": _RandomForestRegressor,
+ }
+
+ if self.base_estimator not in estimators:
+ raise ValueError(
+ f"Unknown base_estimator '{self.base_estimator}'. "
+ f"Supported: {list(estimators.keys())}"
+ )
+
+ return estimators[self.base_estimator]()
+
+ def _create_lag_features(
+ self, series: pd.Series, exog_df: Optional[pd.DataFrame] = None
+ ) -> pd.DataFrame:
+ """Create lag features from time series.
+
+ Parameters
+ ----------
+ series : pd.Series
+ Time series values
+ exog_df : pd.DataFrame, optional
+ Exogenous variables (must have same index as series)
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame with lag features and optional exogenous variables
+ """
+ result = pd.DataFrame(index=series.index)
+
+ # Create lags (lag_1 is t-1, lag_2 is t-2, etc.)
+ for lag in range(1, self.window_size + 1):
+ result[f"lag_{lag}"] = series.shift(lag)
+
+ # Add exogenous variables if present
+ if exog_df is not None:
+ for col in exog_df.columns:
+ result[col] = exog_df[col]
+
+ return result
+
+ def fit(
+ self,
+ x_train: DashAIDataset,
+ y: DashAIDataset,
+ temporal_metadata: dict = None,
+ **fit_params,
+ ) -> "SklearnMultiStepForecaster":
+ """Train the multi-step forecasting model.
+
+ Parameters
+ ----------
+ x_train : DashAIDataset
+ Input features (timestamp + optional exogenous variables)
+ y : DashAIDataset
+ Target time series
+ temporal_metadata : dict
+ Metadata from ForecastingTask with timestamp_col, target_col, etc.
+ **fit_params
+ Additional fitting parameters (can include 'horizon')
+
+ Returns
+ -------
+ SklearnMultiStepForecaster
+ Fitted model
+ """
+ if temporal_metadata is None:
+ raise ValueError(
+ "temporal_metadata is required for SklearnMultiStepForecaster"
+ )
+
+ # Get metadata
+ self.timestamp_col = temporal_metadata.get("timestamp_col")
+ self.target_col = temporal_metadata.get("target_col")
+ self.exog_cols = temporal_metadata.get("exog_cols", [])
+ self.frequency = temporal_metadata.get("frequency", "D")
+
+ print("[SklearnMultiStepForecaster] Using temporal metadata from task:")
+ print(f" - Timestamp: '{self.timestamp_col}'")
+ print(f" - Target: '{self.target_col}'")
+ print(f" - Exogenous: {self.exog_cols}")
+ print(f" - Frequency: {self.frequency}")
+
+ # Convert to pandas
+ x_df = x_train.to_pandas()
+ y_df = y.to_pandas()
+
+ # Get horizon from fit_params (default to 1)
+ horizon = fit_params.get("horizon", 1)
+ self.max_horizon = horizon
+
+ # Store last timestamp and full date sequence for future predictions
+ if self.timestamp_col in x_df.columns:
+ parsed_dates = pd.to_datetime(x_df[self.timestamp_col])
+ self.last_timestamp = parsed_dates.max()
+ self.training_dates = parsed_dates.to_numpy()
+ print(f"[SklearnMultiStepForecaster] Last timestamp: {self.last_timestamp}")
+ else:
+ self.last_timestamp = pd.Timestamp.now()
+ self.training_dates = None
+ print("[SklearnMultiStepForecaster] ⚠️ No timestamp col, default to now()")
+
+ # Get target series
+ target_in_inputs = self.target_col in x_df.columns
+ if target_in_inputs:
+ print(
+ f"[SklearnMultiStepForecaster] ℹ️ Target '{self.target_col}' "
+ "found in inputs - using it from there"
+ )
+ target_series = x_df[self.target_col]
+ else:
+ target_series = y_df[self.target_col]
+
+ # Extract exogenous variables if present
+ exog_df = None
+ if self.exog_cols:
+ exog_df = x_df[self.exog_cols]
+ print(f"[SklearnMultiStepForecaster] Exogenous variables: {self.exog_cols}")
+
+ n_target_samples = len(target_series)
+
+ # Auto-adjust window_size for small datasets
+ # Need: window_size lags + horizon shifts + at least 2 samples to train
+ min_required = self.window_size + horizon + 2
+
+ if n_target_samples < min_required:
+ # Try to fit within constraints by reducing window size
+ # The available space for window is samples minus horizon minus margin
+ available_for_window = n_target_samples - horizon - 2
+
+ if available_for_window < 1:
+ # Even with window_size=1, we can't fit. Reduce horizon too.
+ # Minimum setup: window=1, horizon=1, need at least 4 samples
+ if n_target_samples >= 4:
+ self.window_size = 1
+ horizon = max(1, n_target_samples - 3)
+ print(
+ f"[SklearnMultiStepForecaster] ⚠️ Very small dataset "
+ f"({n_target_samples} samples). "
+ f"Forced window_size=1, horizon={horizon}"
+ )
+ else:
+ raise ValueError(
+ f"Dataset too small for forecasting. Need at least 4 samples, "
+ f"got {n_target_samples}. Please use more training data."
+ )
+ else:
+ old_window = self.window_size
+ self.window_size = max(1, available_for_window)
+ print(
+ f"[SklearnMultiStepForecaster] ⚠️ Auto-adjusted window_size: "
+ f"{old_window} → {self.window_size} "
+ f"(target series has {n_target_samples} samples)"
+ )
+
+ self.max_horizon = horizon
+
+ print(f"[SklearnMultiStepForecaster] Training for horizon: {horizon}")
+ print(f"[SklearnMultiStepForecaster] Window size: {self.window_size}")
+ print(f"[SklearnMultiStepForecaster] Strategy: {self.forecast_strategy}")
+
+ # Create lag features
+ X_with_lags = self._create_lag_features(target_series, exog_df)
+
+ # For direct strategy: train one model per horizon
+ if self.forecast_strategy == "direct":
+ self.models = []
+ for h in range(1, horizon + 1):
+ # Create target: y shifted h steps ahead
+ y_h = target_series.shift(-h)
+
+ # Remove NaN rows
+ mask = X_with_lags.notna().all(axis=1) & y_h.notna()
+ X_clean = X_with_lags[mask]
+ y_clean = y_h[mask]
+
+ if len(X_clean) == 0:
+ raise ValueError(
+ f"No valid samples after creating lags and horizon {h}. "
+ f"Try reducing window_size or using more data."
+ )
+
+ # Train model for this horizon
+ model = MultiOutputRegressor(self._get_base_estimator())
+ model.fit(X_clean.to_numpy(), y_clean.to_numpy().reshape(-1, 1))
+ self.models.append(model)
+
+ print(
+ f"[SklearnMultiStepForecaster] Trained {len(self.models)} models "
+ "(direct strategy)"
+ )
+
+ # For recursive strategy: train single model for 1-step ahead
+ else: # recursive
+ y_1 = target_series.shift(-1)
+ mask = X_with_lags.notna().all(axis=1) & y_1.notna()
+ X_clean = X_with_lags[mask]
+ y_clean = y_1[mask]
+
+ if len(X_clean) == 0:
+ raise ValueError(
+ "No valid samples after creating lags. "
+ "Try reducing window_size or using more data."
+ )
+
+ model = MultiOutputRegressor(self._get_base_estimator())
+ model.fit(X_clean.to_numpy(), y_clean.to_numpy().reshape(-1, 1))
+ self.models = [model]
+
+ print("[SklearnMultiStepForecaster] Trained 1 model (recursive strategy)")
+
+ # Store FULL training series and exog for in-sample predictions
+ # We need the complete history to create lags for any subset
+ self.training_full_series = target_series.copy()
+ self.training_history = target_series.iloc[-self.window_size :]
+ if self.exog_cols and exog_df is not None:
+ self.training_full_exog = exog_df.copy()
+ self.training_exog_history = exog_df.iloc[-self.window_size :]
+
+ print("[SklearnMultiStepForecaster] ✅ Training completed")
+
+ return self
+
+ def predict(
+ self,
+ x_pred: Optional[Any] = None,
+ periods: Optional[int] = None,
+ exog_future: Optional[pd.DataFrame] = None,
+ **kwargs,
+ ) -> Union[np.ndarray, pd.DataFrame]:
+ """Generate forecasts.
+
+ Parameters
+ ----------
+ x_pred : Any, optional
+ Input data for in-sample predictions containing timestamp and
+ optional exogenous variables. Can also be an integer for compatibility.
+ periods : int, optional
+ Number of steps to forecast into the future
+ exog_future : pd.DataFrame, optional
+ Future exogenous variable values
+ **kwargs
+ Additional parameters (can include 'horizon' as alias for 'periods')
+
+ Returns
+ -------
+ np.ndarray
+ Predictions array
+ """
+ if not self.models:
+ raise ValueError("Model not fitted. Call fit() first.")
+
+ # Handle horizon alias
+ if periods is None and "horizon" in kwargs:
+ periods = kwargs["horizon"]
+
+ # Handle different input types (compatibility with ForecastingTask)
+ if x_pred is not None and isinstance(x_pred, (int, np.integer)):
+ periods = int(x_pred)
+ x_pred = None
+
+ # Note: If x_pred is provided with periods, use x_pred as history context
+
+ # In-sample predictions (for metrics calculation)
+ if x_pred is not None and periods is None:
+ from DashAI.back.dataloaders.classes.dashai_dataset import (
+ to_dashai_dataset,
+ )
+
+ if isinstance(x_pred, pd.DataFrame):
+ input_df = x_pred.copy()
+ else:
+ input_df = to_dashai_dataset(x_pred).to_pandas()
+
+ print(
+ f"[SklearnMultiStepForecaster] In-sample prediction for "
+ f"{len(input_df)} time steps"
+ )
+
+ # For in-sample predictions, we need the full training data
+ # because we create lags from historical values
+ if not hasattr(self, "training_full_series"):
+ raise ValueError(
+ "No training history available. Model may not be fitted properly."
+ )
+
+ # Detect whether timestamps are within training range or beyond.
+ # val/test splits reset their pandas index to 0-based, so index lookup
+ # in training lag features would return wrong rows. Use timestamp
+ # comparison to choose the correct prediction strategy.
+ is_within_training = True
+ if self.timestamp_col and self.timestamp_col in input_df.columns:
+ input_ts = pd.to_datetime(input_df[self.timestamp_col])
+ is_within_training = input_ts.max() <= self.last_timestamp
+
+ if is_within_training:
+ # True in-sample: build lag features from training data and look
+ # up the matching row positions.
+ exog_df = None
+ if self.exog_cols:
+ missing_cols = [
+ col for col in self.exog_cols if col not in input_df.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ f"Missing exogenous columns for prediction: {missing_cols}"
+ )
+ exog_df = input_df[self.exog_cols]
+
+ target_series = self.training_full_series
+ full_exog_df = (
+ self.training_full_exog
+ if hasattr(self, "training_full_exog")
+ else None
+ )
+
+ X_with_lags = self._create_lag_features(target_series, full_exog_df)
+
+ if self.exog_cols and exog_df is not None:
+ for col in self.exog_cols:
+ X_with_lags.loc[input_df.index, col] = exog_df[col].to_numpy()
+
+ X_subset = X_with_lags.loc[input_df.index]
+ mask = X_subset.notna().all(axis=1)
+ X_clean = X_subset[mask]
+
+ if len(X_clean) == 0:
+ print(
+ f"[SklearnMultiStepForecaster] ⚠️ No valid samples for "
+ f"in-sample prediction (need {self.window_size} historical "
+ f"values). Returning NaN predictions for "
+ f"{len(input_df)} points."
+ )
+ return np.full(len(input_df), np.nan)
+
+ predictions_full = np.full(len(input_df), np.nan)
+ predictions = self.models[0].predict(X_clean.to_numpy())
+ predictions_full[mask] = predictions.flatten()
+
+ print(
+ f"[SklearnMultiStepForecaster] Generated {mask.sum()} in-sample "
+ f"predictions (first {(~mask).sum()} skipped due to lag window)"
+ )
+ return predictions_full
+
+ else:
+ # Out-of-training (val/test): recursive 1-step-ahead forecast
+ # seeded from the last window_size training values.
+ n_steps = len(input_df)
+ history = list(self.training_history.to_numpy())
+ results = []
+ for _ in range(n_steps):
+ features = np.array(history[-self.window_size :]).reshape(1, -1)
+ pred = float(self.models[0].predict(features).flatten()[0])
+ results.append(pred)
+ history.append(pred)
+
+ print(
+ f"[SklearnMultiStepForecaster] Generated {n_steps} recursive "
+ f"out-of-training predictions"
+ )
+ return np.array(results)
+
+ # Out-of-sample forecast
+ if periods is not None:
+ if periods <= 0:
+ raise ValueError("Prediction horizon must be a positive integer.")
+
+ # Validate exogenous variables if needed
+ if self.exog_cols:
+ if exog_future is None:
+ raise ValueError(
+ f"Future exogenous values required for columns: "
+ f"{self.exog_cols}"
+ )
+
+ missing_cols = [
+ col for col in self.exog_cols if col not in exog_future.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ f"Missing exogenous columns for prediction: {missing_cols}"
+ )
+
+ if len(exog_future) < periods:
+ raise ValueError(
+ f"Exogenous data length ({len(exog_future)}) must be at "
+ f"least {periods} for the requested forecast horizon."
+ )
+
+ # Prepare history for prediction
+ # If x_pred is provided, use it as history (context)
+ # Otherwise, use training history
+ history_series = self.training_history
+
+ if x_pred is not None:
+ # Convert x_pred to pandas if needed
+ if isinstance(x_pred, pd.DataFrame):
+ input_df = x_pred.copy()
+ else:
+ from DashAI.back.dataloaders.classes.dashai_dataset import (
+ to_dashai_dataset,
+ )
+
+ input_df = to_dashai_dataset(x_pred).to_pandas()
+
+ # Check if target column is present
+ if self.target_col in input_df.columns:
+ print(
+ f"[SklearnMultiStepForecaster] Using input as context "
+ f"({len(input_df)} rows)"
+ )
+ history_series = input_df[self.target_col]
+
+ # Also update last_timestamp if available
+ if self.timestamp_col in input_df.columns:
+ self.last_timestamp = pd.to_datetime(
+ input_df[self.timestamp_col]
+ ).max()
+ else:
+ print(
+ f"[SklearnMultiStepForecaster] ⚠️ No target col "
+ f"'{self.target_col}', using training history"
+ )
+
+ # Ensure we have enough history
+ if history_series is None or len(history_series) < self.window_size:
+ history_len = 0 if history_series is None else len(history_series)
+ print(
+ f"[SklearnMultiStepForecaster] ⚠️ History length ({history_len}) "
+ f"is less than window size ({self.window_size}). "
+ f"Returning NaN predictions for {periods} periods."
+ )
+ return np.full(periods, np.nan)
+
+ # Direct strategy: use pre-trained models
+ if self.forecast_strategy == "direct":
+ predictions = []
+
+ # We need to maintain current_window for recursive fallback
+ # Initialize with history
+ current_window = list(history_series.to_numpy())
+
+ # Determine how many steps we can predict directly
+ max_direct_horizon = len(self.models)
+
+ for h in range(periods):
+ # Step h is 0-indexed (0 = 1st step, 1 = 2nd step, etc.)
+
+ # Case 1: Within direct horizon - use specific model
+ if h < max_direct_horizon:
+ # Create features from history
+ lags = history_series.iloc[-self.window_size :].to_numpy()
+
+ # Add exog if needed
+ if self.exog_cols and exog_future is not None:
+ exog_h = exog_future.iloc[h][self.exog_cols].to_numpy()
+ features = np.concatenate([lags, exog_h])
+ else:
+ features = lags
+
+ # Predict using the specific model for this horizon
+ pred = self.models[h].predict(features.reshape(1, -1))[0, 0]
+
+ # Case 2: Beyond direct horizon - fallback to recursive
+ else:
+ # Use the first model (1-step ahead) recursively
+ # Create features from CURRENT window (updated with predictions)
+ lags = np.array(current_window[-self.window_size :])
+
+ # Add exog if needed
+ if self.exog_cols and exog_future is not None:
+ exog_h = exog_future.iloc[h][self.exog_cols].to_numpy()
+ features = np.concatenate([lags, exog_h])
+ else:
+ features = lags
+
+ # Predict next step using model[0]
+ pred = self.models[0].predict(features.reshape(1, -1))[0, 0]
+
+ predictions.append(pred)
+ current_window.append(pred)
+
+ return np.array(predictions)
+
+ # Recursive strategy: iterative predictions
+ else:
+ predictions = []
+ current_window = list(history_series.to_numpy())
+
+ for h in range(periods):
+ # Create features
+ lags = np.array(current_window[-self.window_size :])
+
+ # Add exog if needed
+ if self.exog_cols and exog_future is not None:
+ exog_h = exog_future.iloc[h][self.exog_cols].to_numpy()
+ features = np.concatenate([lags, exog_h])
+ else:
+ features = lags
+
+ # Predict next step
+ pred = self.models[0].predict(features.reshape(1, -1))[0, 0]
+ predictions.append(pred)
+
+ # Update window with prediction
+ current_window.append(pred)
+
+ return np.array(predictions)
+
+ raise ValueError(
+ "Either x_pred or periods parameter must be provided for prediction."
+ )
+
+ def get_forecast_uncertainty(
+ self, horizon: int, confidence_level: float = 0.80
+ ) -> pd.DataFrame:
+ """Get forecast with residual-based prediction intervals.
+
+ Because sklearn regression models have no parametric error distribution,
+ this method estimates prediction uncertainty empirically:
+
+ 1. Compute in-sample residuals on the training data using the 1-step
+ ahead model (``models[0]``).
+ 2. Use the residual standard deviation as the base prediction error.
+ 3. Scale the half-interval by ``sqrt(h)`` for horizon step ``h`` to
+ simulate how uncertainty accumulates over time.
+ 4. Apply a z-score corresponding to the requested confidence level.
+
+ The resulting intervals are wider for longer horizons and reflect the
+ actual in-sample accuracy of the model.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of future periods to forecast.
+ confidence_level : float
+ Confidence level (e.g., 0.80 for 80% intervals).
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``.
+
+ Raises
+ ------
+ ValueError
+ If the model was trained with exogenous variables.
+ """
+ if not self.models:
+ raise ValueError("Model must be fitted before getting uncertainty.")
+
+ if self.exog_cols:
+ raise ValueError(
+ f"Cannot generate forecast uncertainty: model was trained with "
+ f"exogenous variables {self.exog_cols}. Future exogenous values "
+ f"are required but not available. "
+ f"Use ForecastFeatureImportance instead."
+ )
+
+ # --- Estimate residual std from in-sample 1-step predictions ---
+ X_with_lags = self._create_lag_features(self.training_full_series)
+ mask = X_with_lags.notna().all(axis=1)
+ X_clean = X_with_lags[mask]
+ actual_clean = self.training_full_series[mask].to_numpy()
+
+ if len(X_clean) > 0:
+ in_sample_preds = self.models[0].predict(X_clean.to_numpy()).flatten()
+ residuals = actual_clean - in_sample_preds
+ residual_std = float(np.std(residuals))
+ else:
+ residual_std = 0.0
+
+ # Guard against zero or near-zero std (perfect in-sample fit)
+ if residual_std < 1e-10:
+ # Fall back to 5% of the mean absolute value of the training series
+ residual_std = float(
+ np.abs(self.training_full_series.to_numpy()).mean() * 0.05
+ )
+ residual_std = max(residual_std, 1e-6)
+
+ # --- z-score for the requested confidence level ---
+ try:
+ from scipy.stats import norm as _norm
+
+ z = float(_norm.ppf(0.5 + confidence_level / 2.0))
+ except ImportError:
+ # Hardcoded fallback for common levels
+ _z_table = {
+ 0.80: 1.282,
+ 0.85: 1.440,
+ 0.90: 1.645,
+ 0.95: 1.960,
+ 0.99: 2.576,
+ }
+ z = _z_table.get(round(confidence_level, 2), 1.645)
+
+ # --- Point forecast ---
+ predictions = self.predict(periods=horizon)
+
+ # --- Growing intervals: half-width = z * std * sqrt(h) ---
+ horizon_steps = np.arange(1, horizon + 1)
+ half_width = z * residual_std * np.sqrt(horizon_steps)
+
+ freq = self.frequency or "D"
+ future_dates = pd.date_range(
+ start=self.last_timestamp, periods=horizon + 1, freq=freq
+ )[1:]
+
+ return pd.DataFrame(
+ {
+ "ds": future_dates,
+ "yhat": predictions,
+ "yhat_lower": predictions - half_width,
+ "yhat_upper": predictions + half_width,
+ }
+ )
+
+ def get_forecast_components(self, horizon: int) -> pd.DataFrame:
+ """Decompose forecast into trend, seasonal, and residual components.
+
+ Because SklearnMultiStepForecaster is a regression-based model with
+ no intrinsic structural decomposition, this method applies STL
+ (Seasonal-Trend decomposition using LOESS) to the concatenation of
+ the historical training series and the out-of-sample forecast.
+
+ The resulting trend, seasonal, and residual components describe the
+ statistical structure of the full series (history + forecast horizon),
+ and only the forecast portion is returned.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of future periods to forecast and decompose.
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``ds``, ``trend``, ````, ``residual``.
+
+ Raises
+ ------
+ ValueError
+ If the model was trained with exogenous variables.
+ """
+ if not self.models:
+ raise ValueError("Model must be fitted before getting components.")
+
+ if self.exog_cols:
+ raise ValueError(
+ f"Cannot generate forecast components: model was trained with "
+ f"exogenous variables {self.exog_cols}. Future exogenous values "
+ f"are required but not available for decomposition. "
+ f"Use ForecastFeatureImportance instead."
+ )
+
+ try:
+ from statsmodels.tsa.seasonal import STL
+ except ImportError as exc:
+ raise ImportError(
+ "statsmodels is required for STL decomposition. "
+ "Install with: pip install statsmodels"
+ ) from exc
+
+ # Build historical series with a proper DatetimeIndex
+ freq = self.frequency or "D"
+ historical_values = self.training_full_series.to_numpy()
+ n_hist = len(historical_values)
+
+ if self.training_dates is not None:
+ historical_index = pd.DatetimeIndex(self.training_dates)
+ else:
+ # Reconstruct dates ending at last_timestamp
+ historical_index = pd.date_range(
+ end=self.last_timestamp, periods=n_hist, freq=freq
+ )
+
+ historical_series = pd.Series(historical_values, index=historical_index)
+
+ # Out-of-sample forecast
+ predictions = self.predict(periods=horizon)
+ future_dates = pd.date_range(
+ start=self.last_timestamp, periods=horizon + 1, freq=freq
+ )[1:]
+ future_series = pd.Series(predictions, index=future_dates)
+
+ # Combine history + forecast
+ combined = pd.concat([historical_series, future_series])
+
+ # Determine period and run STL decomposition
+ period = self._get_seasonal_period()
+ component_name = self._period_to_seasonality_name(period)
+
+ n = len(combined)
+ if period >= 2 and n >= 2 * period:
+ try:
+ stl = STL(combined, period=period, robust=True)
+ result = stl.fit()
+ trend_vals = result.trend
+ seasonal_vals = result.seasonal
+ residual_vals = result.resid
+ except Exception:
+ window = min(period, max(2, n // 2))
+ trend_vals = combined.rolling(
+ window=window, center=True, min_periods=1
+ ).mean()
+ seasonal_vals = pd.Series(np.zeros(n), index=combined.index)
+ residual_vals = combined - trend_vals
+ else:
+ window = max(2, min(period, n // 2))
+ trend_vals = combined.rolling(
+ window=window, center=True, min_periods=1
+ ).mean()
+ seasonal_vals = pd.Series(np.zeros(n), index=combined.index)
+ residual_vals = combined - trend_vals
+
+ return pd.DataFrame(
+ {
+ "ds": combined.index[-horizon:],
+ "trend": trend_vals.to_numpy()[-horizon:],
+ component_name: seasonal_vals.to_numpy()[-horizon:],
+ "residual": residual_vals.to_numpy()[-horizon:],
+ }
+ )
+
+ def save(self, filename: str) -> None:
+ """Save model to file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to save the model
+ """
+ if not self.models:
+ raise ValueError("Cannot save model before fitting.")
+
+ model_state = {
+ "models": self.models,
+ "training_history": self.training_history,
+ "training_exog_history": self.training_exog_history,
+ "training_full_series": self.training_full_series,
+ "training_full_exog": self.training_full_exog,
+ "training_dates": getattr(self, "training_dates", None),
+ "exog_cols": self.exog_cols,
+ "timestamp_col": self.timestamp_col,
+ "target_col": self.target_col,
+ "frequency": self.frequency,
+ "max_horizon": self.max_horizon,
+ "last_timestamp": self.last_timestamp,
+ "config": {
+ "base_estimator": self.base_estimator,
+ "window_size": self.window_size,
+ "forecast_strategy": self.forecast_strategy,
+ },
+ }
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ with open(filename, "wb") as f:
+ pickle.dump(model_state, f)
+
+ print(f"✅ SklearnMultiStepForecaster saved to {filename}")
+
+ def load(self, filename: str) -> "SklearnMultiStepForecaster":
+ """Load model from file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to load the model from
+
+ Returns
+ -------
+ SklearnMultiStepForecaster
+ Loaded model instance
+ """
+ with open(filename, "rb") as f:
+ model_state = pickle.load(f)
+
+ self.models = model_state["models"]
+ self.training_history = model_state["training_history"]
+ self.training_exog_history = model_state.get("training_exog_history")
+ self.training_full_series = model_state.get("training_full_series")
+ self.training_full_exog = model_state.get("training_full_exog")
+ self.training_dates = model_state.get("training_dates")
+ self.exog_cols = model_state["exog_cols"]
+ self.timestamp_col = model_state.get("timestamp_col")
+ self.target_col = model_state.get("target_col")
+ self.frequency = model_state.get("frequency")
+ self.max_horizon = model_state.get("max_horizon", 1)
+ self.last_timestamp = model_state.get("last_timestamp")
+
+ config = model_state["config"]
+ for key, value in config.items():
+ setattr(self, key, value)
+
+ print(f"✅ SklearnMultiStepForecaster loaded from {filename}")
+ return self
diff --git a/DashAI/back/models/forecasting/statsmodels_arima_model.py b/DashAI/back/models/forecasting/statsmodels_arima_model.py
new file mode 100644
index 000000000..8b861ec0d
--- /dev/null
+++ b/DashAI/back/models/forecasting/statsmodels_arima_model.py
@@ -0,0 +1,626 @@
+"""Statsmodels ARIMA model wrapper for DashAI forecasting.
+
+This model wraps statsmodels ARIMA for time series forecasting with
+autoregressive integrated moving average modeling.
+"""
+
+import os
+import pickle
+from typing import Any, Optional, Union
+
+import numpy as np
+import pandas as pd
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ enum_field,
+ int_field,
+ schema_field,
+)
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ to_dashai_dataset,
+)
+from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel
+
+
+class StatsmodelsARIMAModelSchema(BaseSchema):
+ """Schema for Statsmodels ARIMA model configuration.
+
+ ARIMA (AutoRegressive Integrated Moving Average) is a forecasting method
+ that captures different aspects of time series:
+ - AR (p): Autoregression - uses past values
+ - I (d): Integration - differencing to make series stationary
+ - MA (q): Moving Average - uses past forecast errors
+ """
+
+ p: schema_field(
+ int_field(ge=0, le=10),
+ placeholder=1,
+ description="Order of autoregressive (AR) component. Number of lag "
+ "observations included in the model (how many past values to use).",
+ ) = 1 # type: ignore
+
+ d: schema_field(
+ int_field(ge=0, le=3),
+ placeholder=1,
+ description="Degree of differencing (I component). Number of times "
+ "to difference the data to make it stationary. 0=stationary, "
+ "1=first difference, 2=second difference.",
+ ) = 1 # type: ignore
+
+ q: schema_field(
+ int_field(ge=0, le=10),
+ placeholder=1,
+ description="Order of moving average (MA) component. Size of the "
+ "moving average window (how many past forecast errors to use).",
+ ) = 1 # type: ignore
+
+ trend: schema_field(
+ enum_field(enum=["n", "c", "t", "ct"]),
+ placeholder="n",
+ description=(
+ "Deterministic trend to include. 'n'=no trend, 'c'=constant "
+ "(level), 't'=linear trend, 'ct'=constant and linear trend. "
+ "Note: When d>0, 'c' is not allowed (use 't' instead). "
+ "When d>1, neither 'c' nor 't' are allowed."
+ ),
+ ) = "n" # type: ignore
+
+
+class StatsmodelsARIMAModel(ForecastingModel):
+ """Statsmodels ARIMA forecasting model wrapper for DashAI.
+
+ This model implements the ForecastingModel interface using statsmodels ARIMA.
+ It handles column name conversions internally and supports exogenous variables.
+ """
+
+ SCHEMA = StatsmodelsARIMAModelSchema
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+ _task_type = "ForecastingTask"
+
+ def __init__(
+ self,
+ p: int = 1,
+ d: int = 1,
+ q: int = 1,
+ trend: str = "n",
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.p = p
+ self.d = d
+ self.q = q
+ self.trend = trend
+ self.order = (p, d, q)
+
+ self.model = None
+ self.model_fit = None
+ self.frequency: Optional[str] = None
+
+ def _validate_forecasting_data(self, x: DashAIDataset, y: DashAIDataset) -> None:
+ """Validate that data is suitable for ARIMA.
+
+ Parameters
+ ----------
+ x : DashAIDataset
+ Input features (must contain a timestamp column)
+ y : DashAIDataset
+ Target values (must contain a numeric column)
+
+ Raises
+ ------
+ ValueError
+ If data is not suitable for ARIMA
+ """
+ x_cols = set(x.column_names)
+ y_cols = set(y.column_names)
+
+ if len(x_cols) == 0:
+ raise ValueError(
+ "ARIMA requires at least one input column (timestamp). "
+ "Received empty dataset."
+ )
+
+ if len(y_cols) != 1:
+ raise ValueError(
+ f"ARIMA requires exactly one target column. "
+ f"Received {len(y_cols)} columns: {list(y_cols)}"
+ )
+
+ def fit(
+ self,
+ x_train: DashAIDataset,
+ y: DashAIDataset,
+ temporal_metadata: dict = None,
+ **fit_params,
+ ) -> "StatsmodelsARIMAModel":
+ """Train ARIMA forecasting model.
+
+ Parameters
+ ----------
+ x_train : DashAIDataset
+ Input features containing timestamp and optional exogenous variables
+ y : DashAIDataset
+ Target time series (single column)
+ temporal_metadata : dict, optional
+ Metadata from ForecastingTask containing:
+ - timestamp_col: name of timestamp column
+ - target_col: name of target column
+ - exog_cols: list of exogenous variable column names
+ - frequency: time series frequency
+ **fit_params
+ Additional fitting parameters
+
+ Returns
+ -------
+ StatsmodelsARIMAModel
+ Fitted model instance
+ """
+ try:
+ from statsmodels.tsa.arima.model import ARIMA
+ except ImportError as e:
+ raise ImportError(
+ "Statsmodels is required for StatsmodelsARIMAModel. "
+ "Install with: pip install statsmodels"
+ ) from e
+
+ # Validate data format
+ self._validate_forecasting_data(x_train, y)
+
+ # Convert to pandas DataFrames
+ x_df = x_train.to_pandas()
+ y_df = y.to_pandas()
+
+ # Get column information from metadata
+ if temporal_metadata:
+ timestamp_col = temporal_metadata.get("timestamp_col")
+ target_col = temporal_metadata.get("target_col")
+ exog_cols_from_task = temporal_metadata.get("exog_cols", [])
+ frequency = temporal_metadata.get("frequency")
+
+ print("[StatsmodelsARIMAModel] Using temporal metadata from task:")
+ print(f" - Timestamp: '{timestamp_col}'")
+ print(f" - Target: '{target_col}'")
+ print(f" - Frequency: {frequency}")
+ if exog_cols_from_task:
+ print(f" - Exogenous variables: {exog_cols_from_task}")
+ else:
+ # Auto-detection if no metadata provided
+ print(
+ "[StatsmodelsARIMAModel] ⚠️ No temporal_metadata provided, "
+ "using auto-detection"
+ )
+
+ target_col = y_df.columns[0]
+
+ # Auto-detect timestamp column
+ timestamp_col = None
+ for col in x_df.columns:
+ try:
+ pd.to_datetime(x_df[col])
+ timestamp_col = col
+ print(f"[StatsmodelsARIMAModel] Detected timestamp column: '{col}'")
+ break
+ except Exception:
+ continue
+
+ if timestamp_col is None:
+ raise ValueError(
+ f"No timestamp column found in input data. "
+ f"Available columns: {list(x_df.columns)}"
+ )
+
+ exog_cols_from_task = []
+ frequency = fit_params.get("frequency")
+
+ # Store original column names
+ self.timestamp_col = timestamp_col
+ self.target_col = target_col
+ self.frequency = frequency
+
+ # Prepare data for ARIMA
+ # Create datetime index
+ dates = pd.to_datetime(x_df[timestamp_col])
+
+ # Store last training date for forecast generation
+ self.last_ds = dates.max()
+
+ # Get target series
+ target_in_inputs = target_col in x_df.columns
+ if target_in_inputs:
+ print(
+ "[StatsmodelsARIMAModel] ℹ️ Target '{}' found in inputs - "
+ "using it from there".format(target_col)
+ )
+ endog = x_df[target_col].to_numpy()
+ else:
+ endog = y_df[target_col].to_numpy()
+
+ # Create time series with datetime index
+ endog_series = pd.Series(endog, index=dates)
+
+ # Prepare exogenous variables
+ self.exog_cols = []
+ exog = None
+
+ for col in x_df.columns:
+ if col == timestamp_col:
+ continue
+ if col == target_col:
+ if target_in_inputs:
+ print(
+ "[StatsmodelsARIMAModel] ℹ️ Excluding target '{}' from "
+ "exogenous variables".format(col)
+ )
+ continue
+
+ # Only add numeric columns
+ if pd.api.types.is_numeric_dtype(x_df[col]):
+ self.exog_cols.append(col)
+ else:
+ print(
+ "[StatsmodelsARIMAModel] ⚠️ Skipping non-numeric column: '{}' "
+ "(type: {})".format(col, x_df[col].dtype)
+ )
+
+ if self.exog_cols:
+ exog = x_df[self.exog_cols].to_numpy()
+ print(f"[StatsmodelsARIMAModel] Exogenous variables: {self.exog_cols}")
+
+ print(f"[StatsmodelsARIMAModel] Training ARIMA{self.order} model")
+ print(f"[StatsmodelsARIMAModel] Training with {len(endog_series)} data points")
+ print(f"[StatsmodelsARIMAModel] Date range: {dates.min()} to {dates.max()}")
+
+ # Fit ARIMA model
+ self.model = ARIMA(
+ endog=endog_series,
+ exog=exog,
+ order=self.order,
+ trend=self.trend,
+ )
+
+ self.model_fit = self.model.fit()
+
+ print("✅ ARIMA model training completed")
+ print(f"[StatsmodelsARIMAModel] AIC: {self.model_fit.aic:.2f}")
+ print(f"[StatsmodelsARIMAModel] BIC: {self.model_fit.bic:.2f}")
+
+ return self
+
+ def predict(
+ self,
+ x_pred: Optional[Any] = None,
+ periods: Optional[int] = None,
+ exog_future: Optional[pd.DataFrame] = None,
+ **kwargs,
+ ) -> Union[np.ndarray, pd.DataFrame]:
+ """Generate forecasts using ARIMA model.
+
+ Parameters
+ ----------
+ x_pred : pd.DataFrame, optional
+ Input data for in-sample predictions containing timestamp and
+ exogenous variables (if model uses them)
+ periods : int, optional
+ Number of future periods to forecast (out-of-sample mode)
+ exog_future : pd.DataFrame, optional
+ Future values of exogenous variables for out-of-sample forecasting
+ **kwargs
+ Additional parameters
+
+ Returns
+ -------
+ np.ndarray or pd.DataFrame
+ Predictions array
+ """
+ if self.model_fit is None:
+ raise ValueError("ARIMA model is not fitted yet. Call fit() first.")
+
+ # Handle different input types
+ if x_pred is not None and isinstance(x_pred, (int, np.integer)):
+ periods = int(x_pred)
+ x_pred = None
+
+ # Out-of-sample forecasting
+ if periods is not None and x_pred is None:
+ if periods <= 0:
+ raise ValueError("Prediction horizon must be a positive integer.")
+
+ # Prepare exogenous variables for forecast
+ exog = None
+ if self.exog_cols:
+ if exog_future is None:
+ raise ValueError(
+ f"Future exogenous values required for columns: "
+ f"{self.exog_cols}."
+ )
+
+ missing_cols = [
+ col for col in self.exog_cols if col not in exog_future.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ f"Missing exogenous columns for prediction: {missing_cols}."
+ )
+
+ if len(exog_future) != periods:
+ raise ValueError(
+ f"Exogenous data length ({len(exog_future)}) must match "
+ f"prediction horizon ({periods})."
+ )
+
+ exog = exog_future[self.exog_cols].to_numpy()
+
+ # Generate forecast
+ forecast = self.model_fit.forecast(steps=periods, exog=exog)
+
+ print(f"[StatsmodelsARIMAModel] Generated forecast for {periods} periods")
+ return forecast.to_numpy()
+
+ # In-sample predictions
+ if x_pred is not None:
+ if isinstance(x_pred, pd.DataFrame):
+ input_df = x_pred.copy()
+ else:
+ input_df = to_dashai_dataset(x_pred).to_pandas()
+
+ # Auto-detect timestamp column
+ timestamp_col = None
+ for col in input_df.columns:
+ try:
+ pd.to_datetime(input_df[col])
+ timestamp_col = col
+ break
+ except Exception:
+ continue
+
+ if timestamp_col is None:
+ raise ValueError(
+ "ARIMA predict requires a timestamp column. "
+ f"Available columns: {list(input_df.columns)}"
+ )
+
+ dates = pd.to_datetime(input_df[timestamp_col])
+
+ # Prepare exogenous variables
+ exog = None
+ if self.exog_cols:
+ missing_cols = [
+ col for col in self.exog_cols if col not in input_df.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ f"Missing exogenous columns for prediction: {missing_cols}."
+ )
+ exog = input_df[self.exog_cols].to_numpy()
+
+ # Use actual dates so statsmodels predicts the correct period
+ # (works for both in-sample and out-of-sample dates)
+ predictions = self.model_fit.predict(
+ start=dates.iloc[0], end=dates.iloc[-1], exog=exog
+ )
+
+ return predictions.to_numpy()
+
+ raise ValueError(
+ "ARIMA predict requires either 'x_pred' data or a 'periods' value."
+ )
+
+ def get_forecast_uncertainty(
+ self, horizon: int, confidence_level: float = 0.80
+ ) -> pd.DataFrame:
+ """Get forecast with parametric confidence intervals from ARIMA.
+
+ Uses statsmodels ``get_forecast().summary_frame()`` to compute
+ analytical confidence intervals derived from the model's error
+ distribution. These are true parametric intervals, not estimates.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of future periods to forecast.
+ confidence_level : float
+ Confidence level (e.g., 0.80 for 80% intervals).
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``.
+
+ Raises
+ ------
+ ValueError
+ If the model was trained with exogenous variables.
+ """
+ if self.model_fit is None:
+ raise ValueError("Model must be fitted before getting uncertainty.")
+
+ if self.exog_cols:
+ raise ValueError(
+ f"Cannot generate forecast uncertainty: model was trained with "
+ f"exogenous variables {self.exog_cols}. Future exogenous values "
+ f"are required but not available. "
+ f"Use ForecastFeatureImportance instead."
+ )
+
+ alpha = 1.0 - confidence_level
+ forecast_obj = self.model_fit.get_forecast(steps=horizon)
+ summary = forecast_obj.summary_frame(alpha=alpha)
+ # summary columns: mean, mean_se, mean_ci_lower, mean_ci_upper
+
+ freq = self.frequency or "D"
+ future_dates = pd.date_range(
+ start=self.last_ds, periods=horizon + 1, freq=freq
+ )[1:]
+
+ return pd.DataFrame(
+ {
+ "ds": future_dates,
+ "yhat": summary["mean"].to_numpy(),
+ "yhat_lower": summary["mean_ci_lower"].to_numpy(),
+ "yhat_upper": summary["mean_ci_upper"].to_numpy(),
+ }
+ )
+
+ def get_forecast_components(self, horizon: int) -> pd.DataFrame:
+ """Decompose forecast into trend, seasonal, and residual components.
+
+ Applies STL (Seasonal-Trend decomposition using LOESS) to the
+ combination of in-sample fitted values and out-of-sample forecast.
+ When there is insufficient data for STL, falls back to a centered
+ moving-average trend.
+
+ ARIMA does not model seasonality explicitly, so the seasonal component
+ reflects the cyclical pattern extracted from the data by STL.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of future periods to forecast and decompose.
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``ds``, ``trend``, ````, ``residual``.
+ The seasonality column name is derived from the stored frequency
+ (e.g. ``weekly`` for daily data, ``yearly`` for monthly data).
+
+ Raises
+ ------
+ ValueError
+ If the model was trained with exogenous variables (future values
+ would be required but are unavailable here).
+ """
+ if self.model_fit is None:
+ raise ValueError("Model must be fitted before getting components.")
+
+ if self.exog_cols:
+ raise ValueError(
+ f"Cannot generate forecast components: model was trained with "
+ f"exogenous variables {self.exog_cols}. Future exogenous values "
+ f"are required but not available for decomposition. "
+ f"Use ForecastFeatureImportance instead."
+ )
+
+ try:
+ from statsmodels.tsa.seasonal import STL
+ except ImportError as exc:
+ raise ImportError("statsmodels is required for STL decomposition.") from exc
+
+ # In-sample fitted values (DatetimeIndex from training)
+ fitted = self.model_fit.fittedvalues.dropna()
+
+ # Out-of-sample forecast
+ freq = self.frequency or "D"
+ forecast_result = self.model_fit.forecast(steps=horizon)
+ future_dates = pd.date_range(
+ start=self.last_ds, periods=horizon + 1, freq=freq
+ )[1:]
+ future_series = pd.Series(forecast_result.to_numpy(), index=future_dates)
+
+ # Combine history + forecast into one series
+ combined = pd.concat([fitted, future_series])
+
+ # Determine period for STL
+ period = self._get_seasonal_period()
+ component_name = self._period_to_seasonality_name(period)
+
+ n = len(combined)
+ if period >= 2 and n >= 2 * period:
+ try:
+ stl = STL(combined, period=period, robust=True)
+ result = stl.fit()
+ trend_vals = result.trend
+ seasonal_vals = result.seasonal
+ residual_vals = result.resid
+ except Exception:
+ # Fallback to moving-average trend if STL fails
+ window = min(period, max(2, n // 2))
+ trend_vals = combined.rolling(
+ window=window, center=True, min_periods=1
+ ).mean()
+ seasonal_vals = pd.Series(np.zeros(n), index=combined.index)
+ residual_vals = combined - trend_vals
+ else:
+ # Not enough data for STL — use simple moving-average trend
+ window = max(2, min(period, n // 2))
+ trend_vals = combined.rolling(
+ window=window, center=True, min_periods=1
+ ).mean()
+ seasonal_vals = pd.Series(np.zeros(n), index=combined.index)
+ residual_vals = combined - trend_vals
+
+ # Return only the forecast horizon (the future portion)
+ return pd.DataFrame(
+ {
+ "ds": combined.index[-horizon:],
+ "trend": trend_vals.to_numpy()[-horizon:],
+ component_name: seasonal_vals.to_numpy()[-horizon:],
+ "residual": residual_vals.to_numpy()[-horizon:],
+ }
+ )
+
+ def save(self, filename: str) -> None:
+ """Save ARIMA model to file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to save the model
+ """
+ if self.model_fit is None:
+ raise ValueError("Cannot save model before fitting.")
+
+ model_state = {
+ "model_fit": self.model_fit,
+ "exog_cols": self.exog_cols,
+ "timestamp_col": self.timestamp_col,
+ "target_col": self.target_col,
+ "frequency": self.frequency,
+ "last_ds": getattr(self, "last_ds", None),
+ "config": {
+ "p": self.p,
+ "d": self.d,
+ "q": self.q,
+ "trend": self.trend,
+ "order": self.order,
+ },
+ }
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ with open(filename, "wb") as f:
+ pickle.dump(model_state, f)
+
+ print(f"✅ ARIMA model saved to {filename}")
+
+ def load(self, filename: str) -> "StatsmodelsARIMAModel":
+ """Load ARIMA model from file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to load the model from
+
+ Returns
+ -------
+ StatsmodelsARIMAModel
+ Loaded model instance
+ """
+ with open(filename, "rb") as f:
+ model_state = pickle.load(f)
+
+ self.model_fit = model_state["model_fit"]
+ self.exog_cols = model_state["exog_cols"]
+ self.timestamp_col = model_state.get("timestamp_col")
+ self.target_col = model_state.get("target_col")
+ self.frequency = model_state.get("frequency")
+ self.last_ds = model_state.get("last_ds")
+
+ config = model_state["config"]
+ for key, value in config.items():
+ setattr(self, key, value)
+
+ print(f"✅ ARIMA model loaded from {filename}")
+ return self
diff --git a/DashAI/back/models/forecasting/statsmodels_sarimax_model.py b/DashAI/back/models/forecasting/statsmodels_sarimax_model.py
new file mode 100644
index 000000000..c78749fa6
--- /dev/null
+++ b/DashAI/back/models/forecasting/statsmodels_sarimax_model.py
@@ -0,0 +1,813 @@
+"""Statsmodels SARIMAX model wrapper for DashAI forecasting.
+
+This model wraps statsmodels SARIMAX for seasonal time series forecasting with
+autoregressive integrated moving average modeling and exogenous variables.
+"""
+
+import os
+import pickle
+from typing import Any, Optional, Union
+
+import numpy as np
+import pandas as pd
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ bool_field,
+ enum_field,
+ int_field,
+ schema_field,
+)
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ to_dashai_dataset,
+)
+from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel
+
+
+class StatsmodelsSARIMAXModelSchema(BaseSchema):
+ """Schema for Statsmodels SARIMAX model configuration.
+
+ SARIMAX (Seasonal AutoRegressive Integrated Moving Average with eXogenous
+ regressors) extends ARIMA with seasonal components and external variables:
+ - (p,d,q): Non-seasonal AR, differencing, MA orders
+ - (P,D,Q,s): Seasonal AR, differencing, MA orders and periodicity
+ - Exogenous variables: External predictors
+ """
+
+ p: schema_field(
+ int_field(ge=0, le=10),
+ placeholder=1,
+ description="Order of non-seasonal autoregressive (AR) component. "
+ "Number of lag observations included in the model.",
+ ) = 1 # type: ignore
+
+ d: schema_field(
+ int_field(ge=0, le=3),
+ placeholder=1,
+ description="Degree of non-seasonal differencing. Number of times "
+ "to difference the data to make it stationary.",
+ ) = 1 # type: ignore
+
+ q: schema_field(
+ int_field(ge=0, le=10),
+ placeholder=1,
+ description="Order of non-seasonal moving average (MA) component. "
+ "Size of the moving average window.",
+ ) = 1 # type: ignore
+
+ P: schema_field(
+ int_field(ge=0, le=5),
+ placeholder=0,
+ description="Order of seasonal autoregressive component. "
+ "Seasonal lag observations. Set to 0 to disable seasonality.",
+ ) = 0 # type: ignore
+
+ D: schema_field(
+ int_field(ge=0, le=2),
+ placeholder=0,
+ description="Degree of seasonal differencing. Seasonal differencing order. "
+ "Set to 0 to disable seasonal differencing.",
+ ) = 0 # type: ignore
+
+ Q: schema_field(
+ int_field(ge=0, le=5),
+ placeholder=0,
+ description="Order of seasonal moving average component. "
+ "Seasonal moving average window. Set to 0 to disable.",
+ ) = 0 # type: ignore
+
+ s: schema_field(
+ int_field(ge=1, le=365),
+ placeholder=1,
+ description="Seasonal period (observations per cycle). "
+ "12=monthly, 4=quarterly, 7=weekly. Set to 1 to disable seasonality.",
+ ) = 1 # type: ignore
+
+ trend: schema_field(
+ enum_field(enum=["n", "c", "t", "ct"]),
+ placeholder="n",
+ description=(
+ "Deterministic trend to include. 'n'=no trend, 'c'=constant, "
+ "'t'=linear trend, 'ct'=constant and linear trend. "
+ "Note: When d>0 or D>0, 'c' is not allowed (use 't' instead). "
+ "When d+D>1, neither 'c' nor 't' are allowed."
+ ),
+ ) = "n" # type: ignore
+
+ enforce_stationarity: schema_field(
+ bool_field(),
+ placeholder=True,
+ description=(
+ "Whether to enforce stationarity of the autoregressive parameters."
+ ),
+ ) = True # type: ignore
+
+ enforce_invertibility: schema_field(
+ bool_field(),
+ placeholder=True,
+ description=(
+ "Whether to enforce invertibility of the moving average parameters."
+ ),
+ ) = True # type: ignore
+
+
+class StatsmodelsSARIMAXModel(ForecastingModel):
+ """Statsmodels SARIMAX forecasting model wrapper for DashAI.
+
+ This model implements the ForecastingModel interface using statsmodels SARIMAX.
+ It handles seasonal patterns, exogenous variables, and column name conversions.
+ """
+
+ SCHEMA = StatsmodelsSARIMAXModelSchema
+ COMPATIBLE_COMPONENTS = ["ForecastingTask"]
+ _task_type = "ForecastingTask"
+
+ def __init__(
+ self,
+ p: int = 1,
+ d: int = 1,
+ q: int = 1,
+ P: int = 0, # noqa: N803
+ D: int = 0, # noqa: N803
+ Q: int = 0, # noqa: N803
+ s: int = 1,
+ trend: str = "n",
+ enforce_stationarity: bool = True,
+ enforce_invertibility: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.p = p
+ self.d = d
+ self.q = q
+ self.P = P
+ self.D = D
+ self.Q = Q
+ self.s = s
+ self.trend = trend
+ self.enforce_stationarity = enforce_stationarity
+ self.enforce_invertibility = enforce_invertibility
+
+ self.order = (p, d, q)
+ self.seasonal_order = (P, D, Q, s)
+
+ self.model = None
+ self.model_fit = None
+ self.frequency: Optional[str] = None
+
+ def _validate_forecasting_data(self, x: DashAIDataset, y: DashAIDataset) -> None:
+ """Validate that data is suitable for SARIMAX.
+
+ Parameters
+ ----------
+ x : DashAIDataset
+ Input features (must contain a timestamp column)
+ y : DashAIDataset
+ Target values (must contain a numeric column)
+
+ Raises
+ ------
+ ValueError
+ If data is not suitable for SARIMAX
+ """
+ x_cols = set(x.column_names)
+ y_cols = set(y.column_names)
+
+ if len(x_cols) == 0:
+ raise ValueError(
+ "SARIMAX requires at least one input column (timestamp). "
+ "Received empty dataset."
+ )
+
+ if len(y_cols) != 1:
+ raise ValueError(
+ f"SARIMAX requires exactly one target column. "
+ f"Received {len(y_cols)} columns: {list(y_cols)}"
+ )
+
+ def fit(
+ self,
+ x_train: DashAIDataset,
+ y: DashAIDataset,
+ temporal_metadata: dict = None,
+ **fit_params,
+ ) -> "StatsmodelsSARIMAXModel":
+ """Train SARIMAX forecasting model.
+
+ Parameters
+ ----------
+ x_train : DashAIDataset
+ Input features containing timestamp and optional exogenous variables
+ y : DashAIDataset
+ Target time series (single column)
+ temporal_metadata : dict, optional
+ Metadata from ForecastingTask containing:
+ - timestamp_col: name of timestamp column
+ - target_col: name of target column
+ - exog_cols: list of exogenous variable column names
+ - frequency: time series frequency
+ **fit_params
+ Additional fitting parameters
+
+ Returns
+ -------
+ StatsmodelsSARIMAXModel
+ Fitted model instance
+ """
+ try:
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
+ except ImportError as e:
+ raise ImportError(
+ "Statsmodels is required for StatsmodelsSARIMAXModel. "
+ "Install with: pip install statsmodels"
+ ) from e
+
+ # Validate data format
+ self._validate_forecasting_data(x_train, y)
+
+ # Convert to pandas DataFrames
+ x_df = x_train.to_pandas()
+ y_df = y.to_pandas()
+
+ # Get column information from metadata
+ if temporal_metadata:
+ timestamp_col = temporal_metadata.get("timestamp_col")
+ target_col = temporal_metadata.get("target_col")
+ exog_cols_from_task = temporal_metadata.get("exog_cols", [])
+ frequency = temporal_metadata.get("frequency")
+
+ print("[StatsmodelsSARIMAXModel] Using temporal metadata from task:")
+ print(f" - Timestamp: '{timestamp_col}'")
+ print(f" - Target: '{target_col}'")
+ print(f" - Frequency: {frequency}")
+ if exog_cols_from_task:
+ print(f" - Exogenous variables: {exog_cols_from_task}")
+ else:
+ # Auto-detection if no metadata provided
+ print(
+ "[StatsmodelsSARIMAXModel] ⚠️ No temporal_metadata provided, "
+ "using auto-detection"
+ )
+
+ target_col = y_df.columns[0]
+
+ # Auto-detect timestamp column
+ timestamp_col = None
+ for col in x_df.columns:
+ try:
+ pd.to_datetime(x_df[col])
+ timestamp_col = col
+ print(
+ f"[StatsmodelsSARIMAXModel] Detected timestamp column: '{col}'"
+ )
+ break
+ except Exception:
+ continue
+
+ if timestamp_col is None:
+ raise ValueError(
+ f"No timestamp column found in input data. "
+ f"Available columns: {list(x_df.columns)}"
+ )
+
+ exog_cols_from_task = []
+ frequency = fit_params.get("frequency")
+
+ # Store original column names
+ self.timestamp_col = timestamp_col
+ self.target_col = target_col
+ self.frequency = frequency
+
+ # Prepare data for SARIMAX
+ # Create datetime index
+ dates = pd.to_datetime(x_df[timestamp_col])
+
+ # Store last training date for forecast generation
+ self.last_ds = dates.max()
+
+ # Get target series
+ target_in_inputs = target_col in x_df.columns
+ if target_in_inputs:
+ print(
+ "[StatsmodelsSARIMAXModel] ℹ️ Target '{}' found in inputs - "
+ "using it from there".format(target_col)
+ )
+ endog = x_df[target_col].to_numpy()
+ else:
+ endog = y_df[target_col].to_numpy()
+
+ # Create time series with datetime index
+ endog_series = pd.Series(endog, index=dates)
+
+ # Prepare exogenous variables
+ self.exog_cols = []
+ exog = None
+
+ for col in x_df.columns:
+ if col == timestamp_col:
+ continue
+ if col == target_col:
+ if target_in_inputs:
+ print(
+ "[StatsmodelsSARIMAXModel] ℹ️ Excluding target '{}' from "
+ "exogenous variables".format(col)
+ )
+ continue
+
+ # Only add numeric columns
+ if pd.api.types.is_numeric_dtype(x_df[col]):
+ self.exog_cols.append(col)
+ else:
+ print(
+ "[StatsmodelsSARIMAXModel] ⚠️ Skipping non-numeric column: '{}' "
+ "(type: {})".format(col, x_df[col].dtype)
+ )
+
+ if self.exog_cols:
+ exog = x_df[self.exog_cols].to_numpy()
+ print(f"[StatsmodelsSARIMAXModel] Exogenous variables: {self.exog_cols}")
+
+ print(
+ f"[StatsmodelsSARIMAXModel] Training "
+ f"SARIMAX{self.order}x{self.seasonal_order} model"
+ )
+ print(
+ f"[StatsmodelsSARIMAXModel] Training with {len(endog_series)} data points"
+ )
+ print(f"[StatsmodelsSARIMAXModel] Date range: {dates.min()} to {dates.max()}")
+
+ # Auto-adjust parameters for small datasets
+ n_samples = len(endog_series)
+
+ # Check if seasonality is enabled (s>1 and any seasonal param > 0)
+ has_seasonality = self.s > 1 and (self.P > 0 or self.D > 0 or self.Q > 0)
+
+ if has_seasonality:
+ # SARIMAX needs: s + d + D + max(p, P) + max(q, Q) samples
+ min_required = (
+ self.s + self.d + self.D + max(self.p, self.P) + max(self.q, self.Q) + 2
+ )
+
+ if n_samples < min_required:
+ print(
+ f"[StatsmodelsSARIMAXModel] ⚠️ Dataset too small "
+ f"({n_samples} samples) for seasonal params "
+ f"(need {min_required}). Disabling seasonality..."
+ )
+ # Disable seasonality entirely for small datasets
+ self.P = 0
+ self.D = 0
+ self.Q = 0
+ self.s = 1
+ has_seasonality = False
+
+ # For non-seasonal ARIMA, check basic requirements
+ min_arima_samples = self.p + self.d + self.q + 3
+ if n_samples < min_arima_samples:
+ print("[StatsmodelsSARIMAXModel] ⚠️ Adjusting ARIMA orders...")
+ # Reduce AR/MA orders if needed
+ max_order = max(1, (n_samples - self.d - 2) // 2)
+ if self.p > max_order:
+ old_p = self.p
+ self.p = max(0, max_order)
+ print(f" - Reduced p (AR order): {old_p} → {self.p}")
+ if self.q > max_order:
+ old_q = self.q
+ self.q = max(0, max_order)
+ print(f" - Reduced q (MA order): {old_q} → {self.q}")
+ if self.d > 1 and n_samples < 10:
+ old_d = self.d
+ self.d = min(1, self.d)
+ print(f" - Reduced d (differencing): {old_d} → {self.d}")
+
+ self.order = (self.p, self.d, self.q)
+
+ # Set seasonal_order: None if no seasonality, otherwise tuple
+ if has_seasonality:
+ self.seasonal_order = (self.P, self.D, self.Q, self.s)
+ print(
+ f"[StatsmodelsSARIMAXModel] Final parameters: "
+ f"SARIMAX{self.order}x{self.seasonal_order}"
+ )
+ else:
+ self.seasonal_order = (0, 0, 0, 0) # Disable seasonality
+ print(
+ f"[StatsmodelsSARIMAXModel] Final parameters: "
+ f"ARIMA{self.order} (no seasonality)"
+ )
+
+ # Fit SARIMAX model (use ARIMA when no seasonality for stability)
+ try:
+ if has_seasonality:
+ self.model = SARIMAX(
+ endog=endog_series,
+ exog=exog,
+ order=self.order,
+ seasonal_order=self.seasonal_order,
+ trend=self.trend,
+ enforce_stationarity=self.enforce_stationarity,
+ enforce_invertibility=self.enforce_invertibility,
+ )
+ self.model_fit = self.model.fit()
+ else:
+ # Use ARIMA (no seasonal component) for small datasets or when s <= 1
+ # SARIMAX doesn't accept seasonal_order with s <= 1
+ from statsmodels.tsa.arima.model import ARIMA
+
+ print("[StatsmodelsSARIMAXModel] Using ARIMA (no seasonal component)")
+ self.model = ARIMA(
+ endog=endog_series,
+ exog=exog,
+ order=self.order,
+ trend=self.trend,
+ enforce_stationarity=self.enforce_stationarity,
+ enforce_invertibility=self.enforce_invertibility,
+ )
+ self.model_fit = self.model.fit()
+
+ print("✅ SARIMAX model training completed")
+ print(f"[StatsmodelsSARIMAXModel] AIC: {self.model_fit.aic:.2f}")
+ print(f"[StatsmodelsSARIMAXModel] BIC: {self.model_fit.bic:.2f}")
+ except ValueError as e:
+ # Fallback: if SARIMAX fails due to seasonality issues, try ARIMA
+ if "Seasonal periodicity" in str(e) or "seasonal" in str(e).lower():
+ print(f"[StatsmodelsSARIMAXModel] ⚠️ SARIMAX seasonality error: {e}")
+ print("[StatsmodelsSARIMAXModel] Falling back to ARIMA")
+ from statsmodels.tsa.arima.model import ARIMA
+
+ self.model = ARIMA(
+ endog=endog_series,
+ exog=exog,
+ order=self.order,
+ trend=self.trend,
+ enforce_stationarity=self.enforce_stationarity,
+ enforce_invertibility=self.enforce_invertibility,
+ )
+ self.model_fit = self.model.fit()
+ self.seasonal_order = (0, 0, 0, 0)
+ print("✅ ARIMA model training completed (fallback)")
+ print(f"[StatsmodelsSARIMAXModel] AIC: {self.model_fit.aic:.2f}")
+ print(f"[StatsmodelsSARIMAXModel] BIC: {self.model_fit.bic:.2f}")
+ else:
+ print(f"[StatsmodelsSARIMAXModel] ❌ Training failed: {e}")
+ raise
+ except Exception as e:
+ print(f"[StatsmodelsSARIMAXModel] ❌ Training failed: {e}")
+ raise
+
+ return self
+
+ def predict(
+ self,
+ x_pred: Optional[Any] = None,
+ periods: Optional[int] = None,
+ exog_future: Optional[pd.DataFrame] = None,
+ **kwargs,
+ ) -> Union[np.ndarray, pd.DataFrame]:
+ """Generate forecasts using SARIMAX model.
+
+ Parameters
+ ----------
+ x_pred : pd.DataFrame, optional
+ Input data for in-sample predictions containing timestamp and
+ exogenous variables (if model uses them)
+ periods : int, optional
+ Number of future periods to forecast (out-of-sample mode)
+ exog_future : pd.DataFrame, optional
+ Future values of exogenous variables for out-of-sample forecasting
+ **kwargs
+ Additional parameters
+
+ Returns
+ -------
+ np.ndarray or pd.DataFrame
+ Predictions array
+ """
+ if self.model_fit is None:
+ raise ValueError("SARIMAX model is not fitted yet. Call fit() first.")
+
+ # Handle different input types
+ if x_pred is not None and isinstance(x_pred, (int, np.integer)):
+ periods = int(x_pred)
+ x_pred = None
+
+ # Out-of-sample forecasting
+ if periods is not None and x_pred is None:
+ if periods <= 0:
+ raise ValueError("Prediction horizon must be a positive integer.")
+
+ # Prepare exogenous variables for forecast
+ exog = None
+ if self.exog_cols:
+ if exog_future is None:
+ raise ValueError(
+ f"Future exogenous values required for columns: "
+ f"{self.exog_cols}."
+ )
+
+ missing_cols = [
+ col for col in self.exog_cols if col not in exog_future.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ f"Missing exogenous columns for prediction: {missing_cols}."
+ )
+
+ if len(exog_future) != periods:
+ raise ValueError(
+ f"Exogenous data length ({len(exog_future)}) must match "
+ f"prediction horizon ({periods})."
+ )
+
+ exog = exog_future[self.exog_cols].to_numpy()
+
+ # Generate forecast
+ forecast = self.model_fit.forecast(steps=periods, exog=exog)
+
+ print(f"[StatsmodelsSARIMAXModel] Generated forecast for {periods} periods")
+ return forecast.to_numpy()
+
+ # In-sample predictions
+ if x_pred is not None:
+ if isinstance(x_pred, pd.DataFrame):
+ input_df = x_pred.copy()
+ else:
+ input_df = to_dashai_dataset(x_pred).to_pandas()
+
+ # Auto-detect timestamp column
+ timestamp_col = None
+ for col in input_df.columns:
+ try:
+ pd.to_datetime(input_df[col])
+ timestamp_col = col
+ break
+ except Exception:
+ continue
+
+ if timestamp_col is None:
+ raise ValueError(
+ "SARIMAX predict requires a timestamp column. "
+ f"Available columns: {list(input_df.columns)}"
+ )
+
+ dates = pd.to_datetime(input_df[timestamp_col])
+
+ # Prepare exogenous variables
+ exog = None
+ if self.exog_cols:
+ missing_cols = [
+ col for col in self.exog_cols if col not in input_df.columns
+ ]
+ if missing_cols:
+ raise ValueError(
+ f"Missing exogenous columns for prediction: {missing_cols}."
+ )
+ exog = input_df[self.exog_cols].to_numpy()
+
+ # Use actual dates so statsmodels predicts the correct period
+ # (works for both in-sample and out-of-sample dates)
+ print(
+ f"[StatsmodelsSARIMAXModel] In-sample prediction: {len(dates)} points "
+ f"({dates.min()} to {dates.max()})"
+ )
+
+ predictions = self.model_fit.predict(
+ start=dates.iloc[0], end=dates.iloc[-1], exog=exog
+ )
+
+ print(f"[StatsmodelsSARIMAXModel] Generated {len(predictions)} predictions")
+
+ return predictions.to_numpy()
+
+ raise ValueError(
+ "SARIMAX predict requires either 'x_pred' data or a 'periods' value."
+ )
+
+ def get_forecast_uncertainty(
+ self, horizon: int, confidence_level: float = 0.80
+ ) -> pd.DataFrame:
+ """Get forecast with parametric confidence intervals from SARIMAX.
+
+ Uses statsmodels ``get_forecast().summary_frame()`` to compute
+ analytical confidence intervals derived from the model's error
+ distribution. These are true parametric intervals that reflect both
+ the non-seasonal and seasonal uncertainty of the model.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of future periods to forecast.
+ confidence_level : float
+ Confidence level (e.g., 0.80 for 80% intervals).
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``.
+
+ Raises
+ ------
+ ValueError
+ If the model was trained with exogenous variables.
+ """
+ if self.model_fit is None:
+ raise ValueError("Model must be fitted before getting uncertainty.")
+
+ if self.exog_cols:
+ raise ValueError(
+ f"Cannot generate forecast uncertainty: model was trained with "
+ f"exogenous variables {self.exog_cols}. Future exogenous values "
+ f"are required but not available. "
+ f"Use ForecastFeatureImportance instead."
+ )
+
+ alpha = 1.0 - confidence_level
+ forecast_obj = self.model_fit.get_forecast(steps=horizon)
+ summary = forecast_obj.summary_frame(alpha=alpha)
+ # summary columns: mean, mean_se, mean_ci_lower, mean_ci_upper
+
+ freq = self.frequency or "D"
+ future_dates = pd.date_range(
+ start=self.last_ds, periods=horizon + 1, freq=freq
+ )[1:]
+
+ return pd.DataFrame(
+ {
+ "ds": future_dates,
+ "yhat": summary["mean"].to_numpy(),
+ "yhat_lower": summary["mean_ci_lower"].to_numpy(),
+ "yhat_upper": summary["mean_ci_upper"].to_numpy(),
+ }
+ )
+
+ def get_forecast_components(self, horizon: int) -> pd.DataFrame:
+ """Decompose forecast into trend, seasonal, and residual components.
+
+ Applies STL (Seasonal-Trend decomposition using LOESS) to the
+ combination of in-sample fitted values and out-of-sample forecast.
+
+ When a seasonal order was configured (``s > 1``), the explicit
+ seasonal period ``s`` is used for STL so the decomposition reflects
+ the model's actual seasonal structure. Otherwise the period is
+ inferred from the stored frequency.
+
+ Parameters
+ ----------
+ horizon : int
+ Number of future periods to forecast and decompose.
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``ds``, ``trend``, ````, ``residual``.
+ The seasonality column name depends on the period (e.g. ``weekly``
+ for s=7, ``yearly`` for s=12).
+
+ Raises
+ ------
+ ValueError
+ If the model was trained with exogenous variables.
+ """
+ if self.model_fit is None:
+ raise ValueError("Model must be fitted before getting components.")
+
+ if self.exog_cols:
+ raise ValueError(
+ f"Cannot generate forecast components: model was trained with "
+ f"exogenous variables {self.exog_cols}. Future exogenous values "
+ f"are required but not available for decomposition. "
+ f"Use ForecastFeatureImportance instead."
+ )
+
+ try:
+ from statsmodels.tsa.seasonal import STL
+ except ImportError as exc:
+ raise ImportError("statsmodels is required for STL decomposition.") from exc
+
+ # In-sample fitted values (DatetimeIndex from training)
+ fitted = self.model_fit.fittedvalues.dropna()
+
+ # Out-of-sample forecast
+ freq = self.frequency or "D"
+ forecast_result = self.model_fit.forecast(steps=horizon)
+ future_dates = pd.date_range(
+ start=self.last_ds, periods=horizon + 1, freq=freq
+ )[1:]
+ future_series = pd.Series(forecast_result.to_numpy(), index=future_dates)
+
+ # Combine history + forecast
+ combined = pd.concat([fitted, future_series])
+
+ # Use explicit seasonal period s when seasonality is active;
+ # otherwise infer from frequency
+ explicit_s = (
+ self.seasonal_order[3]
+ if hasattr(self, "seasonal_order") and self.seasonal_order[3] > 1
+ else None
+ )
+ period = explicit_s if explicit_s else self._get_seasonal_period()
+ component_name = self._period_to_seasonality_name(period)
+
+ n = len(combined)
+ if period >= 2 and n >= 2 * period:
+ try:
+ stl = STL(combined, period=period, robust=True)
+ result = stl.fit()
+ trend_vals = result.trend
+ seasonal_vals = result.seasonal
+ residual_vals = result.resid
+ except Exception:
+ window = min(period, max(2, n // 2))
+ trend_vals = combined.rolling(
+ window=window, center=True, min_periods=1
+ ).mean()
+ seasonal_vals = pd.Series(np.zeros(n), index=combined.index)
+ residual_vals = combined - trend_vals
+ else:
+ window = max(2, min(period, n // 2))
+ trend_vals = combined.rolling(
+ window=window, center=True, min_periods=1
+ ).mean()
+ seasonal_vals = pd.Series(np.zeros(n), index=combined.index)
+ residual_vals = combined - trend_vals
+
+ return pd.DataFrame(
+ {
+ "ds": combined.index[-horizon:],
+ "trend": trend_vals.to_numpy()[-horizon:],
+ component_name: seasonal_vals.to_numpy()[-horizon:],
+ "residual": residual_vals.to_numpy()[-horizon:],
+ }
+ )
+
+ def save(self, filename: str) -> None:
+ """Save SARIMAX model to file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to save the model
+ """
+ if self.model_fit is None:
+ raise ValueError("Cannot save model before fitting.")
+
+ model_state = {
+ "model_fit": self.model_fit,
+ "exog_cols": self.exog_cols,
+ "timestamp_col": self.timestamp_col,
+ "target_col": self.target_col,
+ "frequency": self.frequency,
+ "last_ds": getattr(self, "last_ds", None),
+ "config": {
+ "p": self.p,
+ "d": self.d,
+ "q": self.q,
+ "P": self.P,
+ "D": self.D,
+ "Q": self.Q,
+ "s": self.s,
+ "trend": self.trend,
+ "order": self.order,
+ "seasonal_order": self.seasonal_order,
+ "enforce_stationarity": self.enforce_stationarity,
+ "enforce_invertibility": self.enforce_invertibility,
+ },
+ }
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ with open(filename, "wb") as f:
+ pickle.dump(model_state, f)
+
+ print(f"✅ SARIMAX model saved to {filename}")
+
+ def load(self, filename: str) -> "StatsmodelsSARIMAXModel":
+ """Load SARIMAX model from file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to load the model from
+
+ Returns
+ -------
+ StatsmodelsSARIMAXModel
+ Loaded model instance
+ """
+ with open(filename, "rb") as f:
+ model_state = pickle.load(f)
+
+ self.model_fit = model_state["model_fit"]
+ self.exog_cols = model_state["exog_cols"]
+ self.timestamp_col = model_state.get("timestamp_col")
+ self.target_col = model_state.get("target_col")
+ self.frequency = model_state.get("frequency")
+ self.last_ds = model_state.get("last_ds")
+
+ config = model_state["config"]
+ for key, value in config.items():
+ setattr(self, key, value)
+
+ print(f"✅ SARIMAX model loaded from {filename}")
+ return self
diff --git a/DashAI/back/models/model_factory.py b/DashAI/back/models/model_factory.py
index b9e77765f..c777c0ef2 100644
--- a/DashAI/back/models/model_factory.py
+++ b/DashAI/back/models/model_factory.py
@@ -1,9 +1,24 @@
+import math
+
from kink import di
from DashAI.back.metrics.base_metric import BaseMetric
from DashAI.back.metrics.classification_metric import ClassificationMetric
+def sanitize_metric_value(value):
+ """Convert non-JSON-serializable float values to None.
+
+ JSON doesn't support NaN, Infinity, or -Infinity, so we convert
+ these to None which becomes null in JSON.
+ """
+ if value is None:
+ return None
+ if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
+ return None
+ return value
+
+
class ModelFactory:
"""
A factory class for creating and configuring models.
@@ -257,29 +272,60 @@ def evaluate(self, x, y, metrics):
results = {}
for split in ["train", "validation", "test"]:
split_results = {}
+
+ # Skip empty splits
if x[split].shape[0] == 0:
split_results = {metric.__name__: None for metric in metrics}
results[split] = split_results
continue
- predictions = self.model.predict(x[split])
+
+ # Predict — wrapped in try/except for forecasting models
+ # where a split may be too small for the window_size
+ try:
+ predictions = self.model.predict(x[split])
+ except Exception as e:
+ print(
+ f"[ModelFactory] ⚠️ Prediction failed for {split} split: {e}. "
+ f"Setting all metrics to null."
+ )
+ for metric in metrics:
+ split_results[metric.__name__] = None
+ results[split] = split_results
+ continue
+
+ # Transform y using model's own method (upstream convention)
if hasattr(self.model, "prepare_output"):
transformed_y = self.model.prepare_output(y[split])
else:
transformed_y = self.model.prepare_dataset(y[split])
+
for metric in metrics:
- if (
- isinstance(metric, type)
- and issubclass(metric, ClassificationMetric)
- and "multiclass" in metric.score.__code__.co_varnames
- and multiclass is not None
- ):
- score = metric.score(
- transformed_y, predictions, multiclass=multiclass
- )
- else:
- score = metric.score(transformed_y, predictions)
-
- split_results[metric.__name__] = score
+ try:
+ if (
+ isinstance(metric, type)
+ and issubclass(metric, ClassificationMetric)
+ and "multiclass" in metric.score.__code__.co_varnames
+ and multiclass is not None
+ ):
+ score = metric.score(
+ transformed_y, predictions, multiclass=multiclass
+ )
+ else:
+ score = metric.score(transformed_y, predictions)
+
+ # Sanitize score to ensure JSON compatibility
+ split_results[metric.__name__] = sanitize_metric_value(score)
+ except ValueError as e:
+ # Handle case where all predictions are NaN
+ if "All values are NaN" in str(e):
+ print(
+ f"[ModelFactory] ⚠️ {split}/{metric.__name__}: "
+ f"All predictions are NaN (split too small?). "
+ f"Setting metric to null."
+ )
+ split_results[metric.__name__] = None
+ else:
+ raise
results[split] = split_results
diff --git a/DashAI/back/models/parameters/models_schemas/MultiOutputRegression.json b/DashAI/back/models/parameters/models_schemas/MultiOutputRegression.json
new file mode 100644
index 000000000..731ae6f03
--- /dev/null
+++ b/DashAI/back/models/parameters/models_schemas/MultiOutputRegression.json
@@ -0,0 +1,49 @@
+{
+ "additionalProperties": false,
+ "error_msg": "The parameters for MultiOutputRegression must include one or more of ['fit_intercept', 'copy_X', 'n_jobs', 'positive'].",
+ "description": "MultiOutputRegression trains an independent regressor for each output. By default, it uses LinearRegression for each output.",
+ "properties": {
+ "fit_intercept": {
+ "oneOf": [
+ {
+ "error_msg": "The 'fit_intercept' parameter must be of type boolean.",
+ "description": "Determines whether to calculate the intercept for this model.",
+ "type": "boolean",
+ "default": true
+ }
+ ]
+ },
+ "copy_X": {
+ "oneOf": [
+ {
+ "error_msg": "The 'copy_X' parameter must be of type boolean.",
+ "description": "Determines whether to copy the input matrix X.",
+ "type": "boolean",
+ "default": true
+ }
+ ]
+ },
+ "n_jobs": {
+ "oneOf": [
+ {
+ "error_msg": "The 'n_jobs' parameter must be an integer or null.",
+ "description": "The number of jobs to use for computation. -1 means using all processors.",
+ "type": ["integer", "null"],
+ "default": null,
+ "minimum": -1
+ }
+ ]
+ },
+ "positive": {
+ "oneOf": [
+ {
+ "error_msg": "The 'positive' parameter must be of type boolean.",
+ "description": "When set to True, forces the coefficients to be positive.",
+ "type": "boolean",
+ "default": false
+ }
+ ]
+ }
+ },
+ "type": "object"
+}
diff --git a/DashAI/back/models/scikit_learn/__init__.py b/DashAI/back/models/scikit_learn/__init__.py
index e69de29bb..9c358d4ab 100644
--- a/DashAI/back/models/scikit_learn/__init__.py
+++ b/DashAI/back/models/scikit_learn/__init__.py
@@ -0,0 +1,7 @@
+"""Scikit-learn based models."""
+
+from .multi_output_regression import MultiOutputRegression
+
+__all__ = [
+ "MultiOutputRegression",
+]
diff --git a/DashAI/back/models/scikit_learn/multi_output_regression.py b/DashAI/back/models/scikit_learn/multi_output_regression.py
new file mode 100644
index 000000000..72ac850a1
--- /dev/null
+++ b/DashAI/back/models/scikit_learn/multi_output_regression.py
@@ -0,0 +1,154 @@
+"""
+MultiOutput regression model for DashAI.
+
+This model is a wrapper around sklearn.multioutput.MultiOutputRegressor.
+By default it uses LinearRegression as base estimator but you can select
+other sklearn regressors by passing the `base_estimator` parameter and,
+optionally, `base_params` (a dict with kwargs for the base estimator).
+"""
+
+from typing import Any, Dict, Optional
+
+from sklearn.ensemble import RandomForestRegressor as _RandomForestRegressor
+from sklearn.linear_model import LinearRegression as _LinearRegression
+from sklearn.linear_model import Ridge as _Ridge
+from sklearn.multioutput import MultiOutputRegressor
+
+from DashAI.back.core.schema_fields import (
+ BaseSchema,
+ enum_field,
+ schema_field,
+)
+from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
+from DashAI.back.models.regression_model import RegressionModel
+from DashAI.back.models.scikit_learn.sklearn_like_regressor import SklearnLikeRegressor
+
+
+class MultiOutputRegressionSchema(BaseSchema):
+ """Multi-output regression using sklearn's MultiOutputRegressor.
+
+ This meta-estimator fits one regressor per target variable, allowing you to
+ predict multiple continuous outputs simultaneously. Choose from different base
+ estimators depending on your needs: linear models for interpretability,
+ tree-based models for non-linear relationships.
+ """
+
+ base_estimator: schema_field(
+ enum_field(enum=["linear", "ridge", "random_forest"]),
+ placeholder="linear",
+ description="Base estimator to use for each output target. "
+ "'linear': Fast linear regression (no regularization). "
+ "'ridge': Linear regression with L2 regularization (prevents overfitting). "
+ "'random_forest': Tree-based ensemble (handles non-linear relationships).",
+ ) = "linear" # type: ignore
+
+
+class MultiOutputRegression(RegressionModel, SklearnLikeRegressor):
+ """Meta-model using sklearn's MultiOutputRegressor."""
+
+ SCHEMA = MultiOutputRegressionSchema
+
+ COMPATIBLE_COMPONENTS = ["RegressionTask"]
+
+ def __init__(
+ self,
+ base_estimator: str = "linear",
+ base_params: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Parameters
+ ----------
+ base_estimator : str
+ Identifier of the base estimator. Supported: "linear", "ridge",
+ "random_forest"
+ base_params : dict, optional
+ Keyword args to forward to the base estimator constructor.
+ kwargs : dict
+ Extra args (kept for compatibility with existing infrastructure).
+ """
+ super().__init__(**kwargs)
+
+ if base_params is None:
+ base_params = {}
+
+ # Map string identifiers to sklearn estimators (you can extend with more).
+ estimators = {
+ "linear": _LinearRegression,
+ "ridge": _Ridge,
+ "random_forest": _RandomForestRegressor,
+ }
+
+ if base_estimator not in estimators:
+ raise ValueError(
+ f"Unknown base_estimator '{base_estimator}'. "
+ f"Supported: {list(estimators.keys())}"
+ )
+
+ base_cls = estimators[base_estimator]
+ base_instance = base_cls(**base_params)
+
+ # The actual sklearn model we will fit/predict with
+ self.sklearn_model = MultiOutputRegressor(base_instance)
+
+ # If SklearnLikeRegressor expects certain attributes/methods, adapt accordingly.
+ # We implement fit/predict here to be explicit.
+
+ def fit(self, x_train: DashAIDataset, y_train: DashAIDataset, **fit_params):
+ """
+ Fit the multioutput regressor.
+ x_train: DashAIDataset with input features
+ y_train: DashAIDataset with output targets
+ """
+ import numpy as np
+
+ # CRITICAL: Convert DashAI datasets to pandas first
+ x_pandas = x_train.to_pandas()
+ y_pandas = y_train.to_pandas()
+
+ # Convert pandas to numpy arrays
+ X = np.asarray(x_pandas)
+ y = np.asarray(y_pandas)
+
+ # KEY FIX: Ensure y is 2D for MultiOutputRegressor
+ # sklearn's MultiOutputRegressor requires y to have at least 2 dimensions
+ if y.ndim == 1:
+ print(
+ f"[MultiOutputRegression] Converting 1D y (shape {y.shape}) "
+ f"to 2D for multi-output regression"
+ )
+ y = y.reshape(-1, 1)
+
+ print(
+ f"[MultiOutputRegression] Training with X shape: {X.shape}, "
+ f"y shape: {y.shape}"
+ )
+ print(f"[MultiOutputRegression] X columns: {list(x_pandas.columns)}")
+ print(f"[MultiOutputRegression] y columns: {list(y_pandas.columns)}")
+
+ # Now this will work with both 1D and 2D y arrays
+ self.sklearn_model.fit(X, y, **fit_params)
+ return self
+
+ def predict(self, x_pred: DashAIDataset):
+ """
+ Predict multi-output targets.
+ x_pred: DashAIDataset with input features
+ Returns array shape (n_samples, n_outputs)
+ """
+ import numpy as np
+
+ # CRITICAL: Convert DashAI dataset to pandas first (same as fit method)
+ x_pandas = x_pred.to_pandas()
+
+ # Convert pandas to numpy array
+ X = np.asarray(x_pandas)
+
+ print(f"[MultiOutputRegression] Predicting with X shape: {X.shape}")
+
+ # Now this will work with clean numpy array
+ return self.sklearn_model.predict(X)
+
+ # If DashAI base classes expect `save` and `load`, SklearnLikeRegressor
+ # if not, you should rely on the SklearnLikeRegressor implementations. If necessary,
+ # override save/load following the project's conventions.
diff --git a/DashAI/back/optimizers/base_optimizer.py b/DashAI/back/optimizers/base_optimizer.py
index 7390c791b..15582d391 100644
--- a/DashAI/back/optimizers/base_optimizer.py
+++ b/DashAI/back/optimizers/base_optimizer.py
@@ -357,7 +357,7 @@ def importance_plot(self, trials, goal_metric):
)
sorted_items = sorted(importances.items(), key=lambda item: item[1])
- param_names, importance_values = zip(*sorted_items)
+ param_names, importance_values = zip(*sorted_items, strict=False)
fig = go.Figure(
data=[
go.Bar(
diff --git a/DashAI/back/tasks/forecasting_task.py b/DashAI/back/tasks/forecasting_task.py
new file mode 100644
index 000000000..3b000a819
--- /dev/null
+++ b/DashAI/back/tasks/forecasting_task.py
@@ -0,0 +1,597 @@
+"""Forecasting Task for time series prediction in DashAI.
+
+This task enables native time series forecasting with models like Prophet,
+as well as tabular approaches using TimeSeriesWindowConverter.
+"""
+
+from typing import Any, Dict, List, Optional, Union
+
+import pandas as pd
+from datasets import DatasetDict, Value
+
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ to_dashai_dataset,
+)
+from DashAI.back.tasks.base_task import BaseTask
+
+
+class ForecastingTask(BaseTask):
+ """Task for time series forecasting.
+
+ This task handles two main forecasting approaches:
+ 1. Native forecasting (Prophet, ARIMA): Uses ds (datetime) + y + optional exogenous
+ 2. Tabular forecasting: Uses TimeSeriesWindowConverter + regression models
+
+ Key differences from RegressionTask:
+ - Requires temporal column (ds) and proper time ordering
+ - Uses causal splits (no shuffle) to respect temporal causality
+ - Supports forecasting-specific metrics (MAPE, sMAPE, MASE)
+ - Native models can predict variable horizons
+ """
+
+ DESCRIPTION: str = """
+ Time series forecasting predicts future values based on historical patterns.
+ Supports both native forecasting models (Prophet) that work directly with
+ timestamps and target values, and tabular approaches that convert time series
+ into supervised learning problems using lag features and future windows.
+ """
+
+ metadata = {
+ "inputs_types": [Value], # ds (datetime) + optional exogenous variables
+ "outputs_types": [Value], # y (target time series)
+ "inputs_cardinality": "n", # ds + optional exogenous features
+ "outputs_cardinality": 1, # Single target series
+ }
+
+ def __init__(self):
+ """Initialize ForecastingTask."""
+ super().__init__()
+ self._temporal_metadata: Optional[Dict[str, Any]] = None
+
+ def validate_dataset_for_task(
+ self,
+ dataset: DashAIDataset,
+ dataset_name: str,
+ input_columns: List[str],
+ output_columns: List[str],
+ ) -> None:
+ """Validate a dataset for forecasting task."""
+
+ print("\n🔍 VALIDATE_DATASET_FOR_TASK INICIO")
+ print(f"📄 Dataset: {dataset_name}")
+ print(f"📥 Input columns: {input_columns}")
+ print(f"📤 Output columns: {output_columns}")
+
+ metadata = self.metadata
+ allowed_input_types = tuple(metadata["inputs_types"])
+ allowed_output_types = tuple(metadata["outputs_types"])
+
+ # 🔍 DEBUG: Print full metadata
+ print("\n📐 Metadata:")
+ print(f" - allowed_input_types: {allowed_input_types}")
+ print(f" - allowed_output_types: {allowed_output_types}")
+ print(f" - input_cardinality: {metadata.get('inputs_cardinality')}")
+ print(f" - output_cardinality: {metadata.get('outputs_cardinality')}")
+
+ # Validate cardinality
+ if len(input_columns) < 1:
+ raise ValueError(
+ "ForecastingTask requires at least 1 input column.\n"
+ "Include a timestamp and optional exogenous variables."
+ )
+
+ if len(output_columns) != 1:
+ raise ValueError(
+ "ForecastingTask requires exactly 1 output column "
+ f"(target to forecast). Got: {len(output_columns)} outputs."
+ )
+
+ dataset_df = dataset.to_pandas()
+ if not isinstance(dataset_df, pd.DataFrame):
+ dataset_df = pd.concat(dataset_df, ignore_index=True)
+
+ # 🔬 Revisar tipos de columnas en dataset.features
+ print("\n🧪 DEBUG: Column types from dataset.features")
+ for col_name, col_type in dataset.features.items():
+ print(f" - {col_name}: {col_type} ({type(col_type)})")
+
+ # Validate all input columns exist and have correct types
+ timestamp_found = False
+ detected_timestamp = None
+
+ for input_col in input_columns:
+ if input_col not in dataset.features:
+ raise ValueError(
+ f"Input column '{input_col}' not found in dataset. "
+ f"Available columns: {list(dataset.features.keys())}"
+ )
+
+ input_col_type = dataset.features[input_col]
+
+ # Print individual type check
+ print(
+ f"🔍 Checking input '{input_col}' type: {input_col_type} "
+ f"({type(input_col_type)})"
+ )
+
+ if not isinstance(input_col_type, allowed_input_types):
+ print("❌ Input column type mismatch")
+ raise TypeError(
+ f"Input column '{input_col}' has type "
+ f"{type(input_col_type).__name__}, but expected one of: "
+ f"{allowed_input_types}."
+ )
+
+ # Try to detect if it's the timestamp
+ if not timestamp_found:
+ try:
+ pd.to_datetime(dataset_df[input_col])
+ timestamp_found = True
+ detected_timestamp = input_col
+ print(f"✅ Detected timestamp column: '{input_col}'")
+ except Exception:
+ pass
+
+ if not timestamp_found:
+ raise ValueError(
+ "No timestamp column detected in input columns. "
+ "ForecastingTask requires a datetime column for temporal ordering. "
+ f"Checked columns: {input_columns}"
+ )
+
+ # OUTPUT VALIDATION
+ output_col = output_columns[0]
+ if output_col not in dataset.features:
+ raise ValueError(
+ f"Output column '{output_col}' not found in dataset. "
+ f"Available: {list(dataset.features.keys())}"
+ )
+
+ output_col_type = dataset.features[output_col]
+ print(
+ f"\n🔍 Checking output '{output_col}' type: {output_col_type} "
+ f"({type(output_col_type)})"
+ )
+
+ if not isinstance(output_col_type, allowed_output_types):
+ print("❌ Output column type mismatch")
+ raise TypeError(
+ f"Output column '{output_col}' has type "
+ f"{type(output_col_type).__name__}, but expected one of: "
+ f"{allowed_output_types}."
+ )
+
+ # Validate target column
+ try:
+ pd.to_numeric(dataset_df[output_col])
+ print(f"✅ Target column '{output_col}' is numeric")
+ except Exception as e:
+ raise TypeError(
+ f"Output column '{output_col}' cannot be converted to numeric: {e}"
+ ) from e
+
+ # Check minimum data points
+ if len(dataset) < 5:
+ raise ValueError(
+ f"Dataset '{dataset_name}' has only {len(dataset)} rows. "
+ "Minimum 5 rows required for forecasting."
+ )
+
+ # ✅ VALIDATION PASSED
+ print("✅ ForecastingTask validation PASSED:")
+ print(f" - Inputs: {input_columns} (timestamp: {detected_timestamp})")
+ print(f" - Output: {output_col}")
+ print(f" - Total rows: {len(dataset)}\n")
+
+ @property
+ def schema(self) -> Dict[str, Any]:
+ """Get the schema for ForecastingTask."""
+ return {
+ "type": "object",
+ "properties": {
+ "timestamp_column": {
+ "type": "string",
+ "description": (
+ "Name of the datetime column (will be renamed to 'ds')"
+ ),
+ },
+ "target_column": {
+ "type": "string",
+ "description": (
+ "Name of the target time series column (will be renamed to 'y')"
+ ),
+ },
+ "exogenous_columns": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": (
+ "Optional exogenous variables (holidays, weather, etc.)"
+ ),
+ "default": [],
+ },
+ "frequency": {
+ "type": "string",
+ "description": (
+ "Time series frequency (D, H, M, etc.). "
+ "Auto-detected if not specified"
+ ),
+ "default": "auto",
+ },
+ },
+ "required": ["timestamp_column", "target_column"],
+ }
+
+ def validate_temporal_data(
+ self,
+ dataset: DashAIDataset,
+ timestamp_col: str,
+ target_col: str,
+ exog_cols: Optional[List[str]] = None,
+ ) -> None:
+ """Validate that the dataset is suitable for forecasting.
+
+ Parameters
+ ----------
+ dataset : DashAIDataset
+ Dataset to validate
+ timestamp_col : str
+ Name of timestamp column
+ target_col : str
+ Name of target column
+ exog_cols : Optional[List[str]]
+ Names of exogenous columns
+
+ Raises
+ ------
+ ValueError
+ If dataset is not suitable for forecasting
+ """
+ if exog_cols is None:
+ exog_cols = []
+
+ # Check required columns exist
+ available_cols = set(dataset.column_names)
+
+ if timestamp_col not in available_cols:
+ raise ValueError(
+ f"Timestamp column '{timestamp_col}' not found in dataset. "
+ f"Available columns: {list(available_cols)}"
+ )
+
+ if target_col not in available_cols:
+ raise ValueError(
+ f"Target column '{target_col}' not found in dataset. "
+ f"Available columns: {list(available_cols)}"
+ )
+
+ missing_exog = set(exog_cols) - available_cols
+ if missing_exog:
+ raise ValueError(
+ f"Exogenous columns not found: {list(missing_exog)}. "
+ f"Available columns: {list(available_cols)}"
+ )
+
+ # Convert to pandas for validation
+ dataset_df = dataset.to_pandas() # type: ignore
+ if not isinstance(dataset_df, pd.DataFrame):
+ dataset_df = pd.concat(dataset_df, ignore_index=True)
+
+ # Validate timestamp column can be converted to datetime
+ try:
+ timestamp_series = pd.to_datetime(dataset_df[timestamp_col])
+ except Exception as e:
+ raise ValueError(
+ f"Cannot convert timestamp column '{timestamp_col}' to datetime: {e}"
+ ) from e
+
+ # Check for duplicate timestamps
+ if timestamp_series.duplicated().any():
+ duplicates = timestamp_series[timestamp_series.duplicated()].unique()
+ raise ValueError(
+ f"Found duplicate timestamps in '{timestamp_col}': "
+ f"{duplicates[:5].tolist()}{'...' if len(duplicates) > 5 else ''}"
+ )
+
+ # Validate target is numeric
+ try:
+ target_series = pd.to_numeric(dataset_df[target_col])
+ except Exception as e:
+ raise ValueError(
+ f"Target column '{target_col}' must be numeric: {e}"
+ ) from e
+
+ # Check for too many missing values in target
+ missing_pct = target_series.isna().mean()
+ if missing_pct > 0.5:
+ raise ValueError(
+ f"Target column '{target_col}' has {missing_pct:.1%} missing values. "
+ "Maximum allowed is 50%."
+ )
+
+ # Minimum data points check
+ if len(dataset_df) < 5:
+ raise ValueError(
+ f"Dataset has only {len(dataset_df)} rows. "
+ "Minimum 5 data points required for forecasting."
+ )
+
+ print(
+ f"✅ Validation passed: {len(dataset_df)} data points, "
+ f"timestamp range: {timestamp_series.min()} to {timestamp_series.max()}"
+ )
+
+ def detect_frequency(self, timestamp_series: pd.Series) -> str:
+ """Auto-detect time series frequency.
+
+ Parameters
+ ----------
+ timestamp_series : pd.Series
+ Datetime series
+
+ Returns
+ -------
+ str
+ Detected frequency code (D, H, M, etc.)
+ """
+ try:
+ # Sort timestamps and calculate differences
+ sorted_ts = timestamp_series.sort_values()
+ diffs = sorted_ts.diff().dropna()
+
+ # Get most common difference
+ mode_diff = (
+ diffs.mode().iloc[0] if len(diffs.mode()) > 0 else diffs.median()
+ )
+
+ # Map to pandas frequency codes
+ if mode_diff >= pd.Timedelta(days=365): # type: ignore
+ return "A" # Annual
+ elif mode_diff >= pd.Timedelta(days=30): # type: ignore
+ return "M" # Monthly
+ elif mode_diff >= pd.Timedelta(days=7): # type: ignore
+ return "W" # Weekly
+ elif mode_diff >= pd.Timedelta(days=1): # type: ignore
+ return "D" # Daily
+ elif mode_diff >= pd.Timedelta(hours=1): # type: ignore
+ return "H" # Hourly
+ else:
+ return "T" # Minute
+
+ except Exception:
+ # Fallback to daily
+ return "D"
+
+ def detect_timestamp_column(
+ self, dataset: DashAIDataset, candidate_columns: List[str]
+ ) -> Optional[str]:
+ """Auto-detect which column is the timestamp from a list of candidates.
+
+ Parameters
+ ----------
+ dataset : DashAIDataset
+ Dataset to analyze
+ candidate_columns : List[str]
+ List of column names to check
+
+ Returns
+ -------
+ Optional[str]
+ Name of detected timestamp column, or None if not found
+ """
+ # Convert to pandas for analysis
+ dataset_df = dataset.to_pandas() # type: ignore
+ if not isinstance(dataset_df, pd.DataFrame):
+ dataset_df = pd.concat(dataset_df, ignore_index=True)
+
+ # Strategy 1: Check by column name
+ for col in candidate_columns:
+ col_lower = col.lower()
+ if any(
+ keyword in col_lower
+ for keyword in [
+ "date",
+ "time",
+ "timestamp",
+ "ds",
+ "datetime",
+ "fecha",
+ ]
+ ):
+ # Verify it can be converted to datetime
+ try:
+ pd.to_datetime(dataset_df[col])
+ return col
+ except Exception:
+ continue
+
+ # Strategy 2: Try to convert each column to datetime
+ for col in candidate_columns:
+ try:
+ pd.to_datetime(dataset_df[col])
+ return col
+ except Exception:
+ continue
+
+ return None
+
+ def prepare_for_task(
+ self,
+ dataset: Optional[Union[DatasetDict, DashAIDataset]] = None,
+ outputs_columns: Optional[List[str]] = None,
+ inputs_columns: Optional[List[str]] = None,
+ **kwargs,
+ ) -> DashAIDataset:
+ """Prepare dataset for forecasting task.
+
+ Cambios mínimos:
+ - Acepta `datasetdict` (alias que usa experiments.py).
+ - Si no vienen `inputs_columns` ni `timestamp_column`, intenta
+ detectar el timestamp usando todos los nombres de columnas.
+ """
+ # --- Soporte para alias `datasetdict` usado por experiments.py ---
+ if dataset is None and "datasetdict" in kwargs:
+ dataset = kwargs.pop("datasetdict")
+ if inputs_columns is None and "input_columns" in kwargs:
+ inputs_columns = kwargs.pop("input_columns")
+ if outputs_columns is None and "output_columns" in kwargs:
+ outputs_columns = kwargs.pop("output_columns")
+
+ # Convertir a DashAIDataset si viene como DatasetDict
+ if isinstance(dataset, DatasetDict):
+ split_name = "train" if "train" in dataset else list(dataset.keys())[0]
+ dashai_dataset = to_dashai_dataset(dataset[split_name])
+ elif dataset is not None:
+ dashai_dataset = dataset
+ else:
+ raise ValueError("dataset parameter is required for prepare_for_task")
+
+ # Validaciones básicas de parámetros
+ if not outputs_columns or len(outputs_columns) != 1:
+ raise ValueError(
+ "ForecastingTask requires exactly 1 output column (target variable). "
+ f"Got {len(outputs_columns) if outputs_columns else 0} columns."
+ )
+ target_col = outputs_columns[0]
+
+ # Obtener o detectar columna timestamp
+ timestamp_col = kwargs.get("timestamp_column")
+ if not timestamp_col:
+ # Si no nos dan inputs_columns, intentamos con TODAS las columnas
+ candidate_inputs = (
+ inputs_columns if inputs_columns else list(dashai_dataset.column_names)
+ )
+ timestamp_col = self.detect_timestamp_column(
+ dashai_dataset, candidate_inputs
+ )
+ if not timestamp_col:
+ raise ValueError(
+ "Could not auto-detect timestamp column. "
+ "Provide `timestamp_column` o incluya una columna con fecha/tiempo "
+ "('date', 'timestamp', 'ds', 'datetime', etc.)."
+ )
+ print(f"🔍 Auto-detected timestamp column: '{timestamp_col}'")
+
+ # Exógenas: si no vienen inputs, por defecto ninguna
+ if inputs_columns:
+ exog_cols = [c for c in inputs_columns if c != timestamp_col]
+ else:
+ exog_cols = kwargs.get("exogenous_columns", [])
+
+ frequency = kwargs.get("frequency", "auto")
+
+ # Validar datos
+ self.validate_temporal_data(
+ dashai_dataset, timestamp_col, target_col, exog_cols
+ )
+
+ # Procesamiento pandas
+ dataset_df = dashai_dataset.to_pandas() # type: ignore
+ if not isinstance(dataset_df, pd.DataFrame):
+ dataset_df = pd.concat(dataset_df, ignore_index=True)
+
+ # NO renombrar columnas - mantener nombres originales
+ # El modelo (ej: Prophet) hará el renombramiento si lo necesita
+
+ # Orden temporal
+ # If the timestamp column is numeric (int/float), treat values as
+ # sequential time-step indices rather than nanosecond epoch offsets.
+ if pd.api.types.is_integer_dtype(
+ dataset_df[timestamp_col]
+ ) or pd.api.types.is_float_dtype(dataset_df[timestamp_col]):
+ base_date = pd.Timestamp("2000-01-01")
+ step_vals = dataset_df[timestamp_col]
+ min_val = step_vals.min()
+ dataset_df[timestamp_col] = base_date + pd.to_timedelta(
+ (step_vals - min_val).astype(int), unit="D"
+ )
+ print(
+ f"ℹ️ Column '{timestamp_col}' contains numeric values — "
+ f"converted to day offsets starting from {base_date.date()}"
+ )
+ else:
+ dataset_df[timestamp_col] = pd.to_datetime(dataset_df[timestamp_col])
+ dataset_df = dataset_df.sort_values(timestamp_col).reset_index(drop=True)
+
+ # Frecuencia
+ if frequency == "auto":
+ frequency = self.detect_frequency(dataset_df[timestamp_col])
+
+ # Guardar metadatos con nombres ORIGINALES
+ self._temporal_metadata = {
+ "timestamp_col": timestamp_col,
+ "target_col": target_col,
+ "exog_cols": exog_cols,
+ "frequency": frequency,
+ "start_date": dataset_df[timestamp_col].min(),
+ "end_date": dataset_df[timestamp_col].max(),
+ "n_periods": len(dataset_df),
+ }
+
+ print("✅ Prepared forecasting dataset:")
+ print(f" - Timestamp: {timestamp_col}")
+ print(f" - Target: {target_col}")
+ print(f" - Frequency: {frequency}")
+ print(f" - Periods: {len(dataset_df)}")
+ if exog_cols:
+ print(f" - Exogenous vars: {', '.join(exog_cols)}")
+
+ # Volver a DashAIDataset
+ from datasets import Dataset
+
+ hf_dataset = Dataset.from_pandas(dataset_df)
+ return to_dashai_dataset(hf_dataset)
+
+ def process_predictions(
+ self, dataset: DashAIDataset, predictions: Any, target_column: str
+ ) -> Any:
+ """Process forecasting predictions.
+
+ For forecasting, predictions can be:
+ - Simple array of values (point forecasts)
+ - DataFrame with ds, yhat, yhat_lower, yhat_upper (Prophet style)
+ - Dictionary with forecasts and confidence intervals
+
+ Parameters
+ ----------
+ dataset : DashAIDataset
+ Original dataset
+ predictions : Any
+ Model predictions
+ target_column : str
+ Target column name
+
+ Returns
+ -------
+ Any
+ Processed predictions
+ """
+ # If predictions is a DataFrame (Prophet style), extract yhat
+ if hasattr(predictions, "yhat"):
+ return predictions["yhat"].to_numpy()
+
+ # If it's already an array, return as-is
+ if hasattr(predictions, "shape"):
+ return predictions
+
+ # Handle list/tuple
+ if isinstance(predictions, (list, tuple)):
+ import numpy as np
+
+ return np.array(predictions)
+
+ return predictions
+
+ def num_labels(self, dataset: DashAIDataset, output_column: str) -> None:
+ """Return None — forecasting predicts continuous values, not discrete labels."""
+ return None
+
+ def get_temporal_metadata(self) -> Optional[Dict[str, Any]]:
+ """Get temporal metadata from the last prepare_for_task call.
+
+ Returns
+ -------
+ Optional[Dict[str, Any]]
+ Temporal metadata including frequency, date range, etc.
+ """
+ return self._temporal_metadata
diff --git a/DashAI/back/tasks/multi_output_regression_task.py b/DashAI/back/tasks/multi_output_regression_task.py
new file mode 100644
index 000000000..005cffda3
--- /dev/null
+++ b/DashAI/back/tasks/multi_output_regression_task.py
@@ -0,0 +1,87 @@
+from typing import List
+
+from datasets import DatasetDict, Value
+
+from DashAI.back.dataloaders.classes.dashai_dataset import (
+ DashAIDataset,
+ to_dashai_dataset,
+)
+from DashAI.back.tasks.base_task import BaseTask
+
+
+class MultiOutputRegressionTask(BaseTask):
+ """Task for handling multi-output regression problems.
+
+ Multi-output regression involves predicting multiple continuous outputs
+ for each input sample. This task sets up the necessary metadata and
+ processing functions to support training models that generate multiple
+ outputs per sample.
+ """
+
+ DESCRIPTION: str = """
+ Multi-output regression extends standard regression by predicting more
+ than one continuous value per input instance. Each output dimension is
+ treated as a separate regression target, and models can be trained to
+ jointly predict all outputs, capturing correlations between them.
+ """
+
+ metadata = {
+ "inputs_types": [Value],
+ "outputs_types": [Value],
+ "inputs_cardinality": "n",
+ "outputs_cardinality": "n",
+ }
+
+ def prepare_for_task(
+ self, datasetdict: DatasetDict, outputs_columns: List[str]
+ ) -> DashAIDataset:
+ """Change the column types to suit the multi-output regression task.
+
+ Parameters
+ ----------
+ datasetdict : DatasetDict
+ Dataset to be changed
+ outputs_columns : List[str]
+ Output columns for the task
+
+ Returns
+ -------
+ DashAIDataset
+ Dataset with the new types
+ """
+ return to_dashai_dataset(datasetdict)
+
+ def process_predictions(self, dataset, predictions, output_column):
+ """
+ Process predictions for multi-output regression.
+
+ For multi-output regression, we return the predictions as-is since they
+ are already in the correct format (n_samples, n_outputs) from sklearn.
+
+ Parameters
+ ----------
+ dataset : DashAIDataset
+ The original dataset.
+ predictions : np.ndarray
+ Array 2D with predictions. Shape: (n_samples, n_outputs)
+ output_column : str
+ Not used directly for multi-output regression.
+
+ Returns
+ -------
+ np.ndarray
+ The predictions array as-is for compatibility with DashAI
+ prediction pipeline.
+ """
+ # For multi-output, predictions are already in correct format
+ # Shape should be (n_samples, n_outputs)
+ print(
+ f"[MultiOutputRegressionTask] Processing predictions with "
+ f"shape: {predictions.shape}"
+ )
+
+ # Ensure predictions are 2D (which they should be from MultiOutputRegressor)
+ if predictions.ndim == 1:
+ predictions = predictions.reshape(-1, 1)
+
+ return predictions
diff --git a/DashAI/front/src/api/datasets.ts b/DashAI/front/src/api/datasets.ts
index 281ac1627..a0b78c601 100644
--- a/DashAI/front/src/api/datasets.ts
+++ b/DashAI/front/src/api/datasets.ts
@@ -101,6 +101,27 @@ export const deleteDataset = async (id: string): Promise