Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b244eb7
Add TimeSeriesWindowConverter for transforming time series data into …
ivan-salas Oct 17, 2025
f1970cc
Merge branch 'develop' into feature/forecasting-task to add new chang…
ivan-salas Oct 17, 2025
a68d6cb
Implement MultiOutputRegression model and associated task; update met…
ivan-salas Oct 18, 2025
f6c79a2
feat(forecasting): soporte nativo de series de tiempo y Prophet en Da…
ivan-salas Oct 20, 2025
ed45fa9
fix: actualizar imports de sMAPE a SMAPE
ivan-salas Oct 20, 2025
4a6f371
feat: Add ExtendTimeSeriesConverter for forecasting preparation
ivan-salas Oct 20, 2025
5267d7d
fix: Add jsonable_encoder for timestamp serialization in dataset endp…
ivan-salas Oct 20, 2025
a0d2e39
Add forecasting explainers and base model classes
ivan-salas Oct 26, 2025
8453e22
Add ARIMA and SARIMAX forecasting models
ivan-salas Oct 26, 2025
29acf87
Register ARIMA and SARIMAX models in initial_components
ivan-salas Oct 26, 2025
b0eb8ed
Add ARIMA and SARIMAX imports to models __init__
ivan-salas Oct 26, 2025
2abba7c
Add SklearnMultiStepForecaster and remove MultiOutputRegressionTask
ivan-salas Oct 26, 2025
f4c7161
Fix ARIMA/SARIMAX default trend parameter
ivan-salas Oct 27, 2025
b5c5920
fix: Pass temporal_metadata to ForecastingTask models in model_job
ivan-salas Oct 27, 2025
0a910f3
feat: Add in-sample predictions support to SklearnMultiStepForecaster
ivan-salas Oct 27, 2025
64facab
fix: Use full training history for in-sample predictions
ivan-salas Oct 27, 2025
fb1b03c
fix: Filter NaN values before computing metrics
ivan-salas Oct 27, 2025
7de822e
Fix Prophet gaps handling and optimizer metric directions
ivan-salas Nov 3, 2025
1a76999
Remove test notebook from Git tracking (keep in gitignore)
ivan-salas Nov 3, 2025
53d60ff
feat: Add auto-generate forecast periods feature for predictions
ivan-salas Nov 3, 2025
32d54e3
feat: Add forecasting task support
ivan-salas Nov 30, 2025
80ae821
feat: Enhance forecasting capabilities in explainer modals and predic…
ivan-salas Dec 1, 2025
a5766d5
feat: Improve floating point comparison for dataset splits in tempora…
ivan-salas Mar 2, 2026
e41b60e
Merge remote-tracking branch 'upstream/develop' into feature/forecast…
ivan-salas Mar 10, 2026
81f56da
feat: Enhance column validation with error messaging for forecasting …
ivan-salas Mar 11, 2026
1068e47
fix: forecasting_job fixes
ivan-salas Mar 11, 2026
e934b1b
Enhance forecasting models and prediction handling
Mar 17, 2026
791d70e
fix: pre-commit mistakes
Mar 17, 2026
f9ff13d
add: prophet and statsmodels to requirements.
Mar 17, 2026
c00124b
feat: Add comprehensive tests for SklearnMultiStepForecaster and Stat…
Mar 17, 2026
95d9210
Merge upstream/develop into feature/forecasting-task
Mar 31, 2026
4946ee6
fix: correct import errors introduced during upstream merge
Mar 31, 2026
4a368b3
fix: quote DashAIDataset annotations in PredictJob method signatures
Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,6 @@ job_queue.db-wal

db.sqlite
trained_models/


test_forecasting_models.ipynb
186 changes: 186 additions & 0 deletions DashAI/back/api/api_v1/endpoints/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,189 @@ async def get_info(
return info


@router.get("/{dataset_id}/temporal-info")
@inject
async def get_temporal_info(
dataset_id: int,
timestamp_column: str,
session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]),
):
"""Get temporal information about a dataset for forecasting tasks.

This endpoint analyzes a timestamp column to detect frequency, date range,
and other temporal characteristics useful for time series forecasting.

Parameters
----------
dataset_id : int
ID of the dataset to analyze.
timestamp_column : str
Name of the column containing timestamps.

Returns
-------
dict
Dictionary with temporal information including:
- frequency_code: Short code (D, H, M, W, A, T)
- frequency_label: Human-readable label
- frequency_description: Detailed description
- start_date: First timestamp in the series
- end_date: Last timestamp in the series
- total_periods: Number of data points
- detected_gaps: Number of missing periods detected
"""
import os

import pandas as pd
import pyarrow.ipc as ipc

from DashAI.back.core.enums.status import DatasetStatus

with session_factory() as db:
try:
dataset = db.get(Dataset, dataset_id)
if not dataset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dataset not found",
)

if dataset.status != DatasetStatus.FINISHED:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Dataset is not in finished state",
)

# Load the dataset
dataset_path = f"{dataset.file_path}/dataset"
data_filepath = os.path.join(dataset_path, "data.arrow")

with pa.OSFile(data_filepath, "rb") as source:
reader = ipc.open_file(source)
table = reader.read_all()

data_frame = table.to_pandas()

if timestamp_column not in data_frame.columns:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Column '{timestamp_column}' not found in dataset",
)

# Convert to datetime
try:
timestamps = pd.to_datetime(data_frame[timestamp_column])
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot parse '{timestamp_column}' as datetime: {str(e)}",
) from e

# Sort and analyze
sorted_ts = timestamps.sort_values()
diffs = sorted_ts.diff().dropna()

if len(diffs) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Not enough data points to detect frequency",
)

# Get most common difference (mode)
mode_diff = (
diffs.mode().iloc[0] if len(diffs.mode()) > 0 else diffs.median()
)

# Frequency mapping with detailed info
frequency_map = {
"T": {
"code": "T",
"label": "Minutely",
"description": "Each row represents one minute",
"example": "e.g., 10:00, 10:01, 10:02...",
},
"H": {
"code": "H",
"label": "Hourly",
"description": "Each row represents one hour",
"example": "e.g., 10:00, 11:00, 12:00...",
},
"D": {
"code": "D",
"label": "Daily",
"description": "Each row represents one day",
"example": "e.g., Jan 1, Jan 2, Jan 3...",
},
"W": {
"code": "W",
"label": "Weekly",
"description": "Each row represents one week",
"example": "e.g., Week 1, Week 2, Week 3...",
},
"M": {
"code": "M",
"label": "Monthly",
"description": "Each row represents one month",
"example": "e.g., Jan, Feb, Mar...",
},
"A": {
"code": "A",
"label": "Yearly",
"description": "Each row represents one year",
"example": "e.g., 2022, 2023, 2024...",
},
}

# Detect frequency
if mode_diff >= pd.Timedelta(days=365):
freq_code = "A"
elif mode_diff >= pd.Timedelta(days=28):
freq_code = "M"
elif mode_diff >= pd.Timedelta(days=7):
freq_code = "W"
elif mode_diff >= pd.Timedelta(days=1):
freq_code = "D"
elif mode_diff >= pd.Timedelta(hours=1):
freq_code = "H"
else:
freq_code = "T"

freq_info = frequency_map[freq_code]

# Calculate average difference in human-readable format
avg_diff = diffs.mean()
if avg_diff >= pd.Timedelta(days=1):
avg_diff_str = f"{avg_diff.days} days"
elif avg_diff >= pd.Timedelta(hours=1):
avg_diff_str = f"{avg_diff.seconds // 3600} hours"
else:
avg_diff_str = f"{avg_diff.seconds // 60} minutes"

# Detect gaps (periods where diff is significantly larger than mode)
gap_threshold = mode_diff * 1.5
gaps = (diffs > gap_threshold).sum()

return {
"frequency_code": freq_info["code"],
"frequency_label": freq_info["label"],
"frequency_description": freq_info["description"],
"frequency_example": freq_info["example"],
"average_interval": avg_diff_str,
"start_date": sorted_ts.min().isoformat(),
"end_date": sorted_ts.max().isoformat(),
"total_periods": len(data_frame),
"detected_gaps": int(gaps),
"timestamp_column": timestamp_column,
}

except exc.SQLAlchemyError as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal database error",
) from e


@router.get("/{dataset_id}/model-sessions-exist")
@inject
async def get_model_sessions_exist(
Expand Down Expand Up @@ -1264,6 +1447,9 @@ async def get_dataset_file(
col: sliced_batch[col][j].as_py()
for col in sliced_batch.schema.names
}
# Use jsonable_encoder to handle Timestamp and other
# non-JSON-serializable types
row = jsonable_encoder(row)
rows.append(row)
rows_collected += 1
if rows_collected >= page_size:
Expand Down
18 changes: 17 additions & 1 deletion DashAI/back/api/api_v1/endpoints/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,24 @@ async def upload_global_explainer(
status_code=status.HTTP_404_NOT_FOUND, detail="Run not found"
)

# Check if name exists and append suffix if needed
base_name = params.name
counter = 1
new_name = base_name

while True:
existing = db.scalars(
select(GlobalExplainer).where(GlobalExplainer.name == new_name)
).first()

if not existing:
break

counter += 1
new_name = f"{base_name}_{counter}"

explainer = GlobalExplainer(
name=params.name,
name=new_name,
run_id=params.run_id,
explainer_name=params.explainer_name,
parameters=params.parameters,
Expand Down
2 changes: 2 additions & 0 deletions DashAI/back/api/api_v1/endpoints/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ async def get_job_details(


@router.post("/", status_code=status.HTTP_201_CREATED)
@router.post("/start", status_code=status.HTTP_201_CREATED)
@router.post("/start/", status_code=status.HTTP_201_CREATED)
@inject
async def enqueue_job(
request: Request,
Expand Down
9 changes: 9 additions & 0 deletions DashAI/back/api/api_v1/endpoints/model_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ async def validate_columns(
validation_response = {}

try:
# For ForecastingTask, validate BEFORE prepare_for_task
# (column names change after prepare, so validate with originals)
if params.task_name == "ForecastingTask":
task.validate_dataset_for_task(
dataset=minimal_dataset,
dataset_name=dataset.name,
input_columns=inputs_names,
output_columns=outputs_names,
)
task.prepare_for_task(
dataset=minimal_dataset,
input_columns=inputs_names,
Expand Down
15 changes: 15 additions & 0 deletions DashAI/back/api/api_v1/endpoints/predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import math
from typing import TYPE_CHECKING, Dict, List

from fastapi import APIRouter, Depends, Query, Request, status
Expand All @@ -21,6 +22,20 @@

from DashAI.back.dependencies.registry.component_registry import ComponentRegistry


def sanitize_for_json(value):
"""Convert NaN/Inf float values to None for JSON serialization."""
if isinstance(value, dict):
return {k: sanitize_for_json(v) for k, v in value.items()}
elif isinstance(value, list):
return [sanitize_for_json(item) for item in value]
elif isinstance(value, float):
if math.isnan(value) or math.isinf(value):
return None
return value
return value


logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

Expand Down
50 changes: 45 additions & 5 deletions DashAI/back/api/api_v1/endpoints/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,22 +327,62 @@ async def delete_run(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Run not found"
)

# Delete orphan-prone operations that are not linked by FK constraints.
global_explainers = (
db.query(GlobalExplainer).filter(GlobalExplainer.run_id == run_id).all()
)
for explainer in global_explainers:
if explainer.plot_path and os.path.exists(explainer.plot_path):
remove_path(explainer.plot_path)
if explainer.explanation_path and os.path.exists(
explainer.explanation_path
):
remove_path(explainer.explanation_path)
db.delete(explainer)

local_explainers = (
db.query(LocalExplainer).filter(LocalExplainer.run_id == run_id).all()
)
for explainer in local_explainers:
if explainer.plots_path and os.path.exists(explainer.plots_path):
remove_path(explainer.plots_path)
if explainer.explanation_path and os.path.exists(
explainer.explanation_path
):
remove_path(explainer.explanation_path)
db.delete(explainer)

for prediction in run.predictions:
if prediction.results_path and os.path.exists(prediction.results_path):
remove_path(prediction.results_path)

for path in [
run.run_path,
run.plot_history_path,
run.plot_slice_path,
run.plot_contour_path,
run.plot_importance_path,
]:
if path and os.path.exists(path):
remove_path(path)

db.delete(run)
if run.status == RunStatus.FINISHED:
os.remove(run.run_path)
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
except HTTPException:
raise
except exc.SQLAlchemyError as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal database error",
detail=f"Internal database error: {e}",
) from e
except OSError as e:
except (OSError, ValueError) as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete directory",
detail=f"Failed to delete run resources: {e}",
) from e


Expand Down
1 change: 1 addition & 0 deletions DashAI/back/api/api_v1/schemas/job_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class JobParams(BaseModel):

job_type: Literal[
"ModelJob",
"ForecastingJob",
"ExplainerJob",
"PredictJob",
"DatasetJob",
Expand Down
12 changes: 11 additions & 1 deletion DashAI/back/api/api_v1/schemas/predict_params.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from pydantic import BaseModel
from typing import Optional

from pydantic import BaseModel, Field


class PredictParams(BaseModel):
run_id: int
forecast_periods: Optional[int] = Field(
default=None,
description="Number of future periods to forecast (ForecastingTask only). "
"If provided, timestamps will be generated automatically from the last "
"training date.",
gt=0,
le=1000,
)


class RenameRequest(BaseModel):
Expand Down
Loading
Loading