-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
57 lines (43 loc) · 1.93 KB
/
Copy pathmodel.py
File metadata and controls
57 lines (43 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import timm
import torch.nn as nn
import lightning as L
class MirrorGazeModel(L.LightningModule):
def __init__(self, pretrained=True):
super().__init__()
self.backbone = timm.create_model("fastvit_t8.apple_dist_in1k", pretrained=pretrained, num_classes=0) # 0 classes means no classification head
# self.backbone = timm.create_model("mobilenetv4_conv_small.e1200_r224_in1k", pretrained=pretrained, num_classes=0) # 0 classes means no classification head
self.mlp = nn.Sequential(
nn.Linear(768, 128),
# nn.Linear(1280, 128),
nn.GELU(),
nn.Linear(128, 2)
)
self.loss_fn = nn.MSELoss()
self.learning_rate = 3e-5
def forward(self, rgb):
features = self.backbone(rgb) # batch_size, 768
return self.mlp(features)
def _step(self, batch, batch_idx, step_type="train"):
# x, y = batch
frames, gts = batch
pred_gaze = self(frames)
# weight = self.get_class_weights(touch_vals)
gaze_loss = self.loss_fn(pred_gaze, gts)
self.log(f"{step_type}_loss", gaze_loss, batch_size=len(gts))
# calculate error
diff = pred_gaze - gts
error = torch.sqrt(torch.sum(diff ** 2, dim=1)).mean() * 14.3915
# compute error in cm
# error = torch.sqrt(gaze_loss) * 14.3915 # 14.3915 cm is 871.6 px
self.log(f"{step_type}_error_cm", error, batch_size=len(gts))
return gaze_loss
def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, step_type="train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, step_type="val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, step_type="test")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer