diff --git a/.gitignore b/.gitignore index 1263f60e0..fa6c98c39 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,6 @@ job_queue.db-wal db.sqlite trained_models/ + + +test_forecasting_models.ipynb diff --git a/DashAI/back/api/api_v1/endpoints/datasets.py b/DashAI/back/api/api_v1/endpoints/datasets.py index 562bea733..60677d462 100644 --- a/DashAI/back/api/api_v1/endpoints/datasets.py +++ b/DashAI/back/api/api_v1/endpoints/datasets.py @@ -630,6 +630,189 @@ async def get_info( return info +@router.get("/{dataset_id}/temporal-info") +@inject +async def get_temporal_info( + dataset_id: int, + timestamp_column: str, + session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]), +): + """Get temporal information about a dataset for forecasting tasks. + + This endpoint analyzes a timestamp column to detect frequency, date range, + and other temporal characteristics useful for time series forecasting. + + Parameters + ---------- + dataset_id : int + ID of the dataset to analyze. + timestamp_column : str + Name of the column containing timestamps. + + Returns + ------- + dict + Dictionary with temporal information including: + - frequency_code: Short code (D, H, M, W, A, T) + - frequency_label: Human-readable label + - frequency_description: Detailed description + - start_date: First timestamp in the series + - end_date: Last timestamp in the series + - total_periods: Number of data points + - detected_gaps: Number of missing periods detected + """ + import os + + import pandas as pd + import pyarrow.ipc as ipc + + from DashAI.back.core.enums.status import DatasetStatus + + with session_factory() as db: + try: + dataset = db.get(Dataset, dataset_id) + if not dataset: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Dataset not found", + ) + + if dataset.status != DatasetStatus.FINISHED: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Dataset is not in finished state", + ) + + # Load the dataset + dataset_path = f"{dataset.file_path}/dataset" + data_filepath = os.path.join(dataset_path, "data.arrow") + + with pa.OSFile(data_filepath, "rb") as source: + reader = ipc.open_file(source) + table = reader.read_all() + + data_frame = table.to_pandas() + + if timestamp_column not in data_frame.columns: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Column '{timestamp_column}' not found in dataset", + ) + + # Convert to datetime + try: + timestamps = pd.to_datetime(data_frame[timestamp_column]) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Cannot parse '{timestamp_column}' as datetime: {str(e)}", + ) from e + + # Sort and analyze + sorted_ts = timestamps.sort_values() + diffs = sorted_ts.diff().dropna() + + if len(diffs) == 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Not enough data points to detect frequency", + ) + + # Get most common difference (mode) + mode_diff = ( + diffs.mode().iloc[0] if len(diffs.mode()) > 0 else diffs.median() + ) + + # Frequency mapping with detailed info + frequency_map = { + "T": { + "code": "T", + "label": "Minutely", + "description": "Each row represents one minute", + "example": "e.g., 10:00, 10:01, 10:02...", + }, + "H": { + "code": "H", + "label": "Hourly", + "description": "Each row represents one hour", + "example": "e.g., 10:00, 11:00, 12:00...", + }, + "D": { + "code": "D", + "label": "Daily", + "description": "Each row represents one day", + "example": "e.g., Jan 1, Jan 2, Jan 3...", + }, + "W": { + "code": "W", + "label": "Weekly", + "description": "Each row represents one week", + "example": "e.g., Week 1, Week 2, Week 3...", + }, + "M": { + "code": "M", + "label": "Monthly", + "description": "Each row represents one month", + "example": "e.g., Jan, Feb, Mar...", + }, + "A": { + "code": "A", + "label": "Yearly", + "description": "Each row represents one year", + "example": "e.g., 2022, 2023, 2024...", + }, + } + + # Detect frequency + if mode_diff >= pd.Timedelta(days=365): + freq_code = "A" + elif mode_diff >= pd.Timedelta(days=28): + freq_code = "M" + elif mode_diff >= pd.Timedelta(days=7): + freq_code = "W" + elif mode_diff >= pd.Timedelta(days=1): + freq_code = "D" + elif mode_diff >= pd.Timedelta(hours=1): + freq_code = "H" + else: + freq_code = "T" + + freq_info = frequency_map[freq_code] + + # Calculate average difference in human-readable format + avg_diff = diffs.mean() + if avg_diff >= pd.Timedelta(days=1): + avg_diff_str = f"{avg_diff.days} days" + elif avg_diff >= pd.Timedelta(hours=1): + avg_diff_str = f"{avg_diff.seconds // 3600} hours" + else: + avg_diff_str = f"{avg_diff.seconds // 60} minutes" + + # Detect gaps (periods where diff is significantly larger than mode) + gap_threshold = mode_diff * 1.5 + gaps = (diffs > gap_threshold).sum() + + return { + "frequency_code": freq_info["code"], + "frequency_label": freq_info["label"], + "frequency_description": freq_info["description"], + "frequency_example": freq_info["example"], + "average_interval": avg_diff_str, + "start_date": sorted_ts.min().isoformat(), + "end_date": sorted_ts.max().isoformat(), + "total_periods": len(data_frame), + "detected_gaps": int(gaps), + "timestamp_column": timestamp_column, + } + + except exc.SQLAlchemyError as e: + logger.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal database error", + ) from e + + @router.get("/{dataset_id}/model-sessions-exist") @inject async def get_model_sessions_exist( @@ -1264,6 +1447,9 @@ async def get_dataset_file( col: sliced_batch[col][j].as_py() for col in sliced_batch.schema.names } + # Use jsonable_encoder to handle Timestamp and other + # non-JSON-serializable types + row = jsonable_encoder(row) rows.append(row) rows_collected += 1 if rows_collected >= page_size: diff --git a/DashAI/back/api/api_v1/endpoints/explainers.py b/DashAI/back/api/api_v1/endpoints/explainers.py index a430ef45d..5ace40601 100755 --- a/DashAI/back/api/api_v1/endpoints/explainers.py +++ b/DashAI/back/api/api_v1/endpoints/explainers.py @@ -241,8 +241,24 @@ async def upload_global_explainer( status_code=status.HTTP_404_NOT_FOUND, detail="Run not found" ) + # Check if name exists and append suffix if needed + base_name = params.name + counter = 1 + new_name = base_name + + while True: + existing = db.scalars( + select(GlobalExplainer).where(GlobalExplainer.name == new_name) + ).first() + + if not existing: + break + + counter += 1 + new_name = f"{base_name}_{counter}" + explainer = GlobalExplainer( - name=params.name, + name=new_name, run_id=params.run_id, explainer_name=params.explainer_name, parameters=params.parameters, diff --git a/DashAI/back/api/api_v1/endpoints/jobs.py b/DashAI/back/api/api_v1/endpoints/jobs.py index 715a24783..e39676a1b 100644 --- a/DashAI/back/api/api_v1/endpoints/jobs.py +++ b/DashAI/back/api/api_v1/endpoints/jobs.py @@ -164,6 +164,8 @@ async def get_job_details( @router.post("/", status_code=status.HTTP_201_CREATED) +@router.post("/start", status_code=status.HTTP_201_CREATED) +@router.post("/start/", status_code=status.HTTP_201_CREATED) @inject async def enqueue_job( request: Request, diff --git a/DashAI/back/api/api_v1/endpoints/model_sessions.py b/DashAI/back/api/api_v1/endpoints/model_sessions.py index ae48f8fdb..6bc3b7d80 100644 --- a/DashAI/back/api/api_v1/endpoints/model_sessions.py +++ b/DashAI/back/api/api_v1/endpoints/model_sessions.py @@ -157,6 +157,15 @@ async def validate_columns( validation_response = {} try: + # For ForecastingTask, validate BEFORE prepare_for_task + # (column names change after prepare, so validate with originals) + if params.task_name == "ForecastingTask": + task.validate_dataset_for_task( + dataset=minimal_dataset, + dataset_name=dataset.name, + input_columns=inputs_names, + output_columns=outputs_names, + ) task.prepare_for_task( dataset=minimal_dataset, input_columns=inputs_names, diff --git a/DashAI/back/api/api_v1/endpoints/predict.py b/DashAI/back/api/api_v1/endpoints/predict.py index 188a14b02..b7aafc473 100644 --- a/DashAI/back/api/api_v1/endpoints/predict.py +++ b/DashAI/back/api/api_v1/endpoints/predict.py @@ -1,5 +1,6 @@ import json import logging +import math from typing import TYPE_CHECKING, Dict, List from fastapi import APIRouter, Depends, Query, Request, status @@ -21,6 +22,20 @@ from DashAI.back.dependencies.registry.component_registry import ComponentRegistry + +def sanitize_for_json(value): + """Convert NaN/Inf float values to None for JSON serialization.""" + if isinstance(value, dict): + return {k: sanitize_for_json(v) for k, v in value.items()} + elif isinstance(value, list): + return [sanitize_for_json(item) for item in value] + elif isinstance(value, float): + if math.isnan(value) or math.isinf(value): + return None + return value + return value + + logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) diff --git a/DashAI/back/api/api_v1/endpoints/runs.py b/DashAI/back/api/api_v1/endpoints/runs.py index 58ecc42e4..24202e1b4 100644 --- a/DashAI/back/api/api_v1/endpoints/runs.py +++ b/DashAI/back/api/api_v1/endpoints/runs.py @@ -327,22 +327,62 @@ async def delete_run( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Run not found" ) + + # Delete orphan-prone operations that are not linked by FK constraints. + global_explainers = ( + db.query(GlobalExplainer).filter(GlobalExplainer.run_id == run_id).all() + ) + for explainer in global_explainers: + if explainer.plot_path and os.path.exists(explainer.plot_path): + remove_path(explainer.plot_path) + if explainer.explanation_path and os.path.exists( + explainer.explanation_path + ): + remove_path(explainer.explanation_path) + db.delete(explainer) + + local_explainers = ( + db.query(LocalExplainer).filter(LocalExplainer.run_id == run_id).all() + ) + for explainer in local_explainers: + if explainer.plots_path and os.path.exists(explainer.plots_path): + remove_path(explainer.plots_path) + if explainer.explanation_path and os.path.exists( + explainer.explanation_path + ): + remove_path(explainer.explanation_path) + db.delete(explainer) + + for prediction in run.predictions: + if prediction.results_path and os.path.exists(prediction.results_path): + remove_path(prediction.results_path) + + for path in [ + run.run_path, + run.plot_history_path, + run.plot_slice_path, + run.plot_contour_path, + run.plot_importance_path, + ]: + if path and os.path.exists(path): + remove_path(path) + db.delete(run) - if run.status == RunStatus.FINISHED: - os.remove(run.run_path) db.commit() return Response(status_code=status.HTTP_204_NO_CONTENT) + except HTTPException: + raise except exc.SQLAlchemyError as e: log.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Internal database error", + detail=f"Internal database error: {e}", ) from e - except OSError as e: + except (OSError, ValueError) as e: log.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to delete directory", + detail=f"Failed to delete run resources: {e}", ) from e diff --git a/DashAI/back/api/api_v1/schemas/job_params.py b/DashAI/back/api/api_v1/schemas/job_params.py index 869df132e..a9944922f 100644 --- a/DashAI/back/api/api_v1/schemas/job_params.py +++ b/DashAI/back/api/api_v1/schemas/job_params.py @@ -8,6 +8,7 @@ class JobParams(BaseModel): job_type: Literal[ "ModelJob", + "ForecastingJob", "ExplainerJob", "PredictJob", "DatasetJob", diff --git a/DashAI/back/api/api_v1/schemas/predict_params.py b/DashAI/back/api/api_v1/schemas/predict_params.py index 92c1a78ac..4c8bc9ad9 100644 --- a/DashAI/back/api/api_v1/schemas/predict_params.py +++ b/DashAI/back/api/api_v1/schemas/predict_params.py @@ -1,8 +1,18 @@ -from pydantic import BaseModel +from typing import Optional + +from pydantic import BaseModel, Field class PredictParams(BaseModel): run_id: int + forecast_periods: Optional[int] = Field( + default=None, + description="Number of future periods to forecast (ForecastingTask only). " + "If provided, timestamps will be generated automatically from the last " + "training date.", + gt=0, + le=1000, + ) class RenameRequest(BaseModel): diff --git a/DashAI/back/converters/simple_converters/extend_time_series_converter.py b/DashAI/back/converters/simple_converters/extend_time_series_converter.py new file mode 100644 index 000000000..25c5b8956 --- /dev/null +++ b/DashAI/back/converters/simple_converters/extend_time_series_converter.py @@ -0,0 +1,466 @@ +""" +Extend Time Series Converter for DashAI. + +This converter extends a time series dataset by adding n future timestamps +with the same period as the original dataset. This is useful for preparing +datasets for forecasting predictions. +""" + +from typing import Union + +import pandas as pd + +from DashAI.back.converters.base_converter import BaseConverter +from DashAI.back.core.schema_fields import ( + int_field, + schema_field, + string_field, +) +from DashAI.back.core.schema_fields.base_schema import BaseSchema +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + to_dashai_dataset, +) + + +class ExtendTimeSeriesConverterSchema(BaseSchema): + """Schema for ExtendTimeSeriesConverter parameters.""" + + n_steps: schema_field( + int_field(ge=1, le=100000), + 1, + "Number of future time steps to add to the dataset (max: 100,000).", + ) # type: ignore + + time_column: schema_field( + string_field(), + "", + ( + "Name of the timestamp column to extend. " + "If empty, the converter will auto-detect datetime columns." + ), + ) # type: ignore + + +class ExtendTimeSeriesConverter(BaseConverter): + """ + Converter that extends a time series dataset with future timestamps. + + This converter adds n new rows to the dataset with timestamps that continue + the sequence from the last timestamp in the dataset. The frequency/period + is automatically inferred from the existing timestamps. + + All columns except the timestamp column will be filled with NaN values + in the new rows, as these are future values to be predicted. + + Example: + -------- + Original dataset: + date | y | exog1 + 2024-01-01 | 10.5 | 100 + 2024-01-02 | 11.2 | 105 + 2024-01-03 | 12.1 | 110 + + After extending with n_steps=2: + date | y | exog1 + 2024-01-01 | 10.5 | 100 + 2024-01-02 | 11.2 | 105 + 2024-01-03 | 12.1 | 110 + 2024-01-04 | NaN | NaN + 2024-01-05 | NaN | NaN + """ + + SCHEMA = ExtendTimeSeriesConverterSchema + DESCRIPTION = ( + "Extends a time series dataset by adding n future timestamps with the same " + "period as the original data. Other columns are filled with NaN values. " + "This is useful for preparing datasets for forecasting predictions." + ) + SHORT_DESCRIPTION = "Extends time series with n future timestamps for forecasting." + DISPLAY_NAME = "Extend Time Series Converter" + + # Maximum allowed n_steps to prevent memory issues + MAX_N_STEPS = 100000 + + def __init__(self, n_steps: int = 1, time_column: str = ""): + """Initialize the converter with schema parameters.""" + super().__init__() + self.n_steps = n_steps + self.time_column = time_column + + # Internal state + self._fitted = False + self._time_column_validated = "" + self._inferred_freq = None + + def _detect_datetime_columns(self, df: pd.DataFrame) -> list[str]: + """ + Detect columns with datetime or timestamp data types. + + Parameters + ---------- + df : pd.DataFrame + DataFrame to analyze + + Returns + ------- + list[str] + List of column names with datetime/timestamp types + """ + datetime_columns = [] + + for col in df.columns: + # Check if column dtype is datetime + if pd.api.types.is_datetime64_any_dtype(df[col]): + datetime_columns.append(col) + # Try to parse as datetime if it's object/string type + elif df[col].dtype == "object": + try: + # Try to convert a sample to datetime + pd.to_datetime(df[col].dropna().head(10), errors="raise") + datetime_columns.append(col) + except (ValueError, TypeError): + # Not a datetime column + pass + + return datetime_columns + + def _infer_frequency(self, time_series: pd.Series) -> pd.DateOffset: + """ + Infer the frequency/period of a datetime series. + + Parameters + ---------- + time_series : pd.Series + Series with datetime values + + Returns + ------- + pd.DateOffset + The inferred frequency + + Raises + ------ + ValueError + If frequency cannot be inferred + """ + # Ensure the series is sorted + time_series = time_series.sort_values().reset_index(drop=True) + + # Convert to datetime if not already + if not pd.api.types.is_datetime64_any_dtype(time_series): + time_series = pd.to_datetime(time_series) + + # Remove NaT values + time_series = time_series.dropna() + + # Need at least 2 points to infer frequency + if len(time_series) < 2: + raise ValueError( + "Need at least 2 timestamps to infer frequency. " + f"Found {len(time_series)} timestamps." + ) + + # Check for duplicate timestamps + duplicates = time_series.duplicated() + if duplicates.any(): + n_duplicates = duplicates.sum() + # Warning: we'll still try to infer, but user should know + import warnings + + warnings.warn( + f"Found {n_duplicates} duplicate timestamp(s) in the time series. " + "This may affect frequency inference.", + UserWarning, + stacklevel=2, + ) + # Remove duplicates for frequency inference + time_series = time_series.drop_duplicates() + + # Try using pandas infer_freq on unique sorted values + freq = pd.infer_freq(time_series) + if freq is not None: + return pd.tseries.frequencies.to_offset(freq) + + # If infer_freq fails, calculate the most common difference + diffs = time_series.diff().dropna() + + if len(diffs) == 0: + raise ValueError("Cannot infer frequency: no time differences found") + + # Filter out zero differences (duplicates that weren't caught) + diffs = diffs[diffs != pd.Timedelta(0)] + + if len(diffs) == 0: + raise ValueError( + "Cannot infer frequency: all timestamps are identical " + "after removing duplicates" + ) + + # Get the most common difference + most_common_diff = diffs.mode() + + if len(most_common_diff) == 0: + raise ValueError("Cannot infer frequency: no consistent time difference") + + # Check if the frequency is reasonably consistent + # If there's high variance, warn the user + diff_std = diffs.std() + diff_mean = diffs.mean() + + if diff_std / diff_mean > 0.5: # More than 50% coefficient of variation + import warnings + + warnings.warn( + f"Timestamps have irregular intervals " + f"(std/mean = {diff_std / diff_mean:.2f}). " + f"Using most common difference: {most_common_diff.iloc[0]}. " + "Results may not be accurate for irregular time series.", + UserWarning, + stacklevel=2, + ) + + # Return the most common difference as a Timedelta + return most_common_diff.iloc[0] + + def fit( + self, x: DashAIDataset, y: Union[DashAIDataset, None] = None + ) -> "ExtendTimeSeriesConverter": + """ + Fit the converter by validating parameters and detecting time column. + + Parameters + ---------- + x : DashAIDataset + Input dataset containing the time series data + y : DashAIDataset, optional + Not used in this converter + + Returns + ------- + ExtendTimeSeriesConverter + The fitted converter instance + + Raises + ------ + ValueError + If validation fails (missing time column, invalid parameters, etc.) + """ + # Validate parameters + if self.n_steps < 1: + raise ValueError("n_steps must be a positive integer") + + if self.n_steps > self.MAX_N_STEPS: + raise ValueError( + f"n_steps cannot exceed {self.MAX_N_STEPS} to prevent memory issues. " + f"Requested: {self.n_steps}" + ) + + # Convert to pandas for analysis + data_frame: pd.DataFrame = x.to_pandas() # type: ignore + + # Validate dataset is not empty + if len(data_frame) == 0: + raise ValueError( + "Cannot extend an empty dataset. " + "Please provide a dataset with at least 2 rows." + ) + + # Detect datetime columns + datetime_columns = self._detect_datetime_columns(data_frame) + + if len(datetime_columns) == 0: + raise ValueError( + "No datetime columns found in the dataset. " + "Please ensure your dataset has at least one timestamp column." + ) + + # Determine which time column to use + if self.time_column: + # User specified a time column + if self.time_column not in data_frame.columns: + raise ValueError( + f"Specified time column '{self.time_column}' not found in dataset. " + f"Available columns: {list(data_frame.columns)}" + ) + + if self.time_column not in datetime_columns: + # Try to convert it to datetime + try: + data_frame[self.time_column] = pd.to_datetime( + data_frame[self.time_column] + ) + self._time_column_validated = self.time_column + except (ValueError, TypeError) as e: + raise ValueError( + f"Column '{self.time_column}' cannot be converted " + f"to datetime: {e}" + ) from e + else: + self._time_column_validated = self.time_column + else: + # Auto-detect time column + if len(datetime_columns) > 1: + raise ValueError( + f"Multiple datetime columns found: {datetime_columns}. " + "Please specify which one to use with the 'time_column' parameter." + ) + self._time_column_validated = datetime_columns[0] + + # Infer the frequency + time_series = data_frame[self._time_column_validated] + + # Convert to datetime if needed + if not pd.api.types.is_datetime64_any_dtype(time_series): + time_series = pd.to_datetime(time_series) + + try: + self._inferred_freq = self._infer_frequency(time_series) + except ValueError as e: + raise ValueError( + f"Failed to infer frequency for time column " + f"'{self._time_column_validated}': {e}" + ) from e + + self._fitted = True + return self + + def transform( + self, x: DashAIDataset, y: Union[DashAIDataset, None] = None + ) -> DashAIDataset: + """ + Transform the dataset by adding n future timestamps. + + Parameters + ---------- + x : DashAIDataset + Input dataset to transform + y : DashAIDataset, optional + Not used in this converter + + Returns + ------- + DashAIDataset + Extended dataset with n additional rows containing future timestamps + + Raises + ------ + ValueError + If converter is not fitted or transformation fails + """ + if not self._fitted: + raise ValueError("Converter must be fitted before transform") + + # Convert to pandas + data_frame: pd.DataFrame = x.to_pandas() # type: ignore + + # Verify time column still exists + if self._time_column_validated not in data_frame.columns: + raise ValueError( + f"Time column '{self._time_column_validated}' not found " + f"in transform dataset" + ) + + # Convert time column to datetime if needed + if not pd.api.types.is_datetime64_any_dtype( + data_frame[self._time_column_validated] + ): + data_frame[self._time_column_validated] = pd.to_datetime( + data_frame[self._time_column_validated] + ) + + # Get the last timestamp + last_timestamp = data_frame[self._time_column_validated].max() + + # Validate last_timestamp is not NaT + if pd.isna(last_timestamp): + raise ValueError( + f"Cannot extend time series: all timestamps in column " + f"'{self._time_column_validated}' are NaT (Not a Time)" + ) + + # Generate future timestamps + future_timestamps = [] + current_timestamp = last_timestamp + + try: + for _i in range(self.n_steps): + current_timestamp = current_timestamp + self._inferred_freq + future_timestamps.append(current_timestamp) + except (OverflowError, ValueError) as e: + raise ValueError( + f"Error generating future timestamp at step " + f"{_i + 1}/{self.n_steps}: {e}. " + "This might be due to timestamp overflow or invalid frequency." + ) from e + + # Create new rows with future timestamps + future_rows = [] + for future_ts in future_timestamps: + # Create a row with NaN for all columns except timestamp + new_row = dict.fromkeys(data_frame.columns) + new_row[self._time_column_validated] = future_ts + future_rows.append(new_row) + + # Create DataFrame from future rows + future_df = pd.DataFrame(future_rows) + + # Ensure the timestamp column has the same dtype + future_df[self._time_column_validated] = pd.to_datetime( + future_df[self._time_column_validated] + ) + + # Align column order with original dataframe + future_df = future_df[data_frame.columns] + + # Preserve original data types as much as possible + # for non-timestamp columns + for col in data_frame.columns: + if col != self._time_column_validated: + # Try to maintain the original dtype + # (will be nullable version due to NaN) + try: + original_dtype = data_frame[col].dtype + # For numeric types, pandas will handle + # the NaN conversion automatically + if pd.api.types.is_numeric_dtype(original_dtype): + # Let pandas handle it naturally + # (int -> float for NaN compatibility) + pass + elif pd.api.types.is_datetime64_any_dtype(original_dtype): + future_df[col] = pd.to_datetime(future_df[col]) + except Exception: + # If conversion fails, keep as is + # (likely already None/NaN) + pass + + # Concatenate original data with future data + try: + extended_df = pd.concat([data_frame, future_df], ignore_index=True) + except Exception as e: + raise ValueError( + f"Error concatenating original and extended data: {e}. " + "This might be due to incompatible data types." + ) from e + + # Validate the extended dataframe + if len(extended_df) != len(data_frame) + self.n_steps: + raise ValueError( + f"Extended dataset has unexpected number of rows. " + f"Expected: {len(data_frame) + self.n_steps}, " + f"Got: {len(extended_df)}" + ) + + # Convert back to DashAIDataset + return to_dashai_dataset(extended_df) + + def changes_row_count(self) -> bool: + """ + Indicates that this converter changes the number of rows. + + Returns + ------- + bool + True, as new rows with future timestamps are added + """ + return True diff --git a/DashAI/back/converters/simple_converters/time_series_window_converter.py b/DashAI/back/converters/simple_converters/time_series_window_converter.py new file mode 100644 index 000000000..3a1dfee40 --- /dev/null +++ b/DashAI/back/converters/simple_converters/time_series_window_converter.py @@ -0,0 +1,233 @@ +""" +Time Series Window Converter for DashAI. + +This converter transforms time series data into a tabular regression format +by creating lag features and target columns with fixed horizons. +""" + +from typing import Union + +import pandas as pd + +from DashAI.back.converters.base_converter import BaseConverter +from DashAI.back.core.schema_fields import ( + int_field, + schema_field, + string_field, +) +from DashAI.back.core.schema_fields.base_schema import BaseSchema +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + to_dashai_dataset, +) + + +class TimeSeriesWindowConverterSchema(BaseSchema): + """Schema for TimeSeriesWindowConverter parameters.""" + + window_size: schema_field( + int_field(ge=1), + 7, + "Number of past time steps to use as lag features (window size).", + ) # type: ignore + + horizon: schema_field( + int_field(ge=1), + 1, + "Number of time steps into the future to predict (forecasting horizon).", + ) # type: ignore + + target_column: schema_field( + string_field(), + "", + "Name of the target column containing the time series values to forecast.", + ) # type: ignore + + +class TimeSeriesWindowConverter(BaseConverter): + """ + Converter that transforms time series data into a regression problem. + + This converter creates lag features (lag_1, lag_2, ..., lag_w) from a time series + and a target column shifted h steps into the future (y_target_h), where: + - w is the window_size parameter + - h is the horizon parameter + + The resulting dataset can be used with standard regression models to perform + forecasting as a supervised learning problem. + + Example: + -------- + Original time series: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + With window_size=3 and horizon=1: + + lag_3 lag_2 lag_1 y_target_1 + 1 2 3 4 + 2 3 4 5 + 3 4 5 6 + 4 5 6 7 + 5 6 7 8 + 6 7 8 9 + 7 8 9 10 + """ + + SCHEMA = TimeSeriesWindowConverterSchema + DESCRIPTION = ( + "Transforms time series data into a tabular regression format by creating " + "lag features from past values and a target column shifted into the future. " + "This enables forecasting using standard regression models." + ) + SHORT_DESCRIPTION = ( + "Converts time series to regression with lag features and future targets." + ) + DISPLAY_NAME = "Time Series Window Converter" + + def __init__(self, window_size: int = 7, horizon: int = 1, target_column: str = ""): + """Initialize the converter with schema parameters.""" + super().__init__() + self.window_size = window_size + self.horizon = horizon + self.target_column = target_column + + # Internal state + self._fitted = False + self._target_column_validated = "" + + def fit( + self, x: DashAIDataset, y: Union[DashAIDataset, None] = None + ) -> "TimeSeriesWindowConverter": + """ + Fit the converter by validating parameters and target column. + + Parameters + ---------- + x : DashAIDataset + Input dataset containing the time series data + y : DashAIDataset, optional + Not used in this converter + + Returns + ------- + TimeSeriesWindowConverter + The fitted converter instance + + Raises + ------ + ValueError + If validation fails (missing target column, invalid parameters, etc.) + """ + # Validate parameters + if self.window_size < 1: + raise ValueError("window_size must be a positive integer") + + if self.horizon < 1: + raise ValueError("horizon must be a positive integer") + + if not self.target_column: + raise ValueError("target_column must be a non-empty string") + + # Check if target column exists in dataset + if self.target_column not in x.column_names: + raise ValueError( + f"Target column '{self.target_column}' not found in dataset. " + f"Available columns: {x.column_names}" + ) + + # Validate that we have enough data points + min_required_rows = self.window_size + self.horizon + if len(x) < min_required_rows: + raise ValueError( + f"Dataset has {len(x)} rows but needs at least " + f"{min_required_rows} rows (window_size={self.window_size} + " + f"horizon={self.horizon})" + ) + + # Store validated target column name + self._target_column_validated = self.target_column + self._fitted = True + + return self + + def transform( + self, x: DashAIDataset, y: Union[DashAIDataset, None] = None + ) -> DashAIDataset: + """ + Transform the dataset by creating lag features and target column. + + Parameters + ---------- + x : DashAIDataset + Input dataset to transform + y : DashAIDataset, optional + Not used in this converter + + Returns + ------- + DashAIDataset + Transformed dataset with lag features and target column + + Raises + ------ + ValueError + If converter is not fitted or transformation fails + """ + if not self._fitted: + raise ValueError("Converter must be fitted before transform") + + # Convert to pandas for easier manipulation + data_frame = x.to_pandas() + + # Verify target column still exists + if self._target_column_validated not in data_frame.columns: + raise ValueError( + f"Target column '{self._target_column_validated}' not found " + f"in transform dataset" + ) + + # Create a copy to avoid modifying the original + result_df = pd.DataFrame() + + # Create lag features (lag_1, lag_2, ..., lag_w) + target_series = data_frame[self._target_column_validated] + + for lag in range(1, self.window_size + 1): + lag_column_name = f"lag_{lag}" + result_df[lag_column_name] = target_series.shift(lag) + + # Create multiple target columns (y_target_1 to y_target_horizon) + for h in range(1, self.horizon + 1): + target_column_name = f"y_target_{h}" + result_df[target_column_name] = target_series.shift(-h) + + # Include any other columns that are not the target column + # This preserves potential date columns or other features + other_columns = [ + col for col in data_frame.columns if col != self._target_column_validated + ] + for col in other_columns: + result_df[col] = data_frame[col] + + # Remove rows with NaN values (caused by shifting) + # These occur at the beginning (due to lag) and end (due to future target) + result_df = result_df.dropna() + + # Validate that we still have data after removing NaN rows + if len(result_df) == 0: + raise ValueError( + "No valid rows remain after creating lag features and target column. " + "Try reducing window_size or horizon, or use a larger dataset." + ) + + # Convert back to DashAIDataset + return to_dashai_dataset(result_df) + + def changes_row_count(self) -> bool: + """ + Indicates that this converter changes the number of rows. + + Returns + ------- + bool + True, as rows with NaN values are removed + """ + return True diff --git a/DashAI/back/dataloaders/classes/dashai_dataset.py b/DashAI/back/dataloaders/classes/dashai_dataset.py index 320e952e6..4052feb27 100644 --- a/DashAI/back/dataloaders/classes/dashai_dataset.py +++ b/DashAI/back/dataloaders/classes/dashai_dataset.py @@ -3,9 +3,12 @@ import logging import os +import numpy as np +import pandas as pd +import pyarrow as pa from beartype import beartype from beartype.typing import Dict, List, Literal, Optional, Tuple, Union -from datasets import Dataset +from datasets import Dataset, DatasetDict from DashAI.back.types.categorical import Categorical from DashAI.back.types.dashai_data_type import DashAIDataType @@ -523,7 +526,6 @@ def sample( Dict A dictionary with selected samples. """ - import numpy as np if n > len(self): raise ValueError( @@ -697,7 +699,6 @@ def transform_dataset_with_schema( DashAIDataset - The updated dataset with new type information """ - import pyarrow as pa # local import table = get_arrow_table(dataset) dai_table = {} @@ -789,8 +790,6 @@ def save_dataset( import json - import pyarrow as pa # local import - table = get_arrow_table(dataset) data_filepath = os.path.join(path, "data.arrow") with pa.OSFile(data_filepath, "wb") as sink: @@ -829,8 +828,6 @@ def load_dataset(dataset_path: Union[str, os.PathLike]) -> DashAIDataset: import json - import pyarrow as pa # local import - data_filepath = os.path.join(dataset_path, "data.arrow") with pa.OSFile(data_filepath, "rb") as source: reader = pa.ipc.open_file(source) @@ -960,7 +957,6 @@ def split_indexes( # Generate shuffled indexes if seed is None: seed = 42 - import numpy as np from sklearn.model_selection import train_test_split indexes = np.arange(total_rows) @@ -1045,7 +1041,6 @@ def split_dataset( ValueError Must provide all indexes or none. """ - import numpy as np if all(idx is None for idx in [train_indexes, test_indexes, val_indexes]): from datasets import DatasetDict @@ -1279,8 +1274,6 @@ def get_columns_spec(dataset_path: str) -> Dict[str, Dict]: """ data_filepath = os.path.join(dataset_path, "data.arrow") - import pyarrow as pa # local import - with pa.OSFile(data_filepath, "rb") as source: reader = pa.ipc.open_file(source) schema = reader.schema @@ -1520,7 +1513,6 @@ def modify_table( DashAIDataset The modified dataset with the updated column type. """ - import pyarrow as pa original_table = dataset.arrow_table updated_columns = {} @@ -1554,3 +1546,236 @@ def modify_table( new_types = types if types else dataset.types return DashAIDataset(new_table, splits=dataset.splits, types=new_types) + + +@beartype +def split_dataset_temporal( + dataset: DashAIDataset, + train_size: Union[int, float] = 0.7, + val_size: Union[int, float] = 0.15, + test_size: Union[int, float] = 0.15, + gap: int = 0, + timestamp_col: str = "ds", + min_train_size: int = 50, + min_val_size: int = 10, + min_test_size: int = 10, +) -> DatasetDict: + """Time-aware data splitting for forecasting tasks. + + Unlike random splitting, this maintains temporal order: + - Training data comes first chronologically + - Validation data follows training data + - Test data comes last + - No data leakage from future to past + + Parameters + ---------- + dataset : DashAIDataset + Dataset to split (must be sorted by timestamp) + train_size : Union[int, float] + Size of training set. If float, interpreted as proportion. + val_size : Union[int, float] + Size of validation set. If float, interpreted as proportion. + test_size : Union[int, float] + Size of test set. If float, interpreted as proportion. + gap : int + Number of periods to skip between splits to avoid data leakage. + timestamp_col : str + Name of timestamp column for ordering + min_train_size : int + Minimum number of training samples required + min_val_size : int + Minimum number of validation samples required + min_test_size : int + Minimum number of test samples required + + Returns + ------- + DatasetDict + Dictionary with 'train', 'validation', 'test' splits + + Raises + ------ + ValueError + If insufficient data for splits or validation fails + """ + n_samples = dataset.num_rows + + # Calculate actual split sizes from proportions or absolute values + if isinstance(train_size, float): + train_size = int(n_samples * train_size) + if isinstance(val_size, float): + val_size = int(n_samples * val_size) + if isinstance(test_size, float): + test_size = int(n_samples * test_size) + + # Adjust for gaps + total_with_gaps = train_size + val_size + test_size + (2 * gap) + + if total_with_gaps > n_samples: + # Proportionally reduce sizes to fit + available = n_samples - (2 * gap) + scale_factor = available / (train_size + val_size + test_size) + + train_size = max(min_train_size, int(train_size * scale_factor)) + val_size = max(min_val_size, int(val_size * scale_factor)) + test_size = max(min_test_size, int(test_size * scale_factor)) + + # Validate minimum sizes + if train_size < min_train_size: + raise ValueError( + f"Training set too small: {train_size} < {min_train_size}. " + f"Need more data or smaller validation/test sets." + ) + + if val_size < min_val_size: + raise ValueError( + f"Validation set too small: {val_size} < {min_val_size}. " + f"Need more data or smaller test set." + ) + + if test_size < min_test_size: + raise ValueError( + f"Test set too small: {test_size} < {min_test_size}. Need more data." + ) + + # Ensure dataset is sorted by timestamp + df_raw = dataset.to_pandas() + if isinstance(df_raw, pd.DataFrame): + dataset_df = df_raw + else: + # Handle iterator case + dataset_df = pd.concat(df_raw, ignore_index=True) + + if timestamp_col in dataset_df.columns: + dataset_df = dataset_df.sort_values(timestamp_col).reset_index(drop=True) + + # Calculate split indices with gaps + train_end = train_size + val_start = train_end + gap + val_end = val_start + val_size + test_start = val_end + gap + test_end = test_start + test_size + + if test_end > n_samples: + raise ValueError( + f"Not enough data for splits with gaps. Need {test_end} samples, " + f"have {n_samples}. Try reducing gap or split sizes." + ) + + # Create splits + # Split the dataset + train_df = dataset_df.iloc[:train_end] + val_df = dataset_df.iloc[val_start:val_end] + test_df = dataset_df.iloc[test_start:test_end] + + # Convert back to DashAI datasets + train_dataset = to_dashai_dataset(Dataset.from_pandas(train_df)) + val_dataset = to_dashai_dataset(Dataset.from_pandas(val_df)) + test_dataset = to_dashai_dataset(Dataset.from_pandas(test_df)) + + # Log split information + if timestamp_col in dataset_df.columns: + log.info("✅ Temporal split completed:") + log.info( + f" Train: {len(train_df)} samples " + f"({dataset_df[timestamp_col].iloc[0]} to " + f"{dataset_df[timestamp_col].iloc[train_end - 1]})" + ) + log.info( + f" Validation: {len(val_df)} samples " + f"({dataset_df[timestamp_col].iloc[val_start]} to " + f"{dataset_df[timestamp_col].iloc[val_end - 1]})" + ) + log.info( + f" Test: {len(test_df)} samples " + f"({dataset_df[timestamp_col].iloc[test_start]} to " + f"{dataset_df[timestamp_col].iloc[test_end - 1]})" + ) + if gap > 0: + log.info(f" Gap: {gap} periods between splits") + else: + log.info( + f"✅ Temporal split completed: " + f"Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}" + ) + + return DatasetDict( + {"train": train_dataset, "validation": val_dataset, "test": test_dataset} + ) + + +@beartype +def prepare_for_forecasting_experiment( + dataset: DashAIDataset, + splits: dict, + timestamp_col: str = "ds", + output_columns: List[str] = None, +) -> Tuple[DatasetDict, Dict]: + """Prepare dataset for forecasting experiment with temporal splits. + + Parameters + ---------- + dataset : DashAIDataset + Dataset to prepare for forecasting + splits : dict + Split configuration from frontend + timestamp_col : str + Name of timestamp column + output_columns : List[str] + Output columns (for compatibility) + + Returns + ------- + Tuple[DatasetDict, Dict] + Prepared dataset and split indices + """ + splitType = splits.get("splitType") + + if splitType in {"manual", "predefined"}: + # Preserve explicit user-defined splits for compatibility. + prepared_dataset, split_indices = prepare_for_model_session( + dataset, splits, output_columns or [] + ) + train_indexes = split_indices["train_indexes"] + test_indexes = split_indices["test_indexes"] + val_indexes = split_indices["val_indexes"] + else: + if splitType != "temporal": + log.warning( + "ForecastingTask received splitType=%r. " + "Falling back to temporal split to avoid data leakage.", + splitType, + ) + + train_size = splits.get("train", 0.7) + val_size = splits.get("validation", 0.15) + test_size = splits.get("test", 0.15) + gap = splits.get("gap", 0) + + prepared_dataset = split_dataset_temporal( + dataset, + train_size=train_size, + val_size=val_size, + test_size=test_size, + gap=gap, + timestamp_col=timestamp_col, + ) + + # Keep compatibility with the rest of the system by exposing indices + # relative to the temporally ordered full dataset. + train_len = len(prepared_dataset["train"]) + val_len = len(prepared_dataset["validation"]) + test_len = len(prepared_dataset["test"]) + + train_indexes = list(range(train_len)) + val_start = train_len + gap + val_indexes = list(range(val_start, val_start + val_len)) + test_start = val_start + val_len + gap + test_indexes = list(range(test_start, test_start + test_len)) + + return prepared_dataset, { + "train_indexes": train_indexes, + "test_indexes": test_indexes, + "val_indexes": val_indexes, + } diff --git a/DashAI/back/explainability/explainers/__init__.py b/DashAI/back/explainability/explainers/__init__.py index e69de29bb..eca68d2ff 100644 --- a/DashAI/back/explainability/explainers/__init__.py +++ b/DashAI/back/explainability/explainers/__init__.py @@ -0,0 +1,24 @@ +"""Explainer implementations. + +This module contains all explainer implementations organized by task type. + +Forecasting explainers are in the `forecasting_explainers` submodule. +""" + +# Import forecasting explainers +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_feature_importance import ( + ForecastFeatureImportance, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_decomposition import ( + ForecastDecomposition, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_uncertainty import ( + ForecastUncertainty, +) + +__all__ = [ + # Forecasting explainers + "ForecastFeatureImportance", + "ForecastDecomposition", + "ForecastUncertainty", +] diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/__init__.py b/DashAI/back/explainability/explainers/forecasting_explainers/__init__.py new file mode 100644 index 000000000..5fb14f082 --- /dev/null +++ b/DashAI/back/explainability/explainers/forecasting_explainers/__init__.py @@ -0,0 +1,35 @@ +"""Forecasting explainability module. + +Provides specialized explainers for time series forecasting models: +- Base classes with time series utilities +- Feature importance for exogenous variables +- Forecast decomposition +- Uncertainty analysis + +All forecasting explainers inherit from ForecastingGlobalExplainer or +ForecastingLocalExplainer to leverage common time series functionality. +""" + +from DashAI.back.explainability.explainers.forecasting_explainers.forecasting_global_explainer import ( + ForecastingGlobalExplainer, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecasting_local_explainer import ( + ForecastingLocalExplainer, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_feature_importance import ( + ForecastFeatureImportance, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_decomposition import ( + ForecastDecomposition, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_uncertainty import ( + ForecastUncertainty, +) + +__all__ = [ + "ForecastingGlobalExplainer", + "ForecastingLocalExplainer", + "ForecastFeatureImportance", + "ForecastDecomposition", + "ForecastUncertainty", +] diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecast_decomposition.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_decomposition.py new file mode 100644 index 000000000..b25b1b7bb --- /dev/null +++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_decomposition.py @@ -0,0 +1,383 @@ +"""Forecast Decomposition Explainer for time series models. + +This explainer decomposes forecasts into interpretable components (trend, +seasonality, residual) for any forecasting model that implements +``get_forecast_components()``. + +Works with: +- Prophet → trend, weekly, yearly (native structural decomposition) +- ARIMA → trend, weekly/yearly, residual (STL on fitted + forecast) +- SARIMAX → trend, weekly/yearly, residual (STL with explicit period s) +- SklearnMultiStep → trend, weekly/yearly, residual (STL on history + forecast) + +Any future model that implements ``get_forecast_components(horizon)`` and +returns a DataFrame with at least a ``ds`` column and one component column +will be automatically supported. +""" + +from typing import List, Tuple + +import numpy as np +import pandas as pd +import plotly +import plotly.graph_objects as go +from datasets import DatasetDict +from plotly.subplots import make_subplots + +from DashAI.back.core.schema_fields import ( + BaseSchema, + bool_field, + int_field, + schema_field, +) +from DashAI.back.explainability.global_explainer import BaseGlobalExplainer +from DashAI.back.models.base_model import BaseModel + + +class ForecastDecompositionSchema(BaseSchema): + """Forecast Decomposition breaks down predictions into interpretable components. + + This helps understand what drives the forecast: + - Trend: Long-term direction + - Seasonality: Repeating patterns (weekly, yearly, etc.) + - External factors: Effect of exogenous variables + - Residuals: Unexplained variation + """ + + horizon: schema_field( + int_field(ge=1, le=365), + placeholder=30, + description="Number of future periods to forecast and decompose. " + "Longer horizons show how components evolve over time.", + ) # type: ignore + + include_historical: schema_field( + bool_field(), + placeholder=False, + description="If True, includes historical component decomposition " + "to show how the model understood past data.", + ) # type: ignore + + +class ForecastDecomposition(BaseGlobalExplainer): + """Universal forecast decomposition explainer. + + Decomposes time series forecasts into interpretable components, + adapting to different model types automatically. + """ + + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + SCHEMA = ForecastDecompositionSchema + + def __init__( + self, + model: BaseModel, + horizon: int = 30, + include_historical: bool = False, + ): + """Initialize ForecastDecomposition explainer. + + Parameters + ---------- + model : BaseModel + Trained forecasting model to explain + horizon : int + Number of periods to forecast (default: 30) + include_historical : bool + Whether to include historical decomposition (default: False) + """ + super().__init__(model) + self.horizon = horizon + self.include_historical = include_historical + + def _get_native_components(self) -> pd.DataFrame: + """Extract components from any model that implements get_forecast_components(). + + This covers Prophet, ARIMA, SARIMAX, and SklearnMultiStepForecaster. + """ + if not hasattr(self.model, "get_forecast_components"): + raise AttributeError( + f"{type(self.model).__name__} must implement " + "get_forecast_components(horizon) to use this path." + ) + return self.model.get_forecast_components(self.horizon) + + def _get_generic_components(self, dataset: DatasetDict) -> pd.DataFrame: + """Fallback for models without native decomposition. + + Uses simple predictions as "trend" component. + """ + x, y = dataset + + # Construct history dataframe from dataset (x and y) + # This allows the model to predict continuing from this dataset + try: + # Convert to pandas with error handling + try: + x_df = x.to_pandas() if hasattr(x, "to_pandas") else pd.DataFrame(x) + except Exception as e: + print(f"Warning: Failed to convert x to DataFrame: {e}") + x_df = None + + try: + y_df = y.to_pandas() if hasattr(y, "to_pandas") else pd.DataFrame(y) + except Exception as e: + print(f"Warning: Failed to convert y to DataFrame: {e}") + y_df = None + + # Combine if possible + if x_df is not None and y_df is not None and len(x_df) == len(y_df): + history_df = x_df.copy() + for col in y_df.columns: + history_df[col] = y_df[col].to_numpy() + + # Get predictions using history context + predictions = self.model.predict( + x_pred=history_df, periods=self.horizon + ) + else: + if x_df is not None and y_df is not None: + print(f"Warning: lengths differ (x={len(x_df)}, y={len(y_df)}).") + history_df = x_df.copy() + predictions = self.model.predict( + x_pred=history_df, periods=self.horizon + ) + elif x_df is not None: + print("Warning: Only x dataset available. Using x as history.") + history_df = x_df.copy() + predictions = self.model.predict( + x_pred=history_df, periods=self.horizon + ) + else: + print("Warning: Could not create history. Using standard predict.") + predictions = self.model.predict(periods=self.horizon) + + except Exception as e: + print(f"Warning: Could not use dataset as history context: {e}") + # Fallback to standard prediction + predictions = self.model.predict(periods=self.horizon) + + # Handle case where model returns fewer predictions than requested + # (e.g. SklearnMultiStepForecaster with direct strategy) + actual_horizon = len(predictions) + + # Determine start date + start_date = pd.Timestamp.now() + if ( + hasattr(self.model, "last_timestamp") + and self.model.last_timestamp is not None + ): + start_date = self.model.last_timestamp + elif hasattr(self.model, "last_ds") and self.model.last_ds is not None: + start_date = self.model.last_ds + + # Determine frequency + freq = "D" + if hasattr(self.model, "frequency") and self.model.frequency: + freq = self.model.frequency + + # Generate dates (start from next period after last timestamp) + dates = pd.date_range(start=start_date, periods=actual_horizon + 1, freq=freq)[ + 1: + ] + + # Create simple dataframe with predictions as "trend" + components_df = pd.DataFrame( + { + "ds": dates, + "trend": predictions + if isinstance(predictions, np.ndarray) + else predictions.to_numpy(), + "seasonal": np.zeros(actual_horizon), + "residual": np.zeros(actual_horizon), + } + ) + + return components_df + + def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict: + """Generate component decomposition explanation. + + Parameters + ---------- + dataset : Tuple[DatasetDict, DatasetDict] + Tuple with (input_samples, targets) used for context + + Returns + ------- + dict + Dictionary with: + - ds: Timestamps + - trend: Trend component + - seasonal: Seasonal component (if applicable) + - weekly/yearly: Specific seasonality (if applicable) + - exog_*: External regressor effects (if applicable) + - model_type: Type of model decomposed + """ + # Detect model type and extract components + model_name = type(self.model).__name__ + + # Friendly display names for known model classes + _display_names = { + "ProphetModel": "Prophet", + "StatsmodelsARIMAModel": "ARIMA", + "StatsmodelsSARIMAXModel": "SARIMAX", + "SklearnMultiStepForecaster": "Sklearn MultiStep", + } + + try: + if hasattr(self.model, "get_forecast_components"): + # Prophet, ARIMA, SARIMAX, SklearnMultiStepForecaster + components_df = self._get_native_components() + model_type = _display_names.get(model_name, model_name) + + else: + # Generic fallback for unknown model types + components_df = self._get_generic_components(dataset) + model_type = "Generic" + + except Exception as e: + raise RuntimeError( + f"Failed to extract components from {model_name}: {str(e)}" + ) from e + + # Convert to serializable format + explanation = { + "model_type": model_type, + "horizon": self.horizon, + "ds": components_df["ds"].dt.strftime("%Y-%m-%d %H:%M:%S").tolist() + if "ds" in components_df.columns + else list(range(len(components_df))), + } + + # Add all available components + for col in components_df.columns: + if col != "ds": + explanation[col] = np.round(components_df[col].values, 3).tolist() + + return explanation + + def _create_decomposition_plot(self, explanation: dict) -> go.Figure: + """Create multi-panel decomposition plot.""" + + # Identify available components + component_cols = [ + k for k in explanation if k not in ["ds", "model_type", "horizon"] + ] + + # Prioritize component order for better visualization + priority_order = ["trend", "seasonal", "yearly", "weekly", "daily"] + ordered_components = [] + + for comp in priority_order: + if comp in component_cols: + ordered_components.append(comp) + + # Add remaining components (e.g., exog_*) + for comp in component_cols: + if comp not in ordered_components: + ordered_components.append(comp) + + n_components = len(ordered_components) + + # Create subplots + fig = make_subplots( + rows=n_components, + cols=1, + subplot_titles=[ + comp.replace("_", " ").title() for comp in ordered_components + ], + vertical_spacing=0.05, + ) + + # Add trace for each component + for i, component in enumerate(ordered_components, 1): + fig.add_trace( + go.Scatter( + x=explanation["ds"], + y=explanation[component], + name=component.replace("_", " ").title(), + line={"width": 2}, + mode="lines", + ), + row=i, + col=1, + ) + + # Update layout + fig.update_layout( + height=250 * n_components, + title_text=f"Forecast Decomposition ({explanation['model_type']} Model)", + showlegend=False, + hovermode="x unified", + ) + + fig.update_xaxes(title_text="Date", row=n_components, col=1) + + return fig + + def _create_stacked_plot(self, explanation: dict) -> go.Figure: + """Create stacked area plot showing component contributions.""" + + explanation_df = pd.DataFrame(explanation) + + # Components to stack (exclude residuals/noise) + stack_components = [ + col + for col in explanation_df.columns + if col not in ["ds", "model_type", "horizon", "residual", "noise"] + and not col.startswith("yhat") + ] + + fig = go.Figure() + + for component in stack_components: + fig.add_trace( + go.Scatter( + x=explanation_df["ds"], + y=explanation_df[component], + name=component.replace("_", " ").title(), + mode="lines", + stackgroup="one", + fillcolor="rgba(0,0,0,0.1)", + ) + ) + + fig.update_layout( + title="Component Contribution Over Time", + xaxis_title="Date", + yaxis_title="Contribution", + hovermode="x unified", + ) + + return fig + + def plot(self, explanation: dict) -> List[dict]: + """Create visualization plots. + + Parameters + ---------- + explanation : dict + Explanation dictionary from explain() + + Returns + ------- + List[dict] + List of plotly JSON figures + """ + plots = [] + + # Main decomposition plot + decomp_fig = self._create_decomposition_plot(explanation) + plots.append(plotly.io.to_json(decomp_fig)) + + # Stacked contribution plot (if multiple components) + component_cols = [ + k for k in explanation if k not in ["ds", "model_type", "horizon"] + ] + + if len(component_cols) > 1: + stacked_fig = self._create_stacked_plot(explanation) + plots.append(plotly.io.to_json(stacked_fig)) + + return plots diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecast_feature_importance.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_feature_importance.py new file mode 100644 index 000000000..cf8e02378 --- /dev/null +++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_feature_importance.py @@ -0,0 +1,446 @@ +"""Feature Importance Explainer for Forecasting Models. + +Evaluates the importance of exogenous variables (regressors) in forecasting models +by measuring how model performance degrades when each feature is permuted. + +This is the forecasting adaptation of Permutation Feature Importance, using +time series specific metrics (MAE, RMSE, MAPE) instead of classification metrics. + +Works with any forecasting model that implements ForecastingModel interface, +which provides get_exogenous_columns() to list external features in their original +format (model-agnostic). + +Compatible models: +- Prophet with add_regressor() +- ARIMA/SARIMAX with exog +- Any model inheriting from ForecastingModel +""" + +from typing import List, Tuple + +import numpy as np +import pandas as pd +import plotly +import plotly.express as px +from datasets import DatasetDict + +from DashAI.back.core.schema_fields import ( + BaseSchema, + enum_field, + int_field, + schema_field, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecasting_global_explainer import ( # noqa: E501 + ForecastingGlobalExplainer, +) +from DashAI.back.models.base_model import BaseModel + + +class ForecastFeatureImportanceSchema(BaseSchema): + """Feature Importance for forecasting models with exogenous variables. + + Measures how much each external variable (weather, holidays, promotions, etc.) + contributes to forecast accuracy by randomly shuffling each feature and + measuring performance degradation. + """ + + scoring: schema_field( + enum_field(enum=["mae", "rmse", "mape"]), + placeholder="mae", + description="Metric to evaluate performance degradation. " + "MAE (Mean Absolute Error) is most interpretable, " + "RMSE (Root Mean Squared Error) penalizes large errors, " + "MAPE (Mean Absolute Percentage Error) shows relative error.", + ) # type: ignore + + n_repeats: schema_field( + int_field(ge=1, le=50), + placeholder=10, + description="Number of times to permute each feature. " + "More repeats give more stable importance estimates but take longer.", + ) # type: ignore + + random_state: schema_field( + int_field(ge=0), + placeholder=42, + description="Seed for random number generator to ensure reproducible results.", + ) # type: ignore + + +class ForecastFeatureImportance(ForecastingGlobalExplainer): + """Feature importance explainer for forecasting models with exogenous variables. + + Identifies which external variables (regressors) are most important for + accurate forecasts by measuring performance degradation when each is permuted. + """ + + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + SCHEMA = ForecastFeatureImportanceSchema + + def __init__( + self, + model: BaseModel, + scoring: str = "mae", + n_repeats: int = 10, + random_state: int = 42, + ): + """Initialize ForecastFeatureImportance explainer. + + Parameters + ---------- + model : BaseModel + Trained forecasting model to explain + scoring : str + Metric to use: 'mae', 'rmse', or 'mape' (default: 'mae') + n_repeats : int + Number of permutation repeats (default: 10) + random_state : int + Random seed for reproducibility (default: 42) + """ + super().__init__(model) + + # Define scoring functions + self.scoring_functions = { + "mae": self._mean_absolute_error, + "rmse": self._root_mean_squared_error, + "mape": self._mean_absolute_percentage_error, + } + + if scoring not in self.scoring_functions: + raise ValueError( + f"Unknown scoring metric: {scoring}. " + f"Choose from: {list(self.scoring_functions.keys())}" + ) + + self.scoring = scoring + self.score_func = self.scoring_functions[scoring] + self.n_repeats = n_repeats + self.random_state = random_state + + def _mean_absolute_error(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: + """Calculate Mean Absolute Error.""" + return np.mean(np.abs(y_true - y_pred)) + + def _root_mean_squared_error(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: + """Calculate Root Mean Squared Error.""" + return np.sqrt(np.mean((y_true - y_pred) ** 2)) + + def _mean_absolute_percentage_error( + self, y_true: np.ndarray, y_pred: np.ndarray + ) -> float: + """Calculate Mean Absolute Percentage Error.""" + # Avoid division by zero + mask = y_true != 0 + if not np.any(mask): + return np.inf + return np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100 + + def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict: + """Calculate feature importance using permutation. + + Parameters + ---------- + dataset : Tuple[DatasetDict, DatasetDict] + Tuple with (full_prepared_dataset, targets) + For forecasting: first element contains timestamp + exog vars + target + + Returns + ------- + dict + Dictionary with: + - features: List of feature names + - importances_mean: Average importance for each feature + - importances_std: Standard deviation of importance + - baseline_score: Model performance without permutation + - scoring_metric: Metric used (mae, rmse, mape) + """ + x, y = dataset + + # Use test set for evaluation + x_test = x["test"] + y_test = y["test"] + + # Get exogenous features from model (using base class method) + exog_features = self._get_exogenous_columns() + + if len(exog_features) == 0: + return { + "error": "No exogenous features found", + "message": ( + "This model does not use exogenous variables. " + "Feature importance is only available for models with " + "external regressors." + ), + "features": [], + "importances_mean": [], + "importances_std": [], + } + + print( + "[ForecastFeatureImportance] Evaluating " + f"{len(exog_features)} exogenous variables" + ) + + # Convert to pandas - x_test now has ALL columns including timestamp + x_df = x_test.to_pandas() + y_true = y_test.to_pandas().to_numpy().ravel() + + timestamp_col = self._get_timestamp_column() + target_col = self._get_target_column() + + print( + f"[ForecastFeatureImportance] Using timestamp: {timestamp_col}, " + f"target: {target_col}" + ) + print(f"[ForecastFeatureImportance] Test set size: {len(x_df)} rows") + print(f"[ForecastFeatureImportance] Columns in x_df: {x_df.columns.tolist()}") + + # Calculate baseline score (without permutation) + try: + # Use ForecastingModel interface: predict(x_pred=DataFrame) + # The model's predict() method expects a DataFrame with: + # - Timestamp column (original name) + # - Exogenous variables (original names) + # This is model-agnostic - works for Prophet, ARIMA, LSTM, etc. + + # Prepare input DataFrame for model's predict method + input_df = x_df.copy() + + # Call the model's predict method using ForecastingModel interface + predictions = self.model.predict(x_pred=input_df) # type: ignore + + # Extract predictions from result + # ForecastingModel.predict() returns DataFrame with target column + if isinstance(predictions, pd.DataFrame): + # Try to get the target column + if target_col and target_col in predictions.columns: + y_pred_baseline = predictions[target_col].to_numpy() + elif "yhat" in predictions.columns: + # Fallback to 'yhat' (Prophet style) + y_pred_baseline = predictions["yhat"].to_numpy() + else: + # Take first numeric column + numeric_cols = predictions.select_dtypes( + include=[np.number] + ).columns + if len(numeric_cols) > 0: + y_pred_baseline = predictions[numeric_cols[0]].to_numpy() + else: + raise ValueError("No numeric columns found in predictions") + elif isinstance(predictions, np.ndarray): + y_pred_baseline = predictions + else: + y_pred_baseline = np.array(predictions) + + baseline_score = self.score_func(y_true, y_pred_baseline) + print( + "[ForecastFeatureImportance] Baseline score " + f"({self.scoring}): {baseline_score:.4f}" + ) + except Exception as e: + print("[ForecastFeatureImportance] ERROR in baseline prediction:") + print(f" - x_df shape: {x_df.shape}") + print(f" - x_df columns: {x_df.columns.tolist()}") + print(f" - Error: {str(e)}") + raise RuntimeError(f"Failed to get baseline predictions: {str(e)}") from e + + # Calculate importance for each feature + importances = {feature: [] for feature in exog_features} + + rng = np.random.RandomState(self.random_state) + + for feature in exog_features: + print(f"[ForecastFeatureImportance] Permuting feature: {feature}") + for repeat in range(self.n_repeats): + # Copy dataframe and permute the feature + x_permuted = x_df.copy() + x_permuted[feature] = rng.permutation(x_permuted[feature].to_numpy()) + + # Get predictions with permuted feature using ForecastingModel interface + try: + # Use same interface as baseline + predictions_perm = self.model.predict(x_pred=x_permuted) # type: ignore + + # Extract predictions (same logic as baseline) + if isinstance(predictions_perm, pd.DataFrame): + if target_col and target_col in predictions_perm.columns: + y_pred_permuted = predictions_perm[target_col].to_numpy() + elif "yhat" in predictions_perm.columns: + y_pred_permuted = predictions_perm["yhat"].to_numpy() + else: + numeric_cols = predictions_perm.select_dtypes( + include=[np.number] + ).columns + if len(numeric_cols) > 0: + y_pred_permuted = predictions_perm[ + numeric_cols[0] + ].to_numpy() + else: + raise ValueError( + "No numeric columns in permuted predictions" + ) + elif isinstance(predictions_perm, np.ndarray): + y_pred_permuted = predictions_perm + else: + y_pred_permuted = np.array(predictions_perm) + + permuted_score = self.score_func(y_true, y_pred_permuted) + + # For error metrics (lower is better), importance is positive + # when permutation increases error + importance = permuted_score - baseline_score + importances[feature].append(importance) + + except Exception as e: + print( + f" Warning: Failed repeat {repeat + 1} for {feature}: {str(e)}" + ) + importances[feature].append(0.0) + + # Calculate statistics + features = list(importances.keys()) + importances_mean = [np.mean(importances[f]) for f in features] + importances_std = [np.std(importances[f]) for f in features] + + return { + "features": features, + "importances_mean": np.round(importances_mean, 4).tolist(), + "importances_std": np.round(importances_std, 4).tolist(), + "baseline_score": round(baseline_score, 4), + "scoring_metric": self.scoring, + } + + def _create_plot( + self, data: pd.DataFrame, explanation: dict + ) -> plotly.graph_objs.Figure: + """Create horizontal bar plot showing feature importances. + + Parameters + ---------- + data : pd.DataFrame + Dataframe with features and importances + explanation : dict + Full explanation dictionary + + Returns + ------- + plotly.graph_objs.Figure + Interactive bar chart + """ + # Sort by importance + data = data.sort_values(by="importances_mean", ascending=True) + + fig = px.bar( + data, + x="importances_mean", + y="features", + error_x="importances_std", + orientation="h", + title=f"Feature Importance ({explanation['scoring_metric'].upper()})", + labels={ + "importances_mean": ( + f"Importance (Δ{explanation['scoring_metric'].upper()})" + ), + "features": "Feature", + }, + ) + + # Add baseline info + baseline_text = ( + f"Baseline {explanation['scoring_metric'].upper()}: " + f"{explanation['baseline_score']:.4f}" + ) + + fig.add_annotation( + text=baseline_text, + xref="paper", + yref="paper", + x=0.98, + y=0.98, + showarrow=False, + bgcolor="rgba(255,255,255,0.8)", + bordercolor="black", + borderwidth=1, + ) + + # Add explanation note + note_text = ( + f"Higher values = more important feature
" + f"Measured as increase in {explanation['scoring_metric'].upper()} " + f"when feature is randomly shuffled" + ) + + fig.add_annotation( + text=note_text, + xref="paper", + yref="paper", + x=0.5, + y=-0.15, + showarrow=False, + font={"size": 10}, + xanchor="center", + ) + + fig.update_layout( + height=max(400, len(data) * 40), + margin={"b": 100}, + ) + + return fig + + def plot(self, explanation: dict) -> List[dict]: + """Create visualization of feature importances. + + Parameters + ---------- + explanation : dict + Explanation dictionary from explain() + + Returns + ------- + List[dict] + List with single plotly JSON figure + """ + # Check for errors + if "error" in explanation: + # Return empty plot with error message + import plotly.graph_objects as go + + fig = go.Figure() + fig.add_annotation( + text=explanation["message"], + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + font={"size": 14}, + ) + fig.update_layout( + title="Feature Importance - No Exogenous Variables", + xaxis={"visible": False}, + yaxis={"visible": False}, + ) + return [plotly.io.to_json(fig)] + + # Create dataframe + data = pd.DataFrame( + { + "features": explanation["features"], + "importances_mean": explanation["importances_mean"], + "importances_std": explanation["importances_std"], + } + ) + + # Clean feature names for display + # Remove 'exog_' prefix if present, then format for readability + data["features"] = ( + data["features"] + .str.replace("exog_", "", regex=False) + .str.replace("_", " ") + .str.title() + ) + + fig = self._create_plot(data, explanation) + + return [plotly.io.to_json(fig)] diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecast_uncertainty.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_uncertainty.py new file mode 100644 index 000000000..1be43245f --- /dev/null +++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecast_uncertainty.py @@ -0,0 +1,487 @@ +"""Forecast Uncertainty Analysis Explainer. + +Analyzes and visualizes prediction uncertainty (confidence/prediction intervals) +across the forecast horizon. Essential for risk management and decision-making. + +Shows how confidence in predictions degrades over time and helps users understand +the reliability of forecasts at different time horizons. + +Works with all forecasting models via ``get_forecast_uncertainty()``: +- Prophet → native intervals from Prophet's uncertainty sampling +- ARIMA → parametric intervals from statsmodels (analytical CIs) +- SARIMAX → parametric intervals from statsmodels (analytical CIs) +- SklearnMultiStep → empirical intervals: residual std × sqrt(horizon step) +- Unknown models → fallback placeholder (±10% of forecast value) +""" + +from typing import List, Tuple + +import numpy as np +import pandas as pd +import plotly +import plotly.graph_objects as go +from datasets import DatasetDict +from plotly.subplots import make_subplots + +from DashAI.back.core.schema_fields import ( + BaseSchema, + float_field, + int_field, + schema_field, +) +from DashAI.back.explainability.global_explainer import BaseGlobalExplainer +from DashAI.back.models.base_model import BaseModel + + +class ForecastUncertaintySchema(BaseSchema): + """Forecast Uncertainty Analysis shows prediction confidence intervals. + + Helps answer: + - How confident is the model in its predictions? + - How does uncertainty grow with forecast horizon? + - What's the best/worst case scenario? + + Critical for inventory planning, capacity planning, and risk management. + """ + + horizon: schema_field( + int_field(ge=1, le=365), + placeholder=30, + description="Number of future periods to forecast. " + "Longer horizons typically show increasing uncertainty.", + ) # type: ignore + + confidence_level: schema_field( + float_field(ge=0.5, le=0.99), + placeholder=0.80, + description="Confidence level for prediction intervals (e.g., 0.80 = 80%). " + "Higher values give wider intervals.", + ) # type: ignore + + +class ForecastUncertainty(BaseGlobalExplainer): + """Analyzes forecast uncertainty and prediction intervals. + + Visualizes how prediction confidence changes across the forecast horizon, + helping users understand forecast reliability and plan for uncertainty. + """ + + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + SCHEMA = ForecastUncertaintySchema + + def __init__( + self, + model: BaseModel, + horizon: int = 30, + confidence_level: float = 0.80, + ): + """Initialize ForecastUncertainty explainer. + + Parameters + ---------- + model : BaseModel + Trained forecasting model + horizon : int + Number of periods to forecast (default: 30) + confidence_level : float + Confidence level for intervals (default: 0.80 = 80%) + """ + super().__init__(model) + self.horizon = horizon + self.confidence_level = confidence_level + + def _get_native_uncertainty(self) -> pd.DataFrame: + """Get uncertainty from any model that implements get_forecast_uncertainty(). + + This covers Prophet, ARIMA, SARIMAX, and SklearnMultiStepForecaster. + """ + if not hasattr(self.model, "get_forecast_uncertainty"): + raise AttributeError( + f"{type(self.model).__name__} must implement " + "get_forecast_uncertainty(horizon, confidence_level) to use this path." + ) + return self.model.get_forecast_uncertainty(self.horizon, self.confidence_level) + + def _get_generic_uncertainty( + self, dataset: Tuple[DatasetDict, DatasetDict] + ) -> pd.DataFrame: + """Fallback for models without native uncertainty quantification. + + Returns point predictions with placeholder intervals. + """ + x, y = dataset + + # Construct history dataframe from dataset (x and y) + try: + # Convert to pandas + x_df = x.to_pandas() if hasattr(x, "to_pandas") else pd.DataFrame(x) + + y_df = y.to_pandas() if hasattr(y, "to_pandas") else pd.DataFrame(y) + + # Combine + if len(x_df) == len(y_df): + history_df = x_df.copy() + for col in y_df.columns: + history_df[col] = y_df[col].to_numpy() + else: + print(f"Warning: lengths differ (x={len(x_df)}, y={len(y_df)}).") + history_df = x_df.copy() + + # Get point predictions using history context + predictions = self.model.predict(x_pred=history_df, periods=self.horizon) + + except Exception as e: + print(f"Warning: Could not use dataset as history context: {e}") + # Fallback + predictions = self.model.predict(periods=self.horizon) + + if hasattr(predictions, "to_numpy"): + y_pred = predictions.to_numpy() + elif isinstance(predictions, np.ndarray): + y_pred = predictions + else: + y_pred = np.array(predictions) + + # Handle case where model returns fewer predictions than requested + actual_horizon = len(y_pred) + + # Create placeholder intervals (±10% of prediction) + uncertainty_pct = 0.10 + + # Determine start date + start_date = pd.Timestamp.now() + if ( + hasattr(self.model, "last_timestamp") + and self.model.last_timestamp is not None + ): + start_date = self.model.last_timestamp + elif hasattr(self.model, "last_ds") and self.model.last_ds is not None: + start_date = self.model.last_ds + + # Determine frequency + freq = "D" + if hasattr(self.model, "frequency") and self.model.frequency: + freq = self.model.frequency + + # Generate dates (start from next period after last timestamp) + dates = pd.date_range(start=start_date, periods=actual_horizon + 1, freq=freq)[ + 1: + ] + + uncertainty_df = pd.DataFrame( + { + "ds": dates, + "yhat": y_pred, + "yhat_lower": y_pred * (1 - uncertainty_pct), + "yhat_upper": y_pred * (1 + uncertainty_pct), + "estimated_intervals": True, # Flag that these are not native + } + ) + + return uncertainty_df + + def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict: + """Generate uncertainty analysis explanation. + + Parameters + ---------- + dataset : Tuple[DatasetDict, DatasetDict] + Tuple with (input_features, targets) for context + + Returns + ------- + dict + Dictionary with: + - ds: Timestamps + - yhat: Point predictions + - yhat_lower: Lower bound of prediction interval + - yhat_upper: Upper bound of prediction interval + - uncertainty: Interval width (yhat_upper - yhat_lower) + - uncertainty_pct: Uncertainty as % of prediction + - confidence_level: Configured confidence level + - model_type: Model that generated intervals + """ + model_name = type(self.model).__name__ + + # Friendly display names for known model classes + _display_names = { + "ProphetModel": "Prophet", + "StatsmodelsARIMAModel": "ARIMA", + "StatsmodelsSARIMAXModel": "SARIMAX", + "SklearnMultiStepForecaster": "Sklearn MultiStep", + } + + # Models with true parametric / native intervals + _parametric_models = { + "ProphetModel", + "StatsmodelsARIMAModel", + "StatsmodelsSARIMAXModel", + } + + # Human-readable description of the interval source + _interval_sources = { + "ProphetModel": "Native (Prophet uncertainty sampling)", + "StatsmodelsARIMAModel": "Parametric (ARIMA analytical CI)", + "StatsmodelsSARIMAXModel": "Parametric (SARIMAX analytical CI)", + "SklearnMultiStepForecaster": "Empirical (residual std × √horizon)", + } + + try: + if hasattr(self.model, "get_forecast_uncertainty"): + # Prophet, ARIMA, SARIMAX, SklearnMultiStepForecaster + forecast_df = self._get_native_uncertainty() + model_type = _display_names.get(model_name, model_name) + has_native_intervals = model_name in _parametric_models + interval_source = _interval_sources.get( + model_name, "Native (model-specific)" + ) + + else: + # Generic fallback for unknown model types + forecast_df = self._get_generic_uncertainty(dataset) + model_type = "Generic" + has_native_intervals = False + interval_source = "Placeholder (±10% of forecast)" + + except Exception as e: + raise RuntimeError( + f"Failed to get uncertainty estimates from {model_name}: {str(e)}" + ) from e + + # Calculate uncertainty metrics + forecast_df["uncertainty"] = ( + forecast_df["yhat_upper"] - forecast_df["yhat_lower"] + ) + + # Avoid division by zero + safe_yhat = np.where(forecast_df["yhat"] == 0, 1e-10, forecast_df["yhat"]) + forecast_df["uncertainty_pct"] = ( + forecast_df["uncertainty"] / np.abs(safe_yhat) * 100 + ) + + # Build explanation + explanation = { + "model_type": model_type, + "confidence_level": self.confidence_level, + "horizon": self.horizon, + "has_native_intervals": has_native_intervals, + "interval_source": interval_source, + "ds": forecast_df["ds"].dt.strftime("%Y-%m-%d %H:%M:%S").tolist(), + "yhat": np.round(forecast_df["yhat"].to_numpy(), 3).tolist(), + "yhat_lower": np.round(forecast_df["yhat_lower"].to_numpy(), 3).tolist(), + "yhat_upper": np.round(forecast_df["yhat_upper"].to_numpy(), 3).tolist(), + "uncertainty": np.round(forecast_df["uncertainty"].to_numpy(), 3).tolist(), + "uncertainty_pct": np.round( + forecast_df["uncertainty_pct"].to_numpy(), 2 + ).tolist(), + } + + # Add summary statistics + explanation["summary"] = { + "mean_uncertainty": round(forecast_df["uncertainty"].mean(), 3), + "max_uncertainty": round(forecast_df["uncertainty"].max(), 3), + "mean_uncertainty_pct": round(forecast_df["uncertainty_pct"].mean(), 2), + "uncertainty_growth": round( + forecast_df["uncertainty"].iloc[-1] / forecast_df["uncertainty"].iloc[0] + if forecast_df["uncertainty"].iloc[0] != 0 + else 0, + 2, + ), + } + + return explanation + + def _create_forecast_plot(self, explanation: dict) -> go.Figure: + """Create main forecast plot with confidence intervals.""" + + forecast_plot_df = pd.DataFrame( + { + "ds": pd.to_datetime(explanation["ds"]), + "yhat": explanation["yhat"], + "yhat_lower": explanation["yhat_lower"], + "yhat_upper": explanation["yhat_upper"], + } + ) + + fig = go.Figure() + + # Add confidence interval band + fig.add_trace( + go.Scatter( + x=forecast_plot_df["ds"], + y=forecast_plot_df["yhat_upper"], + mode="lines", + line={"width": 0}, + showlegend=False, + hoverinfo="skip", + ) + ) + + fig.add_trace( + go.Scatter( + x=forecast_plot_df["ds"], + y=forecast_plot_df["yhat_lower"], + mode="lines", + line={"width": 0}, + fillcolor="rgba(68, 68, 68, 0.2)", + fill="tonexty", + name=( + f"{int(explanation['confidence_level'] * 100)}% Confidence Interval" + ), + ) + ) + + # Add point forecast + fig.add_trace( + go.Scatter( + x=forecast_plot_df["ds"], + y=forecast_plot_df["yhat"], + mode="lines", + name="Forecast", + line={"color": "blue", "width": 2}, + ) + ) + + # Title + title = ( + f"Forecast with {int(explanation['confidence_level'] * 100)}% " + "Confidence Interval" + ) + interval_source = explanation.get("interval_source", "") + if interval_source: + title += f"
{interval_source}" + + fig.update_layout( + title=title, + xaxis_title="Date", + yaxis_title="Predicted Value", + hovermode="x unified", + height=500, + ) + + return fig + + def _create_uncertainty_growth_plot(self, explanation: dict) -> go.Figure: + """Create plot showing how uncertainty grows over horizon.""" + + growth_plot_df = pd.DataFrame( + { + "ds": pd.to_datetime(explanation["ds"]), + "uncertainty": explanation["uncertainty"], + "uncertainty_pct": explanation["uncertainty_pct"], + } + ) + + # Create subplot with absolute and relative uncertainty + fig = make_subplots( + rows=2, + cols=1, + subplot_titles=( + "Absolute Uncertainty (Interval Width)", + "Relative Uncertainty (% of Forecast)", + ), + vertical_spacing=0.12, + ) + + # Absolute uncertainty + fig.add_trace( + go.Scatter( + x=growth_plot_df["ds"], + y=growth_plot_df["uncertainty"], + mode="lines+markers", + name="Uncertainty", + line={"color": "red", "width": 2}, + marker={"size": 4}, + ), + row=1, + col=1, + ) + + # Relative uncertainty + fig.add_trace( + go.Scatter( + x=growth_plot_df["ds"], + y=growth_plot_df["uncertainty_pct"], + mode="lines+markers", + name="Uncertainty %", + line={"color": "orange", "width": 2}, + marker={"size": 4}, + ), + row=2, + col=1, + ) + + fig.update_xaxes(title_text="Date", row=2, col=1) + fig.update_yaxes(title_text="Interval Width", row=1, col=1) + fig.update_yaxes(title_text="Uncertainty (%)", row=2, col=1) + + fig.update_layout( + title="Uncertainty Growth Over Forecast Horizon", + height=600, + showlegend=False, + ) + + return fig + + def plot(self, explanation: dict) -> List[dict]: + """Create visualization plots. + + Parameters + ---------- + explanation : dict + Explanation dictionary from explain() + + Returns + ------- + List[dict] + List of plotly JSON figures + """ + plots = [] + + # Main forecast with intervals + forecast_fig = self._create_forecast_plot(explanation) + plots.append(plotly.io.to_json(forecast_fig)) + + # Uncertainty growth analysis + uncertainty_fig = self._create_uncertainty_growth_plot(explanation) + plots.append(plotly.io.to_json(uncertainty_fig)) + + # Add summary statistics as annotation figure + summary = explanation["summary"] + + import plotly.graph_objects as go + + summary_fig = go.Figure() + + summary_text = ( + f"Uncertainty Summary

" + f"Mean Uncertainty: {summary['mean_uncertainty']}
" + f"Max Uncertainty: {summary['max_uncertainty']}
" + f"Mean Uncertainty %: {summary['mean_uncertainty_pct']:.1f}%
" + f"Uncertainty Growth: {summary['uncertainty_growth']:.2f}x" + ) + + summary_fig.add_annotation( + text=summary_text, + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + font={"size": 14}, + align="left", + bgcolor="rgba(255,255,255,0.9)", + bordercolor="black", + borderwidth=2, + ) + + summary_fig.update_layout( + title="Summary Statistics", + xaxis={"visible": False}, + yaxis={"visible": False}, + height=300, + ) + + plots.append(plotly.io.to_json(summary_fig)) + + return plots diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_global_explainer.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_global_explainer.py new file mode 100644 index 000000000..42c96390f --- /dev/null +++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_global_explainer.py @@ -0,0 +1,294 @@ +"""Base class for global explainers specialized for forecasting tasks. + +Provides common functionality for explaining forecasting models: +- Timestamp column detection and handling +- Frequency inference and validation +- Exogenous variable management +- Time series data preparation + +All global explainers for forecasting tasks should inherit from this class. +""" + +from abc import abstractmethod +from typing import List, Optional, Tuple + +import pandas as pd +from datasets import DatasetDict + +from DashAI.back.explainability.global_explainer import BaseGlobalExplainer +from DashAI.back.models.base_model import BaseModel + + +class ForecastingGlobalExplainer(BaseGlobalExplainer): + """Base class for global explainers specialized for forecasting. + + Provides common utilities for handling time series data: + - Timestamp column detection + - Frequency inference + - Exogenous variable extraction + - Data validation for forecasting + + Subclasses must implement: + - explain(): Generate the explanation + - plot(): Create visualizations + """ + + # All forecasting explainers are compatible with ForecastingTask + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + + def __init__(self, model: BaseModel, **kwargs): + """Initialize forecasting global explainer. + + Parameters + ---------- + model : BaseModel + Trained forecasting model to explain + **kwargs : dict + Additional parameters passed to parent class + """ + super().__init__(model, **kwargs) + + # Cache for model metadata + self._timestamp_col: Optional[str] = None + self._target_col: Optional[str] = None + self._exog_cols: Optional[List[str]] = None + self._frequency: Optional[str] = None + + def _get_timestamp_column(self) -> Optional[str]: + """Get timestamp column name from model. + + Returns + ------- + str or None + Name of timestamp column, or None if not available + """ + if self._timestamp_col is not None: + return self._timestamp_col + + # Try to get from model + if hasattr(self.model, "timestamp_col"): + self._timestamp_col = getattr(self.model, "timestamp_col", None) + elif hasattr(self.model, "get_column_names"): + try: + col_names = self.model.get_column_names() # type: ignore + self._timestamp_col = col_names.get("timestamp") + except Exception: + pass + + return self._timestamp_col + + def _get_target_column(self) -> Optional[str]: + """Get target column name from model. + + Returns + ------- + str or None + Name of target column, or None if not available + """ + if self._target_col is not None: + return self._target_col + + # Try to get from model + if hasattr(self.model, "target_col"): + self._target_col = getattr(self.model, "target_col", None) + elif hasattr(self.model, "get_column_names"): + try: + col_names = self.model.get_column_names() # type: ignore + self._target_col = col_names.get("target") + except Exception: + pass + + return self._target_col + + def _get_exogenous_columns(self) -> List[str]: + """Get exogenous variable names from model. + + Uses model's interface to get exogenous columns in original format. + + Returns + ------- + List[str] + List of exogenous variable names + """ + if self._exog_cols is not None: + return self._exog_cols + + # Try to get from model using ForecastingModel interface + if hasattr(self.model, "get_exogenous_columns"): + try: + self._exog_cols = self.model.get_exogenous_columns() # type: ignore + return self._exog_cols or [] + except Exception: + pass + + # Fallback: check exog_cols attribute + if hasattr(self.model, "exog_cols"): + self._exog_cols = getattr(self.model, "exog_cols", []) + return self._exog_cols or [] + + return [] + + def _get_frequency(self) -> Optional[str]: + """Get time series frequency from model. + + Returns + ------- + str or None + Frequency string (e.g., 'D', 'H', 'M'), or None if not available + """ + if self._frequency is not None: + return self._frequency + + # Try to get from model + if hasattr(self.model, "frequency"): + self._frequency = getattr(self.model, "frequency", None) + + return self._frequency + + def _extract_timestamps( + self, dataset: DatasetDict, split: str = "test" + ) -> pd.Series: + """Extract timestamp column from dataset. + + Parameters + ---------- + dataset : DatasetDict + Dataset containing time series data + split : str + Which split to extract from (default: "test") + + Returns + ------- + pd.Series + Series with timestamps as datetime + + Raises + ------ + ValueError + If timestamp column not found or cannot be converted + """ + timestamp_col = self._get_timestamp_column() + + if timestamp_col is None: + raise ValueError( + "Cannot determine timestamp column. " + "Model must store timestamp_col or implement get_column_names()" + ) + + if split not in dataset: + raise ValueError(f"Split '{split}' not found in dataset") + + ds = dataset[split] + + if timestamp_col not in ds.column_names: + raise ValueError( + f"Timestamp column '{timestamp_col}' not found in dataset. " + f"Available columns: {ds.column_names}" + ) + + # Convert to pandas Series with datetime + timestamps = pd.to_datetime(ds.to_pandas()[timestamp_col]) + + return timestamps + + def _prepare_dataset_with_timestamps( + self, dataset: DatasetDict, split: str = "test" + ) -> pd.DataFrame: + """Prepare dataset as DataFrame with all required columns. + + Includes timestamp column, exogenous variables, and target (if available). + This is useful when explainers need the full context for predictions. + + Parameters + ---------- + dataset : DatasetDict + Dataset to prepare + split : str + Which split to use (default: "test") + + Returns + ------- + pd.DataFrame + DataFrame with timestamps, exogenous variables, and target + """ + if split not in dataset: + raise ValueError(f"Split '{split}' not found in dataset") + + split_df = dataset[split].to_pandas() + + # Ensure timestamp column is datetime + timestamp_col = self._get_timestamp_column() + if timestamp_col and timestamp_col in split_df.columns: + split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col]) + + return split_df + + def _validate_has_exogenous_variables(self) -> bool: + """Check if model uses exogenous variables. + + Returns + ------- + bool + True if model has exogenous variables + """ + if hasattr(self.model, "has_exogenous_variables"): + try: + return self.model.has_exogenous_variables() # type: ignore + except Exception: + pass + + # Fallback: check if exog_cols is non-empty + exog_cols = self._get_exogenous_columns() + return len(exog_cols) > 0 + + def _infer_frequency(self, timestamps: pd.Series) -> Optional[str]: + """Infer frequency from timestamp series. + + Parameters + ---------- + timestamps : pd.Series + Series of datetime values + + Returns + ------- + str or None + Inferred frequency (e.g., 'D', 'H'), or None if cannot infer + """ + try: + # Use pandas infer_freq + freq = pd.infer_freq(timestamps) + return freq + except Exception: + # Try getting from model + return self._get_frequency() + + @abstractmethod + def explain(self, dataset: Tuple[DatasetDict, DatasetDict]) -> dict: + """Generate explanation for the forecasting model. + + Parameters + ---------- + dataset : Tuple[DatasetDict, DatasetDict] + Tuple with (input_features, targets) + Note: For forecasting, input_features may need timestamp column + + Returns + ------- + dict + Explanation results + """ + + @abstractmethod + def plot(self, explanation: dict) -> List[dict]: + """Create visualizations for the explanation. + + Parameters + ---------- + explanation : dict + Explanation dictionary from explain() + + Returns + ------- + List[dict] + List of plotly JSON figures + """ diff --git a/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_local_explainer.py b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_local_explainer.py new file mode 100644 index 000000000..217e50c6c --- /dev/null +++ b/DashAI/back/explainability/explainers/forecasting_explainers/forecasting_local_explainer.py @@ -0,0 +1,336 @@ +"""Base class for local explainers specialized for forecasting tasks. + +Provides common functionality for explaining individual forecasts: +- Instance selection from time series +- Window extraction for point-in-time explanations +- Temporal context management +- Per-forecast explanation generation + +All local explainers for forecasting tasks should inherit from this class. +""" + +from abc import abstractmethod +from typing import List, Optional, Tuple + +import pandas as pd +from datasets import DatasetDict + +from DashAI.back.explainability.local_explainer import BaseLocalExplainer +from DashAI.back.models.base_model import BaseModel + + +class ForecastingLocalExplainer(BaseLocalExplainer): + """Base class for local explainers specialized for forecasting. + + Provides common utilities for explaining individual forecasts: + - Timestamp handling for specific forecast points + - Window extraction (e.g., last N days before forecast) + - Exogenous variable context + - Per-instance explanation generation + + Subclasses must implement: + - fit(): Prepare explainer with training data + - explain_instance(): Generate explanation for specific forecast + - plot(): Create visualizations + """ + + # All forecasting local explainers are compatible with ForecastingTask + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + + def __init__(self, model: BaseModel, **kwargs): + """Initialize forecasting local explainer. + + Parameters + ---------- + model : BaseModel + Trained forecasting model to explain + **kwargs : dict + Additional parameters passed to parent class + """ + super().__init__(model, **kwargs) + + # Cache for model metadata + self._timestamp_col: Optional[str] = None + self._target_col: Optional[str] = None + self._exog_cols: Optional[List[str]] = None + self._frequency: Optional[str] = None + + def _get_timestamp_column(self) -> Optional[str]: + """Get timestamp column name from model. + + Returns + ------- + str or None + Name of timestamp column, or None if not available + """ + if self._timestamp_col is not None: + return self._timestamp_col + + # Try to get from model + if hasattr(self.model, "timestamp_col"): + self._timestamp_col = getattr(self.model, "timestamp_col", None) + elif hasattr(self.model, "get_column_names"): + try: + col_names = self.model.get_column_names() # type: ignore + self._timestamp_col = col_names.get("timestamp") + except Exception: + pass + + return self._timestamp_col + + def _get_target_column(self) -> Optional[str]: + """Get target column name from model. + + Returns + ------- + str or None + Name of target column, or None if not available + """ + if self._target_col is not None: + return self._target_col + + # Try to get from model + if hasattr(self.model, "target_col"): + self._target_col = getattr(self.model, "target_col", None) + elif hasattr(self.model, "get_column_names"): + try: + col_names = self.model.get_column_names() # type: ignore + self._target_col = col_names.get("target") + except Exception: + pass + + return self._target_col + + def _get_exogenous_columns(self) -> List[str]: + """Get exogenous variable names from model. + + Uses model's interface to get exogenous columns in original format. + + Returns + ------- + List[str] + List of exogenous variable names + """ + if self._exog_cols is not None: + return self._exog_cols + + # Try to get from model using ForecastingModel interface + if hasattr(self.model, "get_exogenous_columns"): + try: + self._exog_cols = self.model.get_exogenous_columns() # type: ignore + return self._exog_cols or [] + except Exception: + pass + + # Fallback: check exog_cols attribute + if hasattr(self.model, "exog_cols"): + self._exog_cols = getattr(self.model, "exog_cols", []) + return self._exog_cols or [] + + return [] + + def _get_frequency(self) -> Optional[str]: + """Get time series frequency from model. + + Returns + ------- + str or None + Frequency string (e.g., 'D', 'H', 'M'), or None if not available + """ + if self._frequency is not None: + return self._frequency + + # Try to get from model + if hasattr(self.model, "frequency"): + self._frequency = getattr(self.model, "frequency", None) + + return self._frequency + + def _extract_window( + self, + dataset: DatasetDict, + split: str = "test", + window_size: Optional[int] = None, + end_index: Optional[int] = None, + ) -> pd.DataFrame: + """Extract a window of data for local explanation. + + Useful for explaining a specific forecast by showing the context + (e.g., last 30 days before the forecast point). + + Parameters + ---------- + dataset : DatasetDict + Dataset containing time series data + split : str + Which split to use (default: "test") + window_size : int, optional + Number of time points to include in window + If None, returns all data up to end_index + end_index : int, optional + Last index to include (exclusive) + If None, uses all available data + + Returns + ------- + pd.DataFrame + DataFrame with windowed data + """ + if split not in dataset: + raise ValueError(f"Split '{split}' not found in dataset") + + split_df = dataset[split].to_pandas() + + # Apply end index + if end_index is not None: + split_df = split_df.iloc[:end_index] + + # Apply window size + if window_size is not None and len(split_df) > window_size: + split_df = split_df.iloc[-window_size:] + + # Ensure timestamp column is datetime + timestamp_col = self._get_timestamp_column() + if timestamp_col and timestamp_col in split_df.columns: + split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col]) + + return split_df + + def _select_instance_by_timestamp( + self, dataset: DatasetDict, timestamp: pd.Timestamp, split: str = "test" + ) -> pd.Series: + """Select a specific instance by timestamp. + + Parameters + ---------- + dataset : DatasetDict + Dataset containing time series data + timestamp : pd.Timestamp + Timestamp of instance to select + split : str + Which split to use (default: "test") + + Returns + ------- + pd.Series + Single row as Series + + Raises + ------ + ValueError + If timestamp not found in dataset + """ + timestamp_col = self._get_timestamp_column() + + if timestamp_col is None: + raise ValueError( + "Cannot select by timestamp: timestamp column not available" + ) + + split_df = dataset[split].to_pandas() + split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col]) + + mask = split_df[timestamp_col] == timestamp + + if not mask.any(): + raise ValueError(f"Timestamp {timestamp} not found in {split} split") + + return split_df[mask].iloc[0] + + def _prepare_dataset_with_timestamps( + self, dataset: DatasetDict, split: str = "test" + ) -> pd.DataFrame: + """Prepare dataset as DataFrame with all required columns. + + Includes timestamp column, exogenous variables, and target (if available). + + Parameters + ---------- + dataset : DatasetDict + Dataset to prepare + split : str + Which split to use (default: "test") + + Returns + ------- + pd.DataFrame + DataFrame with timestamps, exogenous variables, and target + """ + if split not in dataset: + raise ValueError(f"Split '{split}' not found in dataset") + + split_df = dataset[split].to_pandas() + + # Ensure timestamp column is datetime + timestamp_col = self._get_timestamp_column() + if timestamp_col and timestamp_col in split_df.columns: + split_df[timestamp_col] = pd.to_datetime(split_df[timestamp_col]) + + return split_df + + def _validate_has_exogenous_variables(self) -> bool: + """Check if model uses exogenous variables. + + Returns + ------- + bool + True if model has exogenous variables + """ + if hasattr(self.model, "has_exogenous_variables"): + try: + return self.model.has_exogenous_variables() # type: ignore + except Exception: + pass + + # Fallback: check if exog_cols is non-empty + exog_cols = self._get_exogenous_columns() + return len(exog_cols) > 0 + + @abstractmethod + def fit( + self, dataset: Tuple[DatasetDict, DatasetDict], **fit_params + ) -> "ForecastingLocalExplainer": + """Fit the explainer on training data. + + Parameters + ---------- + dataset : Tuple[DatasetDict, DatasetDict] + Tuple with (input_features, targets) + **fit_params : dict + Additional fitting parameters + + Returns + ------- + ForecastingLocalExplainer + Self + """ + + @abstractmethod + def explain_instance(self, instance: DatasetDict) -> dict: + """Generate explanation for a specific forecast instance. + + Parameters + ---------- + instance : DatasetDict + Single instance or small window to explain + + Returns + ------- + dict + Explanation for this specific instance + """ + + @abstractmethod + def plot(self, explanation: dict) -> List[dict]: + """Create visualizations for the local explanation. + + Parameters + ---------- + explanation : dict + Explanation dictionary from explain_instance() + + Returns + ------- + List[dict] + List of plotly JSON figures + """ diff --git a/DashAI/back/initial_components.py b/DashAI/back/initial_components.py index f8b9f006b..1e2d27091 100644 --- a/DashAI/back/initial_components.py +++ b/DashAI/back/initial_components.py @@ -60,13 +60,32 @@ CharacterReplacer, ) from DashAI.back.converters.simple_converters.column_remover import ColumnRemover + +# Forecasting converters +from DashAI.back.converters.simple_converters.extend_time_series_converter import ( + ExtendTimeSeriesConverter, +) from DashAI.back.converters.simple_converters.nan_remover import NanRemover +from DashAI.back.converters.simple_converters.time_series_window_converter import ( + TimeSeriesWindowConverter, +) # DataLoaders from DashAI.back.dataloaders.classes.csv_dataloader import CSVDataLoader from DashAI.back.dataloaders.classes.excel_dataloader import ExcelDataLoader from DashAI.back.dataloaders.classes.json_dataloader import JSONDataLoader +# Forecasting explainers +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_decomposition import ( # noqa: E501 + ForecastDecomposition, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_feature_importance import ( # noqa: E501 + ForecastFeatureImportance, +) +from DashAI.back.explainability.explainers.forecasting_explainers.forecast_uncertainty import ( # noqa: E501 + ForecastUncertainty, +) + # Explainers from DashAI.back.explainability.explainers.kernel_shap import KernelShap from DashAI.back.explainability.explainers.partial_dependence import PartialDependence @@ -99,6 +118,9 @@ from DashAI.back.job.dataset_job import DatasetJob from DashAI.back.job.explainer_job import ExplainerJob from DashAI.back.job.explorer_job import ExplorerJob + +# Forecasting job +from DashAI.back.job.forecasting_job import ForecastingJob from DashAI.back.job.generative_job import GenerativeJob from DashAI.back.job.model_job import ModelJob from DashAI.back.job.pipeline_job import PipelineJob @@ -113,6 +135,10 @@ from DashAI.back.metrics.classification.precision import Precision from DashAI.back.metrics.classification.recall import Recall from DashAI.back.metrics.classification.roc_auc import ROCAUC + +# Forecasting metrics +from DashAI.back.metrics.forecasting.mape import MAPE +from DashAI.back.metrics.forecasting.smape import SMAPE from DashAI.back.metrics.regression.explained_variance import ExplainedVariance from DashAI.back.metrics.regression.mae import MAE from DashAI.back.metrics.regression.median_absolute_error import MedianAbsoluteError @@ -123,6 +149,18 @@ from DashAI.back.metrics.translation.chrf import Chrf from DashAI.back.metrics.translation.ter import Ter +# Forecasting models +from DashAI.back.models.forecasting.prophet_model import ProphetModel +from DashAI.back.models.forecasting.sklearn_multistep_forecaster import ( + SklearnMultiStepForecaster, +) +from DashAI.back.models.forecasting.statsmodels_arima_model import ( + StatsmodelsARIMAModel, +) +from DashAI.back.models.forecasting.statsmodels_sarimax_model import ( + StatsmodelsSARIMAXModel, +) + # Models from DashAI.back.models.hugging_face.distilbert_transformer import DistilBertTransformer from DashAI.back.models.hugging_face.llama_model import LlamaModel @@ -178,6 +216,9 @@ from DashAI.back.models.scikit_learn.linearSVR import LinearSVR from DashAI.back.models.scikit_learn.logistic_regression import LogisticRegression from DashAI.back.models.scikit_learn.mlp_regression import MLPRegression +from DashAI.back.models.scikit_learn.multi_output_regression import ( + MultiOutputRegression, +) from DashAI.back.models.scikit_learn.random_forest_classifier import ( RandomForestClassifier, ) @@ -203,6 +244,9 @@ # Tasks from DashAI.back.tasks.controlnet_task import ControlNetTask + +# Forecasting task +from DashAI.back.tasks.forecasting_task import ForecastingTask from DashAI.back.tasks.regression_task import RegressionTask from DashAI.back.tasks.tabular_classification_task import TabularClassificationTask from DashAI.back.tasks.text_classification_task import TextClassificationTask @@ -231,6 +275,7 @@ def get_initial_components(): TextClassificationTask, TranslationTask, RegressionTask, + ForecastingTask, TextToImageGenerationTask, TextToTextGenerationTask, ControlNetTask, @@ -259,6 +304,11 @@ def get_initial_components(): SDXLCannyControlNetModel, LogisticRegression, MLPRegression, + MultiOutputRegression, + ProphetModel, + SklearnMultiStepForecaster, + StatsmodelsARIMAModel, + StatsmodelsSARIMAXModel, RandomForestClassifier, RandomForestRegression, DistilBertTransformer, @@ -281,6 +331,8 @@ def get_initial_components(): Chrf, MSE, RMSE, + MAPE, + SMAPE, MAE, R2, MedianAbsoluteError, @@ -296,6 +348,7 @@ def get_initial_components(): ExplainerJob, ModelJob, ExplorerJob, + ForecastingJob, PredictJob, ConverterListJob, DatasetJob, @@ -305,6 +358,9 @@ def get_initial_components(): KernelShap, PartialDependence, PermutationFeatureImportance, + ForecastDecomposition, + ForecastFeatureImportance, + ForecastUncertainty, # Explorers DescribeExplorer, ScatterPlotExplorer, @@ -364,6 +420,8 @@ def get_initial_components(): SMOTEConverter, SMOTEENNConverter, RandomUnderSamplerConverter, + TimeSeriesWindowConverter, + ExtendTimeSeriesConverter, ] # Obtener plugins instalados diff --git a/DashAI/back/job/explainer_job.py b/DashAI/back/job/explainer_job.py index 4863a59f0..e93b1f9c9 100644 --- a/DashAI/back/job/explainer_job.py +++ b/DashAI/back/job/explainer_job.py @@ -416,43 +416,66 @@ def run( ) from e try: splits = json.loads(run.split_indexes) - loaded_dataset = split_dataset( - loaded_dataset, - train_indexes=splits["train_indexes"], - test_indexes=splits["test_indexes"], - val_indexes=splits["val_indexes"], - ) - prepared_dataset = task.prepare_for_task( - dataset=loaded_dataset, - input_columns=self.input_columns, - output_columns=self.output_columns, - ) - data = select_columns( - prepared_dataset, - self.input_columns, - self.output_columns, - ) + if model_session.task_name == "ForecastingTask": + # For forecasting: prepare full dataset BEFORE splitting + # (preserves all data points for temporal integrity) + prepared_dataset = task.prepare_for_task( + dataset=loaded_dataset, + input_columns=self.input_columns, + output_columns=self.output_columns, + ) - data_x = split_dataset( - data[0], - train_indexes=splits["train_indexes"], - test_indexes=splits["test_indexes"], - val_indexes=splits["val_indexes"], - ) - data_y = split_dataset( - data[1], - train_indexes=splits["train_indexes"], - test_indexes=splits["test_indexes"], - val_indexes=splits["val_indexes"], - ) - for split_name in data_x: - data_x[split_name] = trained_model.prepare_dataset( - data_x[split_name], is_fit=False + data = select_columns( + prepared_dataset, + self.input_columns, + self.output_columns, + ) + + data_x = split_dataset( + data[0], + train_indexes=splits["train_indexes"], + test_indexes=splits["test_indexes"], + val_indexes=splits["val_indexes"], + ) + data_y = split_dataset( + data[1], + train_indexes=splits["train_indexes"], + test_indexes=splits["test_indexes"], + val_indexes=splits["val_indexes"], + ) + else: + # For other tasks: standard flow + prepared_dataset = task.prepare_for_task( + dataset=loaded_dataset, + input_columns=self.input_columns, + output_columns=self.output_columns, + ) + data = select_columns( + prepared_dataset, + self.input_columns, + self.output_columns, + ) + + data_x = split_dataset( + data[0], + train_indexes=splits["train_indexes"], + test_indexes=splits["test_indexes"], + val_indexes=splits["val_indexes"], ) - data_y[split_name] = trained_model.prepare_output( - data_y[split_name], is_fit=False + data_y = split_dataset( + data[1], + train_indexes=splits["train_indexes"], + test_indexes=splits["test_indexes"], + val_indexes=splits["val_indexes"], ) + for split_name in data_x: + data_x[split_name] = trained_model.prepare_dataset( + data_x[split_name], is_fit=False + ) + data_y[split_name] = trained_model.prepare_output( + data_y[split_name], is_fit=False + ) except Exception as e: log.exception(e) diff --git a/DashAI/back/job/forecasting_job.py b/DashAI/back/job/forecasting_job.py new file mode 100644 index 000000000..7359c64bc --- /dev/null +++ b/DashAI/back/job/forecasting_job.py @@ -0,0 +1,447 @@ +"""Forecasting-specific job for time series model training.""" + +import gc +import json +import logging +import os +import pickle +from typing import List + +from kink import inject +from sqlalchemy import exc +from sqlalchemy.orm import sessionmaker + +from DashAI.back.core.enums.metrics import LevelEnum, SplitEnum +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + load_dataset, + prepare_for_forecasting_experiment, + select_columns, + split_dataset, +) +from DashAI.back.dependencies.database.models import Dataset, ModelSession, Run +from DashAI.back.job.base_job import BaseJob, JobError +from DashAI.back.metrics.base_metric import BaseMetric +from DashAI.back.models.base_model import BaseModel +from DashAI.back.models.model_factory import ModelFactory +from DashAI.back.optimizers.base_optimizer import BaseOptimizer +from DashAI.back.tasks.base_task import BaseTask + +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) + + +class ForecastingJob(BaseJob): + """ForecastingJob class for time series model training with temporal splitting.""" + + @inject + def set_status_as_delivered( + self, session_factory: sessionmaker = lambda di: di["session_factory"] + ) -> None: + """Set the status of the job as delivered.""" + run_id: int = self.kwargs["run_id"] + + with session_factory() as db: + run: Run = db.get(Run, run_id) + if not run: + raise JobError(f"Run {run_id} does not exist in DB.") + try: + run.set_status_as_delivered() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Internal database error", + ) from e + + @inject + def set_status_as_error( + self, session_factory: sessionmaker = lambda di: di["session_factory"] + ) -> None: + """Set the status of the job as error.""" + run_id: int = self.kwargs.get("run_id") + if run_id is None: + return + + with session_factory() as db: + run: Run = db.get(Run, run_id) + if not run: + return + try: + run.set_status_as_error() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + + @inject + def get_job_name(self) -> str: + """Get a descriptive name for the job.""" + run_id = self.kwargs.get("run_id") + if not run_id: + return "Forecasting Training" + + from kink import di + + session_factory = di["session_factory"] + + try: + with session_factory() as db: + run: Run = db.get(Run, run_id) + if run and run.name: + return f"Forecast: {run.name}" + except Exception: + pass + + return f"Forecasting Training ({run_id})" + + @inject + def run(self) -> None: + from kink import di + + component_registry = di["component_registry"] + session_factory = di["session_factory"] + config = di["config"] + + # Get the necessary parameters + run_id: int = self.kwargs["run_id"] + + with session_factory() as db: + run: Run = db.get(Run, run_id) + run.huey_id = self.kwargs.get("huey_id", None) + db.commit() + try: + # Get the model session, dataset, task, metrics and splits + model_session: ModelSession = db.get(ModelSession, run.model_session_id) + if not model_session: + raise JobError( + f"Model session {run.model_session_id} does not exist in DB." + ) + dataset: Dataset = db.get(Dataset, model_session.dataset_id) + if not dataset: + raise JobError( + f"Dataset {model_session.dataset_id} does not exist in DB." + ) + + try: + loaded_dataset: DashAIDataset = load_dataset( + f"{dataset.file_path}/dataset" + ) + except Exception as e: + log.exception(e) + raise JobError( + f"Can not load dataset from path {dataset.file_path}", + ) from e + + try: + task: BaseTask = component_registry[model_session.task_name][ + "class" + ]() + except Exception as e: + log.exception(e) + raise JobError( + ( + f"Unable to find Task with name {model_session.task_name} " + "in registry" + ), + ) from e + + # Validate this is a forecasting task + if model_session.task_name != "ForecastingTask": + raise JobError( + f"ForecastingJob can only be used with ForecastingTask, " + f"got {model_session.task_name}" + ) + + try: + # Get metrics selected in the model session + train_metrics: List[BaseMetric] = [ + component_registry[m]["class"] + for m in model_session.train_metrics + ] + validation_metrics: List[BaseMetric] = [ + component_registry[m]["class"] + for m in model_session.validation_metrics + ] + test_metrics: List[BaseMetric] = [ + component_registry[m]["class"] + for m in model_session.test_metrics + ] + except Exception as e: + log.exception(e) + raise JobError( + "Unable to find metrics associated with" + f"Task {model_session.task_name} in registry", + ) from e + + try: + # Prepare dataset for forecasting task with auto-detection + prepared_dataset = task.prepare_for_task( + loaded_dataset, + outputs_columns=model_session.output_columns, + inputs_columns=model_session.input_columns, + # Optional: Override auto-detection if specified + timestamp_column=getattr( + model_session, "timestamp_column", None + ), + frequency=getattr(model_session, "frequency", "auto"), + ) + + # Get temporal metadata for logging + temporal_metadata = task.get_temporal_metadata() + log.info(f"Temporal metadata: {temporal_metadata}") + + splits = json.loads(model_session.splits) + + # Use forecasting-specific preparation with temporal splitting + prepared_dataset, splits = prepare_for_forecasting_experiment( + dataset=prepared_dataset, + splits=splits, + timestamp_col=temporal_metadata.get("timestamp_col", "ds"), + output_columns=model_session.output_columns, + ) + + run.split_indexes = json.dumps( + { + "train_indexes": splits["train_indexes"], + "test_indexes": splits["test_indexes"], + "val_indexes": splits["val_indexes"], + } + ) + + x, y = select_columns( + prepared_dataset, + model_session.input_columns, + model_session.output_columns, + ) + + x = split_dataset(x) + y = split_dataset(y) + + except Exception as e: + log.exception(e) + raise JobError( + f"""Can not prepare Dataset {dataset.id} + for ForecastingTask {model_session.task_name}""", + ) from e + + try: + run_model_class = component_registry[run.model_name]["class"] + except Exception as e: + log.exception(e) + raise JobError( + f"Unable to find Model with name {run.model_name} in registry.", + ) from e + + # Validate model is compatible with forecasting. + compatible_tasks = getattr(run_model_class, "_compatible_tasks", None) + if compatible_tasks is None: + compatible_tasks = getattr( + run_model_class, "COMPATIBLE_COMPONENTS", None + ) + + if compatible_tasks is None: + log.warning( + f"Model {run.model_name} does not specify task compatibility" + ) + elif "ForecastingTask" not in compatible_tasks: + raise JobError( + f"Model {run.model_name} is not compatible with ForecastingTask" + ) + + try: + factory = ModelFactory( + run_model_class, + run.parameters, + run_id, + x, + y, + train_metrics, + validation_metrics, + test_metrics, + # No n_labels for forecasting tasks + n_labels=None, + ) + model: BaseModel = factory.model + run_optimizable_parameters = factory.optimizable_parameters + + except Exception as e: + log.exception(e) + raise JobError( + f"Unable to instantiate forecasting model using run {run_id}", + ) from e + + # Handle hyperparameter optimization for forecasting + if run_optimizable_parameters: + try: + # Optimizer configuration + run_optimizer_class = component_registry[run.optimizer_name][ + "class" + ] + except Exception as e: + log.exception(e) + raise JobError( + f"Unable to find Optimizer with name " + f"{run.optimizer_name} in registry.", + ) from e + + if run.goal_metric != "": + try: + goal_metric = component_registry[run.goal_metric] + except Exception as e: + log.exception(e) + raise JobError( + "Metric is not compatible with the ForecastingTask", + ) from e + try: + optimizer: BaseOptimizer = run_optimizer_class( + **run.optimizer_parameters + ) + except Exception as e: + log.exception(e) + raise JobError( + ( + "Optimizer parameters not compatible " + "with the optimizer" + ), + ) from e + + try: + run.set_status_as_started() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Connection with the database failed", + ) from e + + try: + # Forecasting model training + plot_paths = [] + if not run_optimizable_parameters: + # Simple fit with forecasting-specific parameters + # Pass temporal metadata to model for column information + if hasattr(model, "fit") and hasattr(model, "_task_type"): + model.fit( + x["train"], + y["train"], + temporal_metadata=temporal_metadata, + ) + else: + model.fit(x["train"], y["train"]) + else: + # Hyperparameter optimization for forecasting + optimizer.optimize( + model, + x, + y, + run_optimizable_parameters, + goal_metric, + model_session.task_name, + ) + model = optimizer.get_model() + # Generate hyperparameter plot + trials = optimizer.get_trials_values() + plot_filenames, plots = optimizer.create_plots( + trials, run_id, n_params=len(run_optimizable_parameters) + ) + for filename, plot in zip(plot_filenames, plots, strict=True): + plot_path = os.path.join(config["RUNS_PATH"], filename) + with open(plot_path, "wb") as file: + pickle.dump(plot, file) + plot_paths.append(plot_path) + + except Exception as e: + log.exception(e) + raise JobError( + "Forecasting model training failed", + ) from e + + try: + paths = plot_paths + [None] * (4 - len(plot_paths)) + ( + run.plot_history_path, + run.plot_slice_path, + run.plot_contour_path, + run.plot_importance_path, + ) = paths[:4] + db.commit() + except Exception as e: + log.exception(e) + raise JobError( + "Hyperparameter plot path saving failed", + ) from e + + try: + run.set_status_as_finished() + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise JobError( + "Connection with the database failed", + ) from e + + try: + if train_metrics: + model.calculate_metrics( + split=SplitEnum.TRAIN, + level=LevelEnum.LAST, + ) + if validation_metrics: + model.calculate_metrics( + split=SplitEnum.VALIDATION, + level=LevelEnum.LAST, + ) + if test_metrics: + model.calculate_metrics( + split=SplitEnum.TEST, + level=LevelEnum.LAST, + ) + except Exception as e: + log.exception(e) + raise JobError( + "Forecasting metrics calculation failed", + ) from e + + try: + run_path = os.path.join(config["RUNS_PATH"], str(run.id)) + model.save(run_path) + + # Save forecasting-specific artifacts + if hasattr(model, "get_forecast_components"): + try: + # Save forecast components for interpretation + components = model.get_forecast_components(horizon=30) + components_path = os.path.join( + config["RUNS_PATH"], + f"{run.id}_forecast_components.csv", + ) + components.to_csv(components_path, index=False) + log.info(f"Saved forecast components to {components_path}") + except Exception as e: + log.warning(f"Could not save forecast components: {e}") + + except Exception as e: + log.exception(e) + raise JobError( + "Forecasting model saving failed", + ) from e + + try: + run.run_path = run_path + db.commit() + log.info( + f"✅ ForecastingJob completed successfully for run {run_id}" + ) + except exc.SQLAlchemyError as e: + log.exception(e) + run.set_status_as_error() + db.commit() + raise JobError( + "Connection with the database failed", + ) from e + except Exception as e: + run.set_status_as_error() + db.commit() + raise e + finally: + gc.collect() diff --git a/DashAI/back/job/model_job.py b/DashAI/back/job/model_job.py index 3aaff04bc..7cd0b9d1f 100644 --- a/DashAI/back/job/model_job.py +++ b/DashAI/back/job/model_job.py @@ -183,6 +183,11 @@ def run( prepared_dataset, model_session.output_columns[0] ) + # Get temporal metadata for forecasting tasks + temporal_metadata = None + if model_session.task_name == "ForecastingTask": + temporal_metadata = task.get_temporal_metadata() + splits = json.loads(model_session.splits) prepared_dataset, splits = prepare_for_model_session( dataset=prepared_dataset, @@ -275,9 +280,20 @@ def run( # Hyperparameter Tunning plot_paths = [] if not run_optimizable_parameters: - model.train( - x["train"], y["train"], x["validation"], y["validation"] - ) + # Forecasting models use fit() with temporal_metadata + if model_session.task_name == "ForecastingTask": + model.fit( + x["train"], + y["train"], + temporal_metadata=temporal_metadata, + ) + else: + model.train( + x["train"], + y["train"], + x["validation"], + y["validation"], + ) else: optimizer.optimize( model, diff --git a/DashAI/back/job/predict_job.py b/DashAI/back/job/predict_job.py index 54ca2ac31..2a41511dd 100644 --- a/DashAI/back/job/predict_job.py +++ b/DashAI/back/job/predict_job.py @@ -1,7 +1,11 @@ import logging +import math +import uuid from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Tuple +import numpy as np +import pandas as pd from fastapi import status from fastapi.encoders import jsonable_encoder from fastapi.exceptions import HTTPException @@ -18,6 +22,35 @@ from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset + +def sanitize_for_json(value): + """Convert NaN/Inf float values to None for JSON serialization. + + Parameters + ---------- + value : Any + Value to sanitize (can be list, dict, float, etc.) + + Returns + ------- + Any + Sanitized value with NaN/Inf replaced by None + """ + if isinstance(value, dict): + return {k: sanitize_for_json(v) for k, v in value.items()} + elif isinstance(value, list): + return [sanitize_for_json(item) for item in value] + elif isinstance(value, float): + if math.isnan(value) or math.isinf(value): + return None + return value + elif isinstance(value, np.floating): + if np.isnan(value) or np.isinf(value): + return None + return float(value) + return value + + logging.basicConfig(level=logging.DEBUG) log = logging.getLogger(__name__) @@ -255,14 +288,185 @@ def get_job_name(self) -> str: return f"Prediction (Prediction:{prediction_id}, Dataset:{dataset_id})" + def _validate_forecasting_dataset( + self, + dataset: "DashAIDataset", + model_session, + trained_model: Any, + train_dataset: "DashAIDataset" = None, + ) -> str: + """Validate dataset for forecasting prediction. + + Parameters + ---------- + dataset : DashAIDataset + The prediction dataset to validate. + model_session : ModelSession + The model session associated with the prediction. + trained_model : Any + The loaded trained model instance. + train_dataset : DashAIDataset, optional + The training dataset (used for backcasting validation). + + Returns + ------- + str + The name of the detected timestamp column + + Raises + ------ + HTTPException + If dataset is invalid for forecasting + """ + pred_df = dataset.to_pandas() + + # Auto-detect timestamp column (try 'ds' first for compatibility) + timestamp_col = None + if "ds" in pred_df.columns: + timestamp_col = "ds" + else: + for col in pred_df.columns: + try: + pd.to_datetime(pred_df[col]) + timestamp_col = col + log.info(f"Auto-detected timestamp column: '{col}'") + break + except Exception: + continue + + if timestamp_col is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Forecasting prediction requires a timestamp column " + f"(datetime). Available columns: {list(pred_df.columns)}", + ) + + # Parse and validate timestamps + try: + ds_series = pd.to_datetime(pred_df[timestamp_col]) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=(f"Cannot parse '{timestamp_col}' column as datetime: {str(e)}"), + ) from e + + # Check for duplicates + if ds_series.duplicated().any(): + duplicates = ds_series[ds_series.duplicated()].unique()[:5].tolist() + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + f"Duplicate timestamps found in '{timestamp_col}' column: " + f"{duplicates}" + ), + ) + + # Check monotonicity (strictly increasing) + if not ds_series.is_monotonic_increasing: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + f"Timestamps in '{timestamp_col}' column must be strictly " + "increasing (sorted)." + ), + ) + + # Get training metadata from model + train_frequency = getattr(trained_model, "frequency", None) + train_last_ds = getattr(trained_model, "last_ds", None) + exog_cols = getattr(trained_model, "exog_cols", []) + + log.info( + f"Training metadata - frequency: {train_frequency}, " + f"last_ds: {train_last_ds}, exog_cols: {exog_cols}" + ) + + # Validate frequency consistency (if available) + if train_frequency and len(ds_series) >= 2: + try: + inferred_freq = pd.infer_freq(ds_series) + if inferred_freq and inferred_freq != train_frequency: + log.warning( + f"Frequency mismatch: training={train_frequency}, " + f"prediction={inferred_freq}" + ) + except Exception: + log.warning("Could not infer frequency from prediction dataset") + + # Check for backcasting (dates before training start) + if train_last_ds and train_dataset is not None: + try: + train_df = train_dataset.to_pandas() + + # Auto-detect timestamp in training data + train_timestamp_col = None + if "ds" in train_df.columns: + train_timestamp_col = "ds" + else: + for col in train_df.columns: + try: + pd.to_datetime(train_df[col]) + train_timestamp_col = col + break + except Exception: + continue + + if train_timestamp_col: + train_ds_series = pd.to_datetime(train_df[train_timestamp_col]) + train_start = train_ds_series.min() + + min_pred_ds = ds_series.min() + if min_pred_ds < train_start: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + f"Requested timestamps precede the training " + f"window start (train_start = {train_start}). " + f"Retrain the model including those dates or " + f"submit only in-sample/future dates." + ), + ) + except HTTPException: + raise + except Exception as e: + log.warning(f"Could not validate backcasting: {e}") + + # Validate exogenous regressors + if exog_cols: + missing_exog = [col for col in exog_cols if col not in pred_df.columns] + if missing_exog: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + f"Missing required exogenous columns for prediction: " + f"{missing_exog}. The model was trained with these " + f"regressors and requires values for all prediction " + f"timestamps." + ), + ) + + # Check for NaN values in exogenous columns + for col in exog_cols: + if pred_df[col].isna().any(): + nan_count = pred_df[col].isna().sum() + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + f"Exogenous column '{col}' contains {nan_count} " + f"missing values. All exogenous regressors must have " + f"values for every timestamp." + ), + ) + + log.info(f"Forecasting validation passed for {len(ds_series)} timestamps") + return timestamp_col + @inject def run( self, - ) -> List[Any]: - import uuid - from pathlib import Path - + ) -> None: from DashAI.back.dataloaders.classes.dashai_dataset import ( + get_arrow_table, load_dataset, save_dataset, to_dashai_dataset, @@ -273,7 +477,8 @@ def run( config = di["config"] prediction_id: int = self.kwargs["prediction_id"] - manual_input_data: List[dict] = self.kwargs.get("manual_input_data", []) + manual_input_data: list = self.kwargs.get("manual_input_data", []) + forecast_periods = self.kwargs.get("forecast_periods") with session_factory() as db: try: @@ -292,12 +497,17 @@ def run( dataset_id = prediction.dataset_id - # Validate input data - if not manual_input_data and not dataset_id: + # Validate input data (forecast_periods also valid for forecasting) + if ( + not manual_input_data + and not dataset_id + and forecast_periods is None + ): prediction.set_status_as_error() db.commit() raise JobError( - "Either dataset_id or manual_input_data must be provided." + "Either dataset_id, manual_input_data, or " + "forecast_periods must be provided." ) # Retrieve Model Session @@ -370,7 +580,7 @@ def run( ) from e try: - trained_model: BaseModel = model.load(prediction.run.run_path) + trained_model: BaseModel = model().load(prediction.run.run_path) except Exception as e: prediction.set_status_as_error() db.commit() @@ -393,70 +603,318 @@ def run( f"{dataset_trained.file_path}/dataset/" ) from e - try: - # Load or create prediction dataset - if dataset_id: - loaded_dataset: "DashAIDataset" = load_dataset( - str(Path(f"{dataset.file_path}/dataset/")) + # Determine if this is a forecasting task + is_forecasting = model_session.task_name == "ForecastingTask" + timestamp_col = "ds" # default; overwritten inside forecasting block + + if is_forecasting: + # ============ FORECASTING PREDICTION ============ + try: + if forecast_periods is not None: + # --- Auto-generate future timestamps --- + log.info( + f"Auto-generating {forecast_periods} future timestamps" + ) + + timestamp_col = "ds" + frequency = getattr(trained_model, "frequency", "D") + if frequency is None: + frequency = "D" + + # Get last training date from the FULL dataset. + # trained_model.last_ds only stores the end of the + # training split, not the end of the full dataset. + last_ds = None + try: + train_df = get_arrow_table(train_dataset).to_pandas() + if "ds" in train_df.columns: + last_ds = pd.to_datetime(train_df["ds"]).max() + timestamp_col = "ds" + else: + for col in train_df.columns: + try: + ds_series = pd.to_datetime(train_df[col]) + last_ds = ds_series.max() + timestamp_col = col + break + except Exception: + continue + except Exception as e: + log.warning(f"Could not read training dataset: {e}") + + # Fall back to model attribute if dataset read fails + if last_ds is None: + last_ds = getattr(trained_model, "last_ds", None) + if last_ds is None: + last_ds = getattr(trained_model, "last_timestamp", None) + + if last_ds is None: + raise JobError( + "Cannot auto-generate timestamps: Unable to " + "determine the last training date. " + "Please use a dataset instead." + ) + + last_training_date = pd.to_datetime(last_ds) + + # Check exogenous regressors + exog_cols = getattr(trained_model, "exog_cols", []) + if exog_cols: + raise HTTPException( + status_code=(status.HTTP_422_UNPROCESSABLE_ENTITY), + detail=( + f"Cannot auto-generate predictions for " + f"models with exogenous variables " + f"({exog_cols}). Please upload a dataset " + f"with timestamps and exogenous values." + ), + ) + + # Generate future timestamps using DateOffset + freq_offset_map = { + "D": pd.DateOffset(days=1), + "H": pd.DateOffset(hours=1), + "W": pd.DateOffset(weeks=1), + "M": pd.DateOffset(months=1), + "MS": pd.DateOffset(months=1), + "ME": pd.DateOffset(months=1), + "Y": pd.DateOffset(years=1), + "YS": pd.DateOffset(years=1), + "YE": pd.DateOffset(years=1), + "A": pd.DateOffset(years=1), + "AS": pd.DateOffset(years=1), + "Q": pd.DateOffset(months=3), + "QS": pd.DateOffset(months=3), + "QE": pd.DateOffset(months=3), + } + + first_offset = freq_offset_map.get(frequency) + if first_offset is None: + try: + first_offset = pd.Timedelta(1, unit=frequency[0]) + except ValueError: + log.warning( + "Unknown frequency '%s', defaulting to 1 day", + frequency, + ) + first_offset = pd.DateOffset(days=1) + + start_date = last_training_date + first_offset + + freq_alias_map = { + "M": "MS", + "Y": "YS", + "A": "YS", + "Q": "QS", + "ME": "MS", + "YE": "YS", + "QE": "QS", + } + safe_freq = freq_alias_map.get(frequency, frequency) + + future_dates = pd.date_range( + start=start_date, + periods=forecast_periods, + freq=safe_freq, + ) + future_df = pd.DataFrame({timestamp_col: future_dates}) + + log.info( + f"Generated timestamps from {future_dates[0]} " + f"to {future_dates[-1]}" + ) + + else: + # --- Use uploaded dataset for forecasting --- + if dataset_id: + loaded_dataset = load_dataset( + str(Path(f"{dataset.file_path}/dataset/")) + ) + elif manual_input_data: + dataset_trained_path = str( + Path(f"{dataset_trained.file_path}/dataset/") + ) + loaded_dataset = task.process_manual_input( + manual_input_data, dataset_trained_path + ) + else: + raise JobError( + "Either dataset_id, manual_input_data, or " + "forecast_periods must be provided for " + "forecasting." + ) + + # Validate forecasting dataset + timestamp_col = self._validate_forecasting_dataset( + loaded_dataset, + model_session, + trained_model, + train_dataset, + ) + + pred_df = loaded_dataset.to_pandas() + exog_cols = getattr(trained_model, "exog_cols", []) + future_cols = [timestamp_col] + exog_cols + available_cols = [ + col for col in future_cols if col in pred_df.columns + ] + + if timestamp_col not in available_cols: + raise JobError( + f"Forecasting prediction requires " + f"'{timestamp_col}' column in dataset" + ) + + future_df = pred_df[available_cols].copy() + future_df[timestamp_col] = pd.to_datetime( + future_df[timestamp_col] + ) + + # Call model.predict with the future_df + log.info(f"Predicting on {len(future_df)} timestamps") + predictions = trained_model.predict(future_df) + + # Handle different prediction formats + if hasattr(predictions, "yhat"): + y_pred = predictions["yhat"].to_numpy() + elif isinstance(predictions, np.ndarray): + y_pred = predictions + else: + y_pred = np.array(predictions) + + # Build result dataset: timestamp + prediction + output_col = ( + model_session.output_columns[0] + if model_session.output_columns + else "prediction" ) - else: - dataset_trained_path = str( - Path(f"{dataset_trained.file_path}/dataset/") + result_df = future_df[[timestamp_col]].copy() + result_df[output_col] = y_pred + + # Add confidence intervals if available + if hasattr(predictions, "columns"): + if "yhat_lower" in predictions.columns: + result_df["yhat_lower"] = predictions[ + "yhat_lower" + ].to_numpy() + if "yhat_upper" in predictions.columns: + result_df["yhat_upper"] = predictions[ + "yhat_upper" + ].to_numpy() + + # Convert timestamp column to string so DashAI stores it + # consistently (Arrow datetime types lack DashAI metadata). + if pd.api.types.is_datetime64_any_dtype(result_df[timestamp_col]): + result_df[timestamp_col] = result_df[timestamp_col].dt.strftime( + "%Y-%m-%d" + ) + + dataset_with_prediction = to_dashai_dataset(result_df) + + except HTTPException: + prediction.set_status_as_error() + db.commit() + raise + except ValueError as ve: + prediction.set_status_as_error() + db.commit() + log.error(f"Validation Error: {ve}") + raise HTTPException( + status_code=400, + detail=f"Invalid input data: {str(ve)}", + ) from ve + except Exception as e: + prediction.set_status_as_error() + db.commit() + log.error(e) + raise JobError( + "Forecasting prediction failed", + ) from e + + else: + # ============ STANDARD PREDICTION ============ + try: + # Load or create prediction dataset + if dataset_id: + loaded_dataset: DashAIDataset = load_dataset( + str(Path(f"{dataset.file_path}/dataset/")) + ) + else: + dataset_trained_path = str( + Path(f"{dataset_trained.file_path}/dataset/") + ) + loaded_dataset = task.process_manual_input( + manual_input_data, dataset_trained_path + ) + + # Select input columns and make prediction + prepared_dataset = loaded_dataset.select_columns( + model_session.input_columns ) - loaded_dataset = task.process_manual_input( - manual_input_data, dataset_trained_path + y_pred_proba = np.array(trained_model.predict(prepared_dataset)) + + # Process predictions (convert to labels for classification) + y_pred = task.process_predictions( + train_dataset, + y_pred_proba, + model_session.output_columns[0], ) - prepared_dataset, y_pred = _run_prediction_pipeline( - task=task, - trained_model=trained_model, - train_dataset=train_dataset, - loaded_dataset=loaded_dataset, - model_session=model_session, - ) + # Build dataset with predictions + dataset_with_prediction = to_dashai_dataset( + prepared_dataset.add_column( + model_session.output_columns[0], y_pred + ) + ) - except ValueError as ve: - prediction.set_status_as_error() - db.commit() - log.error(f"Validation Error: {ve}") - raise HTTPException( - status_code=400, - detail=f"Invalid input data: {str(ve)}", - ) from ve - except TypeError as te: - log.error(f"Type Error: {te}") - raise HTTPException( - status_code=400, - detail=f"Type validation failed: {str(te)}", - ) from te - except Exception as e: - prediction.set_status_as_error() - db.commit() - log.error(e) - raise JobError( - "Model prediction failed", - ) from e + except ValueError as ve: + prediction.set_status_as_error() + db.commit() + log.error(f"Validation Error: {ve}") + raise HTTPException( + status_code=400, + detail=f"Invalid input data: {str(ve)}", + ) from ve + except TypeError as te: + log.error(f"Type Error: {te}") + raise HTTPException( + status_code=400, + detail=f"Type validation failed: {str(te)}", + ) from te + except Exception as e: + prediction.set_status_as_error() + db.commit() + log.error(e) + raise JobError( + "Model prediction failed", + ) from e - # Save Predictions to Arrow file + # ============ SAVE PREDICTIONS TO ARROW ============ try: - # Create unique folder for predictions path = str(Path(f"{config['DATASETS_PATH']}/predictions/")) folder_name = str(uuid.uuid4()) full_path = Path(path) / folder_name full_path.mkdir(parents=True, exist_ok=True) - # Add predictions to loaded dataset - dataset_with_prediction = to_dashai_dataset( - prepared_dataset.add_column(model_session.output_columns[0], y_pred) - ) - - # Filter schema from trained dataset - trained_schema = train_dataset.types - filtered_schema = { - key: value.to_string() - for key, value in trained_schema.items() - if key in model_session.input_columns + model_session.output_columns - } + # Build schema for the saved dataset + if is_forecasting: + # Build a proper schema so the Arrow file has DashAI type + # metadata. Empty schema {} causes transform_dataset_with_schema + # to produce an empty table with no metadata → 404 on read. + filtered_schema = { + col: {"dtype": "string"} + if col == timestamp_col + else {"dtype": "float64"} + for col in dataset_with_prediction.column_names + } + else: + trained_schema = train_dataset.types + filtered_schema = { + key: value.to_string() + for key, value in trained_schema.items() + if key + in model_session.input_columns + model_session.output_columns + } # Store num of rows, columns, and column names dataset_with_prediction.compute_base_metadata() @@ -472,10 +930,11 @@ def run( prediction.results_path = str(full_path) prediction.set_status_as_finished() db.commit() + except Exception as e: prediction.set_status_as_error() db.commit() log.exception(e) raise JobError( - "Can not save prediction to json file", + "Cannot save prediction results", ) from e diff --git a/DashAI/back/metrics/base_metric.py b/DashAI/back/metrics/base_metric.py index f43290a4f..a99836de5 100644 --- a/DashAI/back/metrics/base_metric.py +++ b/DashAI/back/metrics/base_metric.py @@ -9,7 +9,18 @@ class BaseMetric: - """Abstract class of all metrics.""" + """Abstract class of all metrics. + + Attributes + ---------- + HIGHER_IS_BETTER : bool + Indicates the optimization direction for this metric. + - True: Higher values are better (e.g., Accuracy, F1) + - False: Lower values are better (e.g., MAE, RMSE, SMAPE) + + This attribute is used by hyperparameter optimizers to determine + whether to maximize or minimize the metric during optimization. + """ TYPE: Final[str] = "Metric" MAXIMIZE: Final[bool] = False diff --git a/DashAI/back/metrics/classification/accuracy.py b/DashAI/back/metrics/classification/accuracy.py index 05dcf78eb..b669a3862 100644 --- a/DashAI/back/metrics/classification/accuracy.py +++ b/DashAI/back/metrics/classification/accuracy.py @@ -12,7 +12,12 @@ class Accuracy(ClassificationMetric): - """Accuracy metric to classification tasks.""" + """Accuracy metric to classification tasks. + + Higher accuracy values are better (range: 0.0 to 1.0). + """ + + HIGHER_IS_BETTER = True DESCRIPTION: str = ( "Proportion of correct predictions over all samples, " diff --git a/DashAI/back/metrics/classification/f1.py b/DashAI/back/metrics/classification/f1.py index 9cf4102f2..6935ca601 100644 --- a/DashAI/back/metrics/classification/f1.py +++ b/DashAI/back/metrics/classification/f1.py @@ -12,7 +12,12 @@ class F1(ClassificationMetric): - """F1 score to classification tasks.""" + """F1 score to classification tasks. + + Higher F1 values are better (range: 0.0 to 1.0). + """ + + HIGHER_IS_BETTER = True DESCRIPTION: str = ( "Harmonic mean of precision and recall, " diff --git a/DashAI/back/metrics/classification/precision.py b/DashAI/back/metrics/classification/precision.py index e7f5be067..2d70f7fd4 100644 --- a/DashAI/back/metrics/classification/precision.py +++ b/DashAI/back/metrics/classification/precision.py @@ -12,7 +12,12 @@ class Precision(ClassificationMetric): - """Precision metric to classification tasks.""" + """Precision metric to classification tasks. + + Higher precision values are better (range: 0.0 to 1.0). + """ + + HIGHER_IS_BETTER = True DESCRIPTION: str = ( "Fraction of predicted positives that are correct, " diff --git a/DashAI/back/metrics/classification/recall.py b/DashAI/back/metrics/classification/recall.py index a472909ba..0081eb267 100644 --- a/DashAI/back/metrics/classification/recall.py +++ b/DashAI/back/metrics/classification/recall.py @@ -12,7 +12,12 @@ class Recall(ClassificationMetric): - """Recall metric to classification tasks.""" + """Recall metric to classification tasks. + + Higher recall values are better (range: 0.0 to 1.0). + """ + + HIGHER_IS_BETTER = True DESCRIPTION: str = ( "Fraction of actual positives correctly identified, " diff --git a/DashAI/back/metrics/forecasting/__init__.py b/DashAI/back/metrics/forecasting/__init__.py new file mode 100644 index 000000000..d6dbff035 --- /dev/null +++ b/DashAI/back/metrics/forecasting/__init__.py @@ -0,0 +1,9 @@ +"""Forecasting metrics for time series evaluation.""" + +from .mape import MAPE +from .smape import SMAPE + +__all__ = [ + "MAPE", + "SMAPE", +] diff --git a/DashAI/back/metrics/forecasting/mape.py b/DashAI/back/metrics/forecasting/mape.py new file mode 100644 index 000000000..062af18ee --- /dev/null +++ b/DashAI/back/metrics/forecasting/mape.py @@ -0,0 +1,55 @@ +"""Mean Absolute Percentage Error (MAPE) metric for forecasting.""" + +import numpy as np + +from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset +from DashAI.back.metrics.regression_metric import RegressionMetric, prepare_to_metric + + +class MAPE(RegressionMetric): + """Mean Absolute Percentage Error metric for forecasting tasks. + + MAPE measures the average absolute percentage difference between + predicted and actual values. It's scale-independent and easy to interpret. + + MAPE = (1/n) * Σ|((y_true - y_pred) / y_true)| * 100 + + Note: MAPE can be problematic when true values are close to zero. + """ + + COMPATIBLE_COMPONENTS = [ + "RegressionTask", + "ForecastingTask", + ] + + @staticmethod + def score(true_values: DashAIDataset, predicted_values: np.ndarray) -> float: + """Calculate MAPE between true values and predicted values. + + Parameters + ---------- + true_values : DashAIDataset + A DashAI dataset with true values. + predicted_values : np.ndarray + Array with the predicted values for each instance. + + Returns + ------- + float + MAPE score as percentage (0-100, lower is better) + """ + true_values, pred_values = prepare_to_metric(true_values, predicted_values) + + # Handle zero values in denominator + mask = np.abs(true_values) > 1e-8 # Avoid division by very small numbers + + if not np.any(mask): + # All true values are essentially zero + return 0.0 if np.allclose(pred_values, 0) else 100.0 + + # Calculate MAPE only for non-zero true values + mape_values = np.abs( + (true_values[mask] - pred_values[mask]) / true_values[mask] + ) + + return float(np.mean(mape_values) * 100) diff --git a/DashAI/back/metrics/forecasting/smape.py b/DashAI/back/metrics/forecasting/smape.py new file mode 100644 index 000000000..414b76a10 --- /dev/null +++ b/DashAI/back/metrics/forecasting/smape.py @@ -0,0 +1,56 @@ +"""Symmetric Mean Absolute Percentage Error (sMAPE) metric for forecasting.""" + +import numpy as np + +from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset +from DashAI.back.metrics.regression_metric import RegressionMetric, prepare_to_metric + + +class SMAPE(RegressionMetric): + """Symmetric Mean Absolute Percentage Error metric for forecasting tasks. + + sMAPE is a more stable version of MAPE that handles zero values better + by using the average of actual and predicted values in the denominator. + + sMAPE = (2/n) * Σ|(y_true - y_pred)| / (|y_true| + |y_pred|) * 100 + + sMAPE is bounded between 0% and 200%, making it more stable than MAPE. + """ + + COMPATIBLE_COMPONENTS = [ + "RegressionTask", + "ForecastingTask", + ] + + @staticmethod + def score(true_values: DashAIDataset, predicted_values: np.ndarray) -> float: + """Calculate sMAPE between true values and predicted values. + + Parameters + ---------- + true_values : DashAIDataset + A DashAI dataset with true values. + predicted_values : np.ndarray + Array with the predicted values for each instance. + + Returns + ------- + float + sMAPE score as percentage (0-200, lower is better) + """ + true_values, pred_values = prepare_to_metric(true_values, predicted_values) + + # Calculate symmetric denominator + denominator = np.abs(true_values) + np.abs(pred_values) + + # Handle zero denominator (both actual and predicted are zero) + mask = denominator > 1e-8 + + if not np.any(mask): + # All values are essentially zero + return 0.0 + + # Calculate sMAPE + smape_values = np.abs(true_values[mask] - pred_values[mask]) / denominator[mask] + + return float(np.mean(smape_values) * 200) diff --git a/DashAI/back/metrics/regression_metric.py b/DashAI/back/metrics/regression_metric.py index 08353d6a5..0eb872125 100644 --- a/DashAI/back/metrics/regression_metric.py +++ b/DashAI/back/metrics/regression_metric.py @@ -39,4 +39,14 @@ def prepare_to_metric( true_values = np.array(y.to_pandas().to_numpy().flatten()) pred_values = np.array(y_pred).flatten() + # Filter out NaN values (common in forecasting with lag features) + valid_mask = ~(np.isnan(true_values) | np.isnan(pred_values)) + n_nan = np.sum(~valid_mask) + if n_nan > 0: + true_values = true_values[valid_mask] + pred_values = pred_values[valid_mask] + + if len(true_values) == 0: + raise ValueError("All values are NaN after filtering. Cannot compute metrics.") + return true_values, pred_values diff --git a/DashAI/back/models/forecasting/__init__.py b/DashAI/back/models/forecasting/__init__.py new file mode 100644 index 000000000..a1e826587 --- /dev/null +++ b/DashAI/back/models/forecasting/__init__.py @@ -0,0 +1,15 @@ +"""Forecasting models for time series prediction.""" + +from .base_forecasting_model import ForecastingModel +from .prophet_model import ProphetModel +from .sklearn_multistep_forecaster import SklearnMultiStepForecaster +from .statsmodels_arima_model import StatsmodelsARIMAModel +from .statsmodels_sarimax_model import StatsmodelsSARIMAXModel + +__all__ = [ + "ForecastingModel", + "ProphetModel", + "SklearnMultiStepForecaster", + "StatsmodelsARIMAModel", + "StatsmodelsSARIMAXModel", +] diff --git a/DashAI/back/models/forecasting/base_forecasting_model.py b/DashAI/back/models/forecasting/base_forecasting_model.py new file mode 100644 index 000000000..bfa5d5af9 --- /dev/null +++ b/DashAI/back/models/forecasting/base_forecasting_model.py @@ -0,0 +1,329 @@ +"""Forecasting Model abstract class. + +This module defines the common interface for all forecasting models in DashAI. +It ensures model-agnostic handling of time series data and exogenous variables. +""" + +import warnings +from abc import abstractmethod +from typing import List, Optional + +import pandas as pd + +from DashAI.back.models.base_model import BaseModel + + +class ForecastingModel(BaseModel): + """Abstract class for all forecasting models. + + This class defines the common interface that all forecasting models must implement. + It handles: + - Exogenous variables (external regressors) in a model-agnostic way + - Timestamp and target column management + - Prediction interface for both in-sample and out-of-sample forecasting + + Key Attributes + -------------- + exog_cols : List[str] + List of exogenous variable column names used during training. + These are stored in their ORIGINAL names from the dataset, + not in any model-specific format. + + timestamp_col : Optional[str] + Name of the timestamp/datetime column in the original dataset. + + target_col : Optional[str] + Name of the target variable column in the original dataset. + + Philosophy + ---------- + This class maintains column names in their ORIGINAL format from the user's + dataset. Each specific model implementation (Prophet, ARIMA, etc.) is + responsible for: + 1. Internally converting column names to its required format + (e.g., Prophet needs 'ds'/'y') + 2. Converting predictions back to match original column names + 3. Handling model-specific requirements transparently + + This ensures the system is agnostic to each model's internal conventions. + + Note + ---- + This class inherits TYPE = "Model" from BaseModel. The name "ForecastingModel" + (without "Base" prefix) avoids conflicts with the component registry system + which looks for classes with "Base" in their name. + """ + + _compatible_tasks = ["ForecastingTask"] + + def __init__(self, **kwargs): + """Initialize forecasting model. + + Sets up common attributes that all forecasting models should maintain. + + Parameters + ---------- + **kwargs + Additional arguments passed to BaseModel.__init__ + """ + super().__init__(**kwargs) + + # Store exogenous variable names in ORIGINAL format + self.exog_cols: List[str] = [] + + # Store column names for reference + self.timestamp_col: Optional[str] = None + self.target_col: Optional[str] = None + + @abstractmethod + def fit(self, x: pd.DataFrame, y: pd.DataFrame, **kwargs) -> "ForecastingModel": + """Train the forecasting model. + + Parameters + ---------- + x : pd.DataFrame + Training features including: + - Timestamp column (datetime) + - Exogenous variables (optional) + May also include the target column (will be used from there if present) + + y : pd.DataFrame + Target variable values. + Single column with the variable to forecast. + + **kwargs + Additional model-specific parameters. + + Returns + ------- + self : ForecastingModel + Returns self for method chaining. + + Notes + ----- + Implementations should: + 1. Auto-detect timestamp column (try pd.to_datetime on columns) + 2. Filter exogenous variables (numeric only, exclude timestamp/target) + 3. Store original column names in self.exog_cols, self.timestamp_col, + self.target_col + 4. Internally convert to model-specific format if needed + """ + raise NotImplementedError + + @abstractmethod + def predict( + self, + x_pred: Optional[pd.DataFrame] = None, + periods: Optional[int] = None, + exog_future: Optional[pd.DataFrame] = None, + **kwargs, + ) -> pd.DataFrame: + """Generate forecasts. + + Supports two prediction modes: + 1. In-sample: Provide x_pred with timestamps and exogenous values + 2. Out-of-sample: Provide periods and exog_future for future forecasting + + Parameters + ---------- + x_pred : pd.DataFrame, optional + Input data for in-sample predictions containing: + - Timestamp column + - Exogenous variables (if model uses them) + + periods : int, optional + Number of future periods to forecast (out-of-sample mode). + + exog_future : pd.DataFrame, optional + Future values of exogenous variables for out-of-sample forecasting. + Must contain all columns in self.exog_cols. + + **kwargs + Additional model-specific parameters. + + Returns + ------- + pd.DataFrame or np.ndarray + Predictions with columns using ORIGINAL names: + - Timestamp column (same name as training data) + - Target column (same name as training data) + - Optionally: prediction intervals, components, etc. + + Notes + ----- + Implementations MUST support both prediction modes: + 1. In-sample predictions (x_pred provided): For calculating metrics on + train/validation/test splits + 2. Out-of-sample predictions (periods provided): For future forecasting + + Implementations should: + 1. Auto-detect timestamp column in x_pred (handle both original name and 'ds') + 2. Validate exogenous variables are present if model requires them + 3. Return predictions with ORIGINAL column names (not model-specific names) + + IMPORTANT: Do NOT raise NotImplementedError for in-sample predictions. + Model evaluation (metrics calculation) requires in-sample predictions. + """ + raise NotImplementedError + + def train( + self, + x_train, + y_train, + x_validation=None, + y_validation=None, + **kwargs, + ) -> "ForecastingModel": + """Compatibility wrapper for the generic DashAI model contract. + + Forecasting jobs train models via ``fit()`` so they can pass + ``temporal_metadata`` and other forecasting-specific arguments. This + wrapper keeps forecasting models instantiable through ``ModelFactory``, + which still expects every model to provide a concrete ``train()`` method. + """ + if x_validation is not None or y_validation is not None: + warnings.warn( + "ForecastingModel.train() ignores validation datasets. " + "Forecasting models should be trained via fit() with the " + "appropriate temporal metadata.", + UserWarning, + stacklevel=2, + ) + + return self.fit(x_train, y_train, **kwargs) + + def get_exogenous_columns(self) -> List[str]: + """Get list of exogenous variable names in original format. + + Returns + ------- + List[str] + List of exogenous variable column names as they appear in the + original dataset (not in model-specific format). + + Examples + -------- + >>> model.fit(x_train, y_train) + >>> model.get_exogenous_columns() + ['temperature', 'humidity', 'wind_speed'] + # NOT ['exog_temperature', 'exog_humidity', 'exog_wind_speed'] + # NOT ['extra_regressor_1', 'extra_regressor_2', 'extra_regressor_3'] + """ + return self.exog_cols.copy() + + def has_exogenous_variables(self) -> bool: + """Check if model uses exogenous variables. + + Returns + ------- + bool + True if model was trained with exogenous variables, False otherwise. + """ + return len(self.exog_cols) > 0 + + def get_column_names(self) -> dict: + """Get all relevant column names in original format. + + Returns + ------- + dict + Dictionary with keys: + - 'timestamp': Timestamp column name + - 'target': Target column name + - 'exogenous': List of exogenous variable names + """ + return { + "timestamp": self.timestamp_col, + "target": self.target_col, + "exogenous": self.exog_cols.copy(), + } + + def _get_seasonal_period(self) -> int: + """Infer seasonal period from the model's stored frequency. + + Returns the number of observations per seasonal cycle, used to + configure STL decomposition in ``get_forecast_components()``. + + Returns + ------- + int + Seasonal period (e.g., 7 for daily data → weekly cycle). + """ + if not self.frequency: + return 7 # default: weekly cycle for daily data + + freq = self.frequency.upper().strip() + + if freq.startswith(("T", "MIN")): + return 60 # minutely → hourly seasonality + if freq.startswith("H"): + return 24 # hourly → daily seasonality + if freq.startswith("D"): + return 7 # daily → weekly seasonality + if freq.startswith("W"): + return 52 # weekly → yearly seasonality + if freq.startswith(("M", "ME", "MS")): + return 12 # monthly → yearly seasonality + if freq.startswith(("Q", "QE", "QS")): + return 4 # quarterly → yearly seasonality + if freq.startswith(("A", "Y", "AE", "AS", "YS", "YE")): + return 1 # yearly → no sub-annual seasonality + + return 7 # fallback + + def _period_to_seasonality_name(self, period: int) -> str: + """Map a seasonal period integer to a human-readable component name. + + Parameters + ---------- + period : int + Number of observations per seasonal cycle. + + Returns + ------- + str + Name for the seasonal component column (e.g., 'weekly', 'yearly'). + """ + mapping = { + 60: "hourly", + 24: "daily", + 7: "weekly", + 52: "yearly", + 12: "yearly", + 4: "yearly", + 365: "yearly", + } + return mapping.get(period, "seasonal") + + def _validate_predict_implementation(self) -> None: + """Validate that subclass implements predict() correctly. + + This method can be called in tests to ensure implementations support + both in-sample and out-of-sample predictions. + + Raises + ------ + NotImplementedError + If predict() raises NotImplementedError for in-sample predictions + ValueError + If predict() doesn't handle both prediction modes + + Notes + ----- + This is a helper for testing - not called automatically during runtime. + Developers should call this in unit tests for their forecasting models. + + Example + ------- + >>> # In test_my_model.py + >>> model = MyForecastingModel() + >>> model.fit(x_train, y_train) + >>> model._validate_predict_implementation() # Ensures correct implementation + """ + warnings.warn( + "ForecastingModel.predict() must support both in-sample (x_pred) " + "and out-of-sample (periods) prediction modes. " + "In-sample predictions are required for metrics calculation.", + UserWarning, + stacklevel=2, + ) diff --git a/DashAI/back/models/forecasting/prophet_model.py b/DashAI/back/models/forecasting/prophet_model.py new file mode 100644 index 000000000..94b9a7c9b --- /dev/null +++ b/DashAI/back/models/forecasting/prophet_model.py @@ -0,0 +1,1029 @@ +"""Prophet model wrapper for DashAI forecasting. + +This model wraps Facebook Prophet for native time series forecasting +with automatic seasonality detection and holiday effects. +""" + +import os +import pickle +from typing import Any, Optional, Union + +import numpy as np +import pandas as pd + +from DashAI.back.core.schema_fields import ( + BaseSchema, + enum_field, + float_field, + int_field, + schema_field, +) +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + to_dashai_dataset, +) +from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel + + +def _patch_prophet_regressor_column_matrix(): + """Patch Prophet for compatibility with newer pandas versions.""" + from prophet import Prophet + + if getattr(Prophet, "_dashai_pandas_compat_patch", False): + return Prophet + + @staticmethod + def _dashai_fourier_series(dates, period, series_order): + """Prophet expects nanosecond timestamps, but newer pandas can + hand back ``datetime64[us]`` arrays. Force nanosecond precision so + weekly/yearly seasonal features keep the correct period. + """ + if not (series_order >= 1): + raise ValueError("series_order must be >= 1") + + ns_dates = ( + pd.to_datetime(dates).to_numpy(dtype="datetime64[ns]").astype(np.int64) + ) + t = ns_dates // 1_000_000_000 / (3600 * 24.0) + + x_T = t * np.pi * 2 + fourier_components = np.empty((dates.shape[0], 2 * series_order)) + for i in range(series_order): + c = x_T * (i + 1) / period + fourier_components[:, 2 * i] = np.sin(c) + fourier_components[:, (2 * i) + 1] = np.cos(c) + return fourier_components + + def _dashai_regressor_column_matrix(self, seasonal_features, modes): + components = pd.DataFrame( + { + "col": np.arange(seasonal_features.shape[1]), + "component": [x.split("_delim_")[0] for x in seasonal_features.columns], + } + ) + + if self.train_holiday_names is not None: + components = self.add_group_component( + components, "holidays", self.train_holiday_names.unique() + ) + + for mode in ["additive", "multiplicative"]: + components = self.add_group_component( + components, mode + "_terms", modes[mode] + ) + regressors_by_mode = [ + r for r, props in self.extra_regressors.items() if props["mode"] == mode + ] + components = self.add_group_component( + components, + "extra_regressors_" + mode, + regressors_by_mode, + ) + modes[mode].append(mode + "_terms") + modes[mode].append("extra_regressors_" + mode) + + modes[self.holidays_mode].append("holidays") + + clean_components = components.reset_index(drop=True) + component_cols = pd.crosstab( + pd.Series(clean_components["col"].to_numpy(), name="col"), + pd.Series(clean_components["component"].to_numpy(), name="component"), + ).sort_index(level="col") + + for name in ["additive_terms", "multiplicative_terms"]: + if name not in component_cols: + component_cols[name] = 0 + + component_cols = component_cols.drop("zeros", axis=1, errors="ignore") + + if ( + max( + component_cols["additive_terms"] + + component_cols["multiplicative_terms"] + ) + > 1 + ): + raise Exception("A bug occurred in seasonal components.") + + if self.train_component_cols is not None: + component_cols = component_cols[self.train_component_cols.columns] + if not component_cols.equals(self.train_component_cols): + raise Exception("A bug occurred in constructing regressors.") + + return component_cols, modes + + Prophet.fourier_series = _dashai_fourier_series + Prophet.regressor_column_matrix = _dashai_regressor_column_matrix + Prophet._dashai_pandas_compat_patch = True + return Prophet + + +class ProphetModelSchema(BaseSchema): + """Schema for Prophet model configuration. + + Prophet is a forecasting procedure designed for business time series data. + It works best with time series that have strong seasonal effects and several + seasons of historical data. Prophet is robust to missing data and shifts in + the trend, and typically handles outliers well. + """ + + seasonality_mode: schema_field( + enum_field(enum=["additive", "multiplicative"]), + placeholder="additive", + description="Type of seasonality. 'additive' assumes seasonal effects are " + "added to the trend. 'multiplicative' assumes seasonal effects are " + "multiplied by the trend.", + ) = "additive" # type: ignore + + yearly_seasonality: schema_field( + enum_field(enum=["auto", "true", "false"]), + placeholder="auto", + description="Yearly seasonality. 'auto' detects automatically, " + "'true' forces yearly seasonality, 'false' disables it.", + ) = "auto" # type: ignore + + weekly_seasonality: schema_field( + enum_field(enum=["auto", "true", "false"]), + placeholder="auto", + description="Weekly seasonality. 'auto' detects automatically, " + "'true' forces weekly seasonality, 'false' disables it.", + ) = "auto" # type: ignore + + daily_seasonality: schema_field( + enum_field(enum=["auto", "true", "false"]), + placeholder="auto", + description="Daily seasonality. 'auto' detects automatically, " + "'true' forces daily seasonality, 'false' disables it.", + ) = "auto" # type: ignore + + growth: schema_field( + enum_field(enum=["linear", "logistic", "flat"]), + placeholder="linear", + description="Growth model. 'linear' for unlimited growth, " + "'logistic' for growth that saturates at a carrying capacity " + "(requires cap_multiplier), 'flat' for no trend.", + ) = "linear" # type: ignore + + cap_multiplier: schema_field( + float_field(ge=1.0, le=10.0), + placeholder=1.5, + description="For logistic growth: multiplier applied to max(y) to set " + "the carrying capacity. E.g., 1.5 means cap = 1.5 * max(y). " + "Only used when growth='logistic'.", + ) = 1.5 # type: ignore + + floor_ratio: schema_field( + float_field(ge=0.0, le=1.0), + placeholder=0.0, + description="For logistic growth: floor as ratio of min(y). " + "E.g., 0.5 means floor = 0.5 * min(y). " + "Only used when growth='logistic'.", + ) = 0.0 # type: ignore + + changepoint_prior_scale: schema_field( + float_field(ge=0.001, le=1.0), + placeholder=0.05, + description="Controls flexibility of automatic changepoint selection. " + "Higher values allow more changepoints (more flexible trend). " + "Lower values result in fewer changepoints (more conservative trend).", + ) = 0.05 # type: ignore + + seasonality_prior_scale: schema_field( + float_field(ge=0.01, le=100.0), + placeholder=10.0, + description="Controls flexibility of seasonality. Higher values allow " + "more seasonal variation. Lower values result in smoother seasonality.", + ) = 10.0 # type: ignore + + holidays_prior_scale: schema_field( + float_field(ge=0.01, le=100.0), + placeholder=10.0, + description="Controls flexibility of holiday effects. Higher values " + "allow larger holiday effects.", + ) = 10.0 # type: ignore + + interval_width: schema_field( + float_field(ge=0.5, le=0.99), + placeholder=0.8, + description="Width of prediction intervals. 0.8 means 80% confidence " + "intervals. Prophet will generate yhat_lower and yhat_upper bounds.", + ) = 0.8 # type: ignore + + uncertainty_samples: schema_field( + int_field(ge=100, le=10000), + placeholder=1000, + description="Number of samples to draw for uncertainty estimation. " + "More samples give smoother intervals but slower prediction.", + ) = 1000 # type: ignore + + +class ProphetModel(ForecastingModel): + """Prophet forecasting model wrapper for DashAI. + + This model implements the ForecastingModel interface, handling all + column name conversions internally. It maintains exogenous variables in + their original format and converts to Prophet's 'ds'/'y' convention only + during internal operations. + """ + + SCHEMA = ProphetModelSchema + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + _task_type = "ForecastingTask" + + def __init__( + self, + seasonality_mode: str = "additive", + yearly_seasonality: str = "auto", + weekly_seasonality: str = "auto", + daily_seasonality: str = "auto", + growth: str = "linear", + cap_multiplier: float = 1.5, + floor_ratio: float = 0.0, + changepoint_prior_scale: float = 0.05, + seasonality_prior_scale: float = 10.0, + holidays_prior_scale: float = 10.0, + interval_width: float = 0.8, + uncertainty_samples: int = 1000, + **kwargs, + ) -> None: + super().__init__(**kwargs) # Pass kwargs to ForecastingModel + + self.seasonality_mode = seasonality_mode + self.yearly_seasonality = self._parse_bool_setting(yearly_seasonality) + self.weekly_seasonality = self._parse_bool_setting(weekly_seasonality) + self.daily_seasonality = self._parse_bool_setting(daily_seasonality) + self.growth = growth + self.cap_multiplier = cap_multiplier + self.floor_ratio = floor_ratio + self.changepoint_prior_scale = changepoint_prior_scale + self.seasonality_prior_scale = seasonality_prior_scale + self.holidays_prior_scale = holidays_prior_scale + self.interval_width = interval_width + self.uncertainty_samples = uncertainty_samples + + # Store cap/floor for predictions when using logistic growth + self._cap_value: Optional[float] = None + self._floor_value: Optional[float] = None + + self.model = None + # exog_cols, timestamp_col, target_col are inherited from ForecastingModel + self.last_ds: Optional[pd.Timestamp] = None + self.frequency: Optional[str] = None + + def _parse_bool_setting(self, setting: str) -> Union[bool, str]: + if setting.lower() == "true": + return True + if setting.lower() == "false": + return False + return "auto" + + def _validate_forecasting_data(self, x: DashAIDataset, y: DashAIDataset) -> None: + """Validate that data is suitable for Prophet. + + Parameters + ---------- + X : DashAIDataset + Input features (must contain a timestamp column) + y : DashAIDataset + Target values (must contain a numeric column) + + Raises + ------ + ValueError + If data is not suitable for Prophet + """ + x_cols = set(x.column_names) + y_cols = set(y.column_names) + + if len(x_cols) == 0: + raise ValueError( + "Prophet requires at least one input column (timestamp). " + "Received empty dataset." + ) + + if len(y_cols) != 1: + raise ValueError( + f"Prophet requires exactly one target column. " + f"Received {len(y_cols)} columns: {list(y_cols)}" + ) + + def fit( + self, + x_train: DashAIDataset, + y: DashAIDataset, + temporal_metadata: dict = None, + **fit_params, + ) -> "ProphetModel": + """Train Prophet forecasting model. + + Implements ForecastingModel.fit() interface. Handles all column name + conversions internally - stores original names in base class attributes, + converts to Prophet's 'ds'/'y' convention only for internal use. + + Parameters + ---------- + x_train : DashAIDataset + Input features containing timestamp and optional exogenous variables + y : DashAIDataset + Target time series (single column) + temporal_metadata : dict, optional + Metadata from ForecastingTask containing: + - timestamp_col: name of timestamp column + - target_col: name of target column + - exog_cols: list of exogenous variable column names + - frequency: time series frequency + If not provided, will attempt auto-detection (legacy behavior) + **fit_params + Additional fitting parameters + + Returns + ------- + ProphetModel + Fitted model instance + """ + try: + Prophet = _patch_prophet_regressor_column_matrix() + except ImportError as e: + raise ImportError( + "Prophet is required for ProphetModel. " + "Install with: pip install prophet" + ) from e + + # Validate data format + self._validate_forecasting_data(x_train, y) + + # Convert to pandas DataFrames + x_df = x_train.to_pandas() + y_df = y.to_pandas() + + # Get column information from metadata (task-agnostic approach) + if temporal_metadata: + timestamp_col = temporal_metadata.get("timestamp_col") + target_col = temporal_metadata.get("target_col") + exog_cols_from_task = temporal_metadata.get("exog_cols", []) + frequency = temporal_metadata.get("frequency", "D") + + print("[ProphetModel] Using temporal metadata from task:") + print(f" - Timestamp: '{timestamp_col}'") + print(f" - Target: '{target_col}'") + print(f" - Frequency: {frequency}") + if exog_cols_from_task: + print(f" - Exogenous variables: {exog_cols_from_task}") + else: + # Legacy: auto-detection if no metadata provided + print( + "[ProphetModel] ⚠️ No temporal_metadata provided, using auto-detection" + ) + + # Get target column name (should be single column) + target_col = y_df.columns[0] + + # Auto-detect timestamp column in x_df + timestamp_col = None + for col in x_df.columns: + try: + pd.to_datetime(x_df[col]) + timestamp_col = col + print(f"[ProphetModel] Detected timestamp column: '{col}'") + break + except Exception: + continue + + if timestamp_col is None: + raise ValueError( + f"No timestamp column found in input data. " + f"Available columns: {list(x_df.columns)}" + ) + + exog_cols_from_task = [] + frequency = fit_params.get("frequency", "D") + + # Store original column names in base class attributes + self.timestamp_col = timestamp_col + self.target_col = target_col + self.frequency = frequency + + # Build Prophet dataframe (internal conversion to 'ds'/'y') + prophet_df = pd.DataFrame( + { + "ds": pd.to_datetime(x_df[timestamp_col]).to_numpy(), + } + ) + + # Check if target column is in x_train (user might have included it by mistake) + target_in_inputs = target_col in x_df.columns + + if target_in_inputs: + # Target is in inputs - use it from there for consistency + print( + "[ProphetModel] ℹ️ Target '{}' found in inputs - using it " + "from there".format(target_col) + ) + prophet_df["y"] = pd.to_numeric( + x_df[target_col], errors="coerce" + ).to_numpy() + else: + # Target is only in y - normal case + prophet_df["y"] = pd.to_numeric( + y_df[target_col], errors="coerce" + ).to_numpy() + + # Add exogenous variables (columns that are not timestamp and are numeric) + # Exclude timestamp and target columns, and only include numeric columns + # Store in ORIGINAL format (as per BaseForecastingModel contract) + self.exog_cols = [] + if temporal_metadata: + candidate_exog_cols = [col for col in exog_cols_from_task if col in x_df] + missing_exog_cols = [col for col in exog_cols_from_task if col not in x_df] + if missing_exog_cols: + print( + "[ProphetModel] ⚠️ Ignoring missing exogenous columns from task: " + f"{missing_exog_cols}" + ) + else: + candidate_exog_cols = [ + col for col in x_df.columns if col not in {timestamp_col, target_col} + ] + + for col in candidate_exog_cols: + if col == target_col and target_in_inputs: + print( + "[ProphetModel] ℹ️ Excluding target '{}' from exogenous " + "variables".format(col) + ) + continue + + # Only add numeric columns + if pd.api.types.is_numeric_dtype(x_df[col]): + self.exog_cols.append(col) # Store ORIGINAL name + prophet_df[col] = x_df[col].to_numpy() + else: + print( + "[ProphetModel] ⚠️ Skipping non-numeric column: '{}' " + "(type: {})".format(col, x_df[col].dtype) + ) + + # Handle logistic growth - requires 'cap' (and optionally 'floor') columns + if self.growth == "logistic": + y_max = prophet_df["y"].max() + y_min = prophet_df["y"].min() + + # Calculate cap and floor based on multipliers + self._cap_value = y_max * self.cap_multiplier + self._floor_value = y_min * self.floor_ratio + + # Add cap column (required for logistic growth) + prophet_df["cap"] = self._cap_value + + # Add floor column if floor_ratio > 0 + if self.floor_ratio > 0: + prophet_df["floor"] = self._floor_value + + print( + f"[ProphetModel] Logistic growth: cap={self._cap_value:.2f} " + f"(max*{self.cap_multiplier}), floor={self._floor_value:.2f}" + ) + + # Store additional metadata + self.last_ds = prophet_df["ds"].max() + + print(f"[ProphetModel] Training with {len(prophet_df)} data points") + print( + f"[ProphetModel] Date range: {prophet_df['ds'].min()} to " + f"{prophet_df['ds'].max()}" + ) + if self.exog_cols: + print(f"[ProphetModel] Exogenous variables: {self.exog_cols}") + + # Initialize Prophet model + self.model = Prophet( + seasonality_mode=self.seasonality_mode, + yearly_seasonality=self.yearly_seasonality, + weekly_seasonality=self.weekly_seasonality, + daily_seasonality=self.daily_seasonality, + growth=self.growth, + changepoint_prior_scale=self.changepoint_prior_scale, + seasonality_prior_scale=self.seasonality_prior_scale, + holidays_prior_scale=self.holidays_prior_scale, + interval_width=self.interval_width, + uncertainty_samples=self.uncertainty_samples, + ) + + # Add exogenous regressors to Prophet (using original names) + for col in self.exog_cols: + self.model.add_regressor(col) + + self.model.fit(prophet_df) + + print("✅ Prophet model training completed") + return self + + def _add_cap_floor_columns(self, dataframe: pd.DataFrame) -> pd.DataFrame: + """Add cap and floor columns for logistic growth predictions. + + Parameters + ---------- + dataframe : pd.DataFrame + DataFrame to add cap/floor columns to + + Returns + ------- + pd.DataFrame + DataFrame with cap (and optionally floor) columns added + """ + if self.growth != "logistic": + return dataframe + + result_df = dataframe.copy() + + if self._cap_value is not None: + result_df["cap"] = self._cap_value + + if self._floor_value is not None and self.floor_ratio > 0: + result_df["floor"] = self._floor_value + + return result_df + + def predict( + self, + x_pred: Optional[Any] = None, + periods: Optional[int] = None, + horizon: Optional[int] = None, + exog_future: Optional[pd.DataFrame] = None, + return_components: bool = False, + **kwargs, + ) -> Union[np.ndarray, pd.DataFrame]: + if self.model is None: + raise ValueError("Prophet model is not fitted yet. Call fit() first.") + + def _extract_predictions( + forecast_df: pd.DataFrame, requested_ds: pd.Series + ) -> Union[np.ndarray, pd.DataFrame]: + """Extract predictions for requested timestamps. + + For timestamps that don't exist in Prophet's forecast (gaps in data), + returns NaN. These will be filtered out by prepare_to_metric(). + """ + # Normalize both forecast and requested timestamps to ensure matching + # Prophet internally normalizes dates, so we need to do the same + forecast_df = forecast_df.copy() + forecast_df["ds"] = pd.to_datetime(forecast_df["ds"]).dt.normalize() + requested_ds_normalized = pd.to_datetime(requested_ds).dt.normalize() + + # Debug: Show sample of dates + print( + f"[ProphetModel] _extract_predictions: " + f"forecast has {len(forecast_df)} rows, " + f"requested {len(requested_ds_normalized)} timestamps" + ) + print( + f"[ProphetModel] Forecast dates range: " + f"{forecast_df['ds'].min()} to {forecast_df['ds'].max()}" + ) + print( + f"[ProphetModel] Requested dates range: " + f"{requested_ds_normalized.min()} to {requested_ds_normalized.max()}" + ) + + aligned = forecast_df.set_index("ds").reindex(requested_ds_normalized) + + # Check for missing predictions + missing_mask = aligned["yhat"].isna() + if missing_mask.any(): + missing_count = missing_mask.sum() + total_count = len(requested_ds) + print( + f"[ProphetModel] ⚠️ {missing_count}/{total_count} timestamps " + f"have no predictions (gaps in data). These will be excluded " + f"from metrics calculation." + ) + # Debug: Show which dates are missing + if missing_count <= 10: + missing_dates = requested_ds_normalized[missing_mask.to_numpy()] + print(f"[ProphetModel] Missing dates: {list(missing_dates)}") + else: + missing_dates = requested_ds_normalized[missing_mask.to_numpy()] + print( + f"[ProphetModel] First 5 missing dates: " + f"{list(missing_dates[:5])}" + ) + + if return_components: + return aligned.reset_index() + return aligned["yhat"].to_numpy() + + if x_pred is not None: + if isinstance(x_pred, (int, np.integer)): + periods = int(x_pred) + else: + if isinstance(x_pred, pd.DataFrame): + input_df = x_pred.copy() + else: + input_df = to_dashai_dataset(x_pred).to_pandas() + + # Auto-detect timestamp column (try 'ds' first for compatibility) + timestamp_col = None + if "ds" in input_df.columns: + timestamp_col = "ds" + else: + # Try to find timestamp column + for col in input_df.columns: + try: + pd.to_datetime(input_df[col]) + timestamp_col = col + break + except Exception: + continue + + if timestamp_col is None: + raise ValueError( + "Prophet predict requires a timestamp column. " + f"Available columns: {list(input_df.columns)}" + ) + + input_df = input_df.copy() + + # Rename to 'ds' for Prophet + if timestamp_col != "ds": + input_df = input_df.rename(columns={timestamp_col: "ds"}) + + # Normalize timestamps to ensure consistent comparison + input_df["ds"] = pd.to_datetime(input_df["ds"]).dt.normalize() + input_df = input_df.sort_values("ds").reset_index(drop=True) + + # Check if we need in-sample predictions (for explainability) + # If any requested date is <= last training date, we need to include + # historical dates in the prediction + # Use Prophet's internal history dataframe to get training date range + if not hasattr(self.model, "history_dates"): + raise ValueError( + "Prophet model has no training history. " + "Ensure the model was fitted before prediction." + ) + # Normalize history dates for consistent comparison + history_dates = pd.Series(self.model.history_dates) + history_dates_normalized = history_dates.dt.normalize() + last_train_date = history_dates_normalized.max() + has_historical = (input_df["ds"] <= last_train_date).any() + + if has_historical: + # For in-sample predictions (explainability use case): + # Include both historical and future dates + # Create a complete dataframe from first training date to last + # requested date. This ensures Prophet generates predictions for + # all dates including historical ones + max_requested_date = input_df["ds"].max() + + # Use make_future_dataframe but include historical dates + future_df = self.model.make_future_dataframe( + periods=0, # Don't extend beyond training + freq=self.frequency or "D", + include_history=True, # Include training dates + ) + + # Add any future dates beyond training if needed + if max_requested_date > last_train_date: + additional_periods = pd.date_range( + start=last_train_date + pd.Timedelta(days=1), + end=max_requested_date, + freq=self.frequency or "D", + ) + additional_df = pd.DataFrame({"ds": additional_periods}) + future_df = pd.concat( + [future_df, additional_df], ignore_index=True + ) + + # Add exogenous variables if present + if self.exog_cols: + missing_cols = [ + col for col in self.exog_cols if col not in input_df.columns + ] + if missing_cols: + raise ValueError( + "Missing exogenous columns for prediction: " + f"{missing_cols}." + ) + + # Merge exogenous data from input_df with future_df + # For historical dates, use the provided values + future_df = future_df.merge( + input_df[["ds"] + self.exog_cols], on="ds", how="left" + ) + + # Check if there are missing exogenous values + if future_df[self.exog_cols].isna().any().any(): + raise ValueError( + "Missing exogenous values for some dates. " + "All dates in prediction range must have " + "exogenous data." + ) + else: + # Normal future forecasting (original behavior) + future_df = input_df[["ds"]].copy() + + if self.exog_cols: + missing_cols = [ + col for col in self.exog_cols if col not in input_df.columns + ] + if missing_cols: + raise ValueError( + "Missing exogenous columns for prediction: " + f"{missing_cols}." + ) + future_df = pd.concat( + [ + future_df, + input_df[self.exog_cols].reset_index(drop=True), + ], + axis=1, + ) + + # Add cap/floor for logistic growth + future_df = self._add_cap_floor_columns(future_df) + + # Debug: Log what we're predicting + print( + f"[ProphetModel] Predicting for {len(future_df)} dates: " + f"{future_df['ds'].min()} to {future_df['ds'].max()}" + ) + print(f"[ProphetModel] has_historical={has_historical}") + + forecast = self.model.predict(future_df) + + # Debug: Log what Prophet returned + print( + f"[ProphetModel] Prophet returned {len(forecast)} predictions: " + f"{forecast['ds'].min()} to {forecast['ds'].max()}" + ) + + return _extract_predictions(forecast, input_df["ds"]) + + # Handle periods/horizon compatibility + if periods is None and horizon is not None: + periods = horizon + + if periods is None: + raise ValueError( + "Prophet predict requires either 'x_pred' data or a 'periods' value." + ) + if periods <= 0: + raise ValueError("Prediction horizon must be a positive integer.") + + frequency = self.frequency or "D" + + # If x_pred is provided with periods, use it to determine start date + start_date = None + if x_pred is not None: + if isinstance(x_pred, pd.DataFrame): + input_df = x_pred.copy() + else: + input_df = to_dashai_dataset(x_pred).to_pandas() + + # Find timestamp column + ts_col = None + if "ds" in input_df.columns: + ts_col = "ds" + elif self.timestamp_col in input_df.columns: + ts_col = self.timestamp_col + + if ts_col: + start_date = pd.to_datetime(input_df[ts_col]).max() + print(f"[ProphetModel] Using input as start date: {start_date}") + + # Also update last_ds for explainers + self.last_ds = start_date + + if start_date: + # Generate future dataframe starting after start_date + future_dates = pd.date_range( + start=start_date, periods=periods + 1, freq=frequency + )[1:] + future_df = pd.DataFrame({"ds": future_dates}) + else: + # Standard behavior (continue from training) + future_df = self.model.make_future_dataframe( + periods=periods, freq=frequency + ) + + if self.exog_cols and exog_future is not None: + missing_cols = [ + col for col in self.exog_cols if col not in exog_future.columns + ] + if missing_cols: + raise ValueError( + f"Missing exogenous columns for future prediction: {missing_cols}." + ) + if len(exog_future) != periods: + raise ValueError( + "Missing exogenous values must match the prediction horizon length." + ) + for col in self.exog_cols: + future_df[col] = exog_future[col].to_numpy() + elif self.exog_cols: + raise ValueError( + f"Future exogenous values required for columns: {self.exog_cols}." + ) + + # Add cap/floor for logistic growth + future_df = self._add_cap_floor_columns(future_df) + forecast = self.model.predict(future_df) + print(f"[ProphetModel] Generated forecast for {periods} periods") + print( + "[ProphetModel] Forecast range: " + f"{forecast['ds'].iloc[-periods:].min()} to " + f"{forecast['ds'].iloc[-periods:].max()}" + ) + + if return_components: + return forecast.tail(periods) + return forecast["yhat"].tail(periods).to_numpy() + + def get_forecast_uncertainty( + self, horizon: int, confidence_level: float = 0.80 + ) -> pd.DataFrame: + """Get forecast with native Prophet prediction intervals. + + Parameters + ---------- + horizon : int + Number of future periods to forecast. + confidence_level : float + Desired confidence level. Note: Prophet's intervals are controlled + by ``interval_width`` set at model creation. This parameter is + accepted for interface uniformity but may not match exactly if the + model was initialized with a different ``interval_width``. + + Returns + ------- + pd.DataFrame + Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``. + Intervals come from Prophet's own uncertainty sampling + (``uncertainty_samples`` parameter). + + Raises + ------ + ValueError + If the model was trained with exogenous variables. + """ + if self.model is None: + raise ValueError("Model must be fitted before getting uncertainty.") + + if self.exog_cols: + raise ValueError( + f"Cannot generate forecast uncertainty: model was trained with " + f"exogenous variables {self.exog_cols}. Future exogenous values " + f"are required but not available. " + f"Use ForecastFeatureImportance instead." + ) + + freq = self.frequency or "D" + future_df = self.model.make_future_dataframe(periods=horizon, freq=freq) + future_df = self._add_cap_floor_columns(future_df) + forecast = self.model.predict(future_df) + + fc = forecast.tail(horizon) + return pd.DataFrame( + { + "ds": fc["ds"].to_numpy(), + "yhat": fc["yhat"].to_numpy(), + "yhat_lower": fc["yhat_lower"].to_numpy(), + "yhat_upper": fc["yhat_upper"].to_numpy(), + } + ) + + def get_forecast_components(self, horizon: int) -> pd.DataFrame: + """Get forecast decomposition (trend, seasonality, etc.). + + Note: This method requires making future predictions. If the model + was trained with exogenous variables, this will fail unless future + values for those variables are provided. + + Parameters + ---------- + horizon : int + Number of periods to forecast + + Returns + ------- + pd.DataFrame + Forecast components (trend, seasonal, etc.) + + Raises + ------ + ValueError + If model was trained with exogenous variables (cannot forecast + without future exogenous values) + """ + if self.model is None: + raise ValueError("Model must be fitted before getting components") + + if self.exog_cols: + # Model uses exogenous variables - cannot make valid forecast + raise ValueError( + f"Cannot generate forecast components: model was trained with " + f"exogenous variables {self.exog_cols}.\n" + f"Future forecasting requires known future values for these variables, " + f"which are not available.\n" + f"Recommendation: For models with exogenous variables, use " + f"ForecastFeatureImportance explainer instead." + ) + + # No exogenous variables - can make simple forecast + future_df = self.model.make_future_dataframe( + periods=horizon, freq=self.frequency or "D" + ) + # Add cap/floor for logistic growth + future_df = self._add_cap_floor_columns(future_df) + forecast = self.model.predict(future_df) + + # Return components for the forecast period + component_cols = ["ds", "trend", "seasonal", "weekly", "yearly"] + available_cols = [col for col in component_cols if col in forecast.columns] + return forecast[available_cols].iloc[-horizon:] + + def save(self, filename: str) -> None: + """Save Prophet model to file. + + Parameters + ---------- + filename : str + Path to save the model + """ + model_state = { + "model": self.model, + # Base class attributes (original column names) + "exog_cols": self.exog_cols, + "timestamp_col": self.timestamp_col, + "target_col": self.target_col, + # Prophet-specific metadata + "last_ds": self.last_ds, + "frequency": self.frequency, + # Logistic growth parameters + "_cap_value": self._cap_value, + "_floor_value": self._floor_value, + "cap_multiplier": self.cap_multiplier, + "floor_ratio": self.floor_ratio, + "config": { + "seasonality_mode": self.seasonality_mode, + "yearly_seasonality": self.yearly_seasonality, + "weekly_seasonality": self.weekly_seasonality, + "daily_seasonality": self.daily_seasonality, + "growth": self.growth, + "changepoint_prior_scale": self.changepoint_prior_scale, + "seasonality_prior_scale": self.seasonality_prior_scale, + "holidays_prior_scale": self.holidays_prior_scale, + "interval_width": self.interval_width, + "uncertainty_samples": self.uncertainty_samples, + }, + } + + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(model_state, f) + + print(f"✅ Prophet model saved to {filename}") + + def load(self, filename: str) -> "ProphetModel": + """Load Prophet model from file. + + Parameters + ---------- + filename : str + Path to load the model from + + Returns + ------- + ProphetModel + Loaded model instance + """ + _patch_prophet_regressor_column_matrix() + + with open(filename, "rb") as f: + model_state = pickle.load(f) + + self.model = model_state["model"] + + # Restore base class attributes (original column names) + self.exog_cols = model_state["exog_cols"] + self.timestamp_col = model_state.get( + "timestamp_col" + ) # May not exist in old models + self.target_col = model_state.get("target_col") # May not exist in old models + + # Restore Prophet-specific metadata + self.last_ds = model_state["last_ds"] + self.frequency = model_state["frequency"] + + # Restore logistic growth parameters (may not exist in old models) + self._cap_value = model_state.get("_cap_value") + self._floor_value = model_state.get("_floor_value") + self.cap_multiplier = model_state.get("cap_multiplier", 1.5) + self.floor_ratio = model_state.get("floor_ratio", 0.0) + + # Restore configuration + config = model_state["config"] + for key, value in config.items(): + setattr(self, key, value) + + print(f"✅ Prophet model loaded from {filename}") + return self diff --git a/DashAI/back/models/forecasting/sklearn_multistep_forecaster.py b/DashAI/back/models/forecasting/sklearn_multistep_forecaster.py new file mode 100644 index 000000000..8f1928905 --- /dev/null +++ b/DashAI/back/models/forecasting/sklearn_multistep_forecaster.py @@ -0,0 +1,928 @@ +"""Sklearn-based multi-step forecasting model for DashAI. + +This model uses sklearn regressors with a sliding window approach to perform +multi-step-ahead forecasting. It internally creates lag features and can +predict multiple steps into the future. +""" + +import os +import pickle +from typing import Any, List, Optional, Union + +import numpy as np +import pandas as pd +from sklearn.ensemble import RandomForestRegressor as _RandomForestRegressor +from sklearn.linear_model import LinearRegression as _LinearRegression +from sklearn.linear_model import Ridge as _Ridge +from sklearn.multioutput import MultiOutputRegressor + +from DashAI.back.core.schema_fields import ( + BaseSchema, + enum_field, + int_field, + schema_field, +) +from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset +from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel + + +class SklearnMultiStepForecasterSchema(BaseSchema): + """Schema for SklearnMultiStepForecaster configuration.""" + + base_estimator: schema_field( + enum_field(enum=["linear", "ridge", "random_forest"]), + placeholder="linear", + description=( + "Base estimator for forecasting. " + "'linear': Fast linear regression (best for linear trends). " + "'ridge': Linear regression with L2 regularization " + "(prevents overfitting). " + "'random_forest': Tree-based ensemble (handles non-linear patterns)." + ), + ) = "linear" # type: ignore + + window_size: schema_field( + int_field(ge=1, le=365), + placeholder=3, + description=( + "Number of past time steps (lags) to use as features. " + "Smaller values work better for small datasets. " + "Will be auto-adjusted if dataset is too small." + ), + ) = 3 # type: ignore + + forecast_strategy: schema_field( + enum_field(enum=["direct", "recursive"]), + placeholder="direct", + description=( + "Multi-step forecasting strategy. " + "'direct': Train separate model for each horizon " + "(more accurate, slower). " + "'recursive': Use predictions as inputs for next step " + "(faster, error compounds)." + ), + ) = "direct" # type: ignore + + +class SklearnMultiStepForecaster(ForecastingModel): + """Sklearn-based multi-step forecasting model. + + This model transforms time series forecasting into a supervised learning problem + by creating lag features automatically. It supports: + - Multiple sklearn base estimators (linear, ridge, random_forest) + - Direct multi-step strategy (separate model per horizon) + - Recursive strategy (iterative predictions) + - Exogenous variables + + Example usage in ForecastingTask: + 1. User uploads time series with columns: [timestamp, value, exog1, exog2] + 2. Task identifies timestamp, target, exogenous variables + 3. Model internally creates lags and trains + 4. Prediction works exactly like Prophet/ARIMA + + The key advantage is that users can leverage sklearn's powerful regression + models for forecasting without manually creating lag features. + """ + + SCHEMA = SklearnMultiStepForecasterSchema + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + _task_type = "ForecastingTask" + + def __init__( + self, + base_estimator: str = "linear", + window_size: int = 3, + forecast_strategy: str = "direct", + **kwargs, + ) -> None: + """Initialize SklearnMultiStepForecaster. + + Parameters + ---------- + base_estimator : str + Base sklearn estimator to use ('linear', 'ridge', 'random_forest') + window_size : int + Number of past time steps to use as lag features + forecast_strategy : str + Strategy for multi-step forecasting ('direct' or 'recursive') + **kwargs + Additional arguments passed to ForecastingModel + """ + super().__init__(**kwargs) + + self.base_estimator = base_estimator + self.window_size = window_size + self.forecast_strategy = forecast_strategy + + # Internal state + self.models: List[Any] = [] + self.training_history: Optional[pd.Series] = None + self.training_exog_history: Optional[pd.DataFrame] = None + self.training_full_series: Optional[pd.Series] = None + self.training_full_exog: Optional[pd.DataFrame] = None + self.training_full_exog: Optional[pd.DataFrame] = None + self.max_horizon: int = 1 + self.last_timestamp: Optional[pd.Timestamp] = None + + def _get_base_estimator(self): + """Get instance of base estimator.""" + estimators = { + "linear": _LinearRegression, + "ridge": _Ridge, + "random_forest": _RandomForestRegressor, + } + + if self.base_estimator not in estimators: + raise ValueError( + f"Unknown base_estimator '{self.base_estimator}'. " + f"Supported: {list(estimators.keys())}" + ) + + return estimators[self.base_estimator]() + + def _create_lag_features( + self, series: pd.Series, exog_df: Optional[pd.DataFrame] = None + ) -> pd.DataFrame: + """Create lag features from time series. + + Parameters + ---------- + series : pd.Series + Time series values + exog_df : pd.DataFrame, optional + Exogenous variables (must have same index as series) + + Returns + ------- + pd.DataFrame + DataFrame with lag features and optional exogenous variables + """ + result = pd.DataFrame(index=series.index) + + # Create lags (lag_1 is t-1, lag_2 is t-2, etc.) + for lag in range(1, self.window_size + 1): + result[f"lag_{lag}"] = series.shift(lag) + + # Add exogenous variables if present + if exog_df is not None: + for col in exog_df.columns: + result[col] = exog_df[col] + + return result + + def fit( + self, + x_train: DashAIDataset, + y: DashAIDataset, + temporal_metadata: dict = None, + **fit_params, + ) -> "SklearnMultiStepForecaster": + """Train the multi-step forecasting model. + + Parameters + ---------- + x_train : DashAIDataset + Input features (timestamp + optional exogenous variables) + y : DashAIDataset + Target time series + temporal_metadata : dict + Metadata from ForecastingTask with timestamp_col, target_col, etc. + **fit_params + Additional fitting parameters (can include 'horizon') + + Returns + ------- + SklearnMultiStepForecaster + Fitted model + """ + if temporal_metadata is None: + raise ValueError( + "temporal_metadata is required for SklearnMultiStepForecaster" + ) + + # Get metadata + self.timestamp_col = temporal_metadata.get("timestamp_col") + self.target_col = temporal_metadata.get("target_col") + self.exog_cols = temporal_metadata.get("exog_cols", []) + self.frequency = temporal_metadata.get("frequency", "D") + + print("[SklearnMultiStepForecaster] Using temporal metadata from task:") + print(f" - Timestamp: '{self.timestamp_col}'") + print(f" - Target: '{self.target_col}'") + print(f" - Exogenous: {self.exog_cols}") + print(f" - Frequency: {self.frequency}") + + # Convert to pandas + x_df = x_train.to_pandas() + y_df = y.to_pandas() + + # Get horizon from fit_params (default to 1) + horizon = fit_params.get("horizon", 1) + self.max_horizon = horizon + + # Store last timestamp and full date sequence for future predictions + if self.timestamp_col in x_df.columns: + parsed_dates = pd.to_datetime(x_df[self.timestamp_col]) + self.last_timestamp = parsed_dates.max() + self.training_dates = parsed_dates.to_numpy() + print(f"[SklearnMultiStepForecaster] Last timestamp: {self.last_timestamp}") + else: + self.last_timestamp = pd.Timestamp.now() + self.training_dates = None + print("[SklearnMultiStepForecaster] ⚠️ No timestamp col, default to now()") + + # Get target series + target_in_inputs = self.target_col in x_df.columns + if target_in_inputs: + print( + f"[SklearnMultiStepForecaster] ℹ️ Target '{self.target_col}' " + "found in inputs - using it from there" + ) + target_series = x_df[self.target_col] + else: + target_series = y_df[self.target_col] + + # Extract exogenous variables if present + exog_df = None + if self.exog_cols: + exog_df = x_df[self.exog_cols] + print(f"[SklearnMultiStepForecaster] Exogenous variables: {self.exog_cols}") + + n_target_samples = len(target_series) + + # Auto-adjust window_size for small datasets + # Need: window_size lags + horizon shifts + at least 2 samples to train + min_required = self.window_size + horizon + 2 + + if n_target_samples < min_required: + # Try to fit within constraints by reducing window size + # The available space for window is samples minus horizon minus margin + available_for_window = n_target_samples - horizon - 2 + + if available_for_window < 1: + # Even with window_size=1, we can't fit. Reduce horizon too. + # Minimum setup: window=1, horizon=1, need at least 4 samples + if n_target_samples >= 4: + self.window_size = 1 + horizon = max(1, n_target_samples - 3) + print( + f"[SklearnMultiStepForecaster] ⚠️ Very small dataset " + f"({n_target_samples} samples). " + f"Forced window_size=1, horizon={horizon}" + ) + else: + raise ValueError( + f"Dataset too small for forecasting. Need at least 4 samples, " + f"got {n_target_samples}. Please use more training data." + ) + else: + old_window = self.window_size + self.window_size = max(1, available_for_window) + print( + f"[SklearnMultiStepForecaster] ⚠️ Auto-adjusted window_size: " + f"{old_window} → {self.window_size} " + f"(target series has {n_target_samples} samples)" + ) + + self.max_horizon = horizon + + print(f"[SklearnMultiStepForecaster] Training for horizon: {horizon}") + print(f"[SklearnMultiStepForecaster] Window size: {self.window_size}") + print(f"[SklearnMultiStepForecaster] Strategy: {self.forecast_strategy}") + + # Create lag features + X_with_lags = self._create_lag_features(target_series, exog_df) + + # For direct strategy: train one model per horizon + if self.forecast_strategy == "direct": + self.models = [] + for h in range(1, horizon + 1): + # Create target: y shifted h steps ahead + y_h = target_series.shift(-h) + + # Remove NaN rows + mask = X_with_lags.notna().all(axis=1) & y_h.notna() + X_clean = X_with_lags[mask] + y_clean = y_h[mask] + + if len(X_clean) == 0: + raise ValueError( + f"No valid samples after creating lags and horizon {h}. " + f"Try reducing window_size or using more data." + ) + + # Train model for this horizon + model = MultiOutputRegressor(self._get_base_estimator()) + model.fit(X_clean.to_numpy(), y_clean.to_numpy().reshape(-1, 1)) + self.models.append(model) + + print( + f"[SklearnMultiStepForecaster] Trained {len(self.models)} models " + "(direct strategy)" + ) + + # For recursive strategy: train single model for 1-step ahead + else: # recursive + y_1 = target_series.shift(-1) + mask = X_with_lags.notna().all(axis=1) & y_1.notna() + X_clean = X_with_lags[mask] + y_clean = y_1[mask] + + if len(X_clean) == 0: + raise ValueError( + "No valid samples after creating lags. " + "Try reducing window_size or using more data." + ) + + model = MultiOutputRegressor(self._get_base_estimator()) + model.fit(X_clean.to_numpy(), y_clean.to_numpy().reshape(-1, 1)) + self.models = [model] + + print("[SklearnMultiStepForecaster] Trained 1 model (recursive strategy)") + + # Store FULL training series and exog for in-sample predictions + # We need the complete history to create lags for any subset + self.training_full_series = target_series.copy() + self.training_history = target_series.iloc[-self.window_size :] + if self.exog_cols and exog_df is not None: + self.training_full_exog = exog_df.copy() + self.training_exog_history = exog_df.iloc[-self.window_size :] + + print("[SklearnMultiStepForecaster] ✅ Training completed") + + return self + + def predict( + self, + x_pred: Optional[Any] = None, + periods: Optional[int] = None, + exog_future: Optional[pd.DataFrame] = None, + **kwargs, + ) -> Union[np.ndarray, pd.DataFrame]: + """Generate forecasts. + + Parameters + ---------- + x_pred : Any, optional + Input data for in-sample predictions containing timestamp and + optional exogenous variables. Can also be an integer for compatibility. + periods : int, optional + Number of steps to forecast into the future + exog_future : pd.DataFrame, optional + Future exogenous variable values + **kwargs + Additional parameters (can include 'horizon' as alias for 'periods') + + Returns + ------- + np.ndarray + Predictions array + """ + if not self.models: + raise ValueError("Model not fitted. Call fit() first.") + + # Handle horizon alias + if periods is None and "horizon" in kwargs: + periods = kwargs["horizon"] + + # Handle different input types (compatibility with ForecastingTask) + if x_pred is not None and isinstance(x_pred, (int, np.integer)): + periods = int(x_pred) + x_pred = None + + # Note: If x_pred is provided with periods, use x_pred as history context + + # In-sample predictions (for metrics calculation) + if x_pred is not None and periods is None: + from DashAI.back.dataloaders.classes.dashai_dataset import ( + to_dashai_dataset, + ) + + if isinstance(x_pred, pd.DataFrame): + input_df = x_pred.copy() + else: + input_df = to_dashai_dataset(x_pred).to_pandas() + + print( + f"[SklearnMultiStepForecaster] In-sample prediction for " + f"{len(input_df)} time steps" + ) + + # For in-sample predictions, we need the full training data + # because we create lags from historical values + if not hasattr(self, "training_full_series"): + raise ValueError( + "No training history available. Model may not be fitted properly." + ) + + # Detect whether timestamps are within training range or beyond. + # val/test splits reset their pandas index to 0-based, so index lookup + # in training lag features would return wrong rows. Use timestamp + # comparison to choose the correct prediction strategy. + is_within_training = True + if self.timestamp_col and self.timestamp_col in input_df.columns: + input_ts = pd.to_datetime(input_df[self.timestamp_col]) + is_within_training = input_ts.max() <= self.last_timestamp + + if is_within_training: + # True in-sample: build lag features from training data and look + # up the matching row positions. + exog_df = None + if self.exog_cols: + missing_cols = [ + col for col in self.exog_cols if col not in input_df.columns + ] + if missing_cols: + raise ValueError( + f"Missing exogenous columns for prediction: {missing_cols}" + ) + exog_df = input_df[self.exog_cols] + + target_series = self.training_full_series + full_exog_df = ( + self.training_full_exog + if hasattr(self, "training_full_exog") + else None + ) + + X_with_lags = self._create_lag_features(target_series, full_exog_df) + + if self.exog_cols and exog_df is not None: + for col in self.exog_cols: + X_with_lags.loc[input_df.index, col] = exog_df[col].to_numpy() + + X_subset = X_with_lags.loc[input_df.index] + mask = X_subset.notna().all(axis=1) + X_clean = X_subset[mask] + + if len(X_clean) == 0: + print( + f"[SklearnMultiStepForecaster] ⚠️ No valid samples for " + f"in-sample prediction (need {self.window_size} historical " + f"values). Returning NaN predictions for " + f"{len(input_df)} points." + ) + return np.full(len(input_df), np.nan) + + predictions_full = np.full(len(input_df), np.nan) + predictions = self.models[0].predict(X_clean.to_numpy()) + predictions_full[mask] = predictions.flatten() + + print( + f"[SklearnMultiStepForecaster] Generated {mask.sum()} in-sample " + f"predictions (first {(~mask).sum()} skipped due to lag window)" + ) + return predictions_full + + else: + # Out-of-training (val/test): recursive 1-step-ahead forecast + # seeded from the last window_size training values. + n_steps = len(input_df) + history = list(self.training_history.to_numpy()) + results = [] + for _ in range(n_steps): + features = np.array(history[-self.window_size :]).reshape(1, -1) + pred = float(self.models[0].predict(features).flatten()[0]) + results.append(pred) + history.append(pred) + + print( + f"[SklearnMultiStepForecaster] Generated {n_steps} recursive " + f"out-of-training predictions" + ) + return np.array(results) + + # Out-of-sample forecast + if periods is not None: + if periods <= 0: + raise ValueError("Prediction horizon must be a positive integer.") + + # Validate exogenous variables if needed + if self.exog_cols: + if exog_future is None: + raise ValueError( + f"Future exogenous values required for columns: " + f"{self.exog_cols}" + ) + + missing_cols = [ + col for col in self.exog_cols if col not in exog_future.columns + ] + if missing_cols: + raise ValueError( + f"Missing exogenous columns for prediction: {missing_cols}" + ) + + if len(exog_future) < periods: + raise ValueError( + f"Exogenous data length ({len(exog_future)}) must be at " + f"least {periods} for the requested forecast horizon." + ) + + # Prepare history for prediction + # If x_pred is provided, use it as history (context) + # Otherwise, use training history + history_series = self.training_history + + if x_pred is not None: + # Convert x_pred to pandas if needed + if isinstance(x_pred, pd.DataFrame): + input_df = x_pred.copy() + else: + from DashAI.back.dataloaders.classes.dashai_dataset import ( + to_dashai_dataset, + ) + + input_df = to_dashai_dataset(x_pred).to_pandas() + + # Check if target column is present + if self.target_col in input_df.columns: + print( + f"[SklearnMultiStepForecaster] Using input as context " + f"({len(input_df)} rows)" + ) + history_series = input_df[self.target_col] + + # Also update last_timestamp if available + if self.timestamp_col in input_df.columns: + self.last_timestamp = pd.to_datetime( + input_df[self.timestamp_col] + ).max() + else: + print( + f"[SklearnMultiStepForecaster] ⚠️ No target col " + f"'{self.target_col}', using training history" + ) + + # Ensure we have enough history + if history_series is None or len(history_series) < self.window_size: + history_len = 0 if history_series is None else len(history_series) + print( + f"[SklearnMultiStepForecaster] ⚠️ History length ({history_len}) " + f"is less than window size ({self.window_size}). " + f"Returning NaN predictions for {periods} periods." + ) + return np.full(periods, np.nan) + + # Direct strategy: use pre-trained models + if self.forecast_strategy == "direct": + predictions = [] + + # We need to maintain current_window for recursive fallback + # Initialize with history + current_window = list(history_series.to_numpy()) + + # Determine how many steps we can predict directly + max_direct_horizon = len(self.models) + + for h in range(periods): + # Step h is 0-indexed (0 = 1st step, 1 = 2nd step, etc.) + + # Case 1: Within direct horizon - use specific model + if h < max_direct_horizon: + # Create features from history + lags = history_series.iloc[-self.window_size :].to_numpy() + + # Add exog if needed + if self.exog_cols and exog_future is not None: + exog_h = exog_future.iloc[h][self.exog_cols].to_numpy() + features = np.concatenate([lags, exog_h]) + else: + features = lags + + # Predict using the specific model for this horizon + pred = self.models[h].predict(features.reshape(1, -1))[0, 0] + + # Case 2: Beyond direct horizon - fallback to recursive + else: + # Use the first model (1-step ahead) recursively + # Create features from CURRENT window (updated with predictions) + lags = np.array(current_window[-self.window_size :]) + + # Add exog if needed + if self.exog_cols and exog_future is not None: + exog_h = exog_future.iloc[h][self.exog_cols].to_numpy() + features = np.concatenate([lags, exog_h]) + else: + features = lags + + # Predict next step using model[0] + pred = self.models[0].predict(features.reshape(1, -1))[0, 0] + + predictions.append(pred) + current_window.append(pred) + + return np.array(predictions) + + # Recursive strategy: iterative predictions + else: + predictions = [] + current_window = list(history_series.to_numpy()) + + for h in range(periods): + # Create features + lags = np.array(current_window[-self.window_size :]) + + # Add exog if needed + if self.exog_cols and exog_future is not None: + exog_h = exog_future.iloc[h][self.exog_cols].to_numpy() + features = np.concatenate([lags, exog_h]) + else: + features = lags + + # Predict next step + pred = self.models[0].predict(features.reshape(1, -1))[0, 0] + predictions.append(pred) + + # Update window with prediction + current_window.append(pred) + + return np.array(predictions) + + raise ValueError( + "Either x_pred or periods parameter must be provided for prediction." + ) + + def get_forecast_uncertainty( + self, horizon: int, confidence_level: float = 0.80 + ) -> pd.DataFrame: + """Get forecast with residual-based prediction intervals. + + Because sklearn regression models have no parametric error distribution, + this method estimates prediction uncertainty empirically: + + 1. Compute in-sample residuals on the training data using the 1-step + ahead model (``models[0]``). + 2. Use the residual standard deviation as the base prediction error. + 3. Scale the half-interval by ``sqrt(h)`` for horizon step ``h`` to + simulate how uncertainty accumulates over time. + 4. Apply a z-score corresponding to the requested confidence level. + + The resulting intervals are wider for longer horizons and reflect the + actual in-sample accuracy of the model. + + Parameters + ---------- + horizon : int + Number of future periods to forecast. + confidence_level : float + Confidence level (e.g., 0.80 for 80% intervals). + + Returns + ------- + pd.DataFrame + Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``. + + Raises + ------ + ValueError + If the model was trained with exogenous variables. + """ + if not self.models: + raise ValueError("Model must be fitted before getting uncertainty.") + + if self.exog_cols: + raise ValueError( + f"Cannot generate forecast uncertainty: model was trained with " + f"exogenous variables {self.exog_cols}. Future exogenous values " + f"are required but not available. " + f"Use ForecastFeatureImportance instead." + ) + + # --- Estimate residual std from in-sample 1-step predictions --- + X_with_lags = self._create_lag_features(self.training_full_series) + mask = X_with_lags.notna().all(axis=1) + X_clean = X_with_lags[mask] + actual_clean = self.training_full_series[mask].to_numpy() + + if len(X_clean) > 0: + in_sample_preds = self.models[0].predict(X_clean.to_numpy()).flatten() + residuals = actual_clean - in_sample_preds + residual_std = float(np.std(residuals)) + else: + residual_std = 0.0 + + # Guard against zero or near-zero std (perfect in-sample fit) + if residual_std < 1e-10: + # Fall back to 5% of the mean absolute value of the training series + residual_std = float( + np.abs(self.training_full_series.to_numpy()).mean() * 0.05 + ) + residual_std = max(residual_std, 1e-6) + + # --- z-score for the requested confidence level --- + try: + from scipy.stats import norm as _norm + + z = float(_norm.ppf(0.5 + confidence_level / 2.0)) + except ImportError: + # Hardcoded fallback for common levels + _z_table = { + 0.80: 1.282, + 0.85: 1.440, + 0.90: 1.645, + 0.95: 1.960, + 0.99: 2.576, + } + z = _z_table.get(round(confidence_level, 2), 1.645) + + # --- Point forecast --- + predictions = self.predict(periods=horizon) + + # --- Growing intervals: half-width = z * std * sqrt(h) --- + horizon_steps = np.arange(1, horizon + 1) + half_width = z * residual_std * np.sqrt(horizon_steps) + + freq = self.frequency or "D" + future_dates = pd.date_range( + start=self.last_timestamp, periods=horizon + 1, freq=freq + )[1:] + + return pd.DataFrame( + { + "ds": future_dates, + "yhat": predictions, + "yhat_lower": predictions - half_width, + "yhat_upper": predictions + half_width, + } + ) + + def get_forecast_components(self, horizon: int) -> pd.DataFrame: + """Decompose forecast into trend, seasonal, and residual components. + + Because SklearnMultiStepForecaster is a regression-based model with + no intrinsic structural decomposition, this method applies STL + (Seasonal-Trend decomposition using LOESS) to the concatenation of + the historical training series and the out-of-sample forecast. + + The resulting trend, seasonal, and residual components describe the + statistical structure of the full series (history + forecast horizon), + and only the forecast portion is returned. + + Parameters + ---------- + horizon : int + Number of future periods to forecast and decompose. + + Returns + ------- + pd.DataFrame + Columns: ``ds``, ``trend``, ````, ``residual``. + + Raises + ------ + ValueError + If the model was trained with exogenous variables. + """ + if not self.models: + raise ValueError("Model must be fitted before getting components.") + + if self.exog_cols: + raise ValueError( + f"Cannot generate forecast components: model was trained with " + f"exogenous variables {self.exog_cols}. Future exogenous values " + f"are required but not available for decomposition. " + f"Use ForecastFeatureImportance instead." + ) + + try: + from statsmodels.tsa.seasonal import STL + except ImportError as exc: + raise ImportError( + "statsmodels is required for STL decomposition. " + "Install with: pip install statsmodels" + ) from exc + + # Build historical series with a proper DatetimeIndex + freq = self.frequency or "D" + historical_values = self.training_full_series.to_numpy() + n_hist = len(historical_values) + + if self.training_dates is not None: + historical_index = pd.DatetimeIndex(self.training_dates) + else: + # Reconstruct dates ending at last_timestamp + historical_index = pd.date_range( + end=self.last_timestamp, periods=n_hist, freq=freq + ) + + historical_series = pd.Series(historical_values, index=historical_index) + + # Out-of-sample forecast + predictions = self.predict(periods=horizon) + future_dates = pd.date_range( + start=self.last_timestamp, periods=horizon + 1, freq=freq + )[1:] + future_series = pd.Series(predictions, index=future_dates) + + # Combine history + forecast + combined = pd.concat([historical_series, future_series]) + + # Determine period and run STL decomposition + period = self._get_seasonal_period() + component_name = self._period_to_seasonality_name(period) + + n = len(combined) + if period >= 2 and n >= 2 * period: + try: + stl = STL(combined, period=period, robust=True) + result = stl.fit() + trend_vals = result.trend + seasonal_vals = result.seasonal + residual_vals = result.resid + except Exception: + window = min(period, max(2, n // 2)) + trend_vals = combined.rolling( + window=window, center=True, min_periods=1 + ).mean() + seasonal_vals = pd.Series(np.zeros(n), index=combined.index) + residual_vals = combined - trend_vals + else: + window = max(2, min(period, n // 2)) + trend_vals = combined.rolling( + window=window, center=True, min_periods=1 + ).mean() + seasonal_vals = pd.Series(np.zeros(n), index=combined.index) + residual_vals = combined - trend_vals + + return pd.DataFrame( + { + "ds": combined.index[-horizon:], + "trend": trend_vals.to_numpy()[-horizon:], + component_name: seasonal_vals.to_numpy()[-horizon:], + "residual": residual_vals.to_numpy()[-horizon:], + } + ) + + def save(self, filename: str) -> None: + """Save model to file. + + Parameters + ---------- + filename : str + Path to save the model + """ + if not self.models: + raise ValueError("Cannot save model before fitting.") + + model_state = { + "models": self.models, + "training_history": self.training_history, + "training_exog_history": self.training_exog_history, + "training_full_series": self.training_full_series, + "training_full_exog": self.training_full_exog, + "training_dates": getattr(self, "training_dates", None), + "exog_cols": self.exog_cols, + "timestamp_col": self.timestamp_col, + "target_col": self.target_col, + "frequency": self.frequency, + "max_horizon": self.max_horizon, + "last_timestamp": self.last_timestamp, + "config": { + "base_estimator": self.base_estimator, + "window_size": self.window_size, + "forecast_strategy": self.forecast_strategy, + }, + } + + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(model_state, f) + + print(f"✅ SklearnMultiStepForecaster saved to {filename}") + + def load(self, filename: str) -> "SklearnMultiStepForecaster": + """Load model from file. + + Parameters + ---------- + filename : str + Path to load the model from + + Returns + ------- + SklearnMultiStepForecaster + Loaded model instance + """ + with open(filename, "rb") as f: + model_state = pickle.load(f) + + self.models = model_state["models"] + self.training_history = model_state["training_history"] + self.training_exog_history = model_state.get("training_exog_history") + self.training_full_series = model_state.get("training_full_series") + self.training_full_exog = model_state.get("training_full_exog") + self.training_dates = model_state.get("training_dates") + self.exog_cols = model_state["exog_cols"] + self.timestamp_col = model_state.get("timestamp_col") + self.target_col = model_state.get("target_col") + self.frequency = model_state.get("frequency") + self.max_horizon = model_state.get("max_horizon", 1) + self.last_timestamp = model_state.get("last_timestamp") + + config = model_state["config"] + for key, value in config.items(): + setattr(self, key, value) + + print(f"✅ SklearnMultiStepForecaster loaded from {filename}") + return self diff --git a/DashAI/back/models/forecasting/statsmodels_arima_model.py b/DashAI/back/models/forecasting/statsmodels_arima_model.py new file mode 100644 index 000000000..8b861ec0d --- /dev/null +++ b/DashAI/back/models/forecasting/statsmodels_arima_model.py @@ -0,0 +1,626 @@ +"""Statsmodels ARIMA model wrapper for DashAI forecasting. + +This model wraps statsmodels ARIMA for time series forecasting with +autoregressive integrated moving average modeling. +""" + +import os +import pickle +from typing import Any, Optional, Union + +import numpy as np +import pandas as pd + +from DashAI.back.core.schema_fields import ( + BaseSchema, + enum_field, + int_field, + schema_field, +) +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + to_dashai_dataset, +) +from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel + + +class StatsmodelsARIMAModelSchema(BaseSchema): + """Schema for Statsmodels ARIMA model configuration. + + ARIMA (AutoRegressive Integrated Moving Average) is a forecasting method + that captures different aspects of time series: + - AR (p): Autoregression - uses past values + - I (d): Integration - differencing to make series stationary + - MA (q): Moving Average - uses past forecast errors + """ + + p: schema_field( + int_field(ge=0, le=10), + placeholder=1, + description="Order of autoregressive (AR) component. Number of lag " + "observations included in the model (how many past values to use).", + ) = 1 # type: ignore + + d: schema_field( + int_field(ge=0, le=3), + placeholder=1, + description="Degree of differencing (I component). Number of times " + "to difference the data to make it stationary. 0=stationary, " + "1=first difference, 2=second difference.", + ) = 1 # type: ignore + + q: schema_field( + int_field(ge=0, le=10), + placeholder=1, + description="Order of moving average (MA) component. Size of the " + "moving average window (how many past forecast errors to use).", + ) = 1 # type: ignore + + trend: schema_field( + enum_field(enum=["n", "c", "t", "ct"]), + placeholder="n", + description=( + "Deterministic trend to include. 'n'=no trend, 'c'=constant " + "(level), 't'=linear trend, 'ct'=constant and linear trend. " + "Note: When d>0, 'c' is not allowed (use 't' instead). " + "When d>1, neither 'c' nor 't' are allowed." + ), + ) = "n" # type: ignore + + +class StatsmodelsARIMAModel(ForecastingModel): + """Statsmodels ARIMA forecasting model wrapper for DashAI. + + This model implements the ForecastingModel interface using statsmodels ARIMA. + It handles column name conversions internally and supports exogenous variables. + """ + + SCHEMA = StatsmodelsARIMAModelSchema + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + _task_type = "ForecastingTask" + + def __init__( + self, + p: int = 1, + d: int = 1, + q: int = 1, + trend: str = "n", + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.p = p + self.d = d + self.q = q + self.trend = trend + self.order = (p, d, q) + + self.model = None + self.model_fit = None + self.frequency: Optional[str] = None + + def _validate_forecasting_data(self, x: DashAIDataset, y: DashAIDataset) -> None: + """Validate that data is suitable for ARIMA. + + Parameters + ---------- + x : DashAIDataset + Input features (must contain a timestamp column) + y : DashAIDataset + Target values (must contain a numeric column) + + Raises + ------ + ValueError + If data is not suitable for ARIMA + """ + x_cols = set(x.column_names) + y_cols = set(y.column_names) + + if len(x_cols) == 0: + raise ValueError( + "ARIMA requires at least one input column (timestamp). " + "Received empty dataset." + ) + + if len(y_cols) != 1: + raise ValueError( + f"ARIMA requires exactly one target column. " + f"Received {len(y_cols)} columns: {list(y_cols)}" + ) + + def fit( + self, + x_train: DashAIDataset, + y: DashAIDataset, + temporal_metadata: dict = None, + **fit_params, + ) -> "StatsmodelsARIMAModel": + """Train ARIMA forecasting model. + + Parameters + ---------- + x_train : DashAIDataset + Input features containing timestamp and optional exogenous variables + y : DashAIDataset + Target time series (single column) + temporal_metadata : dict, optional + Metadata from ForecastingTask containing: + - timestamp_col: name of timestamp column + - target_col: name of target column + - exog_cols: list of exogenous variable column names + - frequency: time series frequency + **fit_params + Additional fitting parameters + + Returns + ------- + StatsmodelsARIMAModel + Fitted model instance + """ + try: + from statsmodels.tsa.arima.model import ARIMA + except ImportError as e: + raise ImportError( + "Statsmodels is required for StatsmodelsARIMAModel. " + "Install with: pip install statsmodels" + ) from e + + # Validate data format + self._validate_forecasting_data(x_train, y) + + # Convert to pandas DataFrames + x_df = x_train.to_pandas() + y_df = y.to_pandas() + + # Get column information from metadata + if temporal_metadata: + timestamp_col = temporal_metadata.get("timestamp_col") + target_col = temporal_metadata.get("target_col") + exog_cols_from_task = temporal_metadata.get("exog_cols", []) + frequency = temporal_metadata.get("frequency") + + print("[StatsmodelsARIMAModel] Using temporal metadata from task:") + print(f" - Timestamp: '{timestamp_col}'") + print(f" - Target: '{target_col}'") + print(f" - Frequency: {frequency}") + if exog_cols_from_task: + print(f" - Exogenous variables: {exog_cols_from_task}") + else: + # Auto-detection if no metadata provided + print( + "[StatsmodelsARIMAModel] ⚠️ No temporal_metadata provided, " + "using auto-detection" + ) + + target_col = y_df.columns[0] + + # Auto-detect timestamp column + timestamp_col = None + for col in x_df.columns: + try: + pd.to_datetime(x_df[col]) + timestamp_col = col + print(f"[StatsmodelsARIMAModel] Detected timestamp column: '{col}'") + break + except Exception: + continue + + if timestamp_col is None: + raise ValueError( + f"No timestamp column found in input data. " + f"Available columns: {list(x_df.columns)}" + ) + + exog_cols_from_task = [] + frequency = fit_params.get("frequency") + + # Store original column names + self.timestamp_col = timestamp_col + self.target_col = target_col + self.frequency = frequency + + # Prepare data for ARIMA + # Create datetime index + dates = pd.to_datetime(x_df[timestamp_col]) + + # Store last training date for forecast generation + self.last_ds = dates.max() + + # Get target series + target_in_inputs = target_col in x_df.columns + if target_in_inputs: + print( + "[StatsmodelsARIMAModel] ℹ️ Target '{}' found in inputs - " + "using it from there".format(target_col) + ) + endog = x_df[target_col].to_numpy() + else: + endog = y_df[target_col].to_numpy() + + # Create time series with datetime index + endog_series = pd.Series(endog, index=dates) + + # Prepare exogenous variables + self.exog_cols = [] + exog = None + + for col in x_df.columns: + if col == timestamp_col: + continue + if col == target_col: + if target_in_inputs: + print( + "[StatsmodelsARIMAModel] ℹ️ Excluding target '{}' from " + "exogenous variables".format(col) + ) + continue + + # Only add numeric columns + if pd.api.types.is_numeric_dtype(x_df[col]): + self.exog_cols.append(col) + else: + print( + "[StatsmodelsARIMAModel] ⚠️ Skipping non-numeric column: '{}' " + "(type: {})".format(col, x_df[col].dtype) + ) + + if self.exog_cols: + exog = x_df[self.exog_cols].to_numpy() + print(f"[StatsmodelsARIMAModel] Exogenous variables: {self.exog_cols}") + + print(f"[StatsmodelsARIMAModel] Training ARIMA{self.order} model") + print(f"[StatsmodelsARIMAModel] Training with {len(endog_series)} data points") + print(f"[StatsmodelsARIMAModel] Date range: {dates.min()} to {dates.max()}") + + # Fit ARIMA model + self.model = ARIMA( + endog=endog_series, + exog=exog, + order=self.order, + trend=self.trend, + ) + + self.model_fit = self.model.fit() + + print("✅ ARIMA model training completed") + print(f"[StatsmodelsARIMAModel] AIC: {self.model_fit.aic:.2f}") + print(f"[StatsmodelsARIMAModel] BIC: {self.model_fit.bic:.2f}") + + return self + + def predict( + self, + x_pred: Optional[Any] = None, + periods: Optional[int] = None, + exog_future: Optional[pd.DataFrame] = None, + **kwargs, + ) -> Union[np.ndarray, pd.DataFrame]: + """Generate forecasts using ARIMA model. + + Parameters + ---------- + x_pred : pd.DataFrame, optional + Input data for in-sample predictions containing timestamp and + exogenous variables (if model uses them) + periods : int, optional + Number of future periods to forecast (out-of-sample mode) + exog_future : pd.DataFrame, optional + Future values of exogenous variables for out-of-sample forecasting + **kwargs + Additional parameters + + Returns + ------- + np.ndarray or pd.DataFrame + Predictions array + """ + if self.model_fit is None: + raise ValueError("ARIMA model is not fitted yet. Call fit() first.") + + # Handle different input types + if x_pred is not None and isinstance(x_pred, (int, np.integer)): + periods = int(x_pred) + x_pred = None + + # Out-of-sample forecasting + if periods is not None and x_pred is None: + if periods <= 0: + raise ValueError("Prediction horizon must be a positive integer.") + + # Prepare exogenous variables for forecast + exog = None + if self.exog_cols: + if exog_future is None: + raise ValueError( + f"Future exogenous values required for columns: " + f"{self.exog_cols}." + ) + + missing_cols = [ + col for col in self.exog_cols if col not in exog_future.columns + ] + if missing_cols: + raise ValueError( + f"Missing exogenous columns for prediction: {missing_cols}." + ) + + if len(exog_future) != periods: + raise ValueError( + f"Exogenous data length ({len(exog_future)}) must match " + f"prediction horizon ({periods})." + ) + + exog = exog_future[self.exog_cols].to_numpy() + + # Generate forecast + forecast = self.model_fit.forecast(steps=periods, exog=exog) + + print(f"[StatsmodelsARIMAModel] Generated forecast for {periods} periods") + return forecast.to_numpy() + + # In-sample predictions + if x_pred is not None: + if isinstance(x_pred, pd.DataFrame): + input_df = x_pred.copy() + else: + input_df = to_dashai_dataset(x_pred).to_pandas() + + # Auto-detect timestamp column + timestamp_col = None + for col in input_df.columns: + try: + pd.to_datetime(input_df[col]) + timestamp_col = col + break + except Exception: + continue + + if timestamp_col is None: + raise ValueError( + "ARIMA predict requires a timestamp column. " + f"Available columns: {list(input_df.columns)}" + ) + + dates = pd.to_datetime(input_df[timestamp_col]) + + # Prepare exogenous variables + exog = None + if self.exog_cols: + missing_cols = [ + col for col in self.exog_cols if col not in input_df.columns + ] + if missing_cols: + raise ValueError( + f"Missing exogenous columns for prediction: {missing_cols}." + ) + exog = input_df[self.exog_cols].to_numpy() + + # Use actual dates so statsmodels predicts the correct period + # (works for both in-sample and out-of-sample dates) + predictions = self.model_fit.predict( + start=dates.iloc[0], end=dates.iloc[-1], exog=exog + ) + + return predictions.to_numpy() + + raise ValueError( + "ARIMA predict requires either 'x_pred' data or a 'periods' value." + ) + + def get_forecast_uncertainty( + self, horizon: int, confidence_level: float = 0.80 + ) -> pd.DataFrame: + """Get forecast with parametric confidence intervals from ARIMA. + + Uses statsmodels ``get_forecast().summary_frame()`` to compute + analytical confidence intervals derived from the model's error + distribution. These are true parametric intervals, not estimates. + + Parameters + ---------- + horizon : int + Number of future periods to forecast. + confidence_level : float + Confidence level (e.g., 0.80 for 80% intervals). + + Returns + ------- + pd.DataFrame + Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``. + + Raises + ------ + ValueError + If the model was trained with exogenous variables. + """ + if self.model_fit is None: + raise ValueError("Model must be fitted before getting uncertainty.") + + if self.exog_cols: + raise ValueError( + f"Cannot generate forecast uncertainty: model was trained with " + f"exogenous variables {self.exog_cols}. Future exogenous values " + f"are required but not available. " + f"Use ForecastFeatureImportance instead." + ) + + alpha = 1.0 - confidence_level + forecast_obj = self.model_fit.get_forecast(steps=horizon) + summary = forecast_obj.summary_frame(alpha=alpha) + # summary columns: mean, mean_se, mean_ci_lower, mean_ci_upper + + freq = self.frequency or "D" + future_dates = pd.date_range( + start=self.last_ds, periods=horizon + 1, freq=freq + )[1:] + + return pd.DataFrame( + { + "ds": future_dates, + "yhat": summary["mean"].to_numpy(), + "yhat_lower": summary["mean_ci_lower"].to_numpy(), + "yhat_upper": summary["mean_ci_upper"].to_numpy(), + } + ) + + def get_forecast_components(self, horizon: int) -> pd.DataFrame: + """Decompose forecast into trend, seasonal, and residual components. + + Applies STL (Seasonal-Trend decomposition using LOESS) to the + combination of in-sample fitted values and out-of-sample forecast. + When there is insufficient data for STL, falls back to a centered + moving-average trend. + + ARIMA does not model seasonality explicitly, so the seasonal component + reflects the cyclical pattern extracted from the data by STL. + + Parameters + ---------- + horizon : int + Number of future periods to forecast and decompose. + + Returns + ------- + pd.DataFrame + Columns: ``ds``, ``trend``, ````, ``residual``. + The seasonality column name is derived from the stored frequency + (e.g. ``weekly`` for daily data, ``yearly`` for monthly data). + + Raises + ------ + ValueError + If the model was trained with exogenous variables (future values + would be required but are unavailable here). + """ + if self.model_fit is None: + raise ValueError("Model must be fitted before getting components.") + + if self.exog_cols: + raise ValueError( + f"Cannot generate forecast components: model was trained with " + f"exogenous variables {self.exog_cols}. Future exogenous values " + f"are required but not available for decomposition. " + f"Use ForecastFeatureImportance instead." + ) + + try: + from statsmodels.tsa.seasonal import STL + except ImportError as exc: + raise ImportError("statsmodels is required for STL decomposition.") from exc + + # In-sample fitted values (DatetimeIndex from training) + fitted = self.model_fit.fittedvalues.dropna() + + # Out-of-sample forecast + freq = self.frequency or "D" + forecast_result = self.model_fit.forecast(steps=horizon) + future_dates = pd.date_range( + start=self.last_ds, periods=horizon + 1, freq=freq + )[1:] + future_series = pd.Series(forecast_result.to_numpy(), index=future_dates) + + # Combine history + forecast into one series + combined = pd.concat([fitted, future_series]) + + # Determine period for STL + period = self._get_seasonal_period() + component_name = self._period_to_seasonality_name(period) + + n = len(combined) + if period >= 2 and n >= 2 * period: + try: + stl = STL(combined, period=period, robust=True) + result = stl.fit() + trend_vals = result.trend + seasonal_vals = result.seasonal + residual_vals = result.resid + except Exception: + # Fallback to moving-average trend if STL fails + window = min(period, max(2, n // 2)) + trend_vals = combined.rolling( + window=window, center=True, min_periods=1 + ).mean() + seasonal_vals = pd.Series(np.zeros(n), index=combined.index) + residual_vals = combined - trend_vals + else: + # Not enough data for STL — use simple moving-average trend + window = max(2, min(period, n // 2)) + trend_vals = combined.rolling( + window=window, center=True, min_periods=1 + ).mean() + seasonal_vals = pd.Series(np.zeros(n), index=combined.index) + residual_vals = combined - trend_vals + + # Return only the forecast horizon (the future portion) + return pd.DataFrame( + { + "ds": combined.index[-horizon:], + "trend": trend_vals.to_numpy()[-horizon:], + component_name: seasonal_vals.to_numpy()[-horizon:], + "residual": residual_vals.to_numpy()[-horizon:], + } + ) + + def save(self, filename: str) -> None: + """Save ARIMA model to file. + + Parameters + ---------- + filename : str + Path to save the model + """ + if self.model_fit is None: + raise ValueError("Cannot save model before fitting.") + + model_state = { + "model_fit": self.model_fit, + "exog_cols": self.exog_cols, + "timestamp_col": self.timestamp_col, + "target_col": self.target_col, + "frequency": self.frequency, + "last_ds": getattr(self, "last_ds", None), + "config": { + "p": self.p, + "d": self.d, + "q": self.q, + "trend": self.trend, + "order": self.order, + }, + } + + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(model_state, f) + + print(f"✅ ARIMA model saved to {filename}") + + def load(self, filename: str) -> "StatsmodelsARIMAModel": + """Load ARIMA model from file. + + Parameters + ---------- + filename : str + Path to load the model from + + Returns + ------- + StatsmodelsARIMAModel + Loaded model instance + """ + with open(filename, "rb") as f: + model_state = pickle.load(f) + + self.model_fit = model_state["model_fit"] + self.exog_cols = model_state["exog_cols"] + self.timestamp_col = model_state.get("timestamp_col") + self.target_col = model_state.get("target_col") + self.frequency = model_state.get("frequency") + self.last_ds = model_state.get("last_ds") + + config = model_state["config"] + for key, value in config.items(): + setattr(self, key, value) + + print(f"✅ ARIMA model loaded from {filename}") + return self diff --git a/DashAI/back/models/forecasting/statsmodels_sarimax_model.py b/DashAI/back/models/forecasting/statsmodels_sarimax_model.py new file mode 100644 index 000000000..c78749fa6 --- /dev/null +++ b/DashAI/back/models/forecasting/statsmodels_sarimax_model.py @@ -0,0 +1,813 @@ +"""Statsmodels SARIMAX model wrapper for DashAI forecasting. + +This model wraps statsmodels SARIMAX for seasonal time series forecasting with +autoregressive integrated moving average modeling and exogenous variables. +""" + +import os +import pickle +from typing import Any, Optional, Union + +import numpy as np +import pandas as pd + +from DashAI.back.core.schema_fields import ( + BaseSchema, + bool_field, + enum_field, + int_field, + schema_field, +) +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + to_dashai_dataset, +) +from DashAI.back.models.forecasting.base_forecasting_model import ForecastingModel + + +class StatsmodelsSARIMAXModelSchema(BaseSchema): + """Schema for Statsmodels SARIMAX model configuration. + + SARIMAX (Seasonal AutoRegressive Integrated Moving Average with eXogenous + regressors) extends ARIMA with seasonal components and external variables: + - (p,d,q): Non-seasonal AR, differencing, MA orders + - (P,D,Q,s): Seasonal AR, differencing, MA orders and periodicity + - Exogenous variables: External predictors + """ + + p: schema_field( + int_field(ge=0, le=10), + placeholder=1, + description="Order of non-seasonal autoregressive (AR) component. " + "Number of lag observations included in the model.", + ) = 1 # type: ignore + + d: schema_field( + int_field(ge=0, le=3), + placeholder=1, + description="Degree of non-seasonal differencing. Number of times " + "to difference the data to make it stationary.", + ) = 1 # type: ignore + + q: schema_field( + int_field(ge=0, le=10), + placeholder=1, + description="Order of non-seasonal moving average (MA) component. " + "Size of the moving average window.", + ) = 1 # type: ignore + + P: schema_field( + int_field(ge=0, le=5), + placeholder=0, + description="Order of seasonal autoregressive component. " + "Seasonal lag observations. Set to 0 to disable seasonality.", + ) = 0 # type: ignore + + D: schema_field( + int_field(ge=0, le=2), + placeholder=0, + description="Degree of seasonal differencing. Seasonal differencing order. " + "Set to 0 to disable seasonal differencing.", + ) = 0 # type: ignore + + Q: schema_field( + int_field(ge=0, le=5), + placeholder=0, + description="Order of seasonal moving average component. " + "Seasonal moving average window. Set to 0 to disable.", + ) = 0 # type: ignore + + s: schema_field( + int_field(ge=1, le=365), + placeholder=1, + description="Seasonal period (observations per cycle). " + "12=monthly, 4=quarterly, 7=weekly. Set to 1 to disable seasonality.", + ) = 1 # type: ignore + + trend: schema_field( + enum_field(enum=["n", "c", "t", "ct"]), + placeholder="n", + description=( + "Deterministic trend to include. 'n'=no trend, 'c'=constant, " + "'t'=linear trend, 'ct'=constant and linear trend. " + "Note: When d>0 or D>0, 'c' is not allowed (use 't' instead). " + "When d+D>1, neither 'c' nor 't' are allowed." + ), + ) = "n" # type: ignore + + enforce_stationarity: schema_field( + bool_field(), + placeholder=True, + description=( + "Whether to enforce stationarity of the autoregressive parameters." + ), + ) = True # type: ignore + + enforce_invertibility: schema_field( + bool_field(), + placeholder=True, + description=( + "Whether to enforce invertibility of the moving average parameters." + ), + ) = True # type: ignore + + +class StatsmodelsSARIMAXModel(ForecastingModel): + """Statsmodels SARIMAX forecasting model wrapper for DashAI. + + This model implements the ForecastingModel interface using statsmodels SARIMAX. + It handles seasonal patterns, exogenous variables, and column name conversions. + """ + + SCHEMA = StatsmodelsSARIMAXModelSchema + COMPATIBLE_COMPONENTS = ["ForecastingTask"] + _task_type = "ForecastingTask" + + def __init__( + self, + p: int = 1, + d: int = 1, + q: int = 1, + P: int = 0, # noqa: N803 + D: int = 0, # noqa: N803 + Q: int = 0, # noqa: N803 + s: int = 1, + trend: str = "n", + enforce_stationarity: bool = True, + enforce_invertibility: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.p = p + self.d = d + self.q = q + self.P = P + self.D = D + self.Q = Q + self.s = s + self.trend = trend + self.enforce_stationarity = enforce_stationarity + self.enforce_invertibility = enforce_invertibility + + self.order = (p, d, q) + self.seasonal_order = (P, D, Q, s) + + self.model = None + self.model_fit = None + self.frequency: Optional[str] = None + + def _validate_forecasting_data(self, x: DashAIDataset, y: DashAIDataset) -> None: + """Validate that data is suitable for SARIMAX. + + Parameters + ---------- + x : DashAIDataset + Input features (must contain a timestamp column) + y : DashAIDataset + Target values (must contain a numeric column) + + Raises + ------ + ValueError + If data is not suitable for SARIMAX + """ + x_cols = set(x.column_names) + y_cols = set(y.column_names) + + if len(x_cols) == 0: + raise ValueError( + "SARIMAX requires at least one input column (timestamp). " + "Received empty dataset." + ) + + if len(y_cols) != 1: + raise ValueError( + f"SARIMAX requires exactly one target column. " + f"Received {len(y_cols)} columns: {list(y_cols)}" + ) + + def fit( + self, + x_train: DashAIDataset, + y: DashAIDataset, + temporal_metadata: dict = None, + **fit_params, + ) -> "StatsmodelsSARIMAXModel": + """Train SARIMAX forecasting model. + + Parameters + ---------- + x_train : DashAIDataset + Input features containing timestamp and optional exogenous variables + y : DashAIDataset + Target time series (single column) + temporal_metadata : dict, optional + Metadata from ForecastingTask containing: + - timestamp_col: name of timestamp column + - target_col: name of target column + - exog_cols: list of exogenous variable column names + - frequency: time series frequency + **fit_params + Additional fitting parameters + + Returns + ------- + StatsmodelsSARIMAXModel + Fitted model instance + """ + try: + from statsmodels.tsa.statespace.sarimax import SARIMAX + except ImportError as e: + raise ImportError( + "Statsmodels is required for StatsmodelsSARIMAXModel. " + "Install with: pip install statsmodels" + ) from e + + # Validate data format + self._validate_forecasting_data(x_train, y) + + # Convert to pandas DataFrames + x_df = x_train.to_pandas() + y_df = y.to_pandas() + + # Get column information from metadata + if temporal_metadata: + timestamp_col = temporal_metadata.get("timestamp_col") + target_col = temporal_metadata.get("target_col") + exog_cols_from_task = temporal_metadata.get("exog_cols", []) + frequency = temporal_metadata.get("frequency") + + print("[StatsmodelsSARIMAXModel] Using temporal metadata from task:") + print(f" - Timestamp: '{timestamp_col}'") + print(f" - Target: '{target_col}'") + print(f" - Frequency: {frequency}") + if exog_cols_from_task: + print(f" - Exogenous variables: {exog_cols_from_task}") + else: + # Auto-detection if no metadata provided + print( + "[StatsmodelsSARIMAXModel] ⚠️ No temporal_metadata provided, " + "using auto-detection" + ) + + target_col = y_df.columns[0] + + # Auto-detect timestamp column + timestamp_col = None + for col in x_df.columns: + try: + pd.to_datetime(x_df[col]) + timestamp_col = col + print( + f"[StatsmodelsSARIMAXModel] Detected timestamp column: '{col}'" + ) + break + except Exception: + continue + + if timestamp_col is None: + raise ValueError( + f"No timestamp column found in input data. " + f"Available columns: {list(x_df.columns)}" + ) + + exog_cols_from_task = [] + frequency = fit_params.get("frequency") + + # Store original column names + self.timestamp_col = timestamp_col + self.target_col = target_col + self.frequency = frequency + + # Prepare data for SARIMAX + # Create datetime index + dates = pd.to_datetime(x_df[timestamp_col]) + + # Store last training date for forecast generation + self.last_ds = dates.max() + + # Get target series + target_in_inputs = target_col in x_df.columns + if target_in_inputs: + print( + "[StatsmodelsSARIMAXModel] ℹ️ Target '{}' found in inputs - " + "using it from there".format(target_col) + ) + endog = x_df[target_col].to_numpy() + else: + endog = y_df[target_col].to_numpy() + + # Create time series with datetime index + endog_series = pd.Series(endog, index=dates) + + # Prepare exogenous variables + self.exog_cols = [] + exog = None + + for col in x_df.columns: + if col == timestamp_col: + continue + if col == target_col: + if target_in_inputs: + print( + "[StatsmodelsSARIMAXModel] ℹ️ Excluding target '{}' from " + "exogenous variables".format(col) + ) + continue + + # Only add numeric columns + if pd.api.types.is_numeric_dtype(x_df[col]): + self.exog_cols.append(col) + else: + print( + "[StatsmodelsSARIMAXModel] ⚠️ Skipping non-numeric column: '{}' " + "(type: {})".format(col, x_df[col].dtype) + ) + + if self.exog_cols: + exog = x_df[self.exog_cols].to_numpy() + print(f"[StatsmodelsSARIMAXModel] Exogenous variables: {self.exog_cols}") + + print( + f"[StatsmodelsSARIMAXModel] Training " + f"SARIMAX{self.order}x{self.seasonal_order} model" + ) + print( + f"[StatsmodelsSARIMAXModel] Training with {len(endog_series)} data points" + ) + print(f"[StatsmodelsSARIMAXModel] Date range: {dates.min()} to {dates.max()}") + + # Auto-adjust parameters for small datasets + n_samples = len(endog_series) + + # Check if seasonality is enabled (s>1 and any seasonal param > 0) + has_seasonality = self.s > 1 and (self.P > 0 or self.D > 0 or self.Q > 0) + + if has_seasonality: + # SARIMAX needs: s + d + D + max(p, P) + max(q, Q) samples + min_required = ( + self.s + self.d + self.D + max(self.p, self.P) + max(self.q, self.Q) + 2 + ) + + if n_samples < min_required: + print( + f"[StatsmodelsSARIMAXModel] ⚠️ Dataset too small " + f"({n_samples} samples) for seasonal params " + f"(need {min_required}). Disabling seasonality..." + ) + # Disable seasonality entirely for small datasets + self.P = 0 + self.D = 0 + self.Q = 0 + self.s = 1 + has_seasonality = False + + # For non-seasonal ARIMA, check basic requirements + min_arima_samples = self.p + self.d + self.q + 3 + if n_samples < min_arima_samples: + print("[StatsmodelsSARIMAXModel] ⚠️ Adjusting ARIMA orders...") + # Reduce AR/MA orders if needed + max_order = max(1, (n_samples - self.d - 2) // 2) + if self.p > max_order: + old_p = self.p + self.p = max(0, max_order) + print(f" - Reduced p (AR order): {old_p} → {self.p}") + if self.q > max_order: + old_q = self.q + self.q = max(0, max_order) + print(f" - Reduced q (MA order): {old_q} → {self.q}") + if self.d > 1 and n_samples < 10: + old_d = self.d + self.d = min(1, self.d) + print(f" - Reduced d (differencing): {old_d} → {self.d}") + + self.order = (self.p, self.d, self.q) + + # Set seasonal_order: None if no seasonality, otherwise tuple + if has_seasonality: + self.seasonal_order = (self.P, self.D, self.Q, self.s) + print( + f"[StatsmodelsSARIMAXModel] Final parameters: " + f"SARIMAX{self.order}x{self.seasonal_order}" + ) + else: + self.seasonal_order = (0, 0, 0, 0) # Disable seasonality + print( + f"[StatsmodelsSARIMAXModel] Final parameters: " + f"ARIMA{self.order} (no seasonality)" + ) + + # Fit SARIMAX model (use ARIMA when no seasonality for stability) + try: + if has_seasonality: + self.model = SARIMAX( + endog=endog_series, + exog=exog, + order=self.order, + seasonal_order=self.seasonal_order, + trend=self.trend, + enforce_stationarity=self.enforce_stationarity, + enforce_invertibility=self.enforce_invertibility, + ) + self.model_fit = self.model.fit() + else: + # Use ARIMA (no seasonal component) for small datasets or when s <= 1 + # SARIMAX doesn't accept seasonal_order with s <= 1 + from statsmodels.tsa.arima.model import ARIMA + + print("[StatsmodelsSARIMAXModel] Using ARIMA (no seasonal component)") + self.model = ARIMA( + endog=endog_series, + exog=exog, + order=self.order, + trend=self.trend, + enforce_stationarity=self.enforce_stationarity, + enforce_invertibility=self.enforce_invertibility, + ) + self.model_fit = self.model.fit() + + print("✅ SARIMAX model training completed") + print(f"[StatsmodelsSARIMAXModel] AIC: {self.model_fit.aic:.2f}") + print(f"[StatsmodelsSARIMAXModel] BIC: {self.model_fit.bic:.2f}") + except ValueError as e: + # Fallback: if SARIMAX fails due to seasonality issues, try ARIMA + if "Seasonal periodicity" in str(e) or "seasonal" in str(e).lower(): + print(f"[StatsmodelsSARIMAXModel] ⚠️ SARIMAX seasonality error: {e}") + print("[StatsmodelsSARIMAXModel] Falling back to ARIMA") + from statsmodels.tsa.arima.model import ARIMA + + self.model = ARIMA( + endog=endog_series, + exog=exog, + order=self.order, + trend=self.trend, + enforce_stationarity=self.enforce_stationarity, + enforce_invertibility=self.enforce_invertibility, + ) + self.model_fit = self.model.fit() + self.seasonal_order = (0, 0, 0, 0) + print("✅ ARIMA model training completed (fallback)") + print(f"[StatsmodelsSARIMAXModel] AIC: {self.model_fit.aic:.2f}") + print(f"[StatsmodelsSARIMAXModel] BIC: {self.model_fit.bic:.2f}") + else: + print(f"[StatsmodelsSARIMAXModel] ❌ Training failed: {e}") + raise + except Exception as e: + print(f"[StatsmodelsSARIMAXModel] ❌ Training failed: {e}") + raise + + return self + + def predict( + self, + x_pred: Optional[Any] = None, + periods: Optional[int] = None, + exog_future: Optional[pd.DataFrame] = None, + **kwargs, + ) -> Union[np.ndarray, pd.DataFrame]: + """Generate forecasts using SARIMAX model. + + Parameters + ---------- + x_pred : pd.DataFrame, optional + Input data for in-sample predictions containing timestamp and + exogenous variables (if model uses them) + periods : int, optional + Number of future periods to forecast (out-of-sample mode) + exog_future : pd.DataFrame, optional + Future values of exogenous variables for out-of-sample forecasting + **kwargs + Additional parameters + + Returns + ------- + np.ndarray or pd.DataFrame + Predictions array + """ + if self.model_fit is None: + raise ValueError("SARIMAX model is not fitted yet. Call fit() first.") + + # Handle different input types + if x_pred is not None and isinstance(x_pred, (int, np.integer)): + periods = int(x_pred) + x_pred = None + + # Out-of-sample forecasting + if periods is not None and x_pred is None: + if periods <= 0: + raise ValueError("Prediction horizon must be a positive integer.") + + # Prepare exogenous variables for forecast + exog = None + if self.exog_cols: + if exog_future is None: + raise ValueError( + f"Future exogenous values required for columns: " + f"{self.exog_cols}." + ) + + missing_cols = [ + col for col in self.exog_cols if col not in exog_future.columns + ] + if missing_cols: + raise ValueError( + f"Missing exogenous columns for prediction: {missing_cols}." + ) + + if len(exog_future) != periods: + raise ValueError( + f"Exogenous data length ({len(exog_future)}) must match " + f"prediction horizon ({periods})." + ) + + exog = exog_future[self.exog_cols].to_numpy() + + # Generate forecast + forecast = self.model_fit.forecast(steps=periods, exog=exog) + + print(f"[StatsmodelsSARIMAXModel] Generated forecast for {periods} periods") + return forecast.to_numpy() + + # In-sample predictions + if x_pred is not None: + if isinstance(x_pred, pd.DataFrame): + input_df = x_pred.copy() + else: + input_df = to_dashai_dataset(x_pred).to_pandas() + + # Auto-detect timestamp column + timestamp_col = None + for col in input_df.columns: + try: + pd.to_datetime(input_df[col]) + timestamp_col = col + break + except Exception: + continue + + if timestamp_col is None: + raise ValueError( + "SARIMAX predict requires a timestamp column. " + f"Available columns: {list(input_df.columns)}" + ) + + dates = pd.to_datetime(input_df[timestamp_col]) + + # Prepare exogenous variables + exog = None + if self.exog_cols: + missing_cols = [ + col for col in self.exog_cols if col not in input_df.columns + ] + if missing_cols: + raise ValueError( + f"Missing exogenous columns for prediction: {missing_cols}." + ) + exog = input_df[self.exog_cols].to_numpy() + + # Use actual dates so statsmodels predicts the correct period + # (works for both in-sample and out-of-sample dates) + print( + f"[StatsmodelsSARIMAXModel] In-sample prediction: {len(dates)} points " + f"({dates.min()} to {dates.max()})" + ) + + predictions = self.model_fit.predict( + start=dates.iloc[0], end=dates.iloc[-1], exog=exog + ) + + print(f"[StatsmodelsSARIMAXModel] Generated {len(predictions)} predictions") + + return predictions.to_numpy() + + raise ValueError( + "SARIMAX predict requires either 'x_pred' data or a 'periods' value." + ) + + def get_forecast_uncertainty( + self, horizon: int, confidence_level: float = 0.80 + ) -> pd.DataFrame: + """Get forecast with parametric confidence intervals from SARIMAX. + + Uses statsmodels ``get_forecast().summary_frame()`` to compute + analytical confidence intervals derived from the model's error + distribution. These are true parametric intervals that reflect both + the non-seasonal and seasonal uncertainty of the model. + + Parameters + ---------- + horizon : int + Number of future periods to forecast. + confidence_level : float + Confidence level (e.g., 0.80 for 80% intervals). + + Returns + ------- + pd.DataFrame + Columns: ``ds``, ``yhat``, ``yhat_lower``, ``yhat_upper``. + + Raises + ------ + ValueError + If the model was trained with exogenous variables. + """ + if self.model_fit is None: + raise ValueError("Model must be fitted before getting uncertainty.") + + if self.exog_cols: + raise ValueError( + f"Cannot generate forecast uncertainty: model was trained with " + f"exogenous variables {self.exog_cols}. Future exogenous values " + f"are required but not available. " + f"Use ForecastFeatureImportance instead." + ) + + alpha = 1.0 - confidence_level + forecast_obj = self.model_fit.get_forecast(steps=horizon) + summary = forecast_obj.summary_frame(alpha=alpha) + # summary columns: mean, mean_se, mean_ci_lower, mean_ci_upper + + freq = self.frequency or "D" + future_dates = pd.date_range( + start=self.last_ds, periods=horizon + 1, freq=freq + )[1:] + + return pd.DataFrame( + { + "ds": future_dates, + "yhat": summary["mean"].to_numpy(), + "yhat_lower": summary["mean_ci_lower"].to_numpy(), + "yhat_upper": summary["mean_ci_upper"].to_numpy(), + } + ) + + def get_forecast_components(self, horizon: int) -> pd.DataFrame: + """Decompose forecast into trend, seasonal, and residual components. + + Applies STL (Seasonal-Trend decomposition using LOESS) to the + combination of in-sample fitted values and out-of-sample forecast. + + When a seasonal order was configured (``s > 1``), the explicit + seasonal period ``s`` is used for STL so the decomposition reflects + the model's actual seasonal structure. Otherwise the period is + inferred from the stored frequency. + + Parameters + ---------- + horizon : int + Number of future periods to forecast and decompose. + + Returns + ------- + pd.DataFrame + Columns: ``ds``, ``trend``, ````, ``residual``. + The seasonality column name depends on the period (e.g. ``weekly`` + for s=7, ``yearly`` for s=12). + + Raises + ------ + ValueError + If the model was trained with exogenous variables. + """ + if self.model_fit is None: + raise ValueError("Model must be fitted before getting components.") + + if self.exog_cols: + raise ValueError( + f"Cannot generate forecast components: model was trained with " + f"exogenous variables {self.exog_cols}. Future exogenous values " + f"are required but not available for decomposition. " + f"Use ForecastFeatureImportance instead." + ) + + try: + from statsmodels.tsa.seasonal import STL + except ImportError as exc: + raise ImportError("statsmodels is required for STL decomposition.") from exc + + # In-sample fitted values (DatetimeIndex from training) + fitted = self.model_fit.fittedvalues.dropna() + + # Out-of-sample forecast + freq = self.frequency or "D" + forecast_result = self.model_fit.forecast(steps=horizon) + future_dates = pd.date_range( + start=self.last_ds, periods=horizon + 1, freq=freq + )[1:] + future_series = pd.Series(forecast_result.to_numpy(), index=future_dates) + + # Combine history + forecast + combined = pd.concat([fitted, future_series]) + + # Use explicit seasonal period s when seasonality is active; + # otherwise infer from frequency + explicit_s = ( + self.seasonal_order[3] + if hasattr(self, "seasonal_order") and self.seasonal_order[3] > 1 + else None + ) + period = explicit_s if explicit_s else self._get_seasonal_period() + component_name = self._period_to_seasonality_name(period) + + n = len(combined) + if period >= 2 and n >= 2 * period: + try: + stl = STL(combined, period=period, robust=True) + result = stl.fit() + trend_vals = result.trend + seasonal_vals = result.seasonal + residual_vals = result.resid + except Exception: + window = min(period, max(2, n // 2)) + trend_vals = combined.rolling( + window=window, center=True, min_periods=1 + ).mean() + seasonal_vals = pd.Series(np.zeros(n), index=combined.index) + residual_vals = combined - trend_vals + else: + window = max(2, min(period, n // 2)) + trend_vals = combined.rolling( + window=window, center=True, min_periods=1 + ).mean() + seasonal_vals = pd.Series(np.zeros(n), index=combined.index) + residual_vals = combined - trend_vals + + return pd.DataFrame( + { + "ds": combined.index[-horizon:], + "trend": trend_vals.to_numpy()[-horizon:], + component_name: seasonal_vals.to_numpy()[-horizon:], + "residual": residual_vals.to_numpy()[-horizon:], + } + ) + + def save(self, filename: str) -> None: + """Save SARIMAX model to file. + + Parameters + ---------- + filename : str + Path to save the model + """ + if self.model_fit is None: + raise ValueError("Cannot save model before fitting.") + + model_state = { + "model_fit": self.model_fit, + "exog_cols": self.exog_cols, + "timestamp_col": self.timestamp_col, + "target_col": self.target_col, + "frequency": self.frequency, + "last_ds": getattr(self, "last_ds", None), + "config": { + "p": self.p, + "d": self.d, + "q": self.q, + "P": self.P, + "D": self.D, + "Q": self.Q, + "s": self.s, + "trend": self.trend, + "order": self.order, + "seasonal_order": self.seasonal_order, + "enforce_stationarity": self.enforce_stationarity, + "enforce_invertibility": self.enforce_invertibility, + }, + } + + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(model_state, f) + + print(f"✅ SARIMAX model saved to {filename}") + + def load(self, filename: str) -> "StatsmodelsSARIMAXModel": + """Load SARIMAX model from file. + + Parameters + ---------- + filename : str + Path to load the model from + + Returns + ------- + StatsmodelsSARIMAXModel + Loaded model instance + """ + with open(filename, "rb") as f: + model_state = pickle.load(f) + + self.model_fit = model_state["model_fit"] + self.exog_cols = model_state["exog_cols"] + self.timestamp_col = model_state.get("timestamp_col") + self.target_col = model_state.get("target_col") + self.frequency = model_state.get("frequency") + self.last_ds = model_state.get("last_ds") + + config = model_state["config"] + for key, value in config.items(): + setattr(self, key, value) + + print(f"✅ SARIMAX model loaded from {filename}") + return self diff --git a/DashAI/back/models/model_factory.py b/DashAI/back/models/model_factory.py index b9e77765f..c777c0ef2 100644 --- a/DashAI/back/models/model_factory.py +++ b/DashAI/back/models/model_factory.py @@ -1,9 +1,24 @@ +import math + from kink import di from DashAI.back.metrics.base_metric import BaseMetric from DashAI.back.metrics.classification_metric import ClassificationMetric +def sanitize_metric_value(value): + """Convert non-JSON-serializable float values to None. + + JSON doesn't support NaN, Infinity, or -Infinity, so we convert + these to None which becomes null in JSON. + """ + if value is None: + return None + if isinstance(value, float) and (math.isnan(value) or math.isinf(value)): + return None + return value + + class ModelFactory: """ A factory class for creating and configuring models. @@ -257,29 +272,60 @@ def evaluate(self, x, y, metrics): results = {} for split in ["train", "validation", "test"]: split_results = {} + + # Skip empty splits if x[split].shape[0] == 0: split_results = {metric.__name__: None for metric in metrics} results[split] = split_results continue - predictions = self.model.predict(x[split]) + + # Predict — wrapped in try/except for forecasting models + # where a split may be too small for the window_size + try: + predictions = self.model.predict(x[split]) + except Exception as e: + print( + f"[ModelFactory] ⚠️ Prediction failed for {split} split: {e}. " + f"Setting all metrics to null." + ) + for metric in metrics: + split_results[metric.__name__] = None + results[split] = split_results + continue + + # Transform y using model's own method (upstream convention) if hasattr(self.model, "prepare_output"): transformed_y = self.model.prepare_output(y[split]) else: transformed_y = self.model.prepare_dataset(y[split]) + for metric in metrics: - if ( - isinstance(metric, type) - and issubclass(metric, ClassificationMetric) - and "multiclass" in metric.score.__code__.co_varnames - and multiclass is not None - ): - score = metric.score( - transformed_y, predictions, multiclass=multiclass - ) - else: - score = metric.score(transformed_y, predictions) - - split_results[metric.__name__] = score + try: + if ( + isinstance(metric, type) + and issubclass(metric, ClassificationMetric) + and "multiclass" in metric.score.__code__.co_varnames + and multiclass is not None + ): + score = metric.score( + transformed_y, predictions, multiclass=multiclass + ) + else: + score = metric.score(transformed_y, predictions) + + # Sanitize score to ensure JSON compatibility + split_results[metric.__name__] = sanitize_metric_value(score) + except ValueError as e: + # Handle case where all predictions are NaN + if "All values are NaN" in str(e): + print( + f"[ModelFactory] ⚠️ {split}/{metric.__name__}: " + f"All predictions are NaN (split too small?). " + f"Setting metric to null." + ) + split_results[metric.__name__] = None + else: + raise results[split] = split_results diff --git a/DashAI/back/models/parameters/models_schemas/MultiOutputRegression.json b/DashAI/back/models/parameters/models_schemas/MultiOutputRegression.json new file mode 100644 index 000000000..731ae6f03 --- /dev/null +++ b/DashAI/back/models/parameters/models_schemas/MultiOutputRegression.json @@ -0,0 +1,49 @@ +{ + "additionalProperties": false, + "error_msg": "The parameters for MultiOutputRegression must include one or more of ['fit_intercept', 'copy_X', 'n_jobs', 'positive'].", + "description": "MultiOutputRegression trains an independent regressor for each output. By default, it uses LinearRegression for each output.", + "properties": { + "fit_intercept": { + "oneOf": [ + { + "error_msg": "The 'fit_intercept' parameter must be of type boolean.", + "description": "Determines whether to calculate the intercept for this model.", + "type": "boolean", + "default": true + } + ] + }, + "copy_X": { + "oneOf": [ + { + "error_msg": "The 'copy_X' parameter must be of type boolean.", + "description": "Determines whether to copy the input matrix X.", + "type": "boolean", + "default": true + } + ] + }, + "n_jobs": { + "oneOf": [ + { + "error_msg": "The 'n_jobs' parameter must be an integer or null.", + "description": "The number of jobs to use for computation. -1 means using all processors.", + "type": ["integer", "null"], + "default": null, + "minimum": -1 + } + ] + }, + "positive": { + "oneOf": [ + { + "error_msg": "The 'positive' parameter must be of type boolean.", + "description": "When set to True, forces the coefficients to be positive.", + "type": "boolean", + "default": false + } + ] + } + }, + "type": "object" +} diff --git a/DashAI/back/models/scikit_learn/__init__.py b/DashAI/back/models/scikit_learn/__init__.py index e69de29bb..9c358d4ab 100644 --- a/DashAI/back/models/scikit_learn/__init__.py +++ b/DashAI/back/models/scikit_learn/__init__.py @@ -0,0 +1,7 @@ +"""Scikit-learn based models.""" + +from .multi_output_regression import MultiOutputRegression + +__all__ = [ + "MultiOutputRegression", +] diff --git a/DashAI/back/models/scikit_learn/multi_output_regression.py b/DashAI/back/models/scikit_learn/multi_output_regression.py new file mode 100644 index 000000000..72ac850a1 --- /dev/null +++ b/DashAI/back/models/scikit_learn/multi_output_regression.py @@ -0,0 +1,154 @@ +""" +MultiOutput regression model for DashAI. + +This model is a wrapper around sklearn.multioutput.MultiOutputRegressor. +By default it uses LinearRegression as base estimator but you can select +other sklearn regressors by passing the `base_estimator` parameter and, +optionally, `base_params` (a dict with kwargs for the base estimator). +""" + +from typing import Any, Dict, Optional + +from sklearn.ensemble import RandomForestRegressor as _RandomForestRegressor +from sklearn.linear_model import LinearRegression as _LinearRegression +from sklearn.linear_model import Ridge as _Ridge +from sklearn.multioutput import MultiOutputRegressor + +from DashAI.back.core.schema_fields import ( + BaseSchema, + enum_field, + schema_field, +) +from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset +from DashAI.back.models.regression_model import RegressionModel +from DashAI.back.models.scikit_learn.sklearn_like_regressor import SklearnLikeRegressor + + +class MultiOutputRegressionSchema(BaseSchema): + """Multi-output regression using sklearn's MultiOutputRegressor. + + This meta-estimator fits one regressor per target variable, allowing you to + predict multiple continuous outputs simultaneously. Choose from different base + estimators depending on your needs: linear models for interpretability, + tree-based models for non-linear relationships. + """ + + base_estimator: schema_field( + enum_field(enum=["linear", "ridge", "random_forest"]), + placeholder="linear", + description="Base estimator to use for each output target. " + "'linear': Fast linear regression (no regularization). " + "'ridge': Linear regression with L2 regularization (prevents overfitting). " + "'random_forest': Tree-based ensemble (handles non-linear relationships).", + ) = "linear" # type: ignore + + +class MultiOutputRegression(RegressionModel, SklearnLikeRegressor): + """Meta-model using sklearn's MultiOutputRegressor.""" + + SCHEMA = MultiOutputRegressionSchema + + COMPATIBLE_COMPONENTS = ["RegressionTask"] + + def __init__( + self, + base_estimator: str = "linear", + base_params: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + """ + Parameters + ---------- + base_estimator : str + Identifier of the base estimator. Supported: "linear", "ridge", + "random_forest" + base_params : dict, optional + Keyword args to forward to the base estimator constructor. + kwargs : dict + Extra args (kept for compatibility with existing infrastructure). + """ + super().__init__(**kwargs) + + if base_params is None: + base_params = {} + + # Map string identifiers to sklearn estimators (you can extend with more). + estimators = { + "linear": _LinearRegression, + "ridge": _Ridge, + "random_forest": _RandomForestRegressor, + } + + if base_estimator not in estimators: + raise ValueError( + f"Unknown base_estimator '{base_estimator}'. " + f"Supported: {list(estimators.keys())}" + ) + + base_cls = estimators[base_estimator] + base_instance = base_cls(**base_params) + + # The actual sklearn model we will fit/predict with + self.sklearn_model = MultiOutputRegressor(base_instance) + + # If SklearnLikeRegressor expects certain attributes/methods, adapt accordingly. + # We implement fit/predict here to be explicit. + + def fit(self, x_train: DashAIDataset, y_train: DashAIDataset, **fit_params): + """ + Fit the multioutput regressor. + x_train: DashAIDataset with input features + y_train: DashAIDataset with output targets + """ + import numpy as np + + # CRITICAL: Convert DashAI datasets to pandas first + x_pandas = x_train.to_pandas() + y_pandas = y_train.to_pandas() + + # Convert pandas to numpy arrays + X = np.asarray(x_pandas) + y = np.asarray(y_pandas) + + # KEY FIX: Ensure y is 2D for MultiOutputRegressor + # sklearn's MultiOutputRegressor requires y to have at least 2 dimensions + if y.ndim == 1: + print( + f"[MultiOutputRegression] Converting 1D y (shape {y.shape}) " + f"to 2D for multi-output regression" + ) + y = y.reshape(-1, 1) + + print( + f"[MultiOutputRegression] Training with X shape: {X.shape}, " + f"y shape: {y.shape}" + ) + print(f"[MultiOutputRegression] X columns: {list(x_pandas.columns)}") + print(f"[MultiOutputRegression] y columns: {list(y_pandas.columns)}") + + # Now this will work with both 1D and 2D y arrays + self.sklearn_model.fit(X, y, **fit_params) + return self + + def predict(self, x_pred: DashAIDataset): + """ + Predict multi-output targets. + x_pred: DashAIDataset with input features + Returns array shape (n_samples, n_outputs) + """ + import numpy as np + + # CRITICAL: Convert DashAI dataset to pandas first (same as fit method) + x_pandas = x_pred.to_pandas() + + # Convert pandas to numpy array + X = np.asarray(x_pandas) + + print(f"[MultiOutputRegression] Predicting with X shape: {X.shape}") + + # Now this will work with clean numpy array + return self.sklearn_model.predict(X) + + # If DashAI base classes expect `save` and `load`, SklearnLikeRegressor + # if not, you should rely on the SklearnLikeRegressor implementations. If necessary, + # override save/load following the project's conventions. diff --git a/DashAI/back/optimizers/base_optimizer.py b/DashAI/back/optimizers/base_optimizer.py index 7390c791b..15582d391 100644 --- a/DashAI/back/optimizers/base_optimizer.py +++ b/DashAI/back/optimizers/base_optimizer.py @@ -357,7 +357,7 @@ def importance_plot(self, trials, goal_metric): ) sorted_items = sorted(importances.items(), key=lambda item: item[1]) - param_names, importance_values = zip(*sorted_items) + param_names, importance_values = zip(*sorted_items, strict=False) fig = go.Figure( data=[ go.Bar( diff --git a/DashAI/back/tasks/forecasting_task.py b/DashAI/back/tasks/forecasting_task.py new file mode 100644 index 000000000..3b000a819 --- /dev/null +++ b/DashAI/back/tasks/forecasting_task.py @@ -0,0 +1,597 @@ +"""Forecasting Task for time series prediction in DashAI. + +This task enables native time series forecasting with models like Prophet, +as well as tabular approaches using TimeSeriesWindowConverter. +""" + +from typing import Any, Dict, List, Optional, Union + +import pandas as pd +from datasets import DatasetDict, Value + +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + to_dashai_dataset, +) +from DashAI.back.tasks.base_task import BaseTask + + +class ForecastingTask(BaseTask): + """Task for time series forecasting. + + This task handles two main forecasting approaches: + 1. Native forecasting (Prophet, ARIMA): Uses ds (datetime) + y + optional exogenous + 2. Tabular forecasting: Uses TimeSeriesWindowConverter + regression models + + Key differences from RegressionTask: + - Requires temporal column (ds) and proper time ordering + - Uses causal splits (no shuffle) to respect temporal causality + - Supports forecasting-specific metrics (MAPE, sMAPE, MASE) + - Native models can predict variable horizons + """ + + DESCRIPTION: str = """ + Time series forecasting predicts future values based on historical patterns. + Supports both native forecasting models (Prophet) that work directly with + timestamps and target values, and tabular approaches that convert time series + into supervised learning problems using lag features and future windows. + """ + + metadata = { + "inputs_types": [Value], # ds (datetime) + optional exogenous variables + "outputs_types": [Value], # y (target time series) + "inputs_cardinality": "n", # ds + optional exogenous features + "outputs_cardinality": 1, # Single target series + } + + def __init__(self): + """Initialize ForecastingTask.""" + super().__init__() + self._temporal_metadata: Optional[Dict[str, Any]] = None + + def validate_dataset_for_task( + self, + dataset: DashAIDataset, + dataset_name: str, + input_columns: List[str], + output_columns: List[str], + ) -> None: + """Validate a dataset for forecasting task.""" + + print("\n🔍 VALIDATE_DATASET_FOR_TASK INICIO") + print(f"📄 Dataset: {dataset_name}") + print(f"📥 Input columns: {input_columns}") + print(f"📤 Output columns: {output_columns}") + + metadata = self.metadata + allowed_input_types = tuple(metadata["inputs_types"]) + allowed_output_types = tuple(metadata["outputs_types"]) + + # 🔍 DEBUG: Print full metadata + print("\n📐 Metadata:") + print(f" - allowed_input_types: {allowed_input_types}") + print(f" - allowed_output_types: {allowed_output_types}") + print(f" - input_cardinality: {metadata.get('inputs_cardinality')}") + print(f" - output_cardinality: {metadata.get('outputs_cardinality')}") + + # Validate cardinality + if len(input_columns) < 1: + raise ValueError( + "ForecastingTask requires at least 1 input column.\n" + "Include a timestamp and optional exogenous variables." + ) + + if len(output_columns) != 1: + raise ValueError( + "ForecastingTask requires exactly 1 output column " + f"(target to forecast). Got: {len(output_columns)} outputs." + ) + + dataset_df = dataset.to_pandas() + if not isinstance(dataset_df, pd.DataFrame): + dataset_df = pd.concat(dataset_df, ignore_index=True) + + # 🔬 Revisar tipos de columnas en dataset.features + print("\n🧪 DEBUG: Column types from dataset.features") + for col_name, col_type in dataset.features.items(): + print(f" - {col_name}: {col_type} ({type(col_type)})") + + # Validate all input columns exist and have correct types + timestamp_found = False + detected_timestamp = None + + for input_col in input_columns: + if input_col not in dataset.features: + raise ValueError( + f"Input column '{input_col}' not found in dataset. " + f"Available columns: {list(dataset.features.keys())}" + ) + + input_col_type = dataset.features[input_col] + + # Print individual type check + print( + f"🔍 Checking input '{input_col}' type: {input_col_type} " + f"({type(input_col_type)})" + ) + + if not isinstance(input_col_type, allowed_input_types): + print("❌ Input column type mismatch") + raise TypeError( + f"Input column '{input_col}' has type " + f"{type(input_col_type).__name__}, but expected one of: " + f"{allowed_input_types}." + ) + + # Try to detect if it's the timestamp + if not timestamp_found: + try: + pd.to_datetime(dataset_df[input_col]) + timestamp_found = True + detected_timestamp = input_col + print(f"✅ Detected timestamp column: '{input_col}'") + except Exception: + pass + + if not timestamp_found: + raise ValueError( + "No timestamp column detected in input columns. " + "ForecastingTask requires a datetime column for temporal ordering. " + f"Checked columns: {input_columns}" + ) + + # OUTPUT VALIDATION + output_col = output_columns[0] + if output_col not in dataset.features: + raise ValueError( + f"Output column '{output_col}' not found in dataset. " + f"Available: {list(dataset.features.keys())}" + ) + + output_col_type = dataset.features[output_col] + print( + f"\n🔍 Checking output '{output_col}' type: {output_col_type} " + f"({type(output_col_type)})" + ) + + if not isinstance(output_col_type, allowed_output_types): + print("❌ Output column type mismatch") + raise TypeError( + f"Output column '{output_col}' has type " + f"{type(output_col_type).__name__}, but expected one of: " + f"{allowed_output_types}." + ) + + # Validate target column + try: + pd.to_numeric(dataset_df[output_col]) + print(f"✅ Target column '{output_col}' is numeric") + except Exception as e: + raise TypeError( + f"Output column '{output_col}' cannot be converted to numeric: {e}" + ) from e + + # Check minimum data points + if len(dataset) < 5: + raise ValueError( + f"Dataset '{dataset_name}' has only {len(dataset)} rows. " + "Minimum 5 rows required for forecasting." + ) + + # ✅ VALIDATION PASSED + print("✅ ForecastingTask validation PASSED:") + print(f" - Inputs: {input_columns} (timestamp: {detected_timestamp})") + print(f" - Output: {output_col}") + print(f" - Total rows: {len(dataset)}\n") + + @property + def schema(self) -> Dict[str, Any]: + """Get the schema for ForecastingTask.""" + return { + "type": "object", + "properties": { + "timestamp_column": { + "type": "string", + "description": ( + "Name of the datetime column (will be renamed to 'ds')" + ), + }, + "target_column": { + "type": "string", + "description": ( + "Name of the target time series column (will be renamed to 'y')" + ), + }, + "exogenous_columns": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "Optional exogenous variables (holidays, weather, etc.)" + ), + "default": [], + }, + "frequency": { + "type": "string", + "description": ( + "Time series frequency (D, H, M, etc.). " + "Auto-detected if not specified" + ), + "default": "auto", + }, + }, + "required": ["timestamp_column", "target_column"], + } + + def validate_temporal_data( + self, + dataset: DashAIDataset, + timestamp_col: str, + target_col: str, + exog_cols: Optional[List[str]] = None, + ) -> None: + """Validate that the dataset is suitable for forecasting. + + Parameters + ---------- + dataset : DashAIDataset + Dataset to validate + timestamp_col : str + Name of timestamp column + target_col : str + Name of target column + exog_cols : Optional[List[str]] + Names of exogenous columns + + Raises + ------ + ValueError + If dataset is not suitable for forecasting + """ + if exog_cols is None: + exog_cols = [] + + # Check required columns exist + available_cols = set(dataset.column_names) + + if timestamp_col not in available_cols: + raise ValueError( + f"Timestamp column '{timestamp_col}' not found in dataset. " + f"Available columns: {list(available_cols)}" + ) + + if target_col not in available_cols: + raise ValueError( + f"Target column '{target_col}' not found in dataset. " + f"Available columns: {list(available_cols)}" + ) + + missing_exog = set(exog_cols) - available_cols + if missing_exog: + raise ValueError( + f"Exogenous columns not found: {list(missing_exog)}. " + f"Available columns: {list(available_cols)}" + ) + + # Convert to pandas for validation + dataset_df = dataset.to_pandas() # type: ignore + if not isinstance(dataset_df, pd.DataFrame): + dataset_df = pd.concat(dataset_df, ignore_index=True) + + # Validate timestamp column can be converted to datetime + try: + timestamp_series = pd.to_datetime(dataset_df[timestamp_col]) + except Exception as e: + raise ValueError( + f"Cannot convert timestamp column '{timestamp_col}' to datetime: {e}" + ) from e + + # Check for duplicate timestamps + if timestamp_series.duplicated().any(): + duplicates = timestamp_series[timestamp_series.duplicated()].unique() + raise ValueError( + f"Found duplicate timestamps in '{timestamp_col}': " + f"{duplicates[:5].tolist()}{'...' if len(duplicates) > 5 else ''}" + ) + + # Validate target is numeric + try: + target_series = pd.to_numeric(dataset_df[target_col]) + except Exception as e: + raise ValueError( + f"Target column '{target_col}' must be numeric: {e}" + ) from e + + # Check for too many missing values in target + missing_pct = target_series.isna().mean() + if missing_pct > 0.5: + raise ValueError( + f"Target column '{target_col}' has {missing_pct:.1%} missing values. " + "Maximum allowed is 50%." + ) + + # Minimum data points check + if len(dataset_df) < 5: + raise ValueError( + f"Dataset has only {len(dataset_df)} rows. " + "Minimum 5 data points required for forecasting." + ) + + print( + f"✅ Validation passed: {len(dataset_df)} data points, " + f"timestamp range: {timestamp_series.min()} to {timestamp_series.max()}" + ) + + def detect_frequency(self, timestamp_series: pd.Series) -> str: + """Auto-detect time series frequency. + + Parameters + ---------- + timestamp_series : pd.Series + Datetime series + + Returns + ------- + str + Detected frequency code (D, H, M, etc.) + """ + try: + # Sort timestamps and calculate differences + sorted_ts = timestamp_series.sort_values() + diffs = sorted_ts.diff().dropna() + + # Get most common difference + mode_diff = ( + diffs.mode().iloc[0] if len(diffs.mode()) > 0 else diffs.median() + ) + + # Map to pandas frequency codes + if mode_diff >= pd.Timedelta(days=365): # type: ignore + return "A" # Annual + elif mode_diff >= pd.Timedelta(days=30): # type: ignore + return "M" # Monthly + elif mode_diff >= pd.Timedelta(days=7): # type: ignore + return "W" # Weekly + elif mode_diff >= pd.Timedelta(days=1): # type: ignore + return "D" # Daily + elif mode_diff >= pd.Timedelta(hours=1): # type: ignore + return "H" # Hourly + else: + return "T" # Minute + + except Exception: + # Fallback to daily + return "D" + + def detect_timestamp_column( + self, dataset: DashAIDataset, candidate_columns: List[str] + ) -> Optional[str]: + """Auto-detect which column is the timestamp from a list of candidates. + + Parameters + ---------- + dataset : DashAIDataset + Dataset to analyze + candidate_columns : List[str] + List of column names to check + + Returns + ------- + Optional[str] + Name of detected timestamp column, or None if not found + """ + # Convert to pandas for analysis + dataset_df = dataset.to_pandas() # type: ignore + if not isinstance(dataset_df, pd.DataFrame): + dataset_df = pd.concat(dataset_df, ignore_index=True) + + # Strategy 1: Check by column name + for col in candidate_columns: + col_lower = col.lower() + if any( + keyword in col_lower + for keyword in [ + "date", + "time", + "timestamp", + "ds", + "datetime", + "fecha", + ] + ): + # Verify it can be converted to datetime + try: + pd.to_datetime(dataset_df[col]) + return col + except Exception: + continue + + # Strategy 2: Try to convert each column to datetime + for col in candidate_columns: + try: + pd.to_datetime(dataset_df[col]) + return col + except Exception: + continue + + return None + + def prepare_for_task( + self, + dataset: Optional[Union[DatasetDict, DashAIDataset]] = None, + outputs_columns: Optional[List[str]] = None, + inputs_columns: Optional[List[str]] = None, + **kwargs, + ) -> DashAIDataset: + """Prepare dataset for forecasting task. + + Cambios mínimos: + - Acepta `datasetdict` (alias que usa experiments.py). + - Si no vienen `inputs_columns` ni `timestamp_column`, intenta + detectar el timestamp usando todos los nombres de columnas. + """ + # --- Soporte para alias `datasetdict` usado por experiments.py --- + if dataset is None and "datasetdict" in kwargs: + dataset = kwargs.pop("datasetdict") + if inputs_columns is None and "input_columns" in kwargs: + inputs_columns = kwargs.pop("input_columns") + if outputs_columns is None and "output_columns" in kwargs: + outputs_columns = kwargs.pop("output_columns") + + # Convertir a DashAIDataset si viene como DatasetDict + if isinstance(dataset, DatasetDict): + split_name = "train" if "train" in dataset else list(dataset.keys())[0] + dashai_dataset = to_dashai_dataset(dataset[split_name]) + elif dataset is not None: + dashai_dataset = dataset + else: + raise ValueError("dataset parameter is required for prepare_for_task") + + # Validaciones básicas de parámetros + if not outputs_columns or len(outputs_columns) != 1: + raise ValueError( + "ForecastingTask requires exactly 1 output column (target variable). " + f"Got {len(outputs_columns) if outputs_columns else 0} columns." + ) + target_col = outputs_columns[0] + + # Obtener o detectar columna timestamp + timestamp_col = kwargs.get("timestamp_column") + if not timestamp_col: + # Si no nos dan inputs_columns, intentamos con TODAS las columnas + candidate_inputs = ( + inputs_columns if inputs_columns else list(dashai_dataset.column_names) + ) + timestamp_col = self.detect_timestamp_column( + dashai_dataset, candidate_inputs + ) + if not timestamp_col: + raise ValueError( + "Could not auto-detect timestamp column. " + "Provide `timestamp_column` o incluya una columna con fecha/tiempo " + "('date', 'timestamp', 'ds', 'datetime', etc.)." + ) + print(f"🔍 Auto-detected timestamp column: '{timestamp_col}'") + + # Exógenas: si no vienen inputs, por defecto ninguna + if inputs_columns: + exog_cols = [c for c in inputs_columns if c != timestamp_col] + else: + exog_cols = kwargs.get("exogenous_columns", []) + + frequency = kwargs.get("frequency", "auto") + + # Validar datos + self.validate_temporal_data( + dashai_dataset, timestamp_col, target_col, exog_cols + ) + + # Procesamiento pandas + dataset_df = dashai_dataset.to_pandas() # type: ignore + if not isinstance(dataset_df, pd.DataFrame): + dataset_df = pd.concat(dataset_df, ignore_index=True) + + # NO renombrar columnas - mantener nombres originales + # El modelo (ej: Prophet) hará el renombramiento si lo necesita + + # Orden temporal + # If the timestamp column is numeric (int/float), treat values as + # sequential time-step indices rather than nanosecond epoch offsets. + if pd.api.types.is_integer_dtype( + dataset_df[timestamp_col] + ) or pd.api.types.is_float_dtype(dataset_df[timestamp_col]): + base_date = pd.Timestamp("2000-01-01") + step_vals = dataset_df[timestamp_col] + min_val = step_vals.min() + dataset_df[timestamp_col] = base_date + pd.to_timedelta( + (step_vals - min_val).astype(int), unit="D" + ) + print( + f"ℹ️ Column '{timestamp_col}' contains numeric values — " + f"converted to day offsets starting from {base_date.date()}" + ) + else: + dataset_df[timestamp_col] = pd.to_datetime(dataset_df[timestamp_col]) + dataset_df = dataset_df.sort_values(timestamp_col).reset_index(drop=True) + + # Frecuencia + if frequency == "auto": + frequency = self.detect_frequency(dataset_df[timestamp_col]) + + # Guardar metadatos con nombres ORIGINALES + self._temporal_metadata = { + "timestamp_col": timestamp_col, + "target_col": target_col, + "exog_cols": exog_cols, + "frequency": frequency, + "start_date": dataset_df[timestamp_col].min(), + "end_date": dataset_df[timestamp_col].max(), + "n_periods": len(dataset_df), + } + + print("✅ Prepared forecasting dataset:") + print(f" - Timestamp: {timestamp_col}") + print(f" - Target: {target_col}") + print(f" - Frequency: {frequency}") + print(f" - Periods: {len(dataset_df)}") + if exog_cols: + print(f" - Exogenous vars: {', '.join(exog_cols)}") + + # Volver a DashAIDataset + from datasets import Dataset + + hf_dataset = Dataset.from_pandas(dataset_df) + return to_dashai_dataset(hf_dataset) + + def process_predictions( + self, dataset: DashAIDataset, predictions: Any, target_column: str + ) -> Any: + """Process forecasting predictions. + + For forecasting, predictions can be: + - Simple array of values (point forecasts) + - DataFrame with ds, yhat, yhat_lower, yhat_upper (Prophet style) + - Dictionary with forecasts and confidence intervals + + Parameters + ---------- + dataset : DashAIDataset + Original dataset + predictions : Any + Model predictions + target_column : str + Target column name + + Returns + ------- + Any + Processed predictions + """ + # If predictions is a DataFrame (Prophet style), extract yhat + if hasattr(predictions, "yhat"): + return predictions["yhat"].to_numpy() + + # If it's already an array, return as-is + if hasattr(predictions, "shape"): + return predictions + + # Handle list/tuple + if isinstance(predictions, (list, tuple)): + import numpy as np + + return np.array(predictions) + + return predictions + + def num_labels(self, dataset: DashAIDataset, output_column: str) -> None: + """Return None — forecasting predicts continuous values, not discrete labels.""" + return None + + def get_temporal_metadata(self) -> Optional[Dict[str, Any]]: + """Get temporal metadata from the last prepare_for_task call. + + Returns + ------- + Optional[Dict[str, Any]] + Temporal metadata including frequency, date range, etc. + """ + return self._temporal_metadata diff --git a/DashAI/back/tasks/multi_output_regression_task.py b/DashAI/back/tasks/multi_output_regression_task.py new file mode 100644 index 000000000..005cffda3 --- /dev/null +++ b/DashAI/back/tasks/multi_output_regression_task.py @@ -0,0 +1,87 @@ +from typing import List + +from datasets import DatasetDict, Value + +from DashAI.back.dataloaders.classes.dashai_dataset import ( + DashAIDataset, + to_dashai_dataset, +) +from DashAI.back.tasks.base_task import BaseTask + + +class MultiOutputRegressionTask(BaseTask): + """Task for handling multi-output regression problems. + + Multi-output regression involves predicting multiple continuous outputs + for each input sample. This task sets up the necessary metadata and + processing functions to support training models that generate multiple + outputs per sample. + """ + + DESCRIPTION: str = """ + Multi-output regression extends standard regression by predicting more + than one continuous value per input instance. Each output dimension is + treated as a separate regression target, and models can be trained to + jointly predict all outputs, capturing correlations between them. + """ + + metadata = { + "inputs_types": [Value], + "outputs_types": [Value], + "inputs_cardinality": "n", + "outputs_cardinality": "n", + } + + def prepare_for_task( + self, datasetdict: DatasetDict, outputs_columns: List[str] + ) -> DashAIDataset: + """Change the column types to suit the multi-output regression task. + + Parameters + ---------- + datasetdict : DatasetDict + Dataset to be changed + outputs_columns : List[str] + Output columns for the task + + Returns + ------- + DashAIDataset + Dataset with the new types + """ + return to_dashai_dataset(datasetdict) + + def process_predictions(self, dataset, predictions, output_column): + """ + Process predictions for multi-output regression. + + For multi-output regression, we return the predictions as-is since they + are already in the correct format (n_samples, n_outputs) from sklearn. + + Parameters + ---------- + dataset : DashAIDataset + The original dataset. + predictions : np.ndarray + Array 2D with predictions. Shape: (n_samples, n_outputs) + output_column : str + Not used directly for multi-output regression. + + Returns + ------- + np.ndarray + The predictions array as-is for compatibility with DashAI + prediction pipeline. + """ + # For multi-output, predictions are already in correct format + # Shape should be (n_samples, n_outputs) + print( + f"[MultiOutputRegressionTask] Processing predictions with " + f"shape: {predictions.shape}" + ) + + # Ensure predictions are 2D (which they should be from MultiOutputRegressor) + if predictions.ndim == 1: + predictions = predictions.reshape(-1, 1) + + return predictions diff --git a/DashAI/front/src/api/datasets.ts b/DashAI/front/src/api/datasets.ts index 281ac1627..a0b78c601 100644 --- a/DashAI/front/src/api/datasets.ts +++ b/DashAI/front/src/api/datasets.ts @@ -101,6 +101,27 @@ export const deleteDataset = async (id: string): Promise => { return response; }; +export const getDatasetTemporalInfo = async ( + id: number, + timestampColumn: string, +): Promise<{ + frequency_code: string; + frequency_label: string; + frequency_description: string; + frequency_example: string; + average_interval: string; + start_date: string; + end_date: string; + total_periods: number; + detected_gaps: number; + timestamp_column: string; +}> => { + const response = await api.get(`${datasetEndpoint}/${id}/temporal-info`, { + params: { timestamp_column: timestampColumn }, + }); + return response.data; +}; + export const getDatasetFile = async (path: string, page = 0, pageSize = 5) => { const response = await api.get(`${datasetEndpoint}/file/`, { params: { path, page, page_size: pageSize }, diff --git a/DashAI/front/src/api/job.ts b/DashAI/front/src/api/job.ts index 570a73adc..be60b1c47 100644 --- a/DashAI/front/src/api/job.ts +++ b/DashAI/front/src/api/job.ts @@ -35,9 +35,14 @@ export const getJobStatus = async (jobId: string): Promise => { return response.data; }; -export const enqueueRunnerJob = async (runId: number): Promise => { +export const enqueueRunnerJob = async ( + runId: number, + taskName?: string, +): Promise => { + const jobType = + taskName === "ForecastingTask" ? "ForecastingJob" : "ModelJob"; const data = { - job_type: "ModelJob", + job_type: jobType, kwargs: { run_id: runId }, }; const formData = new FormData(); @@ -123,6 +128,7 @@ export const enqueueExplorerJob = async ( export const enqueuePredictionJob = async ( prediction_id: number, manual_input_data?: object[], + forecast_periods?: number, ): Promise => { const formData = new FormData(); @@ -138,9 +144,16 @@ export const enqueuePredictionJob = async ( return cleanObj; }); + const kwargs: any = { prediction_id, manual_input_data: simpleManualData }; + + // Add forecast_periods only if provided (for ForecastingTask) + if (forecast_periods !== undefined && forecast_periods > 0) { + kwargs.forecast_periods = forecast_periods; + } + const data = { job_type: "PredictJob", - kwargs: { prediction_id, manual_input_data: simpleManualData }, + kwargs, }; formData.append("job_type", data.job_type); diff --git a/DashAI/front/src/components/explainers/ConfigureExplainerStep.jsx b/DashAI/front/src/components/explainers/ConfigureExplainerStep.jsx index f3eca46b2..854bb4271 100644 --- a/DashAI/front/src/components/explainers/ConfigureExplainerStep.jsx +++ b/DashAI/front/src/components/explainers/ConfigureExplainerStep.jsx @@ -11,6 +11,7 @@ import PropTypes from "prop-types"; import FormSchema from "../shared/FormSchema"; import FormSchemaLayout from "../shared/FormSchemaLayout"; import useSchema from "../../hooks/useSchema"; +import ForecastingExplainerInfo from "./ForecastingExplainerInfo"; import { useTranslation } from "react-i18next"; function ConfigureExplainerStep({ @@ -19,6 +20,10 @@ function ConfigureExplainerStep({ setNextEnabled, formSubmitRef, scope, + temporalInfo, + temporalInfoLoading, + modelName, + isForecastingTask, }) { const { defaultValues } = useSchema({ modelName: newExpl.explainer_name }); const [error, setError] = useState(false); @@ -88,6 +93,20 @@ function ConfigureExplainerStep({ {t("explainers:label.configureExplainer")} + + {/* Forecasting explainer temporal info */} + {isForecastingTask && ( + + + + )} + {/* Configure dataloader parameters */} + + + + ); + } + + if (!temporalInfo) { + return null; + } + + // Get explainer-specific description + const getExplainerDescription = () => { + switch (explainerName) { + case "ForecastDecomposition": + return { + title: "Forecast Decomposition", + icon: , + color: "primary", + description: `This explainer will decompose ${horizon || 30} future ${temporalInfo.frequency_label?.toLowerCase() || "periods"} into interpretable components: trend, seasonality, and external factors.`, + details: [ + "Trend: Shows the long-term direction of your forecast", + "Seasonality: Reveals repeating patterns (daily, weekly, yearly)", + "Residuals: Captures unexplained variations", + ], + }; + case "ForecastFeatureImportance": + return { + title: "Feature Importance", + icon: , + color: "secondary", + description: + "This explainer measures how each external variable (exogenous feature) contributes to forecast accuracy.", + details: [ + "Permutation-based importance scoring", + "Shows which features have the most impact", + "Helps identify which external data to prioritize", + ], + }; + case "ForecastUncertainty": + return { + title: "Uncertainty Analysis", + icon: , + color: "warning", + description: `This explainer will analyze prediction confidence for ${horizon || 30} future ${temporalInfo.frequency_label?.toLowerCase() || "periods"}, showing how uncertainty grows over time.`, + details: [ + "Confidence intervals for each forecast step", + "Best/worst case scenario bounds", + "Critical for risk management and planning", + ], + }; + default: + return { + title: "Forecasting Explainer", + icon: , + color: "info", + description: "This explainer will analyze your forecasting model.", + details: [], + }; + } + }; + + const explainerInfo = getExplainerDescription(); + + return ( + + {/* Model & Training Data Info */} + + + + Model Training Data Properties + {modelName && ( + + )} + + + + + + + + + Detected Frequency + + + + + + + + + + + + + + Training Period + + + {new Date(temporalInfo.start_date).toLocaleDateString()} →{" "} + {new Date(temporalInfo.end_date).toLocaleDateString()} + + + + + + + + + Training Periods + + + {temporalInfo.total_periods}{" "} + {temporalInfo.frequency_label?.toLowerCase()} + + + + + + + + Average Interval + + + {temporalInfo.average_interval} + + + + + + + {/* Explainer-specific Info */} + {explainerName && ( + + + {explainerInfo.icon} + {explainerInfo.title} Analysis + + + + {explainerInfo.description} + + + {explainerInfo.details.length > 0 && ( + <> + + + What you'll learn: + + + {explainerInfo.details.map((detail, index) => ( + + {detail} + + ))} + + + )} + + {horizon && ( + }> + + Forecast Window: The explainer will analyze{" "} + {horizon}{" "} + {temporalInfo.frequency_label?.toLowerCase() || "periods"} into + the future, from{" "} + + {new Date(temporalInfo.end_date).toLocaleDateString()} + {" "} + onwards. + + + )} + + )} + + ); +} + +ForecastingExplainerInfo.propTypes = { + temporalInfo: PropTypes.shape({ + frequency_label: PropTypes.string, + frequency_code: PropTypes.string, + start_date: PropTypes.string, + end_date: PropTypes.string, + total_periods: PropTypes.number, + average_interval: PropTypes.string, + frequency_example: PropTypes.string, + timestamp_column: PropTypes.string, + }), + loading: PropTypes.bool, + explainerName: PropTypes.string, + modelName: PropTypes.string, + horizon: PropTypes.number, +}; + +ForecastingExplainerInfo.defaultProps = { + temporalInfo: null, + loading: false, + explainerName: null, + modelName: null, + horizon: null, +}; + +export default ForecastingExplainerInfo; diff --git a/DashAI/front/src/components/explainers/NewGlobalExplainerModal.jsx b/DashAI/front/src/components/explainers/NewGlobalExplainerModal.jsx index 48dbd0211..fdc9946f2 100644 --- a/DashAI/front/src/components/explainers/NewGlobalExplainerModal.jsx +++ b/DashAI/front/src/components/explainers/NewGlobalExplainerModal.jsx @@ -27,6 +27,9 @@ import { getExplainers, } from "../../api/explainer"; import { enqueueExplainerJob as enqueueExplainerJobRequest } from "../../api/job"; +import { getRunById } from "../../api/run"; +import { getModelSessionById } from "../../api/modelSession"; +import { getDatasetTemporalInfo } from "../../api/datasets"; import ConfigureExplainerStep from "./ConfigureExplainerStep"; import SetNameAndExplainerStep from "./SetNameAndExplainerStep"; @@ -92,11 +95,16 @@ export default function NewGlobalExplainerModal({ const [nextEnabled, setNextEnabled] = useState(false); const [newGlobalExpl, setNewGlobalExpl] = useState(defaultNewGlobalExpl); const [existingGlobalExplainers, setExistingGlobalExplainers] = useState([]); + const [temporalInfo, setTemporalInfo] = useState(null); + const [temporalInfoLoading, setTemporalInfoLoading] = useState(false); + const [modelName, setModelName] = useState(null); const [existingGlobalExplainersLoaded, setExistingGlobalExplainersLoaded] = useState(false); const [isLoading, setIsLoading] = useState(false); + const isForecastingTask = taskName === "ForecastingTask"; + const { updateFlag: updateExplainers } = useUpdateFlag({ flag: flags.EXPLAINERS, }); @@ -113,12 +121,50 @@ export default function NewGlobalExplainerModal({ } }; + // Fetch temporal info for forecasting tasks + const fetchTemporalInfo = async () => { + if (!isForecastingTask || !runId) return; + + setTemporalInfoLoading(true); + try { + const run = await getRunById(runId.toString()); + setModelName(run.name); + + const experiment = await getModelSessionById( + run.model_session_id.toString(), + ); + + // For forecasting, the first input column is the timestamp column + const inputCols = experiment.input_columns || []; + if (inputCols.length > 0 && experiment.dataset_id) { + const timestampColumn = inputCols[0]; + console.log( + "[NewGlobalExplainerModal] Fetching temporal info with timestamp column:", + timestampColumn, + ); + + const info = await getDatasetTemporalInfo( + experiment.dataset_id, + timestampColumn, + ); + setTemporalInfo(info); + } + } catch (error) { + console.error("Error fetching temporal info for explainer:", error); + } finally { + setTemporalInfoLoading(false); + } + }; + useEffect(() => { if (open) { setExistingGlobalExplainersLoaded(false); loadExistingExplainers(); + if (isForecastingTask) { + fetchTemporalInfo(); + } } - }, [open]); + }, [open, isForecastingTask]); useEffect(() => { if (!open || !existingGlobalExplainersLoaded || newGlobalExpl.name.trim()) { @@ -213,6 +259,8 @@ export default function NewGlobalExplainerModal({ setOpen(false); setNewGlobalExpl(defaultNewGlobalExpl); setNextEnabled(false); + setTemporalInfo(null); + setModelName(null); }; const handleStepButton = (stepIndex) => () => { @@ -342,6 +390,10 @@ export default function NewGlobalExplainerModal({ setNextEnabled={setNextEnabled} scope={"global"} formSubmitRef={formSubmitRef} + temporalInfo={temporalInfo} + temporalInfoLoading={temporalInfoLoading} + modelName={modelName} + isForecastingTask={isForecastingTask} /> )} diff --git a/DashAI/front/src/components/explainers/NewLocalExplainerModal.jsx b/DashAI/front/src/components/explainers/NewLocalExplainerModal.jsx index a85e559fe..ee1aecb50 100644 --- a/DashAI/front/src/components/explainers/NewLocalExplainerModal.jsx +++ b/DashAI/front/src/components/explainers/NewLocalExplainerModal.jsx @@ -24,6 +24,9 @@ import useMediaQuery from "@mui/material/useMediaQuery"; import { createLocalExplainer as createLocalExplainerRequest } from "../../api/explainer"; import { enqueueExplainerJob as enqueueExplainerJobRequest } from "../../api/job"; import { getExplainers } from "../../api/explainer"; +import { getRunById } from "../../api/run"; +import { getModelSessionById } from "../../api/modelSession"; +import { getDatasetTemporalInfo } from "../../api/datasets"; import { startJobPolling } from "../../utils/jobPoller"; import ConfigureExplainerStep from "./ConfigureExplainerStep"; @@ -100,6 +103,11 @@ export default function NewLocalExplainerModal({ const [existingLocalExplainersLoaded, setExistingLocalExplainersLoaded] = useState(false); const [isLoading, setIsLoading] = useState(false); + const [temporalInfo, setTemporalInfo] = useState(null); + const [temporalInfoLoading, setTemporalInfoLoading] = useState(false); + const [modelName, setModelName] = useState(null); + + const isForecastingTask = taskName === "ForecastingTask"; const { updateFlag: updateExplainers } = useUpdateFlag({ flag: flags.EXPLAINERS, @@ -117,12 +125,50 @@ export default function NewLocalExplainerModal({ } }; + // Fetch temporal info for forecasting tasks + const fetchTemporalInfo = async () => { + if (!isForecastingTask || !runId) return; + + setTemporalInfoLoading(true); + try { + const run = await getRunById(runId.toString()); + setModelName(run.name); + + const experiment = await getModelSessionById( + run.model_session_id.toString(), + ); + + // For forecasting, the first input column is the timestamp column + const inputCols = experiment.input_columns || []; + if (inputCols.length > 0 && experiment.dataset_id) { + const timestampColumn = inputCols[0]; + console.log( + "[NewLocalExplainerModal] Fetching temporal info with timestamp column:", + timestampColumn, + ); + + const info = await getDatasetTemporalInfo( + experiment.dataset_id, + timestampColumn, + ); + setTemporalInfo(info); + } + } catch (error) { + console.error("Error fetching temporal info for explainer:", error); + } finally { + setTemporalInfoLoading(false); + } + }; + useEffect(() => { if (open) { setExistingLocalExplainersLoaded(false); loadExistingExplainers(); + if (isForecastingTask) { + fetchTemporalInfo(); + } } - }, [open]); + }, [open, isForecastingTask]); useEffect(() => { if (!open || !existingLocalExplainersLoaded || newLocalExpl.name.trim()) { @@ -219,6 +265,8 @@ export default function NewLocalExplainerModal({ setOpen(false); setNewLocalExpl(defaultNewLocalExpl); setNextEnabled(false); + setTemporalInfo(null); + setModelName(null); }; const handleStepButton = (stepIndex) => () => { @@ -355,6 +403,10 @@ export default function NewLocalExplainerModal({ setNextEnabled={setNextEnabled} formSubmitRef={formSubmitRef} scope={"Local"} + temporalInfo={temporalInfo} + temporalInfoLoading={temporalInfoLoading} + modelName={modelName} + isForecastingTask={isForecastingTask} /> )} diff --git a/DashAI/front/src/components/models/LiveMetricsChart.jsx b/DashAI/front/src/components/models/LiveMetricsChart.jsx index c779f0a60..5861aa2c6 100644 --- a/DashAI/front/src/components/models/LiveMetricsChart.jsx +++ b/DashAI/front/src/components/models/LiveMetricsChart.jsx @@ -22,6 +22,10 @@ import { import { useEffect, useMemo, useRef, useState } from "react"; import { useTranslation } from "react-i18next"; import { getModelSessionById } from "../../api/modelSession"; +import { + formatScalarMetricsForChart, + isFiniteMetricValue, +} from "../../utils/metricUtils"; export function LiveMetricsChart({ run }) { const { t } = useTranslation("models"); @@ -47,17 +51,9 @@ export function LiveMetricsChart({ run }) { setData((prev) => { const next = structuredClone(prev); - const formattedTestMetrics = {}; - for (const metricName in run.test_metrics) { - const value = run.test_metrics[metricName]; - if (Array.isArray(value)) { - formattedTestMetrics[metricName] = value; - } else { - formattedTestMetrics[metricName] = [ - { step: 1, value: value, timestamp: new Date().toISOString() }, - ]; - } - } + const formattedTestMetrics = formatScalarMetricsForChart( + run.test_metrics, + ); next.TEST = { TRIAL: formattedTestMetrics, @@ -140,17 +136,9 @@ export function LiveMetricsChart({ run }) { const next = structuredClone(prev); - const formattedTestMetrics = {}; - for (const metricName in run.test_metrics) { - const value = run.test_metrics[metricName]; - if (Array.isArray(value)) { - formattedTestMetrics[metricName] = value; - } else { - formattedTestMetrics[metricName] = [ - { step: 1, value: value, timestamp: new Date().toISOString() }, - ]; - } - } + const formattedTestMetrics = formatScalarMetricsForChart( + run.test_metrics, + ); next.TEST = { TRIAL: formattedTestMetrics, diff --git a/DashAI/front/src/components/models/modelSession/ConfigureModelsStep.jsx b/DashAI/front/src/components/models/modelSession/ConfigureModelsStep.jsx index 09c1ba4f8..34e0c49db 100644 --- a/DashAI/front/src/components/models/modelSession/ConfigureModelsStep.jsx +++ b/DashAI/front/src/components/models/modelSession/ConfigureModelsStep.jsx @@ -1,11 +1,15 @@ import { AddCircleOutline as AddIcon } from "@mui/icons-material"; +import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; import { Button, Grid, MenuItem, TextField, Typography, + Alert, Box, + Chip, + Collapse, } from "@mui/material"; import { useSnackbar } from "notistack"; import PropTypes from "prop-types"; @@ -18,6 +22,48 @@ import { generateSequentialName } from "../../utils/nameGenerator"; import { useTourContext } from "../tour/TourProvider"; import { useTranslation } from "react-i18next"; +// Model hints for forecasting models - helps users understand model requirements +const FORECASTING_MODEL_HINTS = { + ProphetModel: { + minDataPoints: 30, + description: "Facebook's Prophet model for business time series", + strengths: [ + "Handles missing data", + "Automatic seasonality", + "Good for daily/weekly patterns", + ], + limitations: [ + "Needs consistent frequency", + "Better with >2 years of data for yearly seasonality", + ], + smallDatasetNote: + "Works with small datasets but yearly seasonality detection may be limited.", + }, + StatsmodelsSARIMAXModel: { + minDataPoints: 20, + description: "Statistical ARIMA/SARIMAX model", + strengths: [ + "Classic statistical approach", + "Interpretable parameters", + "Good for stationary data", + ], + limitations: ["Requires parameter tuning", "Sensitive to non-stationarity"], + smallDatasetNote: + "With small datasets, seasonality will be auto-disabled and simpler ARIMA will be used.", + }, + SklearnMultiStepForecaster: { + minDataPoints: 10, + description: "Machine learning-based forecaster using sklearn regressors", + strengths: ["Flexible", "Works with small datasets", "Fast training"], + limitations: [ + "May overfit with very few samples", + "No built-in seasonality", + ], + smallDatasetNote: + "Recommended for small datasets. Window size auto-adjusts based on available data.", + }, +}; + /** * Step of the experiment modal: add models to the experiment and configure its parameters * @param {object} newExp object that contains the Experiment Modal state @@ -159,6 +205,24 @@ function ConfigureModelsStep({ newExp, setNewExp, setNextEnabled }) { } }, [selectedModel, defaultName]); + // Check if this is a forecasting task + const isForecastingTask = newExp.task_name === "ForecastingTask"; + + // Get dataset size from splits info (approximate) + const datasetSize = useMemo(() => { + if (newExp.splits && newExp.splits.train) { + // If we have percentage splits, we can estimate from the dataset + // This is a rough estimate - the actual size comes from the dataset + return newExp.dataset?.total_rows || null; + } + return null; + }, [newExp.splits, newExp.dataset]); + + // Get model hint if available + const selectedModelHint = selectedModel + ? FORECASTING_MODEL_HINTS[selectedModel] + : null; + return ( + + {/* Model Info Panel for Forecasting */} + + + } sx={{ mt: 1 }}> + + {selectedModelHint?.description} + + + + + + Strengths: + + + {selectedModelHint?.strengths.map((s, i) => ( + + ))} + + + + + + + + Limitations: + + + {selectedModelHint?.limitations.map((l, i) => ( + + ))} + + + + + {selectedModelHint?.smallDatasetNote && ( + + 💡 Small dataset tip:{" "} + {selectedModelHint.smallDatasetNote} + + )} + + + + + {/* Models table */} diff --git a/DashAI/front/src/components/models/modelSession/PrepareDatasetStep.jsx b/DashAI/front/src/components/models/modelSession/PrepareDatasetStep.jsx index f416dfc52..1a6334d6f 100644 --- a/DashAI/front/src/components/models/modelSession/PrepareDatasetStep.jsx +++ b/DashAI/front/src/components/models/modelSession/PrepareDatasetStep.jsx @@ -11,8 +11,10 @@ import { } from "@mui/material"; import DivideDatasetColumns from "./DivideDatasetColumns"; import SplitDatasetRows from "./SplitDatasetRows"; +import SplitDatasetTemporal from "./SplitDatasetTemporal"; import { getDatasetInfo as getDatasetInfoRequest, + getDatasetTemporalInfo, getDatasetTypes as getDatasetTypesRequest, } from "../../../api/datasets"; import { getComponents as getComponentsRequest } from "../../../api/component"; @@ -54,9 +56,13 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { const [columnsReady, setColumnsReady] = useState(false); const [columnsAreValid, setColumnsAreValid] = useState(false); + const [columnsValidationError, setColumnsValidationError] = useState(""); const [shuffle, setShuffle] = useState(true); const [stratify, setStratify] = useState(false); const [seed, setSeed] = useState(42); + const [gap, setGap] = useState(0); + const [temporalInfo, setTemporalInfo] = useState(null); + const [temporalInfoLoading, setTemporalInfoLoading] = useState(false); const defaultParitionsIndex = { train: [], @@ -81,6 +87,7 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { RANDOM: "random", MANUAL: "manual", PREDEFINED: "predefined", + TEMPORAL: "temporal", }; const [splitType, setSplitType] = useState(""); @@ -190,11 +197,13 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { datasetInfo.column_names.length === 0 ) { setColumnsAreValid(false); + setColumnsValidationError(""); return; } if (inputColumnNames.length === 0 || outputColumnNames.length === 0) { setColumnsAreValid(false); + setColumnsValidationError(""); return; } @@ -205,7 +214,9 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { inputColumnNames, outputColumnNames, ); - setColumnsAreValid(validation.dataset_status === "valid"); + const isValid = validation.dataset_status === "valid"; + setColumnsAreValid(isValid); + setColumnsValidationError(isValid ? "" : validation.error || ""); } catch (error) { enqueueSnackbar(t("experiments:error.errorFetchingColumnsValidation")); if (error.response) { @@ -216,6 +227,7 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { console.error("Unknown Error", error.message); } setColumnsAreValid(false); + setColumnsValidationError(""); } }; @@ -228,29 +240,39 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { return; } + const effectiveSplitType = isForecastingTask + ? SPLIT_TYPES.TEMPORAL + : splitType; + const updatedExpData = { ...newExp, input_columns: inputColumnNames, output_columns: outputColumnNames, }; - if (splitType === SPLIT_TYPES.MANUAL) { + if (effectiveSplitType === SPLIT_TYPES.MANUAL) { updatedExpData.splits = { ...rowsPartitionsIndex, - splitType: splitType, + splitType: effectiveSplitType, }; - } else if (splitType === SPLIT_TYPES.RANDOM) { + } else if (effectiveSplitType === SPLIT_TYPES.RANDOM) { updatedExpData.splits = { ...rowsPartitionsPercentage, shuffle: shuffle, stratify: stratify, seed: seed === "" || seed == null ? 42 : Number(seed), - splitType: splitType, + splitType: effectiveSplitType, }; - } else if (splitType === SPLIT_TYPES.PREDEFINED) { + } else if (effectiveSplitType === SPLIT_TYPES.PREDEFINED) { updatedExpData.splits = { ...datasetPartitionsIndex, - splitType: splitType, + splitType: effectiveSplitType, + }; + } else if (effectiveSplitType === SPLIT_TYPES.TEMPORAL) { + updatedExpData.splits = { + ...rowsPartitionsPercentage, + gap: gap, + splitType: effectiveSplitType, }; } setNewExp(updatedExpData); @@ -267,7 +289,6 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { useEffect(() => { if ( columnsReady && - splitsReady && datasetInfo && datasetInfo.column_names && datasetInfo.column_names.length > 0 @@ -275,14 +296,9 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { validateColumns(); } else { setColumnsAreValid(false); + setColumnsValidationError(""); } - }, [ - columnsReady, - splitsReady, - inputColumnNames, - outputColumnNames, - datasetInfo, - ]); + }, [columnsReady, inputColumnNames, outputColumnNames, datasetInfo]); useEffect(() => { if (columnsAreValid && splitsReady && columnsReady) { @@ -299,6 +315,7 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { shuffle, stratify, seed, + gap, inputColumnNames, outputColumnNames, ]); @@ -308,6 +325,49 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { getTaskRequirements(); }, []); + // Check if current task is ForecastingTask + const isForecastingTask = taskRequirements.name === "ForecastingTask"; + + // Set split type to TEMPORAL for forecasting tasks + useEffect(() => { + if (isForecastingTask && splitType !== SPLIT_TYPES.TEMPORAL) { + setSplitType(SPLIT_TYPES.TEMPORAL); + } + }, [isForecastingTask, splitType]); + + // Fetch temporal info when input columns change for forecasting tasks + useEffect(() => { + const fetchTemporalInfo = async () => { + if ( + !isForecastingTask || + inputColumnNames.length === 0 || + !newExp.dataset?.id + ) { + setTemporalInfo(null); + return; + } + + // Use the first input column as the timestamp column + const timestampColumn = inputColumnNames[0]; + + setTemporalInfoLoading(true); + try { + const info = await getDatasetTemporalInfo( + newExp.dataset.id, + timestampColumn, + ); + setTemporalInfo(info); + } catch (error) { + console.error("Error fetching temporal info:", error); + setTemporalInfo(null); + } finally { + setTemporalInfoLoading(false); + } + }; + + fetchTemporalInfo(); + }, [isForecastingTask, inputColumnNames, newExp.dataset?.id]); + const renderTypesAsChips = (typesList) => { if (!typesList || typesList.length === 0) { return {t("common:any")}; @@ -345,6 +405,12 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { ); }; + const alertTitleKey = columnsAreValid + ? "experiments:label.columnsValidRequirements" + : columnsValidationError + ? "experiments:label.datasetInvalidForTask" + : "experiments:label.columnsInvalidRequirements"; + return ( {taskRequirements - ? t( - columnsAreValid - ? "experiments:label.columnsValidRequirements" - : "experiments:label.columnsInvalidRequirements", - { taskName: taskRequirements.display_name }, - ) + ? t(alertTitleKey, { taskName: taskRequirements.display_name }) : null} @@ -414,6 +475,17 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { + {!columnsAreValid && columnsValidationError ? ( + + {t("experiments:label.validationDetails")}:{" "} + {columnsValidationError} + + ) : null} + {isForecastingTask ? ( + + {t("experiments:label.forecastingValidationHint")} + + ) : null} {!infoLoading && datasetInfo.nan ? ( Object.values(datasetInfo.nan).some((v) => v > 0) ? ( @@ -449,23 +521,36 @@ function PrepareDatasetStep({ newExp, setNewExp, setNextEnabled }) { } /> - + {isForecastingTask ? ( + + ) : ( + + )} ) : ( diff --git a/DashAI/front/src/components/models/modelSession/RunnerDialog.jsx b/DashAI/front/src/components/models/modelSession/RunnerDialog.jsx index 5290e1499..e3985548f 100644 --- a/DashAI/front/src/components/models/modelSession/RunnerDialog.jsx +++ b/DashAI/front/src/components/models/modelSession/RunnerDialog.jsx @@ -135,7 +135,10 @@ function RunnerDialog({ const enqueueRunnerJob = async (runId) => { try { - const response = await enqueueRunnerJobRequest(runId); + const response = await enqueueRunnerJobRequest( + runId, + experiment.task_name, + ); if (response && response.id) { setTrackedJobIds((prev) => new Set(prev).add(response.id)); @@ -238,7 +241,10 @@ function RunnerDialog({ ), ); - const response = await enqueueRunnerJobRequest(run.id); + const response = await enqueueRunnerJobRequest( + run.id, + experiment.task_name, + ); if (response && response.id) { enqueueSnackbar(`Run ${run.name} started successfully`, { @@ -359,10 +365,6 @@ function RunnerDialog({ setRows((prevRows) => prevRows.filter((row) => row.id !== params.row.id), ); - if (rows.length === 1) { - setOpen(false); - deleteExperiment(); - } }} />, ], diff --git a/DashAI/front/src/components/models/modelSession/SelectDatasetStep.jsx b/DashAI/front/src/components/models/modelSession/SelectDatasetStep.jsx index fe3d158c2..1c0b79d06 100644 --- a/DashAI/front/src/components/models/modelSession/SelectDatasetStep.jsx +++ b/DashAI/front/src/components/models/modelSession/SelectDatasetStep.jsx @@ -16,7 +16,7 @@ import { Link as RouterLink } from "react-router-dom"; import { useTourContext } from "../tour/TourProvider"; import { getDatasets as getDatasetsRequest } from "../../api/datasets"; import { formatDate } from "../../utils"; -import { useTranslation } from "react-i18next"; +import { Trans, useTranslation } from "react-i18next"; function SelectDatasetStep({ newExp, setNewExp, setNextEnabled }) { const { enqueueSnackbar } = useSnackbar(); diff --git a/DashAI/front/src/components/models/modelSession/SplitDatasetTemporal.jsx b/DashAI/front/src/components/models/modelSession/SplitDatasetTemporal.jsx new file mode 100644 index 000000000..0de705aef --- /dev/null +++ b/DashAI/front/src/components/models/modelSession/SplitDatasetTemporal.jsx @@ -0,0 +1,539 @@ +import React, { useEffect, useState } from "react"; +import PropTypes from "prop-types"; +import { + Grid, + TextField, + Typography, + FormHelperText, + Slider, + Box, + Alert, + AlertTitle, + Chip, + CircularProgress, + Paper, + Collapse, +} from "@mui/material"; +import CalendarTodayIcon from "@mui/icons-material/CalendarToday"; +import AccessTimeIcon from "@mui/icons-material/AccessTime"; +import TrendingUpIcon from "@mui/icons-material/TrendingUp"; +import WarningAmberIcon from "@mui/icons-material/WarningAmber"; + +/** + * Component for temporal splitting of time series data for forecasting tasks. + * Unlike random splitting, this maintains temporal order to prevent data leakage. + */ +function SplitDatasetTemporal({ + datasetInfo, + rowsPartitionsPercentage, + setRowsPartitionsPercentage, + setSplitsReady, + gap, + setGap, + temporalInfo, + temporalInfoLoading, +}) { + const totalRows = datasetInfo.total_rows; + + const [splitError, setSplitError] = useState(false); + const [splitErrorText, setSplitErrorText] = useState(""); + + // Minimum sizes for temporal splits - scaled based on dataset size + // For small datasets (testing/demo), use proportional minimums + const isSmallDataset = totalRows < 100; + const MIN_TRAIN_SIZE = isSmallDataset + ? Math.max(3, Math.floor(totalRows * 0.5)) + : 50; + const MIN_VAL_SIZE = isSmallDataset + ? Math.max(1, Math.floor(totalRows * 0.15)) + : 10; + const MIN_TEST_SIZE = isSmallDataset + ? Math.max(1, Math.floor(totalRows * 0.15)) + : 10; + + const checkTemporalSplit = (train, validation, test, gapValue) => { + // Convert percentages to actual row counts + const trainRows = Math.floor(totalRows * train); + const valRows = Math.floor(totalRows * validation); + const testRows = Math.floor(totalRows * test); + + // Total rows needed including gaps + const totalNeeded = trainRows + valRows + testRows + 2 * gapValue; + + if (totalNeeded > totalRows) { + setSplitErrorText( + `Not enough data. Need ${totalNeeded} rows but have ${totalRows}. Try reducing gap or split sizes.`, + ); + return false; + } + + if (trainRows < MIN_TRAIN_SIZE) { + setSplitErrorText( + `Training set too small: ${trainRows} < ${MIN_TRAIN_SIZE}. Increase train proportion.`, + ); + return false; + } + + if (valRows < MIN_VAL_SIZE) { + setSplitErrorText( + `Validation set too small: ${valRows} < ${MIN_VAL_SIZE}. Increase validation proportion.`, + ); + return false; + } + + if (testRows < MIN_TEST_SIZE) { + setSplitErrorText( + `Test set too small: ${testRows} < ${MIN_TEST_SIZE}. Increase test proportion.`, + ); + return false; + } + + // Use tolerance for floating point comparison (0.7 + 0.2 + 0.1 !== 1 in JS) + const sum = train + validation + test; + if (Math.abs(sum - 1) > 0.0001) { + setSplitErrorText( + "Splits should be numbers between 0 and 1 and should add 1 in total", + ); + return false; + } + + return true; + }; + + const handleRowsChange = (event) => { + const value = parseFloat(event.target.value); + const id = event.target.id; + + let newSplit = { ...rowsPartitionsPercentage }; + switch (id) { + case "train": + newSplit = { ...newSplit, train: value }; + break; + case "validation": + newSplit = { ...newSplit, validation: value }; + break; + case "test": + newSplit = { ...newSplit, test: value }; + break; + } + + setRowsPartitionsPercentage(newSplit); + + if ( + !checkTemporalSplit( + newSplit.train, + newSplit.validation, + newSplit.test, + gap, + ) + ) { + setSplitError(true); + } else { + setSplitError(false); + setSplitErrorText(""); + } + }; + + const handleGapChange = (event, newValue) => { + setGap(newValue); + + if ( + !checkTemporalSplit( + rowsPartitionsPercentage.train, + rowsPartitionsPercentage.validation, + rowsPartitionsPercentage.test, + newValue, + ) + ) { + setSplitError(true); + } else { + setSplitError(false); + setSplitErrorText(""); + } + }; + + useEffect(() => { + // Validate splits on mount and when data changes + // Ensure we have dataset info before validating + if (!totalRows || totalRows <= 0) { + setSplitsReady(false); + return; + } + + const isValid = + !splitError && + rowsPartitionsPercentage.train > 0 && + rowsPartitionsPercentage.validation > 0 && + rowsPartitionsPercentage.test > 0 && + checkTemporalSplit( + rowsPartitionsPercentage.train, + rowsPartitionsPercentage.validation, + rowsPartitionsPercentage.test, + gap, + ); + + setSplitsReady(isValid); + }, [rowsPartitionsPercentage, splitError, gap, totalRows]); + + // Calculate actual row numbers for display + const trainRows = Math.floor(totalRows * rowsPartitionsPercentage.train); + const valRows = Math.floor(totalRows * rowsPartitionsPercentage.validation); + const testRows = Math.floor(totalRows * rowsPartitionsPercentage.test); + + // Format date for display + const formatDate = (isoString) => { + if (!isoString) return ""; + const date = new Date(isoString); + return date.toLocaleDateString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + }); + }; + + return ( + + + {/* Temporal Information Panel */} + + + + + Detected Time Series Properties + + + {temporalInfoLoading ? ( + + + + Analyzing temporal patterns... + + + ) : temporalInfo ? ( + + + + + + + Frequency + + + + + + {temporalInfo.frequency_description} + + + + + + + + + + + Date Range + + + {formatDate(temporalInfo.start_date)} →{" "} + {formatDate(temporalInfo.end_date)} + + + + + + + + + Total Periods + + + {temporalInfo.total_periods} data points + + + + + + + + Average Interval + + + {temporalInfo.average_interval} + + {temporalInfo.detected_gaps > 0 && ( + + ⚠️ {temporalInfo.detected_gaps} gaps detected + + )} + + + + + + + Prediction interpretation: When you + forecast {gap > 0 ? `with a ${gap} period gap` : ""}, each + prediction step represents{" "} + + 1{" "} + {temporalInfo.frequency_label + .toLowerCase() + .slice(0, -2)} + + . {temporalInfo.frequency_example} + + + + + ) : ( + + Select an input column (timestamp) to analyze temporal + properties. + + )} + + + + + + Temporal Splitting for Time Series + + For forecasting tasks, data is split chronologically to prevent + data leakage: + +
    +
  • Training data comes first (oldest)
  • +
  • Validation data follows training data
  • +
  • Test data comes last (most recent)
  • +
  • + Optional gap between splits to simulate real-world scenarios +
  • +
+
+
+ + + + Select proportions for temporal splits + + +
+ + + + + + + + + + + + + + + + Gap between splits{" "} + {temporalInfo + ? `(${temporalInfo.frequency_label.toLowerCase()} to skip)` + : "(periods to skip)"} + + + + Gap helps simulate real-world forecasting by adding delay between + training and prediction. + {temporalInfo && ( + <> + {" "} + Current gap: {gap}{" "} + {temporalInfo.frequency_label.toLowerCase()}. + + )} + + + + + + {splitError && ( + + {splitErrorText} + + )} + + + + Timeline preview: Train ({trainRows}{" "} + {temporalInfo ? temporalInfo.frequency_label.toLowerCase() : "rows"}) + {gap > 0 && + ` → Gap (${gap} ${temporalInfo ? temporalInfo.frequency_label.toLowerCase() : "rows"})`}{" "} + → Validation ({valRows}{" "} + {temporalInfo ? temporalInfo.frequency_label.toLowerCase() : "rows"}) + {gap > 0 && + ` → Gap (${gap} ${temporalInfo ? temporalInfo.frequency_label.toLowerCase() : "rows"})`}{" "} + → Test ({testRows}{" "} + {temporalInfo ? temporalInfo.frequency_label.toLowerCase() : "rows"}) + + + + {/* Small Dataset Warnings for Forecasting */} + + }> + Small Dataset Detected ({totalRows} rows) + + Your dataset is relatively small for time series forecasting. This + may affect model performance: + + + {trainRows < 10 && ( +
  • + + Training set ({trainRows} rows): Some models + may auto-adjust their parameters (e.g., reduced lag window, + disabled seasonality) to work with limited data. + +
  • + )} + {valRows < 5 && ( +
  • + + Validation set ({valRows} rows): Very small + validation sets may result in unreliable or NaN metrics. + Consider increasing validation proportion. + +
  • + )} + {testRows < 5 && ( +
  • + + Test set ({testRows} rows): Very small test + sets may not provide meaningful evaluation metrics. + +
  • + )} + {totalRows < 20 && ( +
  • + + Seasonal models (SARIMAX): Seasonality will + be automatically disabled as there isn't enough data to + detect seasonal patterns. + +
  • + )} +
    + + 💡 Tip: For best results, time series models + typically need at least 50-100 data points. With small datasets, + simpler models like SklearnMultiStepForecaster{" "} + often perform better than complex ones. + +
    +
    +
    + ); +} + +SplitDatasetTemporal.propTypes = { + datasetInfo: PropTypes.shape({ + total_rows: PropTypes.number, + }).isRequired, + rowsPartitionsPercentage: PropTypes.shape({ + train: PropTypes.number, + validation: PropTypes.number, + test: PropTypes.number, + }).isRequired, + setRowsPartitionsPercentage: PropTypes.func.isRequired, + setSplitsReady: PropTypes.func.isRequired, + gap: PropTypes.number.isRequired, + setGap: PropTypes.func.isRequired, + temporalInfo: PropTypes.shape({ + frequency_code: PropTypes.string, + frequency_label: PropTypes.string, + frequency_description: PropTypes.string, + frequency_example: PropTypes.string, + average_interval: PropTypes.string, + start_date: PropTypes.string, + end_date: PropTypes.string, + total_periods: PropTypes.number, + detected_gaps: PropTypes.number, + timestamp_column: PropTypes.string, + }), + temporalInfoLoading: PropTypes.bool, +}; + +SplitDatasetTemporal.defaultProps = { + temporalInfo: null, + temporalInfoLoading: false, +}; + +export default SplitDatasetTemporal; diff --git a/DashAI/front/src/components/models/modelSession/runButtons/DeleteRun.jsx b/DashAI/front/src/components/models/modelSession/runButtons/DeleteRun.jsx index 858f38dd5..0f9195a18 100644 --- a/DashAI/front/src/components/models/modelSession/runButtons/DeleteRun.jsx +++ b/DashAI/front/src/components/models/modelSession/runButtons/DeleteRun.jsx @@ -37,9 +37,16 @@ export default function DeleteRun({ run, onRunDelete }) { }); } catch (error) { console.error("Error deleting run:", error); - enqueueSnackbar(t("message.errorDeletingRun"), { - variant: "error", - }); + const detail = + error?.response?.data?.detail || error?.message || ""; + enqueueSnackbar( + detail + ? `${t("message.errorDeletingRun")}: ${detail}` + : t("message.errorDeletingRun"), + { + variant: "error", + }, + ); } finally { setOpen(false); } diff --git a/DashAI/front/src/components/predictions/ForecastingOptions.jsx b/DashAI/front/src/components/predictions/ForecastingOptions.jsx new file mode 100644 index 000000000..fa1f59791 --- /dev/null +++ b/DashAI/front/src/components/predictions/ForecastingOptions.jsx @@ -0,0 +1,514 @@ +import React, { useState, useEffect, useCallback } from "react"; +import PropTypes from "prop-types"; +import { + Alert, + AlertTitle, + Box, + Chip, + CircularProgress, + FormControl, + FormControlLabel, + Grid, + Paper, + Radio, + RadioGroup, + TextField, + Typography, +} from "@mui/material"; +import InfoIcon from "@mui/icons-material/Info"; +import AccessTimeIcon from "@mui/icons-material/AccessTime"; +import CalendarTodayIcon from "@mui/icons-material/CalendarToday"; +import TrendingUpIcon from "@mui/icons-material/TrendingUp"; +import WarningAmberIcon from "@mui/icons-material/WarningAmber"; +import AutoAwesomeIcon from "@mui/icons-material/AutoAwesome"; +import UploadFileIcon from "@mui/icons-material/UploadFile"; +import CheckCircleIcon from "@mui/icons-material/CheckCircle"; + +import { getDatasetTemporalInfo } from "../../api/datasets"; + +/** + * ForecastingOptions — child of PredictionModal. + * + * Renders when the experiment task is ForecastingTask. + * Provides: + * 1. Training-data time-series summary (frequency, date range, periods). + * 2. Forecast mode selector: + * - "auto-generate" → user enters the number of future periods. + * - "dataset" → user picks a dataset (handled externally); + * this component validates its temporal frequency. + * 3. Frequency-mismatch / match feedback when a prediction dataset is + * selected. + * + * Props + * ───── + * temporalInfo — temporal metadata of the *training* dataset + * (frequency_code, frequency_label, start_date, …). + * forecastMode — current mode ("auto-generate" | "dataset"). + * setForecastMode — setter for mode. + * forecastPeriods — number of future periods (auto-generate mode). + * setForecastPeriods — setter for periods. + * selectedDataset — currently selected prediction dataset (or null). + * When non-null and mode === "dataset", frequency + * validation runs automatically. + */ +export default function ForecastingOptions({ + temporalInfo, + forecastMode, + setForecastMode, + forecastPeriods, + setForecastPeriods, + selectedDataset, +}) { + // ----- frequency validation state ----- + const [selectedDatasetTemporalInfo, setSelectedDatasetTemporalInfo] = + useState(null); + const [frequencyMismatch, setFrequencyMismatch] = useState(false); + const [loadingTemporalInfo, setLoadingTemporalInfo] = useState(false); + + // Expose mismatch to parent via callback (so it can disable Submit) + // We use the convention: parent reads `frequencyMismatch` via a ref or + // the component simply blocks via canPredict logic in parent. + + // ----- validate temporal frequency of the selected prediction dataset ----- + useEffect(() => { + const validateSelectedDatasetFrequency = async () => { + if (!selectedDataset || forecastMode !== "dataset" || !temporalInfo) { + setSelectedDatasetTemporalInfo(null); + setFrequencyMismatch(false); + return; + } + + setLoadingTemporalInfo(true); + try { + const timestampColumn = temporalInfo.timestamp_column; + const predictionDatasetInfo = await getDatasetTemporalInfo( + selectedDataset.id, + timestampColumn, + ); + setSelectedDatasetTemporalInfo(predictionDatasetInfo); + + if ( + predictionDatasetInfo.frequency_code !== temporalInfo.frequency_code + ) { + setFrequencyMismatch(true); + } else { + setFrequencyMismatch(false); + } + } catch (error) { + console.error("Error validating prediction dataset frequency:", error); + setSelectedDatasetTemporalInfo(null); + setFrequencyMismatch(false); + } finally { + setLoadingTemporalInfo(false); + } + }; + + validateSelectedDatasetFrequency(); + }, [selectedDataset, forecastMode, temporalInfo]); + + // Reset validation when switching to auto-generate + const handleModeChange = useCallback( + (newMode) => { + setForecastMode(newMode); + if (newMode === "auto-generate") { + setSelectedDatasetTemporalInfo(null); + setFrequencyMismatch(false); + } else { + setForecastPeriods(null); + } + }, + [setForecastMode, setForecastPeriods], + ); + + if (!temporalInfo) return null; + + return ( + + {/* ── 1. Training data time-series summary ── */} + + + + Training Data Time Series Properties + + + + + + + + + Frequency + + + + + + + + + + + + + + Training Period + + + {new Date(temporalInfo.start_date).toLocaleDateString()} →{" "} + {new Date(temporalInfo.end_date).toLocaleDateString()} + + + + + + + + + Training Periods + + + {temporalInfo.total_periods}{" "} + {temporalInfo.frequency_label.toLowerCase()} + + + + + + + + Average Interval + + + {temporalInfo.average_interval} + + + + + + + + What this means: The model was trained on{" "} + {temporalInfo.frequency_label.toLowerCase()} data. + Each prediction step will forecast{" "} + + 1 {temporalInfo.frequency_label.toLowerCase().slice(0, -2)} + {" "} + into the future. {temporalInfo.frequency_example} + + + + + {/* ── 2. Forecast mode selector ── */} + + + Choose Prediction Method + + + + handleModeChange(e.target.value)} + > + {/* ── Auto-generate option ── */} + handleModeChange("auto-generate")} + > + } + label={ + + + + + Auto-generate Future Timestamps + + + Automatically generate future dates from the last + training date. + {` Starting from ${new Date(temporalInfo.end_date).toLocaleDateString()}.`} + + + + } + sx={{ m: 0, width: "100%" }} + /> + + {forecastMode === "auto-generate" && ( + + { + const value = e.target.value; + if (value === "") { + setForecastPeriods(null); + } else { + const numValue = parseInt(value, 10); + if (numValue > 0 && numValue <= 1000) { + setForecastPeriods(numValue); + } + } + }} + helperText={`Forecast ${forecastPeriods || "N"} ${temporalInfo.frequency_label.toLowerCase()} into the future`} + inputProps={{ min: 1, max: 1000 }} + /> + }> + + This option is not available for models + trained with exogenous variables, as future values of + those variables are required. + + + + )} + + + {/* ── Upload-dataset option ── */} + handleModeChange("dataset")} + > + } + label={ + + + + + Upload Dataset with Timestamps + + + Use a dataset containing specific timestamps you want to + predict. Required if the model uses exogenous variables. + + + + } + sx={{ m: 0, width: "100%" }} + /> + + + + + + {/* ── 3. Dataset requirements info (only in dataset mode) ── */} + {forecastMode === "dataset" && ( + } sx={{ mb: 2 }}> + Dataset Requirements + + For forecasting predictions: +
      +
    • + Dataset must include a ds (timestamp) column + with dates to predict (past, present, or future) +
    • +
    • + Timestamps must be strictly increasing and + match the training frequency + ({temporalInfo.frequency_label}) +
    • +
    • + If the model used exogenous regressors during training, include + those columns with values for all timestamps +
    • +
    • + Any y (target) column will be ignored during + prediction +
    • +
    +
    +
    + )} + + {/* ── 4. Frequency validation feedback ── */} + + {/* Loading spinner */} + {forecastMode === "dataset" && loadingTemporalInfo && selectedDataset && ( + + + + Validating dataset temporal frequency… + + + )} + + {/* Mismatch error */} + {forecastMode === "dataset" && + frequencyMismatch && + selectedDatasetTemporalInfo && + temporalInfo && ( + } sx={{ mt: 1 }}> + Temporal Frequency Mismatch + + The selected dataset has a{" "} + different temporal frequency than the training + data: + + + + Training Data + + + + + ({temporalInfo.average_interval}) + + + + + + Selected Dataset + + + + + ({selectedDatasetTemporalInfo.average_interval}) + + + + + + This will produce incorrect predictions. Please + select a dataset with{" "} + {temporalInfo.frequency_label.toLowerCase()}{" "} + frequency, or use the auto-generate option above. + + + + )} + + {/* Success match */} + {forecastMode === "dataset" && + !frequencyMismatch && + selectedDatasetTemporalInfo && + temporalInfo && + !loadingTemporalInfo && ( + } sx={{ mt: 1 }}> + + Frequency match! The selected dataset has the + same temporal frequency ( + {selectedDatasetTemporalInfo.frequency_label}) as + the training data. Period:{" "} + {new Date( + selectedDatasetTemporalInfo.start_date, + ).toLocaleDateString()}{" "} + →{" "} + {new Date( + selectedDatasetTemporalInfo.end_date, + ).toLocaleDateString()}{" "} + ({selectedDatasetTemporalInfo.total_periods} periods) + + + )} +
    + ); +} + +/** Whether the current forecasting configuration blocks prediction. */ +ForecastingOptions.isBlocked = ({ + forecastMode, + forecastPeriods, + selectedDataset, + frequencyMismatch, +}) => { + if (forecastMode === "auto-generate") { + return !forecastPeriods || forecastPeriods <= 0; + } + // dataset mode + if (!selectedDataset) return true; + if (frequencyMismatch) return true; + return false; +}; + +ForecastingOptions.propTypes = { + temporalInfo: PropTypes.shape({ + frequency_code: PropTypes.string, + frequency_label: PropTypes.string, + frequency_description: PropTypes.string, + frequency_example: PropTypes.string, + average_interval: PropTypes.string, + start_date: PropTypes.string, + end_date: PropTypes.string, + total_periods: PropTypes.number, + detected_gaps: PropTypes.number, + timestamp_column: PropTypes.string, + }), + forecastMode: PropTypes.oneOf(["auto-generate", "dataset"]).isRequired, + setForecastMode: PropTypes.func.isRequired, + forecastPeriods: PropTypes.number, + setForecastPeriods: PropTypes.func.isRequired, + selectedDataset: PropTypes.object, +}; diff --git a/DashAI/front/src/components/predictions/PredictionModal.jsx b/DashAI/front/src/components/predictions/PredictionModal.jsx index c38a6cd25..319294075 100644 --- a/DashAI/front/src/components/predictions/PredictionModal.jsx +++ b/DashAI/front/src/components/predictions/PredictionModal.jsx @@ -22,6 +22,7 @@ import { } from "@mui/icons-material"; import ModeSelector from "./ModeSelector"; import DatasetSelector from "./DatasetSelector"; +import ForecastingOptions from "./ForecastingOptions"; import ResultsTable from "./ResultsTable"; import PredictionsTable from "./PredictionsTable"; import ManualInput from "./ManualInput"; @@ -31,12 +32,18 @@ import { getPredictions, deletePrediction, } from "../../api/predict"; -import { getDatasetInfo, exportDatasetCsvByPath } from "../../api/datasets"; +import { + getDatasetInfo, + exportDatasetCsvByPath, + getDatasetTemporalInfo, + getDatasetTypes, + getDatasetSample, +} from "../../api/datasets"; import { enqueuePredictionJob } from "../../api/job"; import { getModelSessionById } from "../../api/modelSession"; -import { getDatasetTypes, getDatasetSample } from "../../api/datasets"; import { useSnackbar } from "notistack"; import { useTranslation } from "react-i18next"; +import { getPredictionStatus } from "../../utils/predictionStatus"; export default function PredictionModal({ isOpen, onClose, run }) { const [activeTab, setActiveTab] = useState(0); @@ -57,6 +64,13 @@ export default function PredictionModal({ isOpen, onClose, run }) { const [loadingExperiment, setLoadingExperiment] = useState(true); const [sample, setSample] = useState(null); + // ── Forecasting-specific state ── + const [temporalInfo, setTemporalInfo] = useState(null); + const [forecastMode, setForecastMode] = useState("dataset"); // "auto-generate" | "dataset" + const [forecastPeriods, setForecastPeriods] = useState(null); + + const isForecastingTask = experiment?.task_name === "ForecastingTask"; + const { enqueueSnackbar } = useSnackbar(); const { t } = useTranslation(["prediction", "common"]); @@ -72,6 +86,10 @@ export default function PredictionModal({ isOpen, onClose, run }) { setIsLoading(false); setManualRows([]); setSelectedPrediction(null); + // Reset forecasting state + setTemporalInfo(null); + setForecastMode("dataset"); + setForecastPeriods(null); } }, [isOpen]); @@ -114,21 +132,43 @@ export default function PredictionModal({ isOpen, onClose, run }) { } }; - // Fetch experiment details + // Fetch experiment details + forecasting temporal info const fetchExperiment = async () => { setLoadingExperiment(true); try { - if (run && run.experiment_id) { + if (run && run.model_session_id) { const experimentData = await getModelSessionById( run.model_session_id, ); setExperiment(experimentData); + const datasetTypes = await getDatasetTypes(experimentData.dataset_id); setTypes(datasetTypes); const datasetSample = await getDatasetSample( experimentData.dataset_id, ); setSample(datasetSample); + + // ── Forecasting: fetch temporal info if applicable ── + if (experimentData.task_name === "ForecastingTask") { + try { + const inputCols = experimentData.input_columns || []; + // In forecasting, the first input column is the timestamp column + if (inputCols.length > 0 && experimentData.dataset_id) { + const timestampColumn = + typeof inputCols === "string" + ? JSON.parse(inputCols)[0] + : inputCols[0]; + const info = await getDatasetTemporalInfo( + experimentData.dataset_id, + timestampColumn, + ); + setTemporalInfo(info); + } + } catch (error) { + console.error("Error fetching temporal info:", error); + } + } } } catch (error) { console.error("Error fetching experiment:", error); @@ -144,32 +184,38 @@ export default function PredictionModal({ isOpen, onClose, run }) { // Handle prediction execution const submitPredictionJob = async () => { - // 1.- Set loading to true setIsLoading(true); try { - // 2.- Create a prediction in database - const prediction = await createPrediction( - run.id, - predictionMode === "dataset" ? selectedDataset.id : null, - ); + // Determine dataset_id — null when auto-generating timestamps + const datasetId = + isForecastingTask && forecastMode === "auto-generate" + ? null + : predictionMode === "dataset" + ? (selectedDataset?.id ?? null) + : null; + + // 1. Create prediction record + const prediction = await createPrediction(run.id, datasetId); + + // 2. Enqueue prediction job (with forecast_periods when applicable) + const periodsArg = + isForecastingTask && forecastMode === "auto-generate" + ? forecastPeriods + : undefined; - // 3.- Enqueue prediction job await enqueuePredictionJob( prediction.id, predictionMode === "manual" ? manualRows : null, + periodsArg, ); enqueueSnackbar(t("prediction:message.predictionJobSubmitted"), { variant: "success", }); - // 4.- Change prediction status to Delivered + // 3. Update status and navigate prediction.status = 1; // Delivered - - // 5.- Add prediction to the list setPredictions([prediction, ...predictions]); - - // 6.- Set selected prediction to the new one and switch tab setSelectedPrediction(prediction); setActiveTab(1); } catch (error) { @@ -178,7 +224,6 @@ export default function PredictionModal({ isOpen, onClose, run }) { variant: "error", }); } finally { - // 7.- Set loading to false setIsLoading(false); } }; @@ -224,9 +269,9 @@ export default function PredictionModal({ isOpen, onClose, run }) { enqueueSnackbar( `${t("prediction:label.prediction")} ${ updated.id - } ${statusText.toLowerCase()}.`, + } ${getPredictionStatus(statusText, t).toLowerCase()}.`, { - variant: statusText === 3 ? "success" : "error", // Finished + variant: statusText === 3 ? "success" : "error", }, ); clearInterval(intervals[prediction.id]); @@ -238,10 +283,9 @@ export default function PredictionModal({ isOpen, onClose, run }) { variant: "error", }); } - }, 2000); // Every 2 seconds + }, 2000); }); - // Cleanup return () => { Object.values(intervals).forEach((interval) => clearInterval(interval)); }; @@ -283,6 +327,16 @@ export default function PredictionModal({ isOpen, onClose, run }) { }; const canPredict = () => { + // ── Forecasting-specific logic ── + if (isForecastingTask) { + if (forecastMode === "auto-generate") { + return forecastPeriods != null && forecastPeriods > 0; + } + // forecastMode === "dataset": need a dataset selected + return selectedDataset !== null; + } + + // ── Standard tasks ── if (predictionMode === "dataset") { return selectedDataset !== null; } @@ -355,26 +409,52 @@ export default function PredictionModal({ isOpen, onClose, run }) { <> {viewMode === "input" ? ( - - {predictionMode === "dataset" ? ( - + {/* ── Forecasting: show ForecastingOptions instead of ModeSelector ── */} + {isForecastingTask ? ( + <> + + {/* Show DatasetSelector only when forecasting uses dataset mode */} + {forecastMode === "dataset" && ( + + )} + ) : ( - + <> + {/* ── Standard tasks: original mode selector ── */} + + {predictionMode === "dataset" ? ( + + ) : ( + + )} + )} ) : ( diff --git a/DashAI/front/src/components/predictions/PredictionsTable.jsx b/DashAI/front/src/components/predictions/PredictionsTable.jsx index b508c3548..1d0f24145 100644 --- a/DashAI/front/src/components/predictions/PredictionsTable.jsx +++ b/DashAI/front/src/components/predictions/PredictionsTable.jsx @@ -119,7 +119,7 @@ function PredictionsTable({ predictions, onItemClick, onItemDelete }) { flex: 1, minWidth: 100, renderCell: (params) => { - const statusText = getPredictionStatus(params?.row?.status); + const statusText = getPredictionStatus(params?.row?.status, t); return ( {statusText} diff --git a/DashAI/front/src/components/predictions/ResultsTable.jsx b/DashAI/front/src/components/predictions/ResultsTable.jsx index 753fc5118..b7661befd 100644 --- a/DashAI/front/src/components/predictions/ResultsTable.jsx +++ b/DashAI/front/src/components/predictions/ResultsTable.jsx @@ -12,7 +12,6 @@ import { CircularProgress, } from "@mui/material"; import { useTheme } from "@mui/material/styles"; -import { getPredictionStatus } from "../../utils/predictionStatus"; import DatasetTable from "../notebooks/dataset/DatasetTable"; import { getDatasetFile } from "../../api/datasets"; import { useTranslation } from "react-i18next"; @@ -22,7 +21,7 @@ const RUNNING_STATUSES = [1, 2]; // Delivered or Started function ResultsTable({ selectedPrediction }) { const theme = useTheme(); const [loadingExecution, setLoadingExecution] = useState( - RUNNING_STATUSES.includes(getPredictionStatus(selectedPrediction?.status)), + RUNNING_STATUSES.includes(selectedPrediction?.status), ); const { t } = useTranslation(["prediction"]); @@ -40,11 +39,7 @@ function ResultsTable({ selectedPrediction }) { useEffect(() => { if (!selectedPrediction) return; - setLoadingExecution( - RUNNING_STATUSES.includes( - getPredictionStatus(selectedPrediction?.status), - ), - ); + setLoadingExecution(RUNNING_STATUSES.includes(selectedPrediction?.status)); }, [selectedPrediction]); return ( diff --git a/DashAI/front/src/pages/results/Results.jsx b/DashAI/front/src/pages/results/Results.jsx index 4636a0eaf..71496a8b9 100644 --- a/DashAI/front/src/pages/results/Results.jsx +++ b/DashAI/front/src/pages/results/Results.jsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from "react"; +import React, { useState } from "react"; import PropTypes from "prop-types"; import { IconButton } from "@mui/material"; import VisibilityIcon from "@mui/icons-material/Visibility"; @@ -7,7 +7,7 @@ import TimestampWrapper from "../../components/shared/TimestampWrapper"; import { TIMESTAMP_KEYS } from "../../constants/timestamp"; import { useTourContext } from "../../components/tour/TourProvider"; -function Results({ experiment, handleDeleteExperiment }) { +function Results({ experiment }) { const [open, setOpen] = useState(false); const [showTable, setShowTable] = useState(true); const tourContext = useTourContext(); @@ -49,7 +49,6 @@ function Results({ experiment, handleDeleteExperiment }) { showTable={showTable} handleShowTable={handleShowTable} handleShowGraphs={handleShowGraphs} - handleDeleteExperiment={handleDeleteExperiment} /> )} diff --git a/DashAI/front/src/pages/results/components/LiveMetricsChart.jsx b/DashAI/front/src/pages/results/components/LiveMetricsChart.jsx index 64c759d47..b5710d0f1 100644 --- a/DashAI/front/src/pages/results/components/LiveMetricsChart.jsx +++ b/DashAI/front/src/pages/results/components/LiveMetricsChart.jsx @@ -22,6 +22,10 @@ import { import { useEffect, useRef, useState } from "react"; import { useTranslation } from "react-i18next"; import { getModelSessionById } from "../../../api/modelSession"; +import { + formatScalarMetricsForChart, + isFiniteMetricValue, +} from "../../../utils/metricUtils"; export function LiveMetricsChart({ run }) { const { t } = useTranslation(["models", "common"]); @@ -50,20 +54,9 @@ export function LiveMetricsChart({ run }) { setData((prev) => { const next = structuredClone(prev); - // Convert old format to new format if needed - const formattedTestMetrics = {}; - for (const metricName in run.test_metrics) { - const value = run.test_metrics[metricName]; - // Check if it's already in new format (array of objects) - if (Array.isArray(value)) { - formattedTestMetrics[metricName] = value; - } else { - // Convert old format (single value) to new format - formattedTestMetrics[metricName] = [ - { step: 1, value: value, timestamp: new Date().toISOString() }, - ]; - } - } + const formattedTestMetrics = formatScalarMetricsForChart( + run.test_metrics, + ); next.TEST = { TRIAL: formattedTestMetrics, @@ -121,20 +114,9 @@ export function LiveMetricsChart({ run }) { setData((prev) => { const next = structuredClone(prev); - // Convert old format to new format if needed - const formattedTestMetrics = {}; - for (const metricName in run.test_metrics) { - const value = run.test_metrics[metricName]; - // Check if it's already in new format (array of objects) - if (Array.isArray(value)) { - formattedTestMetrics[metricName] = value; - } else { - // Convert old format (single value) to new format - formattedTestMetrics[metricName] = [ - { step: 1, value: value, timestamp: new Date().toISOString() }, - ]; - } - } + const formattedTestMetrics = formatScalarMetricsForChart( + run.test_metrics, + ); next.TEST = { TRIAL: formattedTestMetrics, @@ -184,7 +166,12 @@ export function LiveMetricsChart({ run }) { const allowed = availableMetrics[split] ?? []; const filteredMetrics = Object.fromEntries( - Object.entries(metrics).filter(([name]) => allowed.includes(name)), + Object.entries(metrics).filter( + ([name, metricValues]) => + allowed.includes(name) && + Array.isArray(metricValues) && + metricValues.some((point) => isFiniteMetricValue(point?.value)), + ), ); // Transform new data structure to chart format diff --git a/DashAI/front/src/pages/results/components/MetricsCard.jsx b/DashAI/front/src/pages/results/components/MetricsCard.jsx index 21ecd6737..2c6a204e9 100644 --- a/DashAI/front/src/pages/results/components/MetricsCard.jsx +++ b/DashAI/front/src/pages/results/components/MetricsCard.jsx @@ -1,9 +1,11 @@ import React from "react"; import { Box, Divider, Paper, Typography } from "@mui/material"; import { useTranslation } from "react-i18next"; +import { getNumericMetricEntries } from "../../../utils/metricUtils"; export default function MetricsCard({ title, metrics }) { const { t } = useTranslation(["models"]); + const numericMetrics = getNumericMetricEntries(metrics); return ( @@ -11,8 +13,8 @@ export default function MetricsCard({ title, metrics }) { {title} - {metrics && Object.keys(metrics).length > 0 ? ( - Object.entries(metrics).map(([key, value]) => ( + {numericMetrics.length > 0 ? ( + numericMetrics.map(([key, value]) => ( { try { - const response = await enqueueRunnerJobRequest(runId); + const response = await enqueueRunnerJobRequest( + runId, + experiment.task_name, + ); if (response && response.id) { setTrackedJobIds((prev) => new Set(prev).add(response.id)); @@ -226,7 +228,10 @@ function ResultsDialogLayout({ const initialUpdatedRun = await resetRunById(run.id); // Enqueue the run - const response = await enqueueRunnerJobRequest(run.id); + const response = await enqueueRunnerJobRequest( + run.id, + experiment.task_name, + ); if (!response || !response.id) { enqueueSnackbar( @@ -403,22 +408,25 @@ function ResultsDialogLayout({ }} onConfirm={async () => { try { + await deleteRun(runToDelete); setRuns((prevRuns) => prevRuns.filter((run) => run.id !== runToDelete), ); - if (runs.length === 1) { - handleDeleteExperiment(experiment.id); - } else { - await deleteRun(runToDelete); - } enqueueSnackbar(t("models:message.runDeletedSuccessfully"), { variant: "success", }); } catch (error) { console.error("Error deleting run:", error); - enqueueSnackbar(t("models:error.errorDeletingRun"), { - variant: "error", - }); + const detail = + error?.response?.data?.detail || error?.message || ""; + enqueueSnackbar( + detail + ? `${t("models:error.errorDeletingRun")}: ${detail}` + : t("models:error.errorDeletingRun"), + { + variant: "error", + }, + ); } finally { setOpenDeleteModal(false); setRunToDelete(null); diff --git a/DashAI/front/src/pages/results/components/ResultsTable.jsx b/DashAI/front/src/pages/results/components/ResultsTable.jsx index 56d25c23e..86e04280d 100644 --- a/DashAI/front/src/pages/results/components/ResultsTable.jsx +++ b/DashAI/front/src/pages/results/components/ResultsTable.jsx @@ -59,7 +59,7 @@ function ResultsTable({ }; const handleExplainer = (run) => { - navigate(`../app/explainers/runs/${run.id}`, { + navigate(`/app/explainers/runs/${run.id}`, { state: { modelName: run.name, taskName: experiment.task_name, diff --git a/DashAI/front/src/pages/results/constants/extractColumns.jsx b/DashAI/front/src/pages/results/constants/extractColumns.jsx index 1f4a57d49..b4b0b2a6f 100644 --- a/DashAI/front/src/pages/results/constants/extractColumns.jsx +++ b/DashAI/front/src/pages/results/constants/extractColumns.jsx @@ -28,8 +28,10 @@ export const extractColumns = ( // Not Started, Delivered, Started return "-"; - return row.test_metrics[metric.name] !== undefined - ? Number(row.test_metrics[metric.name]).toFixed(2) + const testMetrics = row.test_metrics ?? {}; + + return testMetrics[metric.name] !== undefined + ? Number(testMetrics[metric.name]).toFixed(2) : "-"; }, })); diff --git a/DashAI/front/src/utils/i18n/locales/en/experiments.json b/DashAI/front/src/utils/i18n/locales/en/experiments.json index 359dfc2d4..32ccde5a5 100644 --- a/DashAI/front/src/utils/i18n/locales/en/experiments.json +++ b/DashAI/front/src/utils/i18n/locales/en/experiments.json @@ -34,6 +34,7 @@ "addOptimizersToExperiment": "Add Optimizers to your Experiment", "columnsInvalidRequirements": "Current Input and Output columns do not match {{taskName}} requirements", "columnsValidRequirements": "Current Input and Output columns match {{taskName}} requirements", + "datasetInvalidForTask": "The selected dataset and columns are not valid for {{taskName}}", "configureExperimentsSubtitle": "Configure experiments to train models.", "configureModels": "Configure models", "configureOptimizer": "Configure hyperparameter optimization", @@ -75,10 +76,12 @@ "startTime": "Start Time", "stratify": "Stratify", "stratifyDescription": "Defines whether the data will be proportionally separated according to the distribution of classes in each set. Shuffle must be true to stratify the data.", + "forecastingValidationHint": "For forecasting, one input column must contain date or time values and the output column must be numeric. The backend label 'Value' is generic and can correspond to columns shown here as Text, Float, Integer, Date or Timestamp.", "useManualSplittingBySpecifyingRowIndexes": "Use manual splitting by specifying the row indexes of each subset", "usePredefinedSplitsFromDataset": "Use predefined splits from dataset", "usePredefinedSplitsFromDatasetNotAvailable": "Use predefined splits from dataset (not available)", - "useRandomRowsBySpecifyingPortion": "Use random rows by specifying which portion of the dataset you want to use for each subset" + "useRandomRowsBySpecifyingPortion": "Use random rows by specifying which portion of the dataset you want to use for each subset", + "validationDetails": "Validation details" }, "message": { "confirmDeleteRun": "Are you sure you want to delete this run? This action cannot be undone.", diff --git a/DashAI/front/src/utils/i18n/locales/es/experiments.json b/DashAI/front/src/utils/i18n/locales/es/experiments.json index d1ea6cace..63ee1e18c 100644 --- a/DashAI/front/src/utils/i18n/locales/es/experiments.json +++ b/DashAI/front/src/utils/i18n/locales/es/experiments.json @@ -34,6 +34,7 @@ "addOptimizersToExperiment": "Agregar Optimizadores a su Experimento", "columnsInvalidRequirements": "Las columnas de Entrada y Salida actuales no coinciden con los requisitos de {{taskName}}", "columnsValidRequirements": "Las columnas de Entrada y Salida actuales coinciden con los requisitos de {{taskName}}", + "datasetInvalidForTask": "El dataset y las columnas seleccionadas no son válidos para {{taskName}}", "configureExperimentsSubtitle": "Configure experimentos para entrenar modelos.", "configureModels": "Configurar modelos", "configureOptimizer": "Configurar optimización de hiperparámetros", @@ -75,10 +76,12 @@ "startTime": "Hora de Inicio", "stratify": "Estratificar", "stratifyDescription": "Define si los datos se separarán proporcionalmente según la distribución de clases en cada conjunto. Shuffle debe ser verdadero para estratificar los datos.", + "forecastingValidationHint": "Para forecasting, una columna de entrada debe contener fechas u horas y la columna de salida debe ser numérica. La etiqueta del backend 'Value' es genérica y puede corresponder aquí a columnas mostradas como Text, Float, Integer, Date o Timestamp.", "useManualSplittingBySpecifyingRowIndexes": "Usar división manual especificando los índices de fila de cada subconjunto", "usePredefinedSplitsFromDataset": "Usar divisiones predefinidas del dataset", "usePredefinedSplitsFromDatasetNotAvailable": "Usar divisiones predefinidas del dataset (no disponible)", - "useRandomRowsBySpecifyingPortion": "Usar filas aleatorias especificando qué porción del dataset desea usar para cada subconjunto" + "useRandomRowsBySpecifyingPortion": "Usar filas aleatorias especificando qué porción del dataset desea usar para cada subconjunto", + "validationDetails": "Detalle de validación" }, "message": { "confirmDeleteRun": "¿Está seguro de que desea eliminar esta ejecución? Esta acción no se puede deshacer.", diff --git a/DashAI/front/src/utils/metricUtils.js b/DashAI/front/src/utils/metricUtils.js new file mode 100644 index 000000000..07dc0fb97 --- /dev/null +++ b/DashAI/front/src/utils/metricUtils.js @@ -0,0 +1,21 @@ +export const isFiniteMetricValue = (value) => + typeof value === "number" && Number.isFinite(value); + +export const getNumericMetrics = (metrics = {}) => + Object.fromEntries( + Object.entries(metrics).filter(([, value]) => isFiniteMetricValue(value)), + ); + +export const getNumericMetricEntries = (metrics = {}) => + Object.entries(getNumericMetrics(metrics)); + +export const formatScalarMetricsForChart = (metrics = {}) => { + const now = new Date().toISOString(); + + return Object.fromEntries( + getNumericMetricEntries(metrics).map(([metricName, value]) => [ + metricName, + [{ step: 1, value, timestamp: now }], + ]), + ); +}; diff --git a/reproduce_issue.py b/reproduce_issue.py new file mode 100644 index 000000000..d67bf9bfb --- /dev/null +++ b/reproduce_issue.py @@ -0,0 +1,86 @@ +import numpy as np +import pandas as pd + +from DashAI.back.dataloaders.classes.dashai_dataset import to_dashai_dataset +from DashAI.back.models.forecasting.prophet_model import ProphetModel +from DashAI.back.models.forecasting.sklearn_multistep_forecaster import ( + SklearnMultiStepForecaster, +) + + +def create_dummy_data(): + dates = pd.date_range(start="2023-01-01", periods=100, freq="D") + values = np.sin(np.linspace(0, 10, 100)) + np.random.normal(0, 0.1, 100) + data = pd.DataFrame({"ds": dates, "y": values}) + return to_dashai_dataset(data) + + +def test_sklearn_forecaster(): + print("\nTesting SklearnMultiStepForecaster...") + dataset = create_dummy_data() + + # Create x (features) and y (target) datasets + x_df = dataset.to_pandas() + y_df = x_df[["y"]] + y_dataset = to_dashai_dataset(y_df) + + model = SklearnMultiStepForecaster(window_size=5) + + # Metadata usually comes from task, mocking it here + temporal_metadata = {"timestamp_col": "ds", "target_col": "y", "frequency": "D"} + + model.fit(dataset, y_dataset, temporal_metadata=temporal_metadata) + + # Test 1: predict with periods (standard) + try: + pred = model.predict(periods=5) + print(f"✅ predict(periods=5) successful. Shape: {pred.shape}") + except Exception as e: + print(f"❌ predict(periods=5) failed: {e}") + + # Test 2: predict with horizon (alias) + try: + pred = model.predict(horizon=5) + print(f"✅ predict(horizon=5) successful (alias). Shape: {pred.shape}") + except Exception as e: + print(f"❌ predict(horizon=5) failed: {e}") + + +def test_prophet_model(): + print("\nTesting ProphetModel...") + dataset = create_dummy_data() + + # Create x (features) and y (target) datasets + x_df = dataset.to_pandas() + y_df = x_df[["y"]] + y_dataset = to_dashai_dataset(y_df) + + model = ProphetModel() + + # Metadata usually comes from task, mocking it here + temporal_metadata = {"timestamp_col": "ds", "target_col": "y", "frequency": "D"} + + model.fit(dataset, y_dataset, temporal_metadata=temporal_metadata) + + # Test 1: predict with periods (new standard) + try: + pred = model.predict(periods=5) + print(f"✅ predict(periods=5) successful. Shape: {pred.shape}") + except Exception as e: + print(f"❌ predict(periods=5) failed: {e}") + + # Test 2: predict with horizon (backward compatibility) + try: + pred = model.predict(horizon=5) + print(f"✅ predict(horizon=5) successful (compat). Shape: {pred.shape}") + except Exception as e: + print(f"❌ predict(horizon=5) failed: {e}") + + +if __name__ == "__main__": + try: + test_sklearn_forecaster() + test_prophet_model() + print("\nAll tests completed.") + except Exception as e: + print(f"\nGlobal error: {e}") diff --git a/reproduce_length_mismatch.py b/reproduce_length_mismatch.py new file mode 100644 index 000000000..8c9887348 --- /dev/null +++ b/reproduce_length_mismatch.py @@ -0,0 +1,86 @@ +import numpy as np +import pandas as pd + +from DashAI.back.dataloaders.classes.dashai_dataset import to_dashai_dataset +from DashAI.back.explainability.explainers.forecasting_explainers import ( + forecast_decomposition, +) +from DashAI.back.models.forecasting.sklearn_multistep_forecaster import ( + SklearnMultiStepForecaster, +) + + +def create_dummy_data(): + dates = pd.date_range(start="2023-01-01", periods=100, freq="D") + values = np.sin(np.linspace(0, 10, 100)) + np.random.normal(0, 0.1, 100) + data = pd.DataFrame({"ds": dates, "y": values}) + return to_dashai_dataset(data) + + +def reproduce_error(): + print("\nReproducing length mismatch error...") + dataset = create_dummy_data() + + # Create x (features) and y (target) datasets + x_df = dataset.to_pandas() + y_df = x_df[["y"]] + y_dataset = to_dashai_dataset(y_df) + + # Train with default horizon (which is 1 usually, or from fit_params) + # If we don't specify horizon, it defaults to 1 in SklearnMultiStepForecaster + model = SklearnMultiStepForecaster(window_size=5, forecast_strategy="direct") + + temporal_metadata = {"timestamp_col": "ds", "target_col": "y", "frequency": "D"} + + print("Training model (default horizon=1)...") + model.fit(dataset, y_dataset, temporal_metadata=temporal_metadata) + + # Create a dataset that extends beyond the training data + # Training was 100 days from 2023-01-01 (ends ~2023-04-10) + # Let's create a "validation" dataset that goes up to 2023-06-01 + dates_extended = pd.date_range(start="2023-01-01", end="2023-06-01", freq="D") + values_extended = np.sin( + np.linspace(0, 15, len(dates_extended)) + ) + np.random.normal(0, 0.1, len(dates_extended)) + df_extended = pd.DataFrame({"ds": dates_extended, "y": values_extended}) + dataset_extended = to_dashai_dataset(df_extended) + + # Create x (features) and y (target) datasets for extended data + x_df_ext = dataset_extended.to_pandas() + y_df_ext = x_df_ext[["y"]] + y_dataset_ext = to_dashai_dataset(y_df_ext) + + print(f"Extended dataset ends at: {dates_extended.max()}") + + # Now try to explain with horizon 30 using the EXTENDED dataset + print("Attempting explain with horizon=30 using extended dataset...") + explainer = forecast_decomposition.ForecastDecomposition(model, horizon=30) + + try: + # explain() needs a dataset tuple (x, y) + explanation = explainer.explain((dataset_extended, y_dataset_ext)) + print("✅ Explanation successful!") + print(f"Explanation keys: {explanation.keys()}") + print(f"Explanation ds length: {len(explanation['ds'])}") + ds_series = pd.Series(explanation["ds"]) + print(f"Explanation start date: {ds_series.min()}") + print(f"Explanation end date: {ds_series.max()}") + + # Verify start date is after extended dataset + expected_start = dates_extended.max() + pd.Timedelta(days=1) + if ds_series.min() == expected_start: + print("✅ Start date matches expected (from extended dataset)") + else: + print( + f"❌ Start mismatch! Expected {expected_start}, got {ds_series.min()}" + ) + + except Exception as e: + print(f"❌ Explanation failed: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + reproduce_error() diff --git a/reproduce_uncertainty_error.py b/reproduce_uncertainty_error.py new file mode 100644 index 000000000..9df1408f9 --- /dev/null +++ b/reproduce_uncertainty_error.py @@ -0,0 +1,58 @@ +import numpy as np +import pandas as pd + +from DashAI.back.dataloaders.classes.dashai_dataset import to_dashai_dataset +from DashAI.back.explainability.explainers.forecasting_explainers import ( + forecast_uncertainty, +) +from DashAI.back.models.forecasting.sklearn_multistep_forecaster import ( + SklearnMultiStepForecaster, +) + + +def reproduce_error(): + # Create dummy data + dates = pd.date_range(start="2023-01-01", periods=100, freq="D") + values = np.sin(np.linspace(0, 10, 100)) + np.random.normal(0, 0.1, 100) + data = pd.DataFrame({"ds": dates, "y": values}) + + # Create DashAI datasets + dataset = to_dashai_dataset(data) + + # Create x (features) and y (target) datasets + # For this simple case, we'll just use the same dataset for both structure, + # but in reality x would be features and y would be target + x_df = data.drop(columns=["y"]) + y_df = data[["y"]] + + to_dashai_dataset(x_df) # Not used directly, validation only + y_dataset = to_dashai_dataset(y_df) + + # Initialize model + model = SklearnMultiStepForecaster() + + # Fit model + print("Training model...") + temporal_metadata = {"timestamp_col": "ds", "target_col": "y", "frequency": "D"} + model.fit(dataset, y_dataset, temporal_metadata=temporal_metadata) + + # Initialize explainer + print("Initializing ForecastUncertainty explainer...") + explainer = forecast_uncertainty.ForecastUncertainty(model, horizon=30) + + # Explain + print("Attempting to explain...") + try: + explanation = explainer.explain((dataset, y_dataset)) + print("✅ Explanation successful!") + print(f"Explanation keys: {explanation.keys()}") + print(f"Explanation ds length: {len(explanation['ds'])}") + except Exception as e: + print(f"❌ Explanation failed: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + reproduce_error() diff --git a/requirements.txt b/requirements.txt index ca025d1df..2bdb4dfda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,4 +45,6 @@ greenery==3.2 xlrd filetype torchmetrics +prophet +statsmodels pywebview diff --git a/tests/back/dataloaders/test_dashai_dataset.py b/tests/back/dataloaders/test_dashai_dataset.py index cfefa5228..ef22f4f84 100644 --- a/tests/back/dataloaders/test_dashai_dataset.py +++ b/tests/back/dataloaders/test_dashai_dataset.py @@ -6,6 +6,7 @@ from typing import List import datasets +import pandas as pd import pytest from datasets import DatasetDict from pyarrow.lib import ArrowInvalid @@ -16,6 +17,7 @@ DashAIDataset, get_column_names_from_indexes, load_dataset, + prepare_for_forecasting_experiment, save_dataset, select_columns, split_dataset, @@ -313,6 +315,48 @@ def test_split_dataset( assert totals_rows == train_rows + test_rows + validation_rows +def test_prepare_for_forecasting_experiment_forces_temporal_split_for_random_type(): + """Forecasting should never silently fall back to random splitting.""" + dataframe = pd.DataFrame( + { + "ds": pd.date_range("2020-01-01", periods=100, freq="D"), + "y": list(range(100)), + } + ) + dataset = to_dashai_dataset(dataframe) + + prepared_dataset, split_indices = prepare_for_forecasting_experiment( + dataset=dataset, + splits={ + "splitType": "random", + "train": 0.6, + "validation": 0.2, + "test": 0.2, + }, + timestamp_col="ds", + output_columns=["y"], + ) + + train_df = prepared_dataset["train"].to_pandas() + validation_df = prepared_dataset["validation"].to_pandas() + test_df = prepared_dataset["test"].to_pandas() + + assert len(train_df) == 60 + assert len(validation_df) == 20 + assert len(test_df) == 20 + + assert train_df["ds"].is_monotonic_increasing + assert validation_df["ds"].is_monotonic_increasing + assert test_df["ds"].is_monotonic_increasing + + assert train_df["ds"].max() < validation_df["ds"].min() + assert validation_df["ds"].max() < test_df["ds"].min() + + assert split_indices["train_indexes"] == list(range(60)) + assert split_indices["val_indexes"] == list(range(60, 80)) + assert split_indices["test_indexes"] == list(range(80, 100)) + + # ---------------------------------------------------------------------------- # fixture: split dashai datasetdict diff --git a/tests/back/models/test_forecasting_models.py b/tests/back/models/test_forecasting_models.py new file mode 100644 index 000000000..6cf99056f --- /dev/null +++ b/tests/back/models/test_forecasting_models.py @@ -0,0 +1,473 @@ +import os +import tempfile + +import numpy as np +import pandas as pd +import pytest + +from DashAI.back.dataloaders.classes.dashai_dataset import to_dashai_dataset +from DashAI.back.dependencies.registry import ComponentRegistry +from DashAI.back.models.forecasting.prophet_model import ( + ProphetModel, + _patch_prophet_regressor_column_matrix, +) +from DashAI.back.models.forecasting.sklearn_multistep_forecaster import ( + SklearnMultiStepForecaster, +) +from DashAI.back.models.forecasting.statsmodels_arima_model import ( + StatsmodelsARIMAModel, +) +from DashAI.back.models.model_factory import ModelFactory + +# --------------------------------------------------------------------------- +# Registry fixture (required by conftest client fixture) +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True, name="test_registry") +def setup_test_registry(client, monkeypatch: pytest.MonkeyPatch): + container = client.app.container + + test_registry = ComponentRegistry( + initial_components=[ + SklearnMultiStepForecaster, + ] + ) + + monkeypatch.setitem( + container._services, + "component_registry", + test_registry, + ) + return test_registry + + +# --------------------------------------------------------------------------- +# Shared data fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def daily_series(): + """200 days of synthetic daily temperature data (sinusoidal + noise).""" + n = 200 + dates = pd.date_range("2023-01-01", periods=n, freq="D") + np.random.seed(42) + values = 15 + 10 * np.sin(np.linspace(0, 4 * np.pi, n)) + np.random.randn(n) + + x_df = pd.DataFrame({"date": dates.astype(str)}) + y_df = pd.DataFrame({"temp": values}) + + metadata = { + "timestamp_col": "date", + "target_col": "temp", + "exog_cols": [], + "frequency": "D", + } + + return { + "x": to_dashai_dataset(x_df), + "y": to_dashai_dataset(y_df), + "x_df": x_df, + "y_df": y_df, + "dates": dates, + "values": values, + "metadata": metadata, + } + + +@pytest.fixture(scope="module") +def small_series(): + """Small dataset (12 rows) to test auto window-size adjustment.""" + n = 12 + dates = pd.date_range("2023-01-01", periods=n, freq="D") + values = np.arange(n, dtype=float) + + x_df = pd.DataFrame({"date": dates.astype(str)}) + y_df = pd.DataFrame({"temp": values}) + + metadata = { + "timestamp_col": "date", + "target_col": "temp", + "exog_cols": [], + "frequency": "D", + } + + return { + "x": to_dashai_dataset(x_df), + "y": to_dashai_dataset(y_df), + "metadata": metadata, + } + + +# --------------------------------------------------------------------------- +# Original tests (kept intact) +# --------------------------------------------------------------------------- + + +def test_forecasting_model_factory_can_instantiate_model(): + factory = ModelFactory( + SklearnMultiStepForecaster, + { + "base_estimator": "linear", + "window_size": 3, + "forecast_strategy": "direct", + }, + ) + + assert isinstance(factory.model, SklearnMultiStepForecaster) + + +def test_prophet_patch_preserves_weekly_periodicity(): + Prophet = _patch_prophet_regressor_column_matrix() + + dates = pd.Series(pd.date_range("2024-01-01", periods=14, freq="D")) + features = Prophet.fourier_series(dates, period=7, series_order=3) + + assert np.allclose(features[0], features[7], atol=1e-9) + assert np.allclose(features[1], features[8], atol=1e-9) + + +# --------------------------------------------------------------------------- +# SklearnMultiStepForecaster — estimator variants +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("estimator", ["linear", "ridge", "random_forest"]) +def test_sklearn_estimators_fit_and_predict_outsample(daily_series, estimator): + """All three base estimators should fit and produce out-of-sample forecasts.""" + model = SklearnMultiStepForecaster( + base_estimator=estimator, window_size=5, forecast_strategy="recursive" + ) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=10, + ) + + preds = model.predict(periods=10) + assert isinstance(preds, np.ndarray) + assert len(preds) == 10 + assert not np.all(np.isnan(preds)) + + +# --------------------------------------------------------------------------- +# SklearnMultiStepForecaster — strategies +# --------------------------------------------------------------------------- + + +def test_sklearn_direct_strategy_produces_forecast(daily_series): + model = SklearnMultiStepForecaster( + base_estimator="linear", window_size=5, forecast_strategy="direct" + ) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=7, + ) + preds = model.predict(periods=7) + assert len(preds) == 7 + + +def test_sklearn_recursive_strategy_produces_forecast(daily_series): + model = SklearnMultiStepForecaster( + base_estimator="linear", window_size=5, forecast_strategy="recursive" + ) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=7, + ) + preds = model.predict(periods=7) + assert len(preds) == 7 + + +# --------------------------------------------------------------------------- +# SklearnMultiStepForecaster — in-sample predictions +# --------------------------------------------------------------------------- + + +def test_sklearn_insample_predictions_shape(daily_series): + """In-sample predictions should have the same length as the input slice.""" + model = SklearnMultiStepForecaster( + base_estimator="linear", window_size=5, forecast_strategy="recursive" + ) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=10, + ) + + # Use a slice of the training x_df as in-sample input + x_slice = daily_series["x_df"].iloc[10:30].copy() + preds = model.predict(x_pred=x_slice) + assert len(preds) == len(x_slice) + + +# --------------------------------------------------------------------------- +# SklearnMultiStepForecaster — auto window-size adjustment +# --------------------------------------------------------------------------- + + +def test_sklearn_auto_adjusts_window_size_for_small_dataset(small_series): + """Model should not raise when window_size > available samples.""" + model = SklearnMultiStepForecaster( + base_estimator="linear", window_size=20, forecast_strategy="recursive" + ) + # Should not raise — window_size is auto-adjusted internally + model.fit( + small_series["x"], + small_series["y"], + temporal_metadata=small_series["metadata"], + horizon=2, + ) + assert model.window_size < 20 # was reduced + + +# --------------------------------------------------------------------------- +# SklearnMultiStepForecaster — forecast uncertainty & components +# --------------------------------------------------------------------------- + + +def test_sklearn_forecast_uncertainty_columns(daily_series): + model = SklearnMultiStepForecaster( + base_estimator="linear", window_size=5, forecast_strategy="recursive" + ) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=10, + ) + + result = model.get_forecast_uncertainty(horizon=10) + assert set(result.columns) >= {"ds", "yhat", "yhat_lower", "yhat_upper"} + assert len(result) == 10 + assert (result["yhat_upper"] >= result["yhat_lower"]).all() + + +def test_sklearn_forecast_components_columns(daily_series): + model = SklearnMultiStepForecaster( + base_estimator="linear", window_size=5, forecast_strategy="recursive" + ) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=10, + ) + + result = model.get_forecast_components(horizon=10) + assert "ds" in result.columns + assert "trend" in result.columns + assert "residual" in result.columns + assert len(result) == 10 + + +# --------------------------------------------------------------------------- +# SklearnMultiStepForecaster — save / load +# --------------------------------------------------------------------------- + + +def test_sklearn_save_and_load_preserves_predictions(daily_series): + model = SklearnMultiStepForecaster( + base_estimator="linear", window_size=5, forecast_strategy="recursive" + ) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=5, + ) + preds_before = model.predict(periods=5) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "sklearn_model.pkl") + model.save(path) + + loaded = SklearnMultiStepForecaster() + loaded.load(path) + + preds_after = loaded.predict(periods=5) + np.testing.assert_array_almost_equal(preds_before, preds_after) + + +# --------------------------------------------------------------------------- +# SklearnMultiStepForecaster — edge cases +# --------------------------------------------------------------------------- + + +def test_sklearn_predict_before_fit_raises(): + model = SklearnMultiStepForecaster() + with pytest.raises(ValueError, match="Model not fitted"): + model.predict(periods=5) + + +def test_sklearn_negative_periods_raises(daily_series): + model = SklearnMultiStepForecaster(window_size=5) + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + horizon=5, + ) + with pytest.raises(ValueError, match="Prediction horizon must be a positive"): + model.predict(periods=-1) + + +# --------------------------------------------------------------------------- +# StatsmodelsARIMAModel +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def fitted_arima(daily_series): + model = StatsmodelsARIMAModel(p=1, d=1, q=1, trend="n") + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + ) + return model + + +def test_arima_fit_stores_model(fitted_arima): + assert fitted_arima.model_fit is not None + + +def test_arima_outsample_forecast_shape(fitted_arima): + preds = fitted_arima.predict(periods=10) + assert isinstance(preds, np.ndarray) + assert len(preds) == 10 + assert not np.all(np.isnan(preds)) + + +def test_arima_insample_predict_shape(daily_series, fitted_arima): + x_slice = daily_series["x_df"].iloc[5:20].copy() + preds = fitted_arima.predict(x_pred=x_slice) + assert len(preds) == len(x_slice) + + +def test_arima_forecast_uncertainty_columns(fitted_arima): + result = fitted_arima.get_forecast_uncertainty(horizon=10) + assert set(result.columns) >= {"ds", "yhat", "yhat_lower", "yhat_upper"} + assert len(result) == 10 + assert (result["yhat_upper"] >= result["yhat_lower"]).all() + + +def test_arima_forecast_components_columns(fitted_arima): + result = fitted_arima.get_forecast_components(horizon=10) + assert "ds" in result.columns + assert "trend" in result.columns + assert "residual" in result.columns + assert len(result) == 10 + + +def test_arima_save_and_load_preserves_predictions(daily_series): + model = StatsmodelsARIMAModel(p=1, d=1, q=1, trend="n") + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + ) + preds_before = model.predict(periods=5) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "arima_model.pkl") + model.save(path) + + loaded = StatsmodelsARIMAModel() + loaded.load(path) + + preds_after = loaded.predict(periods=5) + np.testing.assert_array_almost_equal(preds_before, preds_after) + + +def test_arima_predict_before_fit_raises(): + model = StatsmodelsARIMAModel() + with pytest.raises(ValueError, match="not fitted"): + model.predict(periods=5) + + +@pytest.mark.parametrize("order", [(1, 0, 0), (0, 1, 1), (2, 1, 0)]) +def test_arima_different_orders_fit(daily_series, order): + """Various ARIMA orders should fit without error.""" + p, d, q = order + model = StatsmodelsARIMAModel(p=p, d=d, q=q, trend="n") + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + ) + assert model.model_fit is not None + + +# --------------------------------------------------------------------------- +# ProphetModel +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def fitted_prophet(daily_series): + model = ProphetModel() + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + ) + return model + + +def test_prophet_fit_stores_model(fitted_prophet): + assert fitted_prophet.model is not None + + +def test_prophet_outsample_forecast_shape(fitted_prophet): + preds = fitted_prophet.predict(periods=30) + assert isinstance(preds, np.ndarray) + assert len(preds) == 30 + assert not np.all(np.isnan(preds)) + + +def test_prophet_insample_predict_shape(daily_series, fitted_prophet): + x_slice = daily_series["x_df"].iloc[:20].copy() + preds = fitted_prophet.predict(x_pred=x_slice) + assert len(preds) == len(x_slice) + + +def test_prophet_forecast_uncertainty_columns(fitted_prophet): + result = fitted_prophet.get_forecast_uncertainty(horizon=14) + assert set(result.columns) >= {"ds", "yhat", "yhat_lower", "yhat_upper"} + assert len(result) == 14 + assert (result["yhat_upper"] >= result["yhat_lower"]).all() + + +def test_prophet_forecast_components_columns(fitted_prophet): + result = fitted_prophet.get_forecast_components(horizon=14) + assert "ds" in result.columns + assert "trend" in result.columns + assert len(result) == 14 + + +def test_prophet_save_and_load_preserves_predictions(daily_series): + model = ProphetModel() + model.fit( + daily_series["x"], + daily_series["y"], + temporal_metadata=daily_series["metadata"], + ) + preds_before = model.predict(periods=7) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "prophet_model.pkl") + model.save(path) + + loaded = ProphetModel() + loaded.load(path) + + preds_after = loaded.predict(periods=7) + np.testing.assert_array_almost_equal(preds_before, preds_after) diff --git a/tests/back/tasks/test_tasks.py b/tests/back/tasks/test_tasks.py index ad5847c1f..287df92ec 100644 --- a/tests/back/tasks/test_tasks.py +++ b/tests/back/tasks/test_tasks.py @@ -1,6 +1,7 @@ import os import pathlib +import pandas as pd import PIL import pytest from datasets import DatasetDict @@ -15,6 +16,7 @@ from DashAI.back.dataloaders.classes.json_dataloader import JSONDataLoader from DashAI.back.dependencies.database.models import ProcessData from DashAI.back.tasks.controlnet_task import ControlNetTask +from DashAI.back.tasks.forecasting_task import ForecastingTask from DashAI.back.tasks.tabular_classification_task import TabularClassificationTask from DashAI.back.tasks.text_classification_task import TextClassificationTask from DashAI.back.tasks.text_to_image_generation_task import TextToImageGenerationTask @@ -142,6 +144,31 @@ def test_get_tabular_class_task_metadata(): assert metadata["outputs_cardinality"] == 1 +def test_prepare_forecasting_task_accepts_singular_column_aliases(): + dataset = to_dashai_dataset( + pd.DataFrame( + { + "date": pd.date_range("2025-01-01", periods=5, freq="D").astype(str), + "temperature_C": [20.5, 21.0, 19.8, 22.1, 20.9], + } + ) + ) + forecasting_task = ForecastingTask() + + prepared = forecasting_task.prepare_for_task( + dataset=dataset, + input_columns=["date"], + output_columns=["temperature_C"], + ) + + temporal_metadata = forecasting_task.get_temporal_metadata() + + assert prepared.num_rows == 5 + assert temporal_metadata is not None + assert temporal_metadata["timestamp_col"] == "date" + assert temporal_metadata["target_col"] == "temperature_C" + + @pytest.fixture(scope="module", name="text_classification_dataset") def text_classification_dataset_fixture(): test_dataset_path = "tests/back/tasks/ImdbSentimentDatasetSmall.json"