mirror of
https://github.com/mariadb-corporation/mariadb-connector-python.git
synced 2025-08-07 11:39:43 +00:00
Fix for CONPY-48:
According to PEP-249 parameters for cursor methods execute() and executemany() are passed as a sequence. The current implementation accepted Tuple only, this fix also allows List as parameter. Valid examples: cursor.execute("SELECT %s", [1]) cursor.execute("SELECT %s", (1,)) cursor.executemany("INSERT INTO t1 VALUES (%s)", [[1],[2]]) cursor.executemany("INSERT INTO t1 VALUES (%s)", [(1,),(2,)]) cursor.executemany("INSERT INTO t1 VALUES (%s)", [[1],(2,)])
This commit is contained in:
@ -718,6 +718,19 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
|
||||
return 1;
|
||||
}
|
||||
|
||||
static PyObject *ListOrTuple_GetItem(PyObject *obj, Py_ssize_t index)
|
||||
{
|
||||
if (Py_TYPE(obj) == &PyList_Type)
|
||||
{
|
||||
return PyList_GetItem(obj, index);
|
||||
} else if (Py_TYPE(obj) == &PyTuple_Type)
|
||||
{
|
||||
return PyTuple_GetItem(obj, index);
|
||||
}
|
||||
/* this should never happen, since the type was checked before */
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
mariadb_get_parameter()
|
||||
@ -742,6 +755,7 @@ mariadb_get_parameter(MrdbCursor *self,
|
||||
PyObject *row= NULL,
|
||||
*column= NULL;
|
||||
|
||||
|
||||
if (is_bulk)
|
||||
{
|
||||
/* check if row_nr and column_nr are in the range from
|
||||
@ -769,7 +783,7 @@ mariadb_get_parameter(MrdbCursor *self,
|
||||
|
||||
if (self->parser->paramstyle != PYFORMAT)
|
||||
{
|
||||
if (!(column= PyTuple_GetItem(row, column_nr)))
|
||||
if (!(column= ListOrTuple_GetItem(row, column_nr)))
|
||||
{
|
||||
mariadb_throw_exception(self->stmt, Mariadb_DataError, 0,
|
||||
"Can't access column number %d at row %d",
|
||||
@ -935,6 +949,19 @@ mariadb_get_parameter_info(MrdbCursor *self,
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Py_ssize_t ListOrTuple_Size(PyObject *obj)
|
||||
{
|
||||
if (Py_TYPE(obj) == &PyList_Type)
|
||||
{
|
||||
return PyList_Size(obj);
|
||||
} else if (Py_TYPE(obj) == &PyTuple_Type)
|
||||
{
|
||||
return PyTuple_Size(obj);
|
||||
}
|
||||
/* this should never happen, since the type was checked before */
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* mariadb_check_bulk_parameters
|
||||
This function validates the specified bulk parameters and
|
||||
translates the field types to MYSQL_TYPE_*.
|
||||
@ -956,7 +983,8 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
|
||||
{
|
||||
PyObject *obj= PyList_GetItem(data, i);
|
||||
if (self->parser->paramstyle != PYFORMAT &&
|
||||
Py_TYPE(obj) != &PyTuple_Type)
|
||||
(Py_TYPE(obj) != &PyTuple_Type &&
|
||||
Py_TYPE(obj) != &PyList_Type))
|
||||
{
|
||||
mariadb_throw_exception(NULL, Mariadb_DataError, 0,
|
||||
"Invalid parameter type in row %d. "\
|
||||
@ -977,7 +1005,7 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
|
||||
|
||||
if (!self->param_count ||
|
||||
(self->parser->paramstyle != PYFORMAT &&
|
||||
self->param_count != PyTuple_Size(obj)))
|
||||
self->param_count != ListOrTuple_Size(obj)))
|
||||
{
|
||||
mariadb_throw_exception(self->stmt, Mariadb_DataError, 1,
|
||||
"Invalid number of parameters in row %d", i+1);
|
||||
|
@ -662,7 +662,7 @@ PyObject *MrdbCursor_execute(MrdbCursor *self,
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
else if (Py_TYPE(Data) != &PyTuple_Type)
|
||||
else if (Py_TYPE(Data) != &PyTuple_Type && Py_TYPE(Data) != &PyList_Type)
|
||||
{
|
||||
PyErr_SetString(PyExc_TypeError, "argument 2 must be tuple!");
|
||||
goto error;
|
||||
|
@ -786,6 +786,22 @@ class TestCursor(unittest.TestCase):
|
||||
self.assertEqual(row[0], 0)
|
||||
del con
|
||||
|
||||
def test_conpy49(self):
|
||||
con= create_connection()
|
||||
cur=con.cursor()
|
||||
cur.execute("select %s", [True])
|
||||
row= cur.fetchone()
|
||||
self.assertEqual(row[0], 1)
|
||||
cur.execute("create temporary table t1 (a int)")
|
||||
cur.executemany("insert into t1 values (%s)", [[1],(2,)])
|
||||
cur.execute("select a from t1")
|
||||
row= cur.fetchone()
|
||||
self.assertEqual(row[0], 1)
|
||||
row= cur.fetchone()
|
||||
self.assertEqual(row[0], 2)
|
||||
del con
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user