-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
139 lines (109 loc) · 3.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import pytorch_lightning as pl
import timm
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from dataloader_cloudsen12 import testing_data, training_data, validation_data
class lit_dataloader(pl.LightningDataModule):
def __init__(self, batch_size):
"""
LightningDataModule for loading and preparing data.
Args:
batch_size (int): The batch size for the data loaders.
"""
super().__init__()
self.batch_size = batch_size
def train_dataloader(self):
"""
Get the training data loader.
Returns:
torch.utils.data.DataLoader: The training data loader.
"""
return torch.utils.data.DataLoader(
training_data, batch_size=self.batch_size, shuffle=True, pin_memory=False
)
def val_dataloader(self):
"""
Get the validation data loader.
Returns:
torch.utils.data.DataLoader: The validation data loader.
"""
return torch.utils.data.DataLoader(
validation_data, batch_size=self.batch_size, shuffle=False, pin_memory=False
)
class reg_model(pl.LightningModule):
def __init__(self):
"""
LightningModule for the hardness index prediction.
"""
super().__init__()
self.model = timm.create_model(
"resnet10t", pretrained=True, num_classes=1, in_chans=13
)
self.loss = torch.nn.BCEWithLogitsLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
X, y = batch
yhat = self(X).squeeze()
loss = self.loss(yhat, y.float())
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
X, y = batch
yhat = self(X).squeeze()
loss = self.loss(yhat, y.float())
self.log("val_loss", loss)
return loss
def test_step(self, batch, batch_idx):
X, y = batch
yhat = self(X).squeeze()
loss = self.loss(yhat, y.float())
self.log("test_loss", loss)
return loss
def configure_optimizers(self):
"""
Configure the optimizer and learning rate scheduler.
Returns:
dict: A dictionary containing the optimizer, learning rate scheduler, and monitor.
"""
# optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
# scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=10, verbose=True
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "val_loss",
}
if __name__ == "__main__":
# Define training parameters
batch_size = 64
nepochs = 250
# Logging
logging = WandbLogger(project="IGARS2023", entity="csaybar")
# Define callbacks
callback1 = ModelCheckpoint(
monitor="val_loss",
save_top_k=1,
mode="min",
filename="{epoch}-{val_loss:.2f}",
dirpath="weights/",
save_weights_only=True,
)
callback2 = EarlyStopping(monitor="val_loss", patience=20)
callbacks = [callback1, callback2]
# Define trainer
trainer = pl.Trainer(
callbacks=callbacks,
logger=logging,
max_epochs=nepochs,
accelerator="gpu",
devices=[0],
)
# Train model
lit_model = reg_model()
lit_dataset = lit_dataloader(batch_size)
trainer.fit(lit_model, lit_dataset)