Skip to content

Commit

Permalink
example: Use torch in mnist example (#927)
Browse files Browse the repository at this point in the history
* use torch in mnist example

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

* fix

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

* fix metadata

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

* license

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

* format

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>
Signed-off-by: Jinjing Zhou <VoVAllen@users.noreply.github.com>
  • Loading branch information
VoVAllen committed Sep 29, 2022
1 parent 395bedc commit 0b73548
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 164 deletions.
24 changes: 18 additions & 6 deletions examples/mnist/build.envd
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
def build():
def base_env():
base(os="ubuntu20.04", language="python3")
install.vscode_extensions(
[
"ms-python.python",
]
)
shell("zsh")
config.jupyter()
# config.pip_index(url = "https://pypi.tuna.tsinghua.edu.cn/simple")


def build():
base_env()
install.python_packages(
[
"tensorflow",
"numpy",
"torch",
"torchvision",
"--extra-index-url https://download.pytorch.org/whl/cpu",
]
)
shell("zsh")
config.jupyter()


def build_gpu():
build()
base_env()
install.cuda(version="11.2.0", cudnn="8")
install.python_packages(
[
"torch",
"torchvision",
"--extra-index-url https://download.pytorch.org/whl/cu113",
]
)
271 changes: 192 additions & 79 deletions examples/mnist/main.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,192 @@
"""
Title: Simple MNIST convnet
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2015/06/19
Last modified: 2020/04/21
Description: A simple convnet that achieves ~99% test accuracy on MNIST.
"""

"""
## Setup
"""

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

"""
## Prepare the data
"""

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

"""
## Build the model
"""

model = keras.Sequential(
[
keras.Input(shape=input_shape),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes, activation="softmax"),
]
)

model.summary()

"""
## Train the model
"""

batch_size = 128
epochs = 1

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

"""
## Evaluate the trained model
"""

score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# from https://github.com/pytorch/examples/blob/main/mnist/main.py

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output


def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
if args.dry_run:
break


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction="sum").item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)


def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=256,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=1,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=1.0,
metavar="LR",
help="learning rate (default: 1.0)",
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {"batch_size": args.batch_size}
test_kwargs = {"batch_size": args.test_batch_size}
if use_cuda:
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()

if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == "__main__":
main()
Loading

0 comments on commit 0b73548

Please sign in to comment.