Fix for CONPY-48:

According to PEP-249 parameters for cursor methods execute() and executemany() are passed as a sequence.
The current implementation accepted Tuple only, this fix also allows List as parameter.

Valid examples:

cursor.execute("SELECT %s", [1])
cursor.execute("SELECT %s", (1,))

cursor.executemany("INSERT INTO t1 VALUES (%s)", [[1],[2]])
cursor.executemany("INSERT INTO t1 VALUES (%s)", [(1,),(2,)])
cursor.executemany("INSERT INTO t1 VALUES (%s)", [[1],(2,)])
This commit is contained in:
Georg Richter
2020-04-03 18:36:54 +02:00
parent c96cb47825
commit 434a490539
3 changed files with 48 additions and 4 deletions

View File

@ -718,6 +718,19 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
return 1;
}
static PyObject *ListOrTuple_GetItem(PyObject *obj, Py_ssize_t index)
{
if (Py_TYPE(obj) == &PyList_Type)
{
return PyList_GetItem(obj, index);
} else if (Py_TYPE(obj) == &PyTuple_Type)
{
return PyTuple_GetItem(obj, index);
}
/* this should never happen, since the type was checked before */
return NULL;
}
/*
mariadb_get_parameter()
@ -742,6 +755,7 @@ mariadb_get_parameter(MrdbCursor *self,
PyObject *row= NULL,
*column= NULL;
if (is_bulk)
{
/* check if row_nr and column_nr are in the range from
@ -769,7 +783,7 @@ mariadb_get_parameter(MrdbCursor *self,
if (self->parser->paramstyle != PYFORMAT)
{
if (!(column= PyTuple_GetItem(row, column_nr)))
if (!(column= ListOrTuple_GetItem(row, column_nr)))
{
mariadb_throw_exception(self->stmt, Mariadb_DataError, 0,
"Can't access column number %d at row %d",
@ -935,6 +949,19 @@ mariadb_get_parameter_info(MrdbCursor *self,
return 0;
}
static Py_ssize_t ListOrTuple_Size(PyObject *obj)
{
if (Py_TYPE(obj) == &PyList_Type)
{
return PyList_Size(obj);
} else if (Py_TYPE(obj) == &PyTuple_Type)
{
return PyTuple_Size(obj);
}
/* this should never happen, since the type was checked before */
return 0;
}
/* mariadb_check_bulk_parameters
This function validates the specified bulk parameters and
translates the field types to MYSQL_TYPE_*.
@ -956,7 +983,8 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
{
PyObject *obj= PyList_GetItem(data, i);
if (self->parser->paramstyle != PYFORMAT &&
Py_TYPE(obj) != &PyTuple_Type)
(Py_TYPE(obj) != &PyTuple_Type &&
Py_TYPE(obj) != &PyList_Type))
{
mariadb_throw_exception(NULL, Mariadb_DataError, 0,
"Invalid parameter type in row %d. "\
@ -977,7 +1005,7 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
if (!self->param_count ||
(self->parser->paramstyle != PYFORMAT &&
self->param_count != PyTuple_Size(obj)))
self->param_count != ListOrTuple_Size(obj)))
{
mariadb_throw_exception(self->stmt, Mariadb_DataError, 1,
"Invalid number of parameters in row %d", i+1);

View File

@ -662,7 +662,7 @@ PyObject *MrdbCursor_execute(MrdbCursor *self,
goto error;
}
}
else if (Py_TYPE(Data) != &PyTuple_Type)
else if (Py_TYPE(Data) != &PyTuple_Type && Py_TYPE(Data) != &PyList_Type)
{
PyErr_SetString(PyExc_TypeError, "argument 2 must be tuple!");
goto error;

View File

@ -786,6 +786,22 @@ class TestCursor(unittest.TestCase):
self.assertEqual(row[0], 0)
del con
def test_conpy49(self):
con= create_connection()
cur=con.cursor()
cur.execute("select %s", [True])
row= cur.fetchone()
self.assertEqual(row[0], 1)
cur.execute("create temporary table t1 (a int)")
cur.executemany("insert into t1 values (%s)", [[1],(2,)])
cur.execute("select a from t1")
row= cur.fetchone()
self.assertEqual(row[0], 1)
row= cur.fetchone()
self.assertEqual(row[0], 2)
del con
if __name__ == '__main__':
unittest.main()