-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathclassification_test.lua
More file actions
119 lines (95 loc) · 2.92 KB
/
Copy pathclassification_test.lua
File metadata and controls
119 lines (95 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
require 'cunn'
require 'cudnn'
require 'nn'
require 'image'
require 'lfs'
require 'cutorch'
require 'optim'
-- cutorch.setDevice(1)
-- sofar = os.clock()
cmd = torch.CmdLine()
cmd:option('-net', '/media/eunbin/Data2/Dis_model/ggiruk/baseline-epoch-9.net', 'trained model')
cmd:option('-dir', '/home/eunbin/TorchNet/test/test_bad2', 'target directory')
opt = cmd:parse(arg or {})
num_bad = 0
num_good = 0
print(opt.net)
print(opt.dir)
model = torch.load(opt.net):cuda()
model:evaluate()
-- require 'hdf5'
-- print('Loading test data...')
-- testFile = hdf5.open('discriminator_test.h5', 'r')
-- testData = testFile:all()
-- testFile:close()
-- dataset = testData
-- data = dataset.data
-- label = dataset.label
-- dataset.size = data:size()[1]
-- batchSize = 32
-- classes = {'o', 'x'}
-- confusion = optim.ConfusionMatrix(classes)
-- print('<trainer> on testing Set:')
-- for t = 1,dataset.size,batchSize do
-- xlua.progress(t, dataset.size)
-- -- prepare input batch
-- local inputs = data:sub(t,math.min(t+batchSize-1, dataset.size))
-- inputs = inputs:cuda()
-- local targets = label:sub(t,math.min(t+batchSize-1, dataset.size))+1
-- --local targets_tl = targets:narrow(4,1,1)
-- --local targets_br = targets:narrow(4,2,1)
-- --local targets_both = {targets_tl, targets_br}
-- targets = targets:cuda()
-- local outputs = model:forward(inputs)
-- for i = 1,batchSize do
-- confusion:add(outputs[i]:view(-1), targets[i]:view(-1)[1])
-- print(outputs[i]:view(-1), targets[i]:view(-1)[1])
-- --confusion_br:add(outputs[2][i]:view(-1), targets_both[2][i]:view(-1)[1])
-- end
-- print(outputs)
-- end
-- print(confusion)
-- --print(confusion_br)
-- --testLogger:add{['% mean class accuracy (test set, top left)'] = confusion_tl.totalValid * 100,
-- -- ['% mean class accuracy (test set, bottom right)'] = confusion_br.totalValid * 100}
-- confusion:zero()
for file in lfs.dir(opt.dir) do
if lfs.attributes(opt.dir..'/'..file, "mode") == "file" then
img = image.load(opt.dir..'/'..file, 3, 'float')
--transposed_img = img:transpose(2,3)
--imggg = image.vflip(transposed_img)
channel = img:size()[1]
height = img:size()[2]
width = img:size()[3]
output_size = channel * height * width
input = image.scale(img, '224x224', 'bicubic')
batch = input:view(1, table.unpack(input:size():totable()))
--input = input:view(1, 3, 224, 224)
--print(input)
--input = input:cuda()
output = model:forward(batch:cuda())
if output[1][1] > 0.5 then
num_good = num_good + 1
result = 1
else
num_bad = num_bad + 1
result = 0
end
--print(file..' : '.. result)
-- prob, idx = torch.max(output, 2)
-- if idx[1][1] == 1 then
-- num_good = num_good + 1
-- else
-- num_bad = num_bad + 1
-- end
--print(file..' : '.. idx[1][1])
print(file..' : '.. result)
--print(output)
--print(idx)
end
end
print('number of good : '.. num_good)
print('number of bad : '.. num_bad)
--image.save('result.png', transposed_img)
--image.save('resultgg.png', imggg)
print('done.')