-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_classifier.py
46 lines (39 loc) · 1.58 KB
/
test_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import time
# os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from utils.data_utils import get_raw_data, SingleClassDataset, PoisonedDataset
# anyways just a nice little function to actually test the classifer
# happy little bob ross function
# UGH THIS WAS WRITTEN SO MUCH BETTER ON THE LAB COMPUTER AND THEN I COULDNT PUSH TO GITHUB
def test_classifier(model, dataset, iterations=1000):
# model.eval()
correct = 0
for idx, (img,lbl) in enumerate(dataset):
if idx >= iterations: break
outputs = model(img.unsqueeze(0))
_, preds = torch.max(outputs, 1)
if lbl.data == preds: correct+=1
# elif lbl.data != preds: print(f"RIGHT:\tpred: {preds}\tlabel: {lbl.data}")
# else: print("whoopsie")
print(f"correct: {correct}")
return correct
if __name__ == "__main__":
train = get_raw_data()
for trigger in ["square", "triangle", "L"]:
checkpoint = torch.load(f"data/models/{trigger}_model.pth")
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 20)
model.load_state_dict(checkpoint)
model.eval()
dataset = PoisonedDataset(train, trigger=trigger, poison_ratio=1)
accuracy = test_classifier(model, dataset)
print(f"{trigger} accuracy on clean data:\t{accuracy}")