Skip to content

Commit

Permalink
feat: add deterministic flag (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuhatry authored and NOBLES5E committed Sep 15, 2021
1 parent 42da5c7 commit f947e43
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions examples/mnist/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import print_function
import argparse
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -123,9 +125,6 @@ def main():
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--log-interval",
type=int,
Expand All @@ -145,10 +144,23 @@ def main():
default="gradient_allreduce",
help="gradient_allreduce, bytegrad, decentralized, low_precision_decentralized, qadam",
)
parser.add_argument(
"--set-deterministic",
action="store_true",
default=False,
help="whether set deterministic",
)

args = parser.parse_args()

torch.manual_seed(args.seed)
if args.set_deterministic:
print("set_deterministic: True")
np.random.seed(666)
random.seed(666)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.manual_seed(666)
torch.cuda.manual_seed_all(666 + int(bagua.get_rank()))
torch.set_printoptions(precision=10)

torch.cuda.set_device(bagua.get_local_rank())
bagua.init_process_group()
Expand Down

0 comments on commit f947e43

Please sign in to comment.