-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathwebapp.py
More file actions
193 lines (162 loc) · 6.32 KB
/
Copy pathwebapp.py
File metadata and controls
193 lines (162 loc) · 6.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import matplotlib
matplotlib.use('Agg')
from flask import Flask, render_template, request
import base64
import io
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# Ensure these paths match your local project structure
from src.multiclass.models import KNNScratch, LinearSVM, OneVsRestClassifier
from src.multiclass.model_files.naive_bayes_pipeline import GaussianNB
from src.multiclass.model_files.decision_tree_pipeline import DecisionTreeScratch
from src.multiclass.model_files.logistic_regression import SoftmaxRegressionScratch
from src.multiclass.model_files.kernel_svm import MulticlassKernelSVM
app = Flask(__name__)
MODELS = {}
DATA = {}
PCA_TRANSFORMER = None
SCALER = None
def load_mnist():
"""Fetches MNIST and splits it into training and testing sets."""
print("Loading MNIST from OpenML...")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X = mnist.data.astype(np.float32) / 255.0
y = mnist.target.astype(int)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
return X_train, X_test, y_train, y_test
def initialize_app():
"""Initializes models and pre-calculates PCA features for the test set."""
global MODELS, DATA, PCA_TRANSFORMER, SCALER
if MODELS:
return
X_train_raw, X_test_raw, y_train, y_test = load_mnist()
# 1. Scaling
SCALER = StandardScaler()
X_train_scaled = SCALER.fit_transform(X_train_raw)
X_test_scaled = SCALER.transform(X_test_raw)
# 2. PCA Dimensionality Reduction
PCA_TRANSFORMER = PCA(n_components=100)
X_train_pca = PCA_TRANSFORMER.fit_transform(X_train_scaled)
X_test_pca = PCA_TRANSFORMER.transform(X_test_scaled)
# 3. Store Test Data
DATA["x_test_features"] = X_test_pca
DATA["y_test"] = y_test
DATA["sample_count"] = len(X_test_pca)
# Store raw pixels (0-1 range) for visual display
DATA["x_test_images"] = X_test_raw
# 4. Training subset (8000 samples for speed)
X_train_small = X_train_pca[:8000]
y_train_small = y_train[:8000]
# 5. Model Instantiation
MODELS = {
"KNN": KNNScratch(k=5, weighted=True, batch_size=100),
"SVM": OneVsRestClassifier(
estimator_class=LinearSVM,
learning_rate=0.01,
lambda_param=0.001,
n_iters=500
),
"Naive Bayes": GaussianNB(eps=1e-3),
"Decision Tree": DecisionTreeScratch(max_depth=12, criterion='entropy'),
"Logistic Regression": SoftmaxRegressionScratch(
learning_rate=0.05,
iterations=1000
),
"Kernel SVM": MulticlassKernelSVM(
C=1.0,
gamma=None,
tol=1e-3,
max_passes=5,
random_state=42
)
}
for name, model in MODELS.items():
print(f"Training {name}...")
if name == "Kernel SVM":
# Use smaller subset for Kernel SVM
model.fit(X_train_pca[:1500], y_train[:1500])
else:
model.fit(X_train_small, y_train_small)
def generate_base64_plot(img_array):
"""Converts a 28x28 numpy array to a base64 string for HTML display."""
fig, ax = plt.subplots(figsize=(3, 3))
ax.imshow(img_array.reshape(28, 28), cmap='gray')
ax.axis('off')
buf = io.BytesIO()
# Save with tight bounding box to remove white margins
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
img_b64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return img_b64
def preprocess_uploaded_image(image_file):
"""Processes an external image file into the model's feature space."""
try:
img = Image.open(image_file).convert('L').resize((28, 28))
img_array = np.array(img).astype(np.float32) / 255.0
# Raw image for display
img_display_b64 = generate_base64_plot(img_array)
# Transformation for prediction
feat = img_array.reshape(1, -1)
feat = SCALER.transform(feat)
feat = PCA_TRANSFORMER.transform(feat)
return feat, img_display_b64
except Exception as e:
print("Upload Preprocess Error:", e)
return None, None
@app.route("/", methods=["GET", "POST"])
def index():
initialize_app()
model_name = request.form.get("model", "KNN")
sample_index = request.form.get("sample_index", "")
uploaded_file = request.files.get("image")
error = None
prediction = None
actual = None
selected_image = None # Matches HTML: {% if selected_image %}
uploaded_image = None # Matches HTML: {% if uploaded_image %}
model = MODELS.get(model_name)
# --- CASE A: USER UPLOADED AN IMAGE ---
if uploaded_file and uploaded_file.filename:
features, img_b64 = preprocess_uploaded_image(uploaded_file)
if features is not None and model:
prediction = int(model.predict(features)[0])
uploaded_image = img_b64
else:
error = "Could not process the uploaded image."
# --- CASE B: USER ENTERED AN INDEX ---
elif sample_index:
try:
idx = int(sample_index)
if 0 <= idx < DATA["sample_count"]:
features = DATA["x_test_features"][idx:idx+1]
prediction = int(model.predict(features)[0])
actual = int(DATA["y_test"][idx])
# Get the raw image data (not PCA) for visual display
raw_pixels = DATA["x_test_images"][idx]
selected_image = generate_base64_plot(raw_pixels)
else:
error = f"Index must be between 0 and {DATA['sample_count'] - 1}."
except ValueError:
error = "Please enter a valid integer for the sample index."
return render_template(
"index.html",
model_name=model_name,
sample_index=sample_index,
models=list(MODELS.keys()),
sample_count=DATA.get("sample_count", 0),
prediction=prediction,
actual=actual,
error=error,
selected_image=selected_image,
uploaded_image=uploaded_image
)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001, debug=True)