Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions DashAI/back/initial_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@
from DashAI.back.metrics.translation.bleu import Bleu
from DashAI.back.metrics.translation.chrf import Chrf
from DashAI.back.metrics.translation.ter import Ter
from DashAI.back.models.cnn_image_classifier import CNNImageClassifier
from DashAI.back.models.efficientnet_b0_image_classifier import (
EfficientNetB0ImageClassifier,
)

# Models
from DashAI.back.models.hugging_face.albert_transformer import AlbertTransformer
Expand Down Expand Up @@ -203,7 +207,10 @@
XlmRobertaTransformer,
)
from DashAI.back.models.hugging_face.xlnet_transformer import XlnetTransformer
from DashAI.back.models.lenet5_image_classifier import LeNet5ImageClassifier
from DashAI.back.models.mlp_image_classifier import MLPImageClassifier
from DashAI.back.models.resnet18_image_classifier import ResNet18ImageClassifier
from DashAI.back.models.resnet50_image_classifier import ResNet50ImageClassifier
from DashAI.back.models.scikit_learn.adaboost_classifier import AdaBoostClassifier
from DashAI.back.models.scikit_learn.adaboost_regression import AdaBoostRegression
from DashAI.back.models.scikit_learn.bagging_classifier import BaggingClassifier
Expand Down Expand Up @@ -380,6 +387,11 @@ def get_initial_components():
XlmRobertaTransformer,
XlnetTransformer,
MLPImageClassifier,
CNNImageClassifier,
LeNet5ImageClassifier,
ResNet18ImageClassifier,
ResNet50ImageClassifier,
EfficientNetB0ImageClassifier,
# Dataloaders
ARFFDataLoader,
CSVDataLoader,
Expand Down
13 changes: 13 additions & 0 deletions DashAI/back/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base Model abstract class."""

import logging
import math
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Final, final

Expand All @@ -12,6 +14,8 @@
if TYPE_CHECKING:
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset

logger = logging.getLogger(__name__)


class BaseModel(ConfigObject, metaclass=ABCMeta):
"""Abstract base class for all machine learning models in DashAI.
Expand Down Expand Up @@ -277,6 +281,15 @@ def calculate_metrics(
results = {}
for metric in metrics:
score = metric.score(y_transformed, y_pred)
if not math.isfinite(score):
logger.warning(
"Metric %s returned a non-finite value (%s) for split %s "
"(e.g. only one class present in the split). Skipping.",
metric.__name__,
score,
split,
)
continue
results[metric.__name__] = score

# Save to database
Expand Down
Loading
Loading