Skip to content

Commit

Permalink
Merge pull request #1815 from opentensor/fix/abe/regencoldkey
Browse files Browse the repository at this point in the history
Support for string mnemonic thru cli when regenerating coldkeys
  • Loading branch information
gus-opentensor committed May 1, 2024
2 parents 5a719a2 + 81bcb8a commit 1c69ad6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
2 changes: 2 additions & 0 deletions bittensor/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,8 @@ def regenerate_coldkey(
if mnemonic is not None:
if isinstance(mnemonic, str):
mnemonic = mnemonic.split()
elif isinstance(mnemonic, list) and len(mnemonic) == 1:
mnemonic = mnemonic[0].split()
if len(mnemonic) not in [12, 15, 18, 21, 24]:
raise ValueError(
"Mnemonic has invalid size. This should be 12,15,18,21 or 24 words"
Expand Down
64 changes: 64 additions & 0 deletions tests/unit_tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,70 @@ def test_regen_hotkey_from_hex_seed_str(mock_wallet):
mock_wallet.regenerate_hotkey(seed=seed_str_bad, overwrite=True, suppress=True)


@pytest.mark.parametrize(
"mnemonic, expected_exception",
[
# Input is in a string format
(
"fiscal prevent noise record smile believe quote front weasel book axis legal",
None,
),
# Input is in a list format (acquired by encapsulating mnemonic arg in a string "" in the cli)
(
[
"fiscal prevent noise record smile believe quote front weasel book axis legal"
],
None,
),
# Input is in a full list format (aquired by pasting mnemonic arg simply w/o quotes in cli)
(
[
"fiscal",
"prevent",
"noise",
"record",
"smile",
"believe",
"quote",
"front",
"weasel",
"book",
"axis",
"legal",
],
None,
),
# Incomplete mnemonic
("word1 word2 word3", ValueError),
# No mnemonic added
(None, ValueError),
],
ids=[
"string-format",
"list-format-thru-string",
"list-format",
"incomplete-mnemonic",
"no-mnemonic",
],
)
def test_regen_coldkey_mnemonic(mock_wallet, mnemonic, expected_exception):
"""Test the `regenerate_coldkey` method of the wallet class, which regenerates the cold key pair from a mnemonic.
We test different input formats of mnemonics and check if the function works as expected.
"""
with patch.object(mock_wallet, "set_coldkey") as mock_set_coldkey, patch.object(
mock_wallet, "set_coldkeypub"
) as mock_set_coldkeypub:
if expected_exception:
with pytest.raises(expected_exception):
mock_wallet.regenerate_coldkey(
mnemonic=mnemonic, overwrite=True, suppress=True
)
else:
mock_wallet.regenerate_coldkey(mnemonic=mnemonic)
mock_set_coldkey.assert_called_once()
mock_set_coldkeypub.assert_called_once()


@pytest.mark.parametrize(
"overwrite, user_input, expected_exception",
[
Expand Down

0 comments on commit 1c69ad6

Please sign in to comment.