Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(examples/implicit): fix iMAML example with functional APIs #108

Merged
merged 8 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix implicit MAML omniglot few-shot classification example by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/torchopt/pull/108).
- Align results of distributed examples by [@XuehaiPan](https://github.com/XuehaiPan) in [#95](https://github.com/metaopt/torchopt/pull/95).
- Fix `None` in module containers by [@XuehaiPan](https://github.com/XuehaiPan).
- Fix backward errors when using inplace `sqrt_` and `add_` by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan).
Expand Down
Binary file modified examples/iMAML/imaml-accs-functional.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/iMAML/imaml-accs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
94 changes: 45 additions & 49 deletions examples/iMAML/imaml_omniglot_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,21 @@ def main():
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
net.train()
fnet, params = functorch.make_functional(net)
fnet, meta_params = model = functorch.make_functional(net)
meta_opt = torchopt.adam(lr=1e-3)
meta_opt_state = meta_opt.init(params)
meta_opt_state = meta_opt.init(meta_params)

log = []
test(db, [params, fnet], epoch=-1, log=log, args=args)
test(db, model, epoch=-1, log=log, args=args)
for epoch in range(10):
meta_opt, meta_opt_state = train(
db, [params, fnet], (meta_opt, meta_opt_state), epoch, log, args
)
test(db, [params, fnet], epoch, log, args)
meta_opt, meta_opt_state = train(db, model, (meta_opt, meta_opt_state), epoch, log, args)
test(db, model, epoch, log, args)
plot(log)


def train(db, net, meta_opt_and_state, epoch, log, args):
def train(db, model, meta_opt_and_state, epoch, log, args):
n_train_iter = db.x_train.shape[0] // db.batchsz
params, fnet = net
fnet, meta_params = model
meta_opt, meta_opt_state = meta_opt_and_state
# Given this module we've created, rip out the parameters and buffers
# and return a functional version of the module. `fnet` is stateless
Expand All @@ -133,21 +131,22 @@ def train(db, net, meta_opt_and_state, epoch, log, args):

n_inner_iter = args.inner_steps
reg_param = args.reg_params

qry_losses = []
qry_accs = []

init_params_copy = pytree.tree_map(
lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
)

for i in range(task_num):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.

init_params = pytree.tree_map(
lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad),
meta_params,
)
optimal_params = train_imaml_inner_solver(
init_params_copy,
params,
init_params,
meta_params,
(x_spt[i], y_spt[i]),
(fnet, n_inner_iter, reg_param),
)
Expand All @@ -156,17 +155,15 @@ def train(db, net, meta_opt_and_state, epoch, log, args):
# These will be used to update the model's meta-parameters.
qry_logits = fnet(optimal_params, x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# qry_loss = qry_loss / task_num # scale gradients
meta_grads = torch.autograd.grad(qry_loss / task_num, params)
meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state)
params = torchopt.apply_updates(params, meta_updates)
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss.item())
qry_losses.append(qry_loss)
qry_accs.append(qry_acc.item())

qry_losses = np.mean(qry_losses)
qry_losses = torch.mean(torch.stack(qry_losses))
meta_grads = torch.autograd.grad(qry_losses, meta_params)
meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state)
meta_params = torchopt.apply_updates(meta_params, meta_updates)
qry_losses = qry_losses.item()
qry_accs = 100.0 * np.mean(qry_accs)
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
Expand All @@ -188,26 +185,19 @@ def train(db, net, meta_opt_and_state, epoch, log, args):
return (meta_opt, meta_opt_state)


def test(db, net, epoch, log, args):
def test(db, model, epoch, log, args):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
params, fnet = net
# fnet, params, buffers = functorch.make_functional_with_buffers(net)
fnet, meta_params = model
n_test_iter = db.x_test.shape[0] // db.batchsz

qry_losses = []
qry_accs = []

# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = args.inner_steps
reg_param = args.reg_params
init_params_copy = pytree.tree_map(
lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
)
qry_losses = []
qry_accs = []

for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
Expand All @@ -219,9 +209,13 @@ def test(db, net, epoch, log, args):
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.

init_params = pytree.tree_map(
lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad),
meta_params,
)
optimal_params = test_imaml_inner_solver(
init_params_copy,
params,
init_params,
meta_params,
(x_spt[i], y_spt[i]),
(fnet, n_inner_iter, reg_param),
)
Expand Down Expand Up @@ -249,12 +243,12 @@ def test(db, net, epoch, log, args):
)


def imaml_objective(optimal_params, init_params, data, aux):
def imaml_objective(params, meta_params, data, aux):
x_spt, y_spt = data
fnet, n_inner_iter, reg_param = aux
y_pred = fnet(optimal_params, x_spt)
y_pred = fnet(params, x_spt)
regularization_loss = 0
for p1, p2 in zip(optimal_params, init_params):
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
loss = F.cross_entropy(y_pred, y_spt) + regularization_loss
return loss
Expand All @@ -266,11 +260,10 @@ def imaml_objective(optimal_params, init_params, data, aux):
has_aux=False,
solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
)
def train_imaml_inner_solver(init_params_copy, init_params, data, aux):
def train_imaml_inner_solver(params, meta_params, data, aux):
x_spt, y_spt = data
fnet, n_inner_iter, reg_param = aux
# Initial functional optimizer based on TorchOpt
params = init_params_copy
inner_opt = torchopt.sgd(lr=1e-1)
inner_opt_state = inner_opt.init(params)
with torch.enable_grad():
Expand All @@ -280,20 +273,21 @@ def train_imaml_inner_solver(init_params_copy, init_params, data, aux):
loss = F.cross_entropy(pred, y_spt) # compute loss
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params, init_params):
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
params = torchopt.apply_updates(params, updates)
updates, inner_opt_state = inner_opt.update(
grads, inner_opt_state, inplace=True
) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
return params


def test_imaml_inner_solver(init_params_copy, init_params, data, aux):
def test_imaml_inner_solver(params, meta_params, data, aux):
x_spt, y_spt = data
fnet, n_inner_iter, reg_param = aux
# Initial functional optimizer based on TorchOpt
params = init_params_copy
inner_opt = torchopt.sgd(lr=1e-1)
inner_opt_state = inner_opt.init(params)
with torch.enable_grad():
Expand All @@ -303,12 +297,14 @@ def test_imaml_inner_solver(init_params_copy, init_params, data, aux):
loss = F.cross_entropy(pred, y_spt) # compute loss
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params, init_params):
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
params = torchopt.apply_updates(params, updates)
updates, inner_opt_state = inner_opt.update(
grads, inner_opt_state, inplace=True
) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
return params


Expand Down
61 changes: 29 additions & 32 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,21 @@ def test_imaml(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int
optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(jax_params)

def imaml_objective_torchopt(optimal_params, init_params, data):
def imaml_objective_torchopt(params, meta_params, data):
x, y, f = data
y_pred = f(optimal_params, x)
y_pred = f(params, x)
regularization_loss = 0
for p1, p2 in zip(optimal_params, init_params):
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
loss = F.cross_entropy(y_pred, y) + regularization_loss
return loss

@torchopt.diff.implicit.custom_root(
functorch.grad(imaml_objective_torchopt, argnums=0), argnums=1, has_aux=True
)
def inner_solver_torchopt(init_params_copy, init_params, data):
def inner_solver_torchopt(params, meta_params, data):
# Initial functional optimizer based on TorchOpt
x, y, f = data
params = init_params_copy
optimizer = torchopt.sgd(lr=inner_lr)
opt_state = optimizer.init(params)
with torch.enable_grad():
Expand All @@ -151,54 +150,53 @@ def inner_solver_torchopt(init_params_copy, init_params, data):
loss = F.cross_entropy(pred, y) # compute loss
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params, init_params):
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1 - p2))
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = torchopt.apply_updates(params, updates)
updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
return params, (0, {'a': 1, 'b': 2})

def imaml_objective_jax(optimal_params, init_params, x, y):
y_pred = jax_model(optimal_params, x)
def imaml_objective_jax(params, meta_params, x, y):
y_pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y))
regularization_loss = 0
for p1, p2 in zip(optimal_params.values(), init_params.values()):
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
loss = loss + regularization_loss
return loss

@jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True)
def inner_solver_jax(init_params_copy, init_params, x, y):
def inner_solver_jax(params, meta_params, x, y):
"""Solve ridge regression by conjugate gradient."""
# Initial functional optimizer based on torchopt
params = init_params_copy
optimizer = optax.sgd(inner_lr)
opt_state = optimizer.init(params)

def compute_loss(params, init_params, x, y):
def compute_loss(params, meta_params, x, y):
pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y))
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params.values(), init_params.values()):
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
final_loss = loss + regularization_loss
return final_loss

for i in range(inner_update):
grads = jax.grad(compute_loss)(params, init_params, x, y) # compute gradients
grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = optax.apply_updates(params, updates)
return params, (0, {'a': 1, 'b': 2})

for xs, ys in loader:
xs = xs.to(dtype=dtype)
data = (xs, ys, fmodel)
init_params_copy = pytree.tree_map(
meta_params_copy = pytree.tree_map(
lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
)
optimal_params, aux = inner_solver_torchopt(init_params_copy, params, data)
optimal_params, aux = inner_solver_torchopt(meta_params_copy, params, data)
assert aux == (0, {'a': 1, 'b': 2})
outer_loss = fmodel(optimal_params, xs).mean()

Expand Down Expand Up @@ -275,35 +273,34 @@ def solve(self, x, y):
optim_jax = optax.sgd(lr)
optim_state_jax = optim_jax.init(jax_params)

def imaml_objective_jax(optimal_params, init_params, x, y):
y_pred = jax_model(optimal_params, x)
def imaml_objective_jax(params, meta_params, x, y):
y_pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y))
regularization_loss = 0
for p1, p2 in zip(optimal_params.values(), init_params.values()):
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
loss = loss + regularization_loss
return loss

@jaxopt.implicit_diff.custom_root(jax.grad(imaml_objective_jax, argnums=0), has_aux=True)
def inner_solver_jax(init_params_copy, init_params, x, y):
def inner_solver_jax(params, meta_params, x, y):
"""Solve ridge regression by conjugate gradient."""
# Initial functional optimizer based on torchopt
params = init_params_copy
optimizer = optax.sgd(inner_lr)
opt_state = optimizer.init(params)

def compute_loss(params, init_params, x, y):
def compute_loss(params, meta_params, x, y):
pred = jax_model(params, x)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred, y))
# Compute regularization loss
regularization_loss = 0
for p1, p2 in zip(params.values(), init_params.values()):
for p1, p2 in zip(params.values(), meta_params.values()):
regularization_loss += 0.5 * jnp.sum(jnp.square((p1 - p2)))
final_loss = loss + regularization_loss
return final_loss

for i in range(inner_update):
grads = jax.grad(compute_loss)(params, init_params, x, y) # compute gradients
grads = jax.grad(compute_loss)(params, meta_params, x, y) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = optax.apply_updates(params, updates)
return params, (0, {'a': 1, 'b': 2})
Expand Down Expand Up @@ -374,7 +371,7 @@ def ridge_objective_torch(params, l2reg, data):
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

@torchopt.diff.implicit.custom_root(functorch.grad(ridge_objective_torch, argnums=0), argnums=1)
def ridge_solver_torch(init_params, l2reg, data):
def ridge_solver_torch(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

Expand All @@ -383,7 +380,7 @@ def matvec(u):

solve = torchopt.linear_solve.solve_cg(
ridge=len(y_tr) * l2reg.item(),
init=init_params,
init=params,
maxiter=20,
)

Expand All @@ -396,7 +393,7 @@ def ridge_objective_jax(params, l2reg, X_tr, y_tr):
return 0.5 * jnp.mean(jnp.square(residuals)) + regularization_loss

@jaxopt.implicit_diff.custom_root(jax.grad(ridge_objective_jax, argnums=0))
def ridge_solver_jax(init_params, l2reg, X_tr, y_tr):
def ridge_solver_jax(params, l2reg, X_tr, y_tr):
"""Solve ridge regression by conjugate gradient."""

def matvec(u):
Expand All @@ -406,7 +403,7 @@ def matvec(u):
matvec=matvec,
b=X_tr.T @ y_tr,
ridge=len(y_tr) * l2reg.item(),
init=init_params,
init=params,
maxiter=20,
)

Expand All @@ -428,8 +425,8 @@ def matvec(u):
xq = jnp.array(xq.numpy(), dtype=np_dtype)
yq = jnp.array(yq.numpy(), dtype=np_dtype)

def outer_level(init_params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax(init_params_jax, l2reg_jax, xs, ys)
def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
w_fit = ridge_solver_jax(params_jax, l2reg_jax, xs, ys)
y_pred = xq @ w_fit
loss_value = jnp.mean(jnp.square(y_pred - yq))
return loss_value
Expand Down
Loading