Skip to content

Commit

Permalink
Merge pull request #1485 from mgetka/ip-address-field
Browse files Browse the repository at this point in the history
Add IP address field type.
  • Loading branch information
lafrech authored Sep 16, 2020
2 parents 2708795 + 1a7d92d commit a857318
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 0 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,4 @@ Contributors (chronological)
- Juan Norris `@juannorris <https://github.com/juannorris>`_
- 장준영 `@jun0jang <https://github.com/jun0jang>`_
- `@ebargtuo <https://github.com/ebargtuo>`_
- Michał Getka `@mgetka <https://github.com/mgetka>`_
54 changes: 54 additions & 0 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime as dt
import numbers
import uuid
import ipaddress
import decimal
import math
import typing
Expand Down Expand Up @@ -51,6 +52,9 @@
"Url",
"URL",
"Email",
"IP",
"IPv4",
"IPv6",
"Method",
"Function",
"Str",
Expand Down Expand Up @@ -1625,6 +1629,56 @@ def __init__(self, *args, **kwargs):
self.validators.insert(0, validator)


class IP(Field):
"""A IP address field.
:param bool exploded: If `True`, serialize ipv6 address in long form, ie. with groups
consisting entirely of zeros included."""

default_error_messages = {"invalid_ip": "Not a valid IP address."}

DESERIALIZATION_CLASS = None # type: typing.Optional[typing.Type]

def __init__(self, *args, exploded=False, **kwargs):
super().__init__(*args, **kwargs)
self.exploded = exploded

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
if value is None:
return None
if self.exploded:
return value.exploded
return value.compressed

def _deserialize(
self, value, attr, data, **kwargs
) -> typing.Optional[typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:
if value is None:
return None
try:
return (self.DESERIALIZATION_CLASS or ipaddress.ip_address)(
utils.ensure_text_type(value)
)
except (ValueError, TypeError) as error:
raise self.make_error("invalid_ip") from error


class IPv4(IP):
"""A IPv4 address field."""

default_error_messages = {"invalid_ip": "Not a valid IPv4 address."}

DESERIALIZATION_CLASS = ipaddress.IPv4Address


class IPv6(IP):
"""A IPv6 address field."""

default_error_messages = {"invalid_ip": "Not a valid IPv6 address."}

DESERIALIZATION_CLASS = ipaddress.IPv6Address


class Method(Field):
"""A field that takes the value returned by a `Schema` method.
Expand Down
58 changes: 58 additions & 0 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime as dt
import uuid
import ipaddress
import decimal
import math

Expand Down Expand Up @@ -850,6 +851,63 @@ def test_invalid_uuid_deserialization(self, in_value):

assert excinfo.value.args[0] == "Not a valid UUID."

def test_ip_field_deserialization(self):
field = fields.IP()
ipv4_str = "140.82.118.3"
result = field.deserialize(ipv4_str)
assert isinstance(result, ipaddress.IPv4Address)
assert str(result) == ipv4_str

ipv6_str = "2a00:1450:4001:824::200e"
result = field.deserialize(ipv6_str)
assert isinstance(result, ipaddress.IPv6Address)
assert str(result) == ipv6_str

@pytest.mark.parametrize(
"in_value", ["malformed", 123, b"\x01\x02\03", "192.168", "ff::aa:1::2"]
)
def test_invalid_ip_deserialization(self, in_value):
field = fields.IP()
with pytest.raises(ValidationError) as excinfo:
field.deserialize(in_value)

assert excinfo.value.args[0] == "Not a valid IP address."

def test_ipv4_field_deserialization(self):
field = fields.IPv4()
ipv4_str = "140.82.118.3"
result = field.deserialize(ipv4_str)
assert isinstance(result, ipaddress.IPv4Address)
assert str(result) == ipv4_str

@pytest.mark.parametrize(
"in_value",
["malformed", 123, b"\x01\x02\03", "192.168", "2a00:1450:4001:81d::200e"],
)
def test_invalid_ipv4_deserialization(self, in_value):
field = fields.IPv4()
with pytest.raises(ValidationError) as excinfo:
field.deserialize(in_value)

assert excinfo.value.args[0] == "Not a valid IPv4 address."

def test_ipv6_field_deserialization(self):
field = fields.IPv6()
ipv6_str = "2a00:1450:4001:824::200e"
result = field.deserialize(ipv6_str)
assert isinstance(result, ipaddress.IPv6Address)
assert str(result) == ipv6_str

@pytest.mark.parametrize(
"in_value", ["malformed", 123, b"\x01\x02\03", "ff::aa:1::2", "192.168.0.1"]
)
def test_invalid_ipv6_deserialization(self, in_value):
field = fields.IPv6()
with pytest.raises(ValidationError) as excinfo:
field.deserialize(in_value)

assert excinfo.value.args[0] == "Not a valid IPv6 address."

def test_deserialization_function_must_be_callable(self):
with pytest.raises(ValueError):
fields.Function(lambda x: None, deserialize="notvalid")
Expand Down
51 changes: 51 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import decimal
import uuid
import ipaddress

import pytest

Expand Down Expand Up @@ -143,6 +144,56 @@ def test_uuid_field(self, user):
assert field.serialize("uuid1", user) == "12345678-1234-5678-1234-567812345678"
assert field.serialize("uuid2", user) is None

def test_ip_address_field(self, user):

ipv4_string = "192.168.0.1"
ipv6_string = "ffff::ffff"
ipv6_exploded_string = ipaddress.ip_address("ffff::ffff").exploded

user.ipv4 = ipaddress.ip_address(ipv4_string)
user.ipv6 = ipaddress.ip_address(ipv6_string)
user.empty_ip = None

field_compressed = fields.IP()
assert isinstance(field_compressed.serialize("ipv4", user), str)
assert field_compressed.serialize("ipv4", user) == ipv4_string
assert isinstance(field_compressed.serialize("ipv6", user), str)
assert field_compressed.serialize("ipv6", user) == ipv6_string
assert field_compressed.serialize("empty_ip", user) is None

field_exploded = fields.IP(exploded=True)
assert isinstance(field_exploded.serialize("ipv6", user), str)
assert field_exploded.serialize("ipv6", user) == ipv6_exploded_string

def test_ipv4_address_field(self, user):

ipv4_string = "192.168.0.1"

user.ipv4 = ipaddress.ip_address(ipv4_string)
user.empty_ip = None

field = fields.IPv4()
assert isinstance(field.serialize("ipv4", user), str)
assert field.serialize("ipv4", user) == ipv4_string
assert field.serialize("empty_ip", user) is None

def test_ipv6_address_field(self, user):

ipv6_string = "ffff::ffff"
ipv6_exploded_string = ipaddress.ip_address("ffff::ffff").exploded

user.ipv6 = ipaddress.ip_address(ipv6_string)
user.empty_ip = None

field_compressed = fields.IPv6()
assert isinstance(field_compressed.serialize("ipv6", user), str)
assert field_compressed.serialize("ipv6", user) == ipv6_string
assert field_compressed.serialize("empty_ip", user) is None

field_exploded = fields.IPv6(exploded=True)
assert isinstance(field_exploded.serialize("ipv6", user), str)
assert field_exploded.serialize("ipv6", user) == ipv6_exploded_string

def test_decimal_field(self, user):
user.m1 = 12
user.m2 = "12.355"
Expand Down

0 comments on commit a857318

Please sign in to comment.