diff --git a/main.py b/main.py index 8c43c30..318849c 100644 --- a/main.py +++ b/main.py @@ -43,7 +43,7 @@ def train(X, Y, model, args): sum_loss += float(loss) - print("Epoch: {:4d}\tloss: {}".format(epoch, sum_loss / N)) + print("Epoch: {:4d}\tloss: {}".format(epoch, sum_loss / (N / args.batchsize))) def visualize(X, Y, model):