Skip to content

Commit 5a5f69f

Browse files
committed
Add Mean Shift clustering algorithm in machine_learning/
1 parent e3b01ec commit 5a5f69f

1 file changed

Lines changed: 261 additions & 0 deletions

File tree

machine_learning/mean_shift.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""
2+
Mean Shift Clustering
3+
4+
A non-parametric, centroid-based clustering algorithm that does not require
5+
specifying the number of clusters in advance. It works by iteratively shifting
6+
each data point toward the mean of points within a given bandwidth (radius),
7+
until convergence.
8+
9+
How it works:
10+
1. Each point starts as its own candidate centroid.
11+
2. For each candidate, compute the mean of all points within `bandwidth`
12+
distance (the "window").
13+
3. Shift the candidate to that mean.
14+
4. Repeat until candidates stop moving (convergence).
15+
5. Merge candidates that are closer than `bandwidth` to each other.
16+
6. Assign each original point to its nearest final centroid.
17+
18+
Key Properties:
19+
- No need to specify number of clusters (unlike K-Means)
20+
- Can find arbitrarily shaped clusters (like DBSCAN)
21+
- Sensitive to the `bandwidth` parameter
22+
- Deterministic (no random initialization)
23+
24+
Time Complexity: O(n² * iterations) with brute-force window search
25+
Space Complexity: O(n)
26+
27+
References:
28+
- https://en.wikipedia.org/wiki/Mean_shift
29+
- Comaniciu, D. & Meer, P. "Mean Shift: A Robust Approach Toward
30+
Feature Space Analysis." IEEE TPAMI, 2002.
31+
https://doi.org/10.1109/34.1000236
32+
"""
33+
34+
35+
def euclidean_distance(point_a: list[float], point_b: list[float]) -> float:
36+
"""
37+
Compute the Euclidean distance between two n-dimensional points.
38+
39+
>>> euclidean_distance([0.0, 0.0], [3.0, 4.0])
40+
5.0
41+
>>> euclidean_distance([1.0, 1.0], [1.0, 1.0])
42+
0.0
43+
>>> euclidean_distance([0.0], [5.0])
44+
5.0
45+
>>> euclidean_distance([0.0, 0.0], [1.0])
46+
Traceback (most recent call last):
47+
...
48+
ValueError: Both points must have the same number of dimensions.
49+
"""
50+
if len(point_a) != len(point_b):
51+
raise ValueError("Both points must have the same number of dimensions.")
52+
return sum((a - b) ** 2 for a, b in zip(point_a, point_b)) ** 0.5
53+
54+
55+
def get_points_within_bandwidth(
56+
data: list[list[float]], center: list[float], bandwidth: float
57+
) -> list[list[float]]:
58+
"""
59+
Return all points in data that lie within `bandwidth` distance of `center`.
60+
61+
>>> data = [[0.0, 0.0], [0.5, 0.5], [5.0, 5.0]]
62+
>>> get_points_within_bandwidth(data, [0.0, 0.0], 1.0)
63+
[[0.0, 0.0], [0.5, 0.5]]
64+
>>> get_points_within_bandwidth(data, [5.0, 5.0], 1.0)
65+
[[5.0, 5.0]]
66+
>>> get_points_within_bandwidth(data, [0.0, 0.0], 10.0)
67+
[[0.0, 0.0], [0.5, 0.5], [5.0, 5.0]]
68+
"""
69+
return [
70+
point for point in data if euclidean_distance(point, center) <= bandwidth
71+
]
72+
73+
74+
def compute_mean(points: list[list[float]]) -> list[float]:
75+
"""
76+
Compute the element-wise mean of a list of points.
77+
78+
>>> compute_mean([[1.0, 2.0], [3.0, 4.0]])
79+
[2.0, 3.0]
80+
>>> compute_mean([[0.0, 0.0, 0.0]])
81+
[0.0, 0.0, 0.0]
82+
>>> compute_mean([])
83+
Traceback (most recent call last):
84+
...
85+
ValueError: Cannot compute mean of empty list.
86+
"""
87+
if not points:
88+
raise ValueError("Cannot compute mean of empty list.")
89+
n_dims = len(points[0])
90+
return [sum(point[dim] for point in points) / len(points) for dim in range(n_dims)]
91+
92+
93+
def shift_point(
94+
point: list[float], data: list[list[float]], bandwidth: float
95+
) -> list[float]:
96+
"""
97+
Shift a single point to the mean of all data points within `bandwidth`.
98+
99+
If no points fall within the bandwidth, the point remains unchanged.
100+
101+
>>> data = [[1.0, 1.0], [1.5, 1.5], [10.0, 10.0]]
102+
>>> shift_point([1.0, 1.0], data, 2.0)
103+
[1.25, 1.25]
104+
>>> shift_point([10.0, 10.0], data, 1.0)
105+
[10.0, 10.0]
106+
"""
107+
neighbors = get_points_within_bandwidth(data, point, bandwidth)
108+
if not neighbors:
109+
return point
110+
return compute_mean(neighbors)
111+
112+
113+
def has_converged(
114+
old_point: list[float], new_point: list[float], tolerance: float
115+
) -> bool:
116+
"""
117+
Check whether a point has converged (moved less than `tolerance`).
118+
119+
>>> has_converged([1.0, 1.0], [1.0000001, 1.0000001], 1e-4)
120+
True
121+
>>> has_converged([1.0, 1.0], [1.5, 1.5], 1e-4)
122+
False
123+
"""
124+
return euclidean_distance(old_point, new_point) < tolerance
125+
126+
127+
def merge_centroids(
128+
centroids: list[list[float]], bandwidth: float
129+
) -> list[list[float]]:
130+
"""
131+
Merge centroids that are within `bandwidth` distance of each other.
132+
133+
Iterates through centroids and greedily merges any that are close enough,
134+
keeping the first encountered as the representative.
135+
136+
>>> centroids = [[1.0, 1.0], [1.1, 1.1], [10.0, 10.0]]
137+
>>> merged = merge_centroids(centroids, 1.0)
138+
>>> len(merged)
139+
2
140+
>>> centroids = [[0.0, 0.0], [5.0, 5.0], [10.0, 10.0]]
141+
>>> len(merge_centroids(centroids, 1.0))
142+
3
143+
"""
144+
merged: list[list[float]] = []
145+
for centroid in centroids:
146+
if all(
147+
euclidean_distance(centroid, existing) >= bandwidth
148+
for existing in merged
149+
):
150+
merged.append(centroid)
151+
return merged
152+
153+
154+
def mean_shift(
155+
data: list[list[float]],
156+
bandwidth: float,
157+
max_iterations: int = 300,
158+
tolerance: float = 1e-4,
159+
) -> list[int]:
160+
"""
161+
Perform Mean Shift clustering on a dataset.
162+
163+
Args:
164+
data: List of n-dimensional data points.
165+
bandwidth: Radius of the window used to compute the mean.
166+
Must be greater than 0.
167+
max_iterations: Maximum number of shift iterations per point.
168+
Must be at least 1.
169+
tolerance: Convergence threshold — stop shifting when movement
170+
is smaller than this value. Must be greater than 0.
171+
172+
Returns:
173+
A list of integer cluster labels, one per input point.
174+
Cluster IDs start from 0.
175+
176+
Raises:
177+
ValueError: If data is empty.
178+
ValueError: If bandwidth is not positive.
179+
ValueError: If max_iterations is less than 1.
180+
ValueError: If tolerance is not positive.
181+
182+
Example — two well-separated clusters:
183+
>>> data = [
184+
... [1.0, 1.0], [1.2, 1.0], [1.0, 1.2],
185+
... [9.0, 9.0], [9.2, 9.0], [9.0, 9.2],
186+
... ]
187+
>>> labels = mean_shift(data, bandwidth=2.0)
188+
>>> len(set(labels)) # two clusters
189+
2
190+
>>> labels[0] == labels[1] == labels[2] # first group same cluster
191+
True
192+
>>> labels[3] == labels[4] == labels[5] # second group same cluster
193+
True
194+
>>> labels[0] != labels[3] # different clusters
195+
True
196+
197+
Example — single cluster (all points close together):
198+
>>> data = [[0.0, 0.0], [0.1, 0.0], [0.0, 0.1], [0.1, 0.1]]
199+
>>> labels = mean_shift(data, bandwidth=2.0)
200+
>>> len(set(labels))
201+
1
202+
203+
Example — invalid inputs:
204+
>>> mean_shift([], bandwidth=1.0)
205+
Traceback (most recent call last):
206+
...
207+
ValueError: Data must not be empty.
208+
>>> mean_shift([[1.0, 2.0]], bandwidth=0.0)
209+
Traceback (most recent call last):
210+
...
211+
ValueError: Bandwidth must be greater than 0.
212+
>>> mean_shift([[1.0, 2.0]], bandwidth=1.0, max_iterations=0)
213+
Traceback (most recent call last):
214+
...
215+
ValueError: max_iterations must be at least 1.
216+
>>> mean_shift([[1.0, 2.0]], bandwidth=1.0, tolerance=0.0)
217+
Traceback (most recent call last):
218+
...
219+
ValueError: Tolerance must be greater than 0.
220+
"""
221+
if not data:
222+
raise ValueError("Data must not be empty.")
223+
if bandwidth <= 0:
224+
raise ValueError("Bandwidth must be greater than 0.")
225+
if max_iterations < 1:
226+
raise ValueError("max_iterations must be at least 1.")
227+
if tolerance <= 0:
228+
raise ValueError("Tolerance must be greater than 0.")
229+
230+
# each point starts as its own candidate centroid
231+
candidates = [point[:] for point in data]
232+
233+
for _ in range(max_iterations):
234+
new_candidates = [
235+
shift_point(candidate, data, bandwidth) for candidate in candidates
236+
]
237+
if all(
238+
has_converged(old, new, tolerance)
239+
for old, new in zip(candidates, new_candidates)
240+
):
241+
break
242+
candidates = new_candidates
243+
244+
centroids = merge_centroids(candidates, bandwidth)
245+
246+
# assign each original point to its nearest centroid
247+
labels = [
248+
min(
249+
range(len(centroids)),
250+
key=lambda i: euclidean_distance(point, centroids[i]),
251+
)
252+
for point in data
253+
]
254+
255+
return labels
256+
257+
258+
if __name__ == "__main__":
259+
import doctest
260+
261+
doctest.testmod(verbose=True)

0 commit comments

Comments
 (0)