Skip to content

Commit

Permalink
MLP fix
Browse files Browse the repository at this point in the history
  • Loading branch information
atticusg committed Jan 16, 2024
1 parent 42fcdc2 commit 7dd1b62
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 115 deletions.
16 changes: 8 additions & 8 deletions pyvene/models/mlp/modelings_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
max_position_embeddings=512,
n_layer=2,
h_dim=512,
num_labels=2,
num_classes=2,
activation_function="gelu",
pdrop=0.3,
problem_type="single_label_classification",
Expand All @@ -34,7 +34,7 @@ def __init__(
self.h_dim = h_dim
self.activation_function = activation_function
self.pdrop = pdrop
self.num_labels = num_labels
self.num_classes = num_classes
self.problem_type = problem_type
self.include_bias = include_bias
self.squeeze_output = squeeze_output
Expand Down Expand Up @@ -108,10 +108,10 @@ def forward(
class MLPForClassification(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.num_classes = config.num_classes
self.squeeze_output = config.squeeze_output
self.mlp = MLPModel(config)
self.score = nn.Linear(config.h_dim, self.num_labels, bias=config.include_bias)
self.score = nn.Linear(config.h_dim, self.num_classes, bias=config.include_bias)

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -140,9 +140,9 @@ def forward(
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
if self.num_classes == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
elif self.num_classes > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
Expand All @@ -151,14 +151,14 @@ def forward(

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
if self.num_classes == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(
pooled_logits.view(-1, self.num_labels), labels.view(-1)
pooled_logits.view(-1, self.num_classes), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
Expand Down
Loading

0 comments on commit 7dd1b62

Please sign in to comment.