diff --git a/include/mariadb_python.h b/include/mariadb_python.h index 0381a49..ebbadb0 100644 --- a/include/mariadb_python.h +++ b/include/mariadb_python.h @@ -273,6 +273,8 @@ extern PyObject *Mariadb_NotSupportedError; extern PyObject *Mariadb_Warning; extern PyObject *Mrdb_Pickle; +extern PyObject *decimal_module, + *decimal_type; /* Object types */ extern PyTypeObject MrdbPool_Type; diff --git a/src/mariadb.c b/src/mariadb.c index dd14f1d..3813c79 100644 --- a/src/mariadb.c +++ b/src/mariadb.c @@ -26,6 +26,8 @@ PyObject *Mrdb_Pickle= NULL; PyObject *cnx_pool= NULL; +PyObject *decimal_module= NULL, + *decimal_type= NULL; extern uint16_t max_pool_size; int @@ -146,6 +148,13 @@ PyMODINIT_FUNC PyInit_mariadb(void) goto error; } + /* Import Decimal support (CONPY-49) */ + if (!(decimal_module= PyImport_ImportModule("decimal")) || + !(decimal_type= PyObject_GetAttrString(decimal_module, "Decimal"))) + { + goto error; + } + /* we need pickle for object serialization */ Mrdb_Pickle= PyImport_ImportModule("pickle"); diff --git a/src/mariadb_codecs.c b/src/mariadb_codecs.c index 445cd07..cca3256 100644 --- a/src/mariadb_codecs.c +++ b/src/mariadb_codecs.c @@ -413,12 +413,19 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column) (Py_ssize_t)length[column]); } break; + case MYSQL_TYPE_NEWDECIMAL: + { + PyObject *decimal; + + decimal= PyObject_CallFunction(decimal_type, "s", (const char *)data); + self->values[column]= decimal; + break; + } case MYSQL_TYPE_STRING: case MYSQL_TYPE_VAR_STRING: case MYSQL_TYPE_JSON: case MYSQL_TYPE_VARCHAR: case MYSQL_TYPE_DECIMAL: - case MYSQL_TYPE_NEWDECIMAL: case MYSQL_TYPE_SET: case MYSQL_TYPE_ENUM: { @@ -703,12 +710,17 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo) } else if (PyDateTime_CheckExact(obj)) { paraminfo->type= MYSQL_TYPE_DATETIME; return 0; - } else if (Py_TYPE(obj) == &PyUnicode_Type) { + } else if (Py_TYPE(obj) == &PyUnicode_Type) { paraminfo->type= MYSQL_TYPE_VAR_STRING; return 0; } else if (obj == Py_None) { paraminfo->type= MYSQL_TYPE_NULL; return 0; + } else if (!strcmp(Py_TYPE(obj)->tp_name, "decimal.Decimal")) { + /* CINPY-49: C-API has no correspondent data type for DECUMAL column type, + so we need to convert decimal.Decimal Object to string during callback */ + paraminfo->type= MYSQL_TYPE_NEWDECIMAL; + return 0; } else { /* no corresponding object, return error */ @@ -868,7 +880,7 @@ mariadb_get_parameter_info(MrdbCursor *self, if (rc == 2) { mariadb_throw_exception(NULL, Mariadb_DataError, 0, - "Data type '%s' in column %d not supported in MariaDB", + "Data type '%s' in column %d not supported in MariaDB Connector/Python", Py_TYPE(paramvalue.value)->tp_name, column_nr); } @@ -1174,9 +1186,19 @@ mariadb_param_to_bind(MYSQL_BIND *bind, bind->buffer= &value->tm; mariadb_pydate_to_tm(bind->buffer_type, value->value, &value->tm); break; + case MYSQL_TYPE_NEWDECIMAL: + { + Py_ssize_t len; + PyObject *obj= PyObject_Str(value->value); + bind->buffer= (void *)PyUnicode_AsUTF8AndSize(obj, &len); + bind->buffer_length= (unsigned long)len; + Py_DECREF(obj); + } + break; case MYSQL_TYPE_VAR_STRING: { Py_ssize_t len; + bind->buffer= (void *)PyUnicode_AsUTF8AndSize(value->value, &len); bind->buffer_length= (unsigned long)len; break; diff --git a/src/mariadb_cursor.c b/src/mariadb_cursor.c index a751273..f6876e9 100644 --- a/src/mariadb_cursor.c +++ b/src/mariadb_cursor.c @@ -726,6 +726,7 @@ PyObject *MrdbCursor_execute(MrdbCursor *self, /* we need to clear the result first, otherwise the cursor remains in usuable state (query out of order) */ + if ((result= mysql_store_result(self->stmt->mysql))) mysql_free_result(result); diff --git a/test/integration/test_cursor.py b/test/integration/test_cursor.py index 433b27e..5a25cbd 100644 --- a/test/integration/test_cursor.py +++ b/test/integration/test_cursor.py @@ -5,6 +5,7 @@ import collections import datetime import unittest import os +from decimal import Decimal import mariadb @@ -786,7 +787,7 @@ class TestCursor(unittest.TestCase): self.assertEqual(row[0], 0) del con - def test_conpy49(self): + def test_conpy48(self): con= create_connection() cur=con.cursor() cur.execute("select %s", [True]) @@ -834,6 +835,15 @@ class TestCursor(unittest.TestCase): cur.execute("drop table if exists temp") del con + def test_conpy49(self): + con= create_connection() + cur=con.cursor() + cur.execute("create temporary table t1 (a decimal(10,2))") + cur.execute("insert into t1 values (?)", (Decimal('10.2'),)) + cur.execute("select a from t1") + row=cur.fetchone() + self.assertEqual(row[0], Decimal('10.20')) + del con if __name__ == '__main__': unittest.main()