Fix for CONPY-94:

Don't check for exact type but also for subtype.
This commit is contained in:
Georg Richter
2020-08-06 15:03:48 +02:00
parent cf90a9de5e
commit bfd71e2fa1
2 changed files with 20 additions and 5 deletions

View File

@ -19,6 +19,10 @@
#include "mariadb_python.h"
#include <datetime.h>
#define CHECK_TYPE(obj, type) \
(Py_TYPE((obj)) == type || \
PyType_IsSubtype(Py_TYPE((obj)), type))
/*
converts a Python date/time/datetime object to MYSQL_TIME
*/
@ -721,7 +725,7 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
return 0;
}
if (Py_TYPE(obj) == &PyLong_Type)
if (CHECK_TYPE(obj, &PyLong_Type))
{
size_t b= _PyLong_NumBits(obj);
if (b > paraminfo->bits)
@ -730,13 +734,13 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
paraminfo->is_negative= 1;
paraminfo->type= MYSQL_TYPE_LONGLONG;
return 0;
} else if (Py_TYPE(obj) == &PyBool_Type) {
} else if (CHECK_TYPE(obj, &PyBool_Type)) {
paraminfo->type= MYSQL_TYPE_TINY;
return 0;
} else if (Py_TYPE(obj) == &PyFloat_Type) {
} else if (CHECK_TYPE(obj, &PyFloat_Type)) {
paraminfo->type= MYSQL_TYPE_DOUBLE;
return 0;
} else if (Py_TYPE(obj) == &PyBytes_Type) {
} else if (CHECK_TYPE(obj, &PyBytes_Type)) {
paraminfo->type= MYSQL_TYPE_LONG_BLOB;
return 0;
} else if (PyDate_CheckExact(obj)) {
@ -748,7 +752,7 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
} else if (PyDateTime_CheckExact(obj)) {
paraminfo->type= MYSQL_TYPE_DATETIME;
return 0;
} else if (Py_TYPE(obj) == &PyUnicode_Type) {
} else if (CHECK_TYPE(obj, &PyUnicode_Type)) {
paraminfo->type= MYSQL_TYPE_VAR_STRING;
return 0;
} else if (obj == Py_None) {

View File

@ -14,6 +14,8 @@ from test.base_test import create_connection
server_indicator_version= 100206
class foo(int):
def bar(self):pass
class TestCursor(unittest.TestCase):
@ -1003,5 +1005,14 @@ class TestCursor(unittest.TestCase):
self.assertEqual(row[0], 1);
del cur
def test_conpy94(self):
con= create_connection()
cur= con.cursor()
a= foo(2)
cur.execute("SELECT ?", (a,))
row= cur.fetchone()
self.assertEqual(row[0], 2)
del cur
if __name__ == '__main__':
unittest.main()