Skip to content

Commit

Permalink
Merge pull request #341 from bpotard/master
Browse files Browse the repository at this point in the history
Fixes #339 and add support for object specific filtering
  • Loading branch information
freakboy3742 committed Sep 30, 2023
2 parents b64cea0 + e661f46 commit b617239
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 34 deletions.
85 changes: 77 additions & 8 deletions src/xero/basemanager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
import json
import requests
from datetime import datetime
from datetime import date, datetime
from urllib.parse import parse_qs
from uuid import UUID
from xml.etree.ElementTree import Element, SubElement, tostring
from xml.parsers.expat import ExpatError

Expand Down Expand Up @@ -40,6 +41,46 @@ class BaseManager:
"Invoices": ["email", "online_invoice"],
"Organisations": ["actions"],
}
OBJECT_FILTER_FIELDS = {
"Invoices": {
"createdByMyApp": bool,
"summaryOnly": bool,
"IDs": list,
"InvoiceNumbers": list,
"ContactIDs": list,
"Statuses": list,
},
"PurchaseOrders": {
"DateFrom": date,
"DateTo": date,
"Status": str,
},
"Quotes": {
"ContactID": UUID,
"ExpiryDateFrom": date,
"ExpiryDateTo": date,
"DateFrom": date,
"DateTo": date,
"Status": str,
"QuoteNumber": str,
},
"Journals": {
"paymentsOnly": bool,
},
"Budgets": {
"DateFrom": date,
"DateTo": date,
},
"Contacts": {
"IDs": list,
"includeArchived": bool,
"summaryOnly": bool,
"searchTerm": str,
},
"TrackingCategories": {
"includeArchived": bool,
},
}
DATETIME_FIELDS = (
"UpdatedDateUTC",
"Updated",
Expand Down Expand Up @@ -397,10 +438,19 @@ def _filter(self, **kwargs):
headers = self.prepare_filtering_date(val)
del kwargs["since"]

# Accept IDs parameter for Invoices and Contacts endpoints
if "IDs" in kwargs:
params["IDs"] = ",".join(kwargs["IDs"])
del kwargs["IDs"]
def get_filter_value(key, value, value_type=None):
if key in self.BOOLEAN_FIELDS or value_type == bool:
return "true" if value else "false"
elif key in self.DATE_FIELDS or value_type == date:
return f"{value.year}-{value.month}-{value.day}"
elif key in self.DATETIME_FIELDS or value_type == datetime:
return value.isoformat()
elif key.endswith("ID") or value_type == UUID:
return "%s" % (
value.hex if type(value) == UUID else UUID(value).hex
)
else:
return value

def get_filter_params(key, value):
last_key = key.split("_")[-1]
Expand Down Expand Up @@ -440,11 +490,30 @@ def generate_param(key, value):
field = field.replace("_", ".")
return fmt % (field, get_filter_params(key, value))

KNOWN_PARAMETERS = ["order", "offset", "page"]
object_params = self.OBJECT_FILTER_FIELDS.get(self.name, {})
LIST_PARAMETERS = list(
filter(lambda x: object_params[x] == list, object_params)
)
EXTRA_PARAMETERS = list(
filter(lambda x: object_params[x] != list, object_params)
)

# Move any known parameter names to the query string
KNOWN_PARAMETERS = ["order", "offset", "page", "includeArchived"]
for param in KNOWN_PARAMETERS:
for param in KNOWN_PARAMETERS + EXTRA_PARAMETERS:
if param in kwargs:
params[param] = get_filter_value(
param, kwargs.pop(param), object_params.get(param, None)
)
# Support xero optimised list filtering; validate IDs we send but may need other validation
for param in LIST_PARAMETERS:
if param in kwargs:
params[param] = kwargs.pop(param)
if param.endswith("IDs"):
params[param] = ",".join(
map(lambda x: UUID(x).hex, kwargs.pop(param))
)
else:
params[param] = ",".join(kwargs.pop(param))

filter_params = []

Expand Down
14 changes: 10 additions & 4 deletions src/xero/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ def __init__(self, response):

elif response.headers["content-type"].startswith("text/html"):
payload = parse_qs(response.text)
self.errors = [payload["oauth_problem"][0]]
self.problem = self.errors[0]
super().__init__(response, payload["oauth_problem_advice"][0])

if payload:
self.errors = [payload["oauth_problem"][0]]
self.problem = self.errors[0]
super().__init__(response, payload["oauth_problem_advice"][0])
else:
# Sometimes xero returns the error message as pure text
# Not sure how to validate this is always the case
self.errors = [response.text]
self.problem = self.errors[0]
super().__init__(response, response.text)
else:
# Extract the messages from the text.
# parseString takes byte content, not unicode.
Expand Down
49 changes: 27 additions & 22 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ManagerTest(unittest.TestCase):
def test_serializer(self):
credentials = Mock(base_url="")
manager = Manager("contacts", credentials)
manager = Manager("Invoice", credentials)

example_invoice_input = {
"Date": datetime.datetime(2015, 6, 6, 16, 25, 2, 711109),
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_serializer(self):

def test_serializer_phones_addresses(self):
credentials = Mock(base_url="")
manager = Manager("contacts", credentials)
manager = Manager("Contacts", credentials)

example_contact_input = {
"ContactID": "565acaa9-e7f3-4fbf-80c3-16b081ddae10",
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_serializer_phones_addresses(self):

def test_serializer_nested_singular(self):
credentials = Mock(base_url="")
manager = Manager("contacts", credentials)
manager = Manager("Invoice", credentials)

example_invoice_input = {
"Date": datetime.datetime(2015, 6, 6, 16, 25, 2, 711109),
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_serializer_nested_singular(self):
def test_filter(self):
"""The filter function should correctly handle various arguments"""
credentials = Mock(base_url="")
manager = Manager("contacts", credentials)
manager = Manager("Contacts", credentials)

uri, params, method, body, headers, singleobject = manager._filter(
order="LastName",
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_filter(self):
self.assertEqual(params, {})
self.assertIsNone(headers)

manager = Manager("invoices", credentials)
manager = Manager("Invoices", credentials)
uri, params, method, body, headers, singleobject = manager._filter(
**{"Contact.ContactID": "3e776c4b-ea9e-4bb1-96be-6b0c7a71a37f"}
)
Expand All @@ -223,22 +223,27 @@ def test_filter(self):
def test_filter_ids(self):
"""The filter function should correctly handle various arguments"""
credentials = Mock(base_url="")
manager = Manager("contacts", credentials)
manager = Manager("Contacts", credentials)

uri, params, method, body, headers, singleobject = manager._filter(
IDs=["1", "2", "3", "4", "5"]
IDs=[
"3e776c4b-ea9e-4bb1-96be-6b0c7a71a37f",
"12345678901234567890123456789012",
]
)

self.assertEqual(method, "get")
self.assertFalse(singleobject)

expected_params = {"IDs": "1,2,3,4,5"}
expected_params = {
"IDs": "3e776c4bea9e4bb196be6b0c7a71a37f,12345678901234567890123456789012"
}
self.assertEqual(params, expected_params)

def test_rawfilter(self):
"""The filter function should correctly handle various arguments"""
credentials = Mock(base_url="")
manager = Manager("invoices", credentials)
manager = Manager("Invoices", credentials)
uri, params, method, body, headers, singleobject = manager._filter(
Status="VOIDED", raw='Name.ToLower()=="test contact"'
)
Expand All @@ -249,7 +254,7 @@ def test_rawfilter(self):
def test_boolean_filter(self):
"""The filter function should correctly handle various arguments"""
credentials = Mock(base_url="")
manager = Manager("invoices", credentials)
manager = Manager("Invoices", credentials)
uri, params, method, body, headers, singleobject = manager._filter(
CanApplyToRevenue=True
)
Expand All @@ -259,14 +264,14 @@ def test_magnitude_filters(self):
"""The filter function should correctlu handle date arguments and gt, lt operators"""
credentials = Mock(base_url="")

manager = Manager("invoices", credentials)
manager = Manager("Invoices", credentials)
uri, params, method, body, headers, singleobject = manager._filter(
**{"Date__gt": datetime.datetime(2007, 12, 6)}
)

self.assertEqual(params, {"where": "Date>DateTime(2007,12,6)"})

manager = Manager("invoices", credentials)
manager = Manager("Invoices", credentials)
uri, params, method, body, headers, singleobject = manager._filter(
**{"Date__lte": datetime.datetime(2007, 12, 6)}
)
Expand All @@ -279,25 +284,25 @@ def test_unit4dps(self):
credentials = Mock(base_url="")

# test 4dps is disabled by default
manager = Manager("contacts", credentials)
manager = Manager("Contacts", credentials)
uri, params, method, body, headers, singleobject = manager._filter()
self.assertEqual(params, {}, "test 4dps not enabled by default")

# test 4dps is enabled by default
manager = Manager("contacts", credentials, unit_price_4dps=True)
manager = Manager("Contacts", credentials, unit_price_4dps=True)
uri, params, method, body, headers, singleobject = manager._filter()
self.assertEqual(params, {"unitdp": 4}, "test 4dps can be enabled explicitly")

# test 4dps can be disable explicitly
manager = Manager("contacts", credentials, unit_price_4dps=False)
manager = Manager("Contacts", credentials, unit_price_4dps=False)
uri, params, method, body, headers, singleobject = manager._filter()
self.assertEqual(params, {}, "test 4dps can be disabled explicitly")

def test_get_params(self):
"""The 'get' methods should pass GET parameters if provided."""

credentials = Mock(base_url="")
manager = Manager("reports", credentials)
manager = Manager("Reports", credentials)

# test no parameters or headers sent by default
uri, params, method, body, headers, singleobject = manager._get("ProfitAndLoss")
Expand All @@ -314,7 +319,7 @@ def test_get_params(self):
self.assertEqual(params, passed_params, "test params can be set")

# test params respect, but can override, existing configuration
manager = Manager("reports", credentials, unit_price_4dps=True)
manager = Manager("Reports", credentials, unit_price_4dps=True)
uri, params, method, body, headers, singleobject = manager._get(
"ProfitAndLoss", params=passed_params
)
Expand All @@ -329,17 +334,17 @@ def test_user_agent_inheritance(self):

# Default used when no user_agent set on manager and credentials has nothing to offer.
credentials = Mock(base_url="", user_agent=None)
manager = Manager("reports", credentials)
manager = Manager("Reports", credentials)
self.assertTrue(manager.user_agent.startswith("pyxero/"))

# Taken from credentials when no user_agent set on manager.
credentials = Mock(base_url="", user_agent="MY_COMPANY-MY_CONSUMER_KEY")
manager = Manager("reports", credentials)
manager = Manager("Reports", credentials)
self.assertEqual(manager.user_agent, "MY_COMPANY-MY_CONSUMER_KEY")

# Manager's user_agent used when explicitly set.
credentials = Mock(base_url="", user_agent="MY_COMPANY-MY_CONSUMER_KEY")
manager = Manager("reports", credentials, user_agent="DemoCompany-1234567890")
manager = Manager("Reports", credentials, user_agent="DemoCompany-1234567890")
self.assertEqual(manager.user_agent, "DemoCompany-1234567890")

@patch("xero.basemanager.requests.post")
Expand All @@ -348,7 +353,7 @@ def test_request_content_type(self, request):

# Default used when no user_agent set on manager and credentials has nothing to offer.
credentials = Mock(base_url="", user_agent=None)
manager = Manager("reports", credentials)
manager = Manager("Reports", credentials)
try:
manager._get_data(lambda: ("_", {}, "post", {}, {}, True))()
except XeroExceptionUnknown:
Expand All @@ -362,7 +367,7 @@ def test_request_body_format(self):

# Default used when no user_agent set on manager and credentials has nothing to offer.
credentials = Mock(base_url="", user_agent=None)
manager = Manager("reports", credentials)
manager = Manager("Reports", credentials)

body = manager.save_or_put({"bing": "bong"})[3]

Expand Down

0 comments on commit b617239

Please sign in to comment.