Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sidgoyal78 committed Mar 8, 2021
1 parent d5600b5 commit b647a9d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import torch
from torch.utils.data import Dataset

Expand Down
7 changes: 4 additions & 3 deletions benchmarks/experimental/experimental_ampnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class MySGD(Optimizer):
lr (float): learning rate (required)
"""

def __init__(self, params, lr=0.01):
def __init__(self, params, lr):
defaults = dict(lr=lr)
super(MySGD, self).__init__(params, defaults)

Expand Down Expand Up @@ -167,10 +167,10 @@ class SpectrainSGDMomentum(Optimizer):
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate (required)
momentum (float): momentum (default=0.9)
"""

def __init__(self, params, lr=0.01, momentum=0.9):
print("called")
def __init__(self, params, lr, momentum=0.9):
defaults = dict(lr=lr, momentum=momentum)
params = list(params)
super(SpectrainSGDMomentum, self).__init__(params, defaults)
Expand Down Expand Up @@ -232,6 +232,7 @@ def update_weight_using_future_predictions(self, model_index, num_gpus, forward)
def step(self, weight_prediction=True, closure=None):
""" Performs a single optimization step.
Arguments:
weight_prediction (bool, optional): Enable weight prediction based updates
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
Expand Down

0 comments on commit b647a9d

Please sign in to comment.