Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
Changed the line diameter to 16 to match the original dataset. Reference: googlecreativelab/quickdraw-dataset#19 (comment)
  • Loading branch information
samuelmbiya committed Jun 22, 2021
1 parent 934081d commit 6e67d66
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions PaintShape.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def forward(self, x):
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


# load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('shapeclassifier.pth')
Expand Down Expand Up @@ -88,7 +90,7 @@ def predict_image(image):
image_tensor = test_transforms(image).float()
image_tensor = image_tensor.unsqueeze_(0)
input = Variable(image_tensor)
input = input.to(device)
input = input.to(device).cpu()
output = model(input)
index = output.data.cpu().numpy().argmax()
return LABELS[index]
Expand Down Expand Up @@ -165,7 +167,7 @@ def mouseMoveEvent(self, event):
# creating painter object
painter = QPainter(self.image)
# set the pen of the painter
painter.setPen(QPen(Qt.black, 10, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin))
painter.setPen(QPen(Qt.black, 16, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin))
# draw line from the last point of cursor to the current point
# this will draw only one step
painter.drawLine(self.lastPoint, event.pos())
Expand Down

0 comments on commit 6e67d66

Please sign in to comment.