-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathworker_train.lua
More file actions
160 lines (139 loc) · 4.53 KB
/
worker_train.lua
File metadata and controls
160 lines (139 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
require "love.image"
require "love.math"
local hyperneat = require "hyperneat"
local neat = require "neat/neat"
-- Receive arguments via the ... construct
local result_channel, start_genome, dataset = ...
local ITERATIONS = 32
local POPULATION_SIZE = 64
local TARGET_ACCURACY = 0.8
local MAX_GENS = 1000
local hyperneat_settings = {
input_res = 28,
output_count = 10,
weight_range = 5.0,
neat_settings = {
input_count = 7,
output_count = 1,
hidden_layers_count = 2,
nodes_per_layer = 8,
mutation_activation_prob = 0.3,
mutation_add_node_prob = 0.1,
mutation_add_conn_prob = 0.1,
}
}
-- Helper to clone a purified genome into a working one
local function clone_genome(source)
local new_genome = {
nodes = {},
connections = {},
fitness = source.fitness or 0,
settings = source.settings or hyperneat_settings.neat_settings
}
for i, node in ipairs(source.nodes) do
table.insert(new_genome.nodes, { id = node.id, type = node.type, activation = node.activation })
end
for i, conn in ipairs(source.connections) do
table.insert(new_genome.connections, {
in_node = conn.in_node,
out_node = conn.out_node,
weight = conn.weight,
innovation = conn.innovation,
enabled = conn.enabled
})
end
return new_genome
end
-- Initialize population
local population = {}
for i = 1, POPULATION_SIZE do
local genome = clone_genome(start_genome)
local substrate = hyperneat.create_2dsubstrate(hyperneat_settings)
substrate.genome = genome
-- Add some initial variation
neat.mutate(substrate.genome, substrate.genome.settings)
table.insert(population, substrate)
end
local function gather_dataset_inputs(dataset)
local processed = {}
for _, item in ipairs(dataset) do
local full_data = love.filesystem.newFileData(item.path)
local image_data = love.image.newImageData(full_data)
local w, h = image_data:getDimensions()
local inputs = {}
for y = 0, h - 1 do
for x = 0, w - 1 do
local _, _, _, a = image_data:getPixel(x, y)
table.insert(inputs, a * 2 - 1)
end
end
table.insert(processed, { inputs = inputs, label = item.label })
image_data = nil
full_data = nil
end
return processed
end
local train_data = gather_dataset_inputs(dataset)
local gen = 0
local best_acc = 0
while best_acc < TARGET_ACCURACY and gen < MAX_GENS do
gen = gen + 1
local current_max_acc = 0
for _, substrate in ipairs(population) do
substrate.total_error = 0
substrate.correct_numbers = 0
for _, data in ipairs(train_data) do
local results = hyperneat.evaluate(substrate, data.inputs)
-- Probability of the correct class
local prob_correct = results[data.label] or 0
-- MSE for tie-breaking
local error_sum = 0
for n = 1, 10 do
local target = (n == data.label) and 1.0 or 0.0
local diff = target - results[n]
error_sum = error_sum + (diff * diff)
end
substrate.total_error = substrate.total_error + error_sum
if (function()
local max_val = -1
local max_idx = -1
for n = 1, 10 do
if results[n] > max_val then
max_val = results[n]
max_idx = n
end
end
return max_idx == data.label
end)() then
substrate.correct_numbers = substrate.correct_numbers + 1
end
-- Use the probability of the correct class as a primary fitness driver
substrate.fitness = substrate.fitness + prob_correct
end
local acc = substrate.correct_numbers / #train_data
if acc > current_max_acc then
current_max_acc = acc
end
end
best_acc = current_max_acc
if best_acc < TARGET_ACCURACY and gen < MAX_GENS then
population = hyperneat.evolve_population(population, hyperneat_settings)
end
end
-- Find best performer
local best_substrate = population[1]
local max_fitness = -math.huge
for _, substrate in ipairs(population) do
local avg_error = substrate.total_error / (#dataset * 10)
-- Fitness combines accuracy (primary) and error reduction (secondary)
local accuracy = substrate.correct_numbers / #dataset
substrate.fitness = (accuracy * 2) + substrate.fitness
substrate.genome.fitness = substrate.fitness
if substrate.fitness > max_fitness then
max_fitness = substrate.fitness
best_substrate = substrate
end
end
result_channel:push(neat.purify_genome(best_substrate.genome))
result_channel:push(string.format("Best Acc: %.1f%% | Fitness: %.3f",
(best_substrate.correct_numbers / #dataset) * 100, max_fitness))