Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate high loss of Mixtral #931

Closed
5 tasks done
casper-hansen opened this issue Dec 10, 2023 · 0 comments · Fixed by #932
Closed
5 tasks done

Investigate high loss of Mixtral #931

casper-hansen opened this issue Dec 10, 2023 · 0 comments · Fixed by #932
Labels
enhancement New feature or request

Comments

@casper-hansen
Copy link
Collaborator

⚠️ Please check that this feature request hasn't been suggested before.

  • I searched previous Ideas in Discussions didn't find any similar feature requests.
  • I searched previous Issues didn't find any similar feature requests.

🔖 Feature description

The axolotl implementation is not aligned with the MegaBlocks implementation.

class LearnedRouter(torch.nn.Module):
    def forward(self, x):
        if self.training and self.args.moe_jitter_eps is not None:
            x = x * self.jitter(x)

        scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
        expert_weights, expert_indices = self._top_k(scores)

        expert_indices = (
            _uniform_expert_assignment(expert_indices, self.args.moe_num_experts)
            if self.args.uniform_expert_assignment else expert_indices
        )
        return scores, expert_weights, expert_indices

The current implementation:

https://github.com/OpenAccess-AI-Collective/axolotl/blob/68b227a7d8045d0f428d7ca3b9750f837d03611f/src/axolotl/models/mixtral/modeling_moe_mistral.py#L223-L232

Mistral also commented on this:

image

✔️ Solution

Investigate how we can adapt the most correct solution for the router. One way to test this is to measure the initial loss. For reference, back when I implemented sliding windows for Mistral, the initial loss dropped from 9.98 on main to 1.9 from the PR.

Measure loss on short and long context data, e.g. use casperhansen/longalpaca_1k_test with alpaca format.

❓ Alternatives

No response

📝 Additional Context

No response

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this feature has not been requested yet.
  • I have provided enough information for the maintainers to understand and evaluate this request.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant