Skip to content

Commit

Permalink
Fixed ScaleType serialization issues with Array and HashMap types
Browse files Browse the repository at this point in the history
  • Loading branch information
arjanz committed Aug 1, 2024
1 parent a11a6e3 commit ddb5caa
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
3 changes: 1 addition & 2 deletions scalecodec/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def new(self, **kwargs) -> 'ScaleType':

return obj


def impl(self, scale_type_cls: type = None, runtime_config=None) -> 'ScaleTypeDef':
"""
Expand Down Expand Up @@ -286,7 +285,7 @@ def deserialize(self, value_serialized: any):
return self.value_object

self.value_object = self.type_def.deserialize(value_serialized)
self.value_serialized = value_serialized
self.value_serialized = self.type_def.serialize(self.value_object)

return self.value_object

Expand Down
66 changes: 58 additions & 8 deletions scalecodec/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,18 @@ def deserialize(self, value: str) -> str:
return value


class ArrayObject(ScaleType):

def to_bytes(self) -> bytes:
if self.type_def.type_def is not U8:
raise ScaleDeserializeException('Only an Array of U8 can be represented as bytes')
return self.value_object


class Array(ScaleTypeDef):

scale_type_cls = ArrayObject

def __init__(self, type_def: ScaleTypeDef, length: int):
self.type_def = type_def
self.length = length
Expand Down Expand Up @@ -715,12 +726,27 @@ def serialize(self, value: Union[list, bytes]) -> Union[list, str]:
return f'0x{value.hex()}'

def deserialize(self, value: Union[list, str, bytes]) -> Union[list, bytes]:

if type(value) not in [list, str, bytes]:
raise ScaleDeserializeException('value should be of type list, str or bytes')

if type(value) is str:
if value[0:2] == '0x':
return bytes.fromhex(value[2:])
value = bytes.fromhex(value[2:])
else:
return value.encode()
else:
value = value.encode()

if len(value) != self.length:
raise ScaleDeserializeException('Length of array does not match size of value')

if type(value) is bytes:
if self.type_def is not U8:
raise ScaleDeserializeException('Only an Array of U8 can be represented as (hex)bytes')

return value

if type(value) is list:

value_object = []

for item in value:
Expand Down Expand Up @@ -793,7 +819,13 @@ def decode(self, data: ScaleBytes) -> list:
return value

def serialize(self, value: list) -> list:
return [(k.value_serialized, v.value_serialized) for k, v in value]
output = []
for k, v in value:
if type(k) is ScaleType and type(v) is ScaleType:
output.append((k.value_serialized, v.value_serialized))
else:
output.append((k, v))
return output

def deserialize(self, value: list) -> list:
return [(self.key_def.deserialize(k), self.value_def.deserialize(v)) for k, v in value]
Expand Down Expand Up @@ -833,12 +865,20 @@ def decode(self, data: ScaleBytes) -> bytearray:
def serialize(self, value: bytearray) -> str:
return f'0x{value.hex()}'

def deserialize(self, value: str) -> bytearray:
if type(value) is str:
def deserialize(self, value: Union[bytes, str, list]) -> bytes:

if type(value) in (list, bytearray):
value = bytes(value)

elif type(value) is str:
if value[0:2] == '0x':
value = bytearray.fromhex(value[2:])
value = bytes.fromhex(value[2:])
else:
value = value.encode('utf-8')

if type(value) is not bytes:
raise ScaleDeserializeException(f'Cannot deserialize type "{type(value)}"')

return value

def example_value(self, _recursion_level: int = 0, max_recursion: int = TYPE_DECOMP_MAX_RECURSIVE):
Expand All @@ -860,13 +900,19 @@ def serialize(self, value: str) -> str:
def deserialize(self, value: str) -> str:
return value


def create_example(self, _recursion_level: int = 0):
return 'String'


class HashDefObject(ScaleType):
def to_bytes(self) -> bytes:
return self.value_object


class HashDef(ScaleTypeDef):

scale_type_cls = HashDefObject

def __init__(self, bits: int):
super().__init__()
self.bits = bits
Expand Down Expand Up @@ -897,6 +943,10 @@ def serialize(self, value: bytes) -> str:
def deserialize(self, value: Union[str, bytes]) -> bytes:
if type(value) is str:
value = bytes.fromhex(value[2:])

if type(value) is not bytes:
raise ScaleDeserializeException('value should be of type str or bytes')

return value

def example_value(self, _recursion_level: int = 0, max_recursion: int = TYPE_DECOMP_MAX_RECURSIVE):
Expand Down
5 changes: 5 additions & 0 deletions test/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def test_bool_encode_decode(self):

self.assertEqual(value, scale_obj.value)

def test_bool_encode_false(self):
scale_obj = Bool().new()
data = scale_obj.encode(False)
self.assertEqual(ScaleBytes("0x00"), data)


if __name__ == '__main__':
unittest.main()

0 comments on commit ddb5caa

Please sign in to comment.