-
Notifications
You must be signed in to change notification settings - Fork 0
/
traineval.py
154 lines (126 loc) · 4.38 KB
/
traineval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import wandb
from dataset import WakeDataset # Assuming dataset.py is in the same directory
from net import EfficientNetB0KeypointDetector
def custom_collate_fn(batch):
images = [item["image"] for item in batch]
keypoints = [item["keypoints"] for item in batch]
images = torch.stack(images, 0)
keypoints = torch.stack(keypoints, 0)
return images, keypoints
def train_one_epoch(
model: nn.Module,
dataloader: DataLoader,
optimizer: optim.Optimizer,
device: torch.device,
epoch: int,
model_save_path: str,
):
"""
Trains the model for one epoch.
Parameters:
- model: The neural network model.
- dataloader: DataLoader providing the training data.
- optimizer: Optimizer used for model training.
- device: The device to train on.
"""
model.train()
total_loss = 0.0
for images, keypoints in dataloader:
print(total_loss)
images, keypoints = images.to(device), keypoints.to(device)
optimizer.zero_grad()
outputs = model(images)
# loss = nn.MSELoss()(outputs, keypoints)
loss = nn.MSELoss()(outputs, keypoints)
r_loss = torch.sqrt(loss)
r_loss.backward()
optimizer.step()
total_loss += r_loss.item()
average_loss = total_loss / len(dataloader)
wandb.log({"train_loss": average_loss})
# Save the model checkpoint
model_filename = f"model_epoch_{epoch}.pth"
model_save_path_full = os.path.join(model_save_path, model_filename)
torch.save(model.state_dict(), model_save_path_full)
print(f"Model saved to {model_save_path_full}")
def evaluate(model: nn.Module, dataloader: DataLoader, device: torch.device):
"""
Evaluates the model on the validation set.
Parameters:
- model: The neural network model.
- dataloader: DataLoader providing the validation data.
- device: The device to evaluate on.
"""
model.eval()
total_loss = 0.0
with torch.no_grad():
for images, keypoints in dataloader:
print("val_loss " + str(total_loss))
images, keypoints = images.to(device), keypoints.to(device)
outputs = model(images)
loss = nn.MSELoss()(outputs, keypoints)
r_loss = torch.sqrt(loss)
total_loss += r_loss.item()
average_loss = total_loss / len(dataloader)
wandb.log({"val_loss": average_loss})
def main():
# Initialize Weights & Biases
wandb.init(project="wake_model_llm_assist")
# Setup
if torch.backends.mps.is_available(): # Check if MPS backend is available
print("Using MPS backend for acceleration on Apple Silicon.")
device = torch.device("mps") # Use MPS device
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientNetB0KeypointDetector().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Load dataset
dataset = WakeDataset(
data_dir="ShipWakes",
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
]
),
)
# Split dataset into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# DataLoaders
train_dataloader = DataLoader(
train_dataset,
batch_size=128,
shuffle=True,
collate_fn=custom_collate_fn,
num_workers=1,
pin_memory=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=16,
shuffle=False,
collate_fn=custom_collate_fn,
num_workers=1,
pin_memory=True,
)
model_save_path = "./model_checkpoints" # Define your model save directory
# Training loop
num_epochs = 500 # Define the number of epochs
for epoch in range(num_epochs):
train_one_epoch(
model, train_dataloader, optimizer, device, epoch, model_save_path
)
evaluate(model, val_dataloader, device)
# Log additional metrics or images if needed
wandb.finish()
if __name__ == "__main__":
main()