diff --git a/src/cursor.cpp b/src/cursor.cpp index 07df9fcf..2b9ba96c 100644 --- a/src/cursor.cpp +++ b/src/cursor.cpp @@ -211,7 +211,7 @@ static bool create_name_map(Cursor* cur, SQLSMALLINT field_count, bool lower) TRACE("Col %d: type=%s (%d) colsize=%d\n", (i+1), SqlTypeName(nDataType), (int)nDataType, (int)nColSize); - Object name(TextBufferToObject(enc, szName, cbName)); + Object name(TextBufferToObject(enc, (byte*)szName, cbName)); if (!name) goto done; diff --git a/src/decimal.cpp b/src/decimal.cpp new file mode 100644 index 00000000..9e08d326 --- /dev/null +++ b/src/decimal.cpp @@ -0,0 +1,157 @@ + +#include "pyodbc.h" +#include "wrapper.h" +#include "textenc.h" +#include "decimal.h" + +static PyObject* decimal = 0; +// The Decimal constructor. + +static PyObject* re_sub = 0; +static PyObject* re_compile = 0; +static PyObject* re_escape = 0; + +// In Python 2.7, the 3 strings below are bytes objects. In 3.x they are Unicode objects. + + +static PyObject* pDecimalPoint = 0; +// A "." object which we replace pLocaleDecimal with if they are not the same. + +static PyObject* pLocaleDecimal = 0; +// The decimal character used by the locale. This can be overridden by the user. +// +// In 2.7 this is a bytes object, otherwise unicode. + +static PyObject* pLocaleDecimalEscaped = 0; +// A version of pLocaleDecimal escaped to be used in a regular expression. (The character +// could be something special in regular expressions.) This is zero when pLocaleDecimal is +// ".", indicating no replacement is necessary. + +static PyObject* pRegExpRemove = 0; +// A regular expression that matches characters we want to remove before parsing. + + +bool InitializeDecimal() { + // This is called when the module is initialized and creates globals. + + Object d(PyImport_ImportModule("decimal")); + decimal = PyObject_GetAttrString(d, "Decimal"); + if (!decimal) + return 0; + Object re(PyImport_ImportModule("re")); + re_sub = PyObject_GetAttrString(re, "sub"); + re_escape = PyObject_GetAttrString(re, "escape"); + re_compile = PyObject_GetAttrString(re, "compile"); + + Object module(PyImport_ImportModule("locale")); + Object ldict(PyObject_CallMethod(module, "localeconv", 0)); + Object point(PyDict_GetItemString(ldict, "decimal_point")); + + if (!point) + return false; + +#if PY_MAJOR_VERSION >= 3 + pDecimalPoint = PyUnicode_FromString("."); +#else + pDecimalPoint = PyBytes_FromString("."); +#endif + + if (!pDecimalPoint) + return false; + +#if PY_MAJOR_VERSION >= 3 + if (!SetDecimalPoint(point)) + return false; +#else + // In 2.7, we only support non-Unicode right now. + if (PyBytes_Check(point)) + if (!SetDecimalPoint(point)) + return false; +#endif + + return true; +} + +PyObject* GetDecimalPoint() { + Py_INCREF(pLocaleDecimal); + return pLocaleDecimal; +} + +bool SetDecimalPoint(PyObject* pNew) +{ + if (PyObject_RichCompareBool(pNew, pDecimalPoint, Py_EQ) == 1) + { + // They are the same. + Py_XDECREF(pLocaleDecimal); + pLocaleDecimal = pDecimalPoint; + Py_INCREF(pLocaleDecimal); + + Py_XDECREF(pLocaleDecimalEscaped); + pLocaleDecimalEscaped = 0; + } + else + { + // They are different, so we'll need a regular expression to match it so it can be + // replaced in getdata GetDataDecimal. + + Py_XDECREF(pLocaleDecimal); + pLocaleDecimal = pNew; + Py_INCREF(pLocaleDecimal); + + Object e(PyObject_CallFunctionObjArgs(re_escape, pNew, 0)); + if (!e) + return false; + + Py_XDECREF(pLocaleDecimalEscaped); + pLocaleDecimalEscaped = e.Detach(); + } + +#if PY_MAJOR_VERSION >= 3 + Object s(PyUnicode_FromFormat("[^0-9%U-]+", pLocaleDecimal)); +#else + Object s(PyBytes_FromFormat("[^0-9%s-]+", PyString_AsString(pLocaleDecimal))); +#endif + if (!s) + return false; + + Object r(PyObject_CallFunctionObjArgs(re_compile, s.Get(), 0)); + if (!r) + return false; + + Py_XDECREF(pRegExpRemove); + pRegExpRemove = r.Detach(); + + return true; +} + + +PyObject* DecimalFromText(const TextEnc& enc, const byte* pb, Py_ssize_t cb) +{ + // Creates a Decimal object from a text buffer. + + // The Decimal constructor requires the decimal point to be '.', so we need to convert the + // locale's decimal to it. We also need to remove non-decimal characters such as thousands + // separators and currency symbols. + // + // Remember that the thousands separate will often be '.', so have to do this carefully. + // We'll create a regular expression with 0-9 and whatever the thousands separator is. + + Object text(TextBufferToObject(enc, pb, cb)); + if (!text) + return 0; + + Object cleaned = PyObject_CallMethod(pRegExpRemove, "sub", "sO", "", text.Get()); + if (!cleaned) + return 0; + + if (pLocaleDecimalEscaped) + { + Object c2(PyObject_CallFunctionObjArgs(re_sub, pLocaleDecimalEscaped, pDecimalPoint, 0)); + if (!c2) + return 0; + cleaned.Attach(c2.Detach()); + } + + PyObject* result = PyObject_CallFunctionObjArgs(decimal, cleaned.Get(), 0); + return result; +} diff --git a/src/decimal.h b/src/decimal.h new file mode 100644 index 00000000..32af7122 --- /dev/null +++ b/src/decimal.h @@ -0,0 +1,7 @@ +#pragma once + +bool InitializeDecimal(); +PyObject* GetDecimalPoint(); +bool SetDecimalPoint(PyObject* pNew); + +PyObject* DecimalFromText(const TextEnc& enc, const byte* pb, Py_ssize_t cb); diff --git a/src/getdata.cpp b/src/getdata.cpp index 7982783b..cd664a83 100644 --- a/src/getdata.cpp +++ b/src/getdata.cpp @@ -10,6 +10,7 @@ #include "connection.h" #include "errors.h" #include "dbspecific.h" +#include "decimal.h" #include #include @@ -381,86 +382,11 @@ static PyObject* GetDataDecimal(Cursor* cur, Py_ssize_t iCol) Py_RETURN_NONE; } - Object result(TextBufferToObject(enc, pbData, cbData)); + Object result(DecimalFromText(enc, pbData, cbData)); pyodbc_free(pbData); - if (!result) - return 0; - - // Remove non-digits and convert the databases decimal to a '.' (required by decimal ctor). - // - // We are assuming that the decimal point and digits fit within the size of ODBCCHAR. - - // If Unicode, convert to UTF-8 and copy the digits and punctuation out. Since these are - // all ASCII characters, we can ignore any multiple-byte characters. Fortunately, if a - // character is multi-byte all bytes will have the high bit set. - - char* pch; - Py_ssize_t cch; - -#if PY_MAJOR_VERSION >= 3 - if (PyUnicode_Check(result)) - { - pch = (char*)PyUnicode_AsUTF8AndSize(result, &cch); - } - else - { - int n = PyBytes_AsStringAndSize(result, &pch, &cch); - if (n < 0) - pch = 0; - } -#else - Object encoded; - if (PyUnicode_Check(result)) - { - encoded = PyUnicode_AsUTF8String(result); - if (!encoded) - return 0; - result = encoded.Detach(); - } - int n = PyString_AsStringAndSize(result, &pch, &cch); - if (n < 0) - pch = 0; -#endif - - if (!pch) - return 0; - - // TODO: Why is this limited to 100? Also, can we perform a check on the original and use - // it as-is? - char ascii[100]; - size_t asciilen = 0; - - const char* pchMax = pch + cch; - while (pch < pchMax) - { - if ((*pch & 0x80) == 0) - { - if (*pch == chDecimal) - { - // Must force it to use '.' since the Decimal class doesn't pay attention to the locale. - ascii[asciilen++] = '.'; - } - else if ((*pch >= '0' && *pch <= '9') || *pch == '-') - { - ascii[asciilen++] = (char)(*pch); - } - } - pch++; - } - - ascii[asciilen] = 0; - - Object str(PyString_FromStringAndSize(ascii, (Py_ssize_t)asciilen)); - if (!str) - return 0; - PyObject* decimal_type = GetClassForThread("decimal", "Decimal"); - if (!decimal_type) - return 0; - PyObject* decimal = PyObject_CallFunction(decimal_type, "O", str.Get()); - Py_DECREF(decimal_type); - return decimal; + return result.Detach(); } static PyObject* GetDataBit(Cursor* cur, Py_ssize_t iCol) @@ -875,4 +801,4 @@ PyObject* GetData(Cursor* cur, Py_ssize_t iCol) return RaiseErrorV("HY106", ProgrammingError, "ODBC SQL type %d is not yet supported. column-index=%zd type=%d", (int)pinfo->sql_type, iCol, (int)pinfo->sql_type); -} \ No newline at end of file +} diff --git a/src/pyodbcmodule.cpp b/src/pyodbcmodule.cpp index f30dfd18..d0b69c58 100644 --- a/src/pyodbcmodule.cpp +++ b/src/pyodbcmodule.cpp @@ -20,6 +20,7 @@ #include "cnxninfo.h" #include "params.h" #include "dbspecific.h" +#include "decimal.h" #include #include @@ -152,9 +153,6 @@ bool UseNativeUUID() HENV henv = SQL_NULL_HANDLE; -Py_UNICODE chDecimal = '.'; - - PyObject* GetClassForThread(const char* szModule, const char* szClass) { // Returns the given class, specific to the current thread's interpreter. For performance @@ -249,36 +247,6 @@ bool IsInstanceForThread(PyObject* param, const char* szModule, const char* szCl } -// Initialize the global decimal character and thousands separator character, used when parsing decimal -// objects. -// -static void init_locale_info() -{ - Object module(PyImport_ImportModule("locale")); - if (!module) - { - PyErr_Clear(); - return; - } - - Object ldict(PyObject_CallMethod(module, "localeconv", 0)); - if (!ldict) - { - PyErr_Clear(); - return; - } - - PyObject* value = PyDict_GetItemString(ldict, "decimal_point"); - if (value) - { - if (PyBytes_Check(value) && PyBytes_Size(value) == 1) - chDecimal = (Py_UNICODE)PyBytes_AS_STRING(value)[0]; - if (PyUnicode_Check(value) && PyUnicode_GET_SIZE(value) == 1) - chDecimal = PyUnicode_AS_UNICODE(value)[0]; - } -} - - static bool import_types() { // Note: We can only import types from C extensions since they are shared among all @@ -300,6 +268,8 @@ static bool import_types() GetData_init(); if (!Params_init()) return false; + if (!InitializeDecimal()) + return false; return true; } @@ -708,24 +678,25 @@ static PyObject* mod_timestampfromticks(PyObject* self, PyObject* args) static PyObject* mod_setdecimalsep(PyObject* self, PyObject* args) { UNUSED(self); - if (!PyString_Check(PyTuple_GET_ITEM(args, 0)) && !PyUnicode_Check(PyTuple_GET_ITEM(args, 0))) - return PyErr_Format(PyExc_TypeError, "argument 1 must be a string or unicode object"); - PyObject* value = PyUnicode_FromObject(PyTuple_GetItem(args, 0)); - if (value) - { - if (PyBytes_Check(value) && PyBytes_Size(value) == 1) - chDecimal = (Py_UNICODE)PyBytes_AS_STRING(value)[0]; - if (PyUnicode_Check(value) && PyUnicode_GET_SIZE(value) == 1) - chDecimal = PyUnicode_AS_UNICODE(value)[0]; - } +#if PY_MAJOR_VERSION >= 3 + const char* type = "U"; +#else + const char* type = "S"; +#endif + + PyObject* p; + if (!PyArg_ParseTuple(args, type, &p)) + return 0; + if (!SetDecimalPoint(p)) + return 0; Py_RETURN_NONE; } static PyObject* mod_getdecimalsep(PyObject* self) { UNUSED(self); - return PyUnicode_FromUnicode(&chDecimal, 1); + return GetDecimalPoint(); } static char connect_doc[] = @@ -1245,8 +1216,6 @@ initpyodbc(void) if (!module || !import_types() || !CreateExceptions()) return MODRETURN(0); - init_locale_info(); - const char* szVersion = TOSTRING(PYODBC_VERSION); PyModule_AddStringConstant(module, "version", (char*)szVersion); diff --git a/src/pyodbcmodule.h b/src/pyodbcmodule.h index d0566a7b..2e4e7c76 100644 --- a/src/pyodbcmodule.h +++ b/src/pyodbcmodule.h @@ -56,8 +56,6 @@ inline bool lowercase() return PyObject_GetAttrString(pModule, "lowercase") == Py_True; } -extern Py_UNICODE chDecimal; - bool UseNativeUUID(); // Returns True if pyodbc.native_uuid is true, meaning uuid.UUID objects should be returned. diff --git a/src/textenc.cpp b/src/textenc.cpp index 5553baf5..ca209c2d 100644 --- a/src/textenc.cpp +++ b/src/textenc.cpp @@ -148,7 +148,7 @@ PyObject* EncodeStr(PyObject* str, const TextEnc& enc) } #endif -PyObject* TextBufferToObject(const TextEnc& enc, void* pbData, Py_ssize_t cbData) +PyObject* TextBufferToObject(const TextEnc& enc, const byte* pbData, Py_ssize_t cbData) { // cbData // The length of data in bytes (cb == 'count of bytes'). diff --git a/src/textenc.h b/src/textenc.h index a4ca3463..1a8bed91 100644 --- a/src/textenc.h +++ b/src/textenc.h @@ -139,7 +139,7 @@ struct SQLWChar }; -PyObject* TextBufferToObject(const TextEnc& enc, void* p, Py_ssize_t len); +PyObject* TextBufferToObject(const TextEnc& enc, const byte* p, Py_ssize_t len); // Convert a text buffer to a Python object using the given encoding. // // The buffer can be a SQLCHAR array or SQLWCHAR array. The text encoding diff --git a/tests3/pgtests.py b/tests3/pgtests.py old mode 100644 new mode 100755 index 8724f884..1c2bb406 --- a/tests3/pgtests.py +++ b/tests3/pgtests.py @@ -108,14 +108,14 @@ def tearDown(self): # If we've already closed the cursor or connection, exceptions are thrown. pass - def _simpletest(datatype, value): + def _simpletest(datatype, inval): # A simple test that can be used for any data type where the Python # type we write is also what we expect to receive. def _t(self): - self.cursor.execute('create table t1(value %s)' % datatype) - self.cursor.execute('insert into t1 values (?)', value) - result = self.cursor.execute("select value from t1").fetchone()[0] - self.assertEqual(result, value) + self.cursor.execute('create table t1(inval %s)' % datatype) + self.cursor.execute('insert into t1 values (?)', inval) + outval = self.cursor.execute("select inval from t1").fetchone()[0] + self.assertEqual(outval, inval) return _t def test_drivers(self): @@ -267,6 +267,11 @@ def test_large_bytea_array(self): name = value.replace('.', '_').replace('-', 'neg_') locals()['test_numeric_%s' % name] = _simpletest('numeric(20,6)', Decimal(value)) + def test_large_decimal(self): + # Version 4.0.35 had a buffer overflow here. + self.cursor.execute("SELECT 991113333311111333331111133333111113333311111333337711133333111113333311111333331111133333881113333321341235123512351123.1231245123512341241234::decimal AS n") + self.cursor.fetchone() + def test_small_decimal(self): value = Decimal('100010') # (I use this because the ODBC docs tell us how the bytes should look in the C struct) self.cursor.execute("create table t1(d numeric(19))") @@ -500,7 +505,7 @@ def test_executemany_failure(self): def test_row_slicing(self): - self.cursor.execute("create table t1(a int, b int, c int, d int)"); + self.cursor.execute("create table t1(a int, b int, c int, d int)") self.cursor.execute("insert into t1 values(1,2,3,4)") row = self.cursor.execute("select * from t1").fetchone()