diff --git a/AUTHORS.rst b/AUTHORS.rst index 1319b54b2..2ef0f30d2 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -153,3 +153,4 @@ Contributors (chronological) - Juan Norris `@juannorris `_ - 장준영 `@jun0jang `_ - `@ebargtuo `_ +- Michał Getka `@mgetka `_ diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 5f412da53..91f8e8071 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -5,6 +5,7 @@ import datetime as dt import numbers import uuid +import ipaddress import decimal import math import typing @@ -51,6 +52,9 @@ "Url", "URL", "Email", + "IP", + "IPv4", + "IPv6", "Method", "Function", "Str", @@ -1625,6 +1629,56 @@ def __init__(self, *args, **kwargs): self.validators.insert(0, validator) +class IP(Field): + """A IP address field. + + :param bool exploded: If `True`, serialize ipv6 address in long form, ie. with groups + consisting entirely of zeros included.""" + + default_error_messages = {"invalid_ip": "Not a valid IP address."} + + DESERIALIZATION_CLASS = None # type: typing.Optional[typing.Type] + + def __init__(self, *args, exploded=False, **kwargs): + super().__init__(*args, **kwargs) + self.exploded = exploded + + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]: + if value is None: + return None + if self.exploded: + return value.exploded + return value.compressed + + def _deserialize( + self, value, attr, data, **kwargs + ) -> typing.Optional[typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]: + if value is None: + return None + try: + return (self.DESERIALIZATION_CLASS or ipaddress.ip_address)( + utils.ensure_text_type(value) + ) + except (ValueError, TypeError) as error: + raise self.make_error("invalid_ip") from error + + +class IPv4(IP): + """A IPv4 address field.""" + + default_error_messages = {"invalid_ip": "Not a valid IPv4 address."} + + DESERIALIZATION_CLASS = ipaddress.IPv4Address + + +class IPv6(IP): + """A IPv6 address field.""" + + default_error_messages = {"invalid_ip": "Not a valid IPv6 address."} + + DESERIALIZATION_CLASS = ipaddress.IPv6Address + + class Method(Field): """A field that takes the value returned by a `Schema` method. diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index e74f467a7..6ce39300d 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -1,5 +1,6 @@ import datetime as dt import uuid +import ipaddress import decimal import math @@ -850,6 +851,63 @@ def test_invalid_uuid_deserialization(self, in_value): assert excinfo.value.args[0] == "Not a valid UUID." + def test_ip_field_deserialization(self): + field = fields.IP() + ipv4_str = "140.82.118.3" + result = field.deserialize(ipv4_str) + assert isinstance(result, ipaddress.IPv4Address) + assert str(result) == ipv4_str + + ipv6_str = "2a00:1450:4001:824::200e" + result = field.deserialize(ipv6_str) + assert isinstance(result, ipaddress.IPv6Address) + assert str(result) == ipv6_str + + @pytest.mark.parametrize( + "in_value", ["malformed", 123, b"\x01\x02\03", "192.168", "ff::aa:1::2"] + ) + def test_invalid_ip_deserialization(self, in_value): + field = fields.IP() + with pytest.raises(ValidationError) as excinfo: + field.deserialize(in_value) + + assert excinfo.value.args[0] == "Not a valid IP address." + + def test_ipv4_field_deserialization(self): + field = fields.IPv4() + ipv4_str = "140.82.118.3" + result = field.deserialize(ipv4_str) + assert isinstance(result, ipaddress.IPv4Address) + assert str(result) == ipv4_str + + @pytest.mark.parametrize( + "in_value", + ["malformed", 123, b"\x01\x02\03", "192.168", "2a00:1450:4001:81d::200e"], + ) + def test_invalid_ipv4_deserialization(self, in_value): + field = fields.IPv4() + with pytest.raises(ValidationError) as excinfo: + field.deserialize(in_value) + + assert excinfo.value.args[0] == "Not a valid IPv4 address." + + def test_ipv6_field_deserialization(self): + field = fields.IPv6() + ipv6_str = "2a00:1450:4001:824::200e" + result = field.deserialize(ipv6_str) + assert isinstance(result, ipaddress.IPv6Address) + assert str(result) == ipv6_str + + @pytest.mark.parametrize( + "in_value", ["malformed", 123, b"\x01\x02\03", "ff::aa:1::2", "192.168.0.1"] + ) + def test_invalid_ipv6_deserialization(self, in_value): + field = fields.IPv6() + with pytest.raises(ValidationError) as excinfo: + field.deserialize(in_value) + + assert excinfo.value.args[0] == "Not a valid IPv6 address." + def test_deserialization_function_must_be_callable(self): with pytest.raises(ValueError): fields.Function(lambda x: None, deserialize="notvalid") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index cbc6ca5e4..9b6480547 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -4,6 +4,7 @@ import itertools import decimal import uuid +import ipaddress import pytest @@ -143,6 +144,56 @@ def test_uuid_field(self, user): assert field.serialize("uuid1", user) == "12345678-1234-5678-1234-567812345678" assert field.serialize("uuid2", user) is None + def test_ip_address_field(self, user): + + ipv4_string = "192.168.0.1" + ipv6_string = "ffff::ffff" + ipv6_exploded_string = ipaddress.ip_address("ffff::ffff").exploded + + user.ipv4 = ipaddress.ip_address(ipv4_string) + user.ipv6 = ipaddress.ip_address(ipv6_string) + user.empty_ip = None + + field_compressed = fields.IP() + assert isinstance(field_compressed.serialize("ipv4", user), str) + assert field_compressed.serialize("ipv4", user) == ipv4_string + assert isinstance(field_compressed.serialize("ipv6", user), str) + assert field_compressed.serialize("ipv6", user) == ipv6_string + assert field_compressed.serialize("empty_ip", user) is None + + field_exploded = fields.IP(exploded=True) + assert isinstance(field_exploded.serialize("ipv6", user), str) + assert field_exploded.serialize("ipv6", user) == ipv6_exploded_string + + def test_ipv4_address_field(self, user): + + ipv4_string = "192.168.0.1" + + user.ipv4 = ipaddress.ip_address(ipv4_string) + user.empty_ip = None + + field = fields.IPv4() + assert isinstance(field.serialize("ipv4", user), str) + assert field.serialize("ipv4", user) == ipv4_string + assert field.serialize("empty_ip", user) is None + + def test_ipv6_address_field(self, user): + + ipv6_string = "ffff::ffff" + ipv6_exploded_string = ipaddress.ip_address("ffff::ffff").exploded + + user.ipv6 = ipaddress.ip_address(ipv6_string) + user.empty_ip = None + + field_compressed = fields.IPv6() + assert isinstance(field_compressed.serialize("ipv6", user), str) + assert field_compressed.serialize("ipv6", user) == ipv6_string + assert field_compressed.serialize("empty_ip", user) is None + + field_exploded = fields.IPv6(exploded=True) + assert isinstance(field_exploded.serialize("ipv6", user), str) + assert field_exploded.serialize("ipv6", user) == ipv6_exploded_string + def test_decimal_field(self, user): user.m1 = 12 user.m2 = "12.355"