Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SGD Linear Model Fixes for Lime #938

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 69 additions & 70 deletions captum/_utils/models/linear_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def sgd_train_linear_model(
This will return the final training loss (averaged with
`running_loss_window`)
"""

loss_window: List[torch.Tensor] = []
min_avg_loss = None
convergence_counter = 0
Expand Down Expand Up @@ -145,77 +144,77 @@ def get_point(datapoint):
if model.linear.bias is not None:
model.linear.bias.zero_()

optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
if reduce_lr:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optim, factor=0.5, patience=patience, threshold=threshold
)

t1 = time.time()
epoch = 0
i = 0
while epoch < max_epoch:
while True: # for x, y, w in dataloader
if running_loss_window is None:
running_loss_window = x.shape[0] * len(dataloader)

y = y.view(x.shape[0], -1)
if w is not None:
w = w.view(x.shape[0], -1)

i += 1

out = model(x)

loss = loss_fn(y, out, w)
if reg_term is not None:
reg = torch.norm(model.linear.weight, p=reg_term)
loss += reg.sum() * alpha

if len(loss_window) >= running_loss_window:
loss_window = loss_window[1:]
loss_window.append(loss.clone().detach())
assert len(loss_window) <= running_loss_window

average_loss = torch.mean(torch.stack(loss_window))
if min_avg_loss is not None:
# if we haven't improved by at least `threshold`
if average_loss > min_avg_loss or torch.isclose(
min_avg_loss, average_loss, atol=threshold
):
convergence_counter += 1
if convergence_counter >= patience:
converged = True
break
else:
convergence_counter = 0
if min_avg_loss is None or min_avg_loss >= average_loss:
min_avg_loss = average_loss.clone()

if debug:
print(
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
)

loss.backward()

optim.step()
model.zero_grad()
if scheduler:
scheduler.step(average_loss)

temp = next(data_iter, None)
if temp is None:
with torch.enable_grad():
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
if reduce_lr:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optim, factor=0.5, patience=patience, threshold=threshold
)

t1 = time.time()
epoch = 0
i = 0
while epoch < max_epoch:
while True: # for x, y, w in dataloader
if running_loss_window is None:
running_loss_window = x.shape[0] * len(dataloader)

y = y.view(x.shape[0], -1)
if w is not None:
w = w.view(x.shape[0], -1)

i += 1

out = model(x)

loss = loss_fn(y, out, w)
if reg_term is not None:
reg = torch.norm(model.linear.weight, p=reg_term)
loss += reg.sum() * alpha

if len(loss_window) >= running_loss_window:
loss_window = loss_window[1:]
loss_window.append(loss.clone().detach())
assert len(loss_window) <= running_loss_window

average_loss = torch.mean(torch.stack(loss_window))
if min_avg_loss is not None:
# if we haven't improved by at least `threshold`
if average_loss > min_avg_loss or torch.isclose(
min_avg_loss, average_loss, atol=threshold
):
convergence_counter += 1
if convergence_counter >= patience:
converged = True
break
else:
convergence_counter = 0
if min_avg_loss is None or min_avg_loss >= average_loss:
min_avg_loss = average_loss.clone()

if debug:
print(
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
)

loss.backward()
optim.step()
model.zero_grad()
if scheduler:
scheduler.step(average_loss)

temp = next(data_iter, None)
if temp is None:
break
x, y, w = get_point(temp)

if converged:
break
x, y, w = get_point(temp)

if converged:
break

epoch += 1
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))
epoch += 1
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))

t2 = time.time()
return {
Expand Down
6 changes: 3 additions & 3 deletions captum/attr/_core/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,17 +512,17 @@ def attribute(
if show_progress:
attr_progress.close()

combined_interp_inps = torch.cat(interpretable_inps).double()
combined_interp_inps = torch.cat(interpretable_inps).float()
combined_outputs = (
torch.cat(outputs)
if len(outputs[0].shape) > 0
else torch.stack(outputs)
).double()
).float()
combined_sim = (
torch.cat(similarities)
if len(similarities[0].shape) > 0
else torch.stack(similarities)
).double()
).float()
dataset = TensorDataset(
combined_interp_inps, combined_outputs, combined_sim
)
Expand Down
31 changes: 27 additions & 4 deletions tests/attr/test_lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import io
import unittest
import unittest.mock
from typing import Any, Callable, Generator, List, Tuple, Union
from functools import partial
from typing import Any, Callable, Generator, List, Tuple, Optional, Union

import torch
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.linear_model import SkLearnLasso, SGDLasso
from captum._utils.models.model import Model
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.lime import get_exp_kernel_similarity_function, Lime, LimeBase
from captum.attr._utils.batching import _batch_example_iterator
Expand Down Expand Up @@ -120,6 +122,22 @@ def test_simple_lime(self) -> None:
test_generator=True,
)

def test_simple_lime_sgd_model(self) -> None:
net = BasicModel_MultiLayer()
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
interpretable_model = SGDLasso()
interpretable_model.fit = partial( # type: ignore
interpretable_model.fit, initial_lr=0.1, max_epoch=500
)
self._lime_test_assert(
net,
inp,
[[73.3716, 193.3349, 113.3349]],
n_samples=1000,
expected_coefs_only=[[73.3716, 193.3349, 113.3349]],
interpretable_model=interpretable_model,
)

def test_simple_lime_with_mask(self) -> None:
net = BasicModel_MultiLayer()
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
Expand Down Expand Up @@ -487,12 +505,15 @@ def _lime_test_assert(
batch_attr: bool = False,
test_generator: bool = False,
show_progress: bool = False,
interpretable_model: Optional[Model] = None,
) -> None:
for batch_size in perturbations_per_eval:
lime = Lime(
model,
similarity_func=get_exp_kernel_similarity_function("cosine", 10.0),
interpretable_model=SkLearnLasso(alpha=1.0),
interpretable_model=interpretable_model
if interpretable_model
else SkLearnLasso(alpha=1.0),
)
attributions = lime.attribute(
test_input,
Expand Down Expand Up @@ -526,7 +547,9 @@ def _lime_test_assert(

lime_alt = LimeBase(
model,
SkLearnLasso(alpha=1.0),
interpretable_model
if interpretable_model
else SkLearnLasso(alpha=1.0),
get_exp_kernel_similarity_function("euclidean", 1000.0),
alt_perturb_generator if test_generator else alt_perturb_func,
False,
Expand Down