Test fixes

This commit is contained in:
Georg Richter
2021-07-18 18:16:38 +02:00
parent 8ee02ac7e3
commit 2bd40ca1de
11 changed files with 169 additions and 972 deletions

View File

@ -41,8 +41,8 @@ class Connection(mariadb._mariadb.connection):
super().__init__(*args, **kwargs)
def cursor(self, **kwargs):
return mariadb.Cursor(self, **kwargs)
cursor= mariadb.Cursor(self, **kwargs)
return cursor
def close(self):
if self._Connection__pool:
@ -58,6 +58,9 @@ class Connection(mariadb._mariadb.connection):
"Closes connection."
self.close()
def get_server_version(self):
return self.server_version_info
@property
def character_set(self):
"""Client character set."""
@ -73,6 +76,11 @@ class Connection(mariadb._mariadb.connection):
"""Returns server status flags."""
return super()._server_status
@property
def server_version_info(self):
version= self.server_version
return (int(version / 10000), int((version % 10000) / 100), version % 100)
@property
def socket(self):
"""Returns the socket used for database connection"""

View File

@ -54,7 +54,6 @@ class Cursor(mariadb._mariadb.cursor):
self._description= None
self._transformed_statement= None
self._prepared= False
self._parsed= False
self._prev_stmt= None
self._force_binary= None
@ -155,11 +154,12 @@ class Cursor(mariadb._mariadb.cursor):
"""
logging.debug("parse_execute: %s" % statement)
if not statement:
raise mariadb.ProgrammingError("empty statement")
# parse statement
if self._prev_stmt != statement:
if self.statement != statement:
super()._parse(statement)
self._prev_stmt= statement
self._reprepare= True
@ -184,7 +184,8 @@ class Cursor(mariadb._mariadb.cursor):
place holder for execute() description
"""
logging.debug("execute prepared %s" % self._prepared)
# Parse statement
do_parse= True
if buffered:
self.buffered= True
@ -193,24 +194,22 @@ class Cursor(mariadb._mariadb.cursor):
logging.debug("clearing result set")
self._clear_result()
logging.debug("cleared")
# if we have a prepared cursor, we have to set statement
# to previous statement
if self._prepared and self._prev_stmt:
statement= self._prev_stmt
self_parsed= False
# to previous statement and don't need to parse
if self._prepared and self.statement:
statement= self.statement
do_parse= False
if statement != self._prev_stmt:
self._parsed= False
# Avoid reparsing of same statement
if statement == self.statement:
do_parse= True
# parse statement and check param style
if not self._parsed:
if do_parse:
self._parse_execute(statement, (data))
logging.debug("transformed: %s -> %s" % (statement, self._transformed_statement))
self._parsed= True
self._description= None
# check if data parameters are passed in correct format
if (self._paramstyle == 3 and not isinstance(data, dict)):
raise TypeError("Argument 2 must be Dict")
elif self._paramstyle < 3 and (not isinstance(data, (tuple, list))):
@ -220,11 +219,12 @@ class Cursor(mariadb._mariadb.cursor):
self._data= data
else:
self._data= None
# No need for binary protocol, if no parameters supplied, except CALL statement
# If statement doesn't contain parameters we force to run in text
# mode, unless a server side cursor or stored procedure will be
# executed.
if self._command != SQL_CALL and self._cursor_type == 0:
self._text= True
logging.debug("Executing %s in %s mode" % (statement, "text" if self._text else "binary"))
if self._force_binary:
self._text= False
@ -232,10 +232,10 @@ class Cursor(mariadb._mariadb.cursor):
# in text mode we need to substitute parameters
# and store transformed statement
if (self.paramcount > 0):
logging.debug("transform")
self._transformed_statement= self._add_text_params()
else:
self._transformed_statement= self.statement
logging.debug("text mode")
self._execute_text(self._transformed_statement)
self._readresponse()
else:
@ -275,8 +275,6 @@ class Cursor(mariadb._mariadb.cursor):
def _fetch_row(self):
# if there is no result set, PEP-249 requires to raise an
# exception
if not self.field_count:
@ -296,6 +294,9 @@ class Cursor(mariadb._mariadb.cursor):
row= tuple(l)
return row
def close(self):
super().close()
def fetchone(self):
row= self._fetch_row()
if not row:
@ -309,14 +310,57 @@ class Cursor(mariadb._mariadb.cursor):
ret= row
return ret
def fetchmany(self, size=0):
rows=[]
if size == 0:
size= self.arraysize
for count in range(0, size):
row= self.fetchone()
if row:
rows.append(row)
return rows
def fetchall(self):
rows=[];
for row in self:
rows.append((row))
return rows
def close(self):
super().close()
def scroll(self, value, mode="relative"):
"""
Scroll the cursor in the result set to a new position according to mode.
If mode is "relative" (default), value is taken as offset to the current
position in the result set, if set to absolute, value states an absolute
target position.
"""
if self.field_count == 0:
raise mariadb.ProgrammingError("Cursor doesn't have a result set")
if not self.buffered:
raise mariadb.ProgrammingError("This method is available only for cursors "\
"with a buffered result set.")
if mode != "absolute" and mode != "relative":
raise mariadb.DataError("Invalid or unknown scroll mode specified.")
if value == 0 and mode != "absolute":
raise mariadb.DataError("Invalid position value 0.")
if mode == "relative":
if self.rownumber + value < 0 or \
self.rownumber + value > self.rowcount:
raise mariadb.DataError("Position value is out of range.")
new_pos= self.rownumber + value
else:
if value < 0 or value >= self.rowcount:
raise mariadb.DataError("Position value is out of range.")
new_pos= value
self._seek(new_pos);
self._rownumber= new_pos;
def __enter__(self):
"""Returns a copy of the cursor."""
@ -326,6 +370,16 @@ class Cursor(mariadb._mariadb.cursor):
"""Closes cursor."""
self.close()
def __del__(self):
self.close()
@property
def lastrowid(self):
id= self.insert_id
if id > 0:
return id
return None
@property
def connection(self):
"""Read-Only attribute which returns the reference to the connection

View File

@ -81,12 +81,6 @@ MrdbConnection_setdb(MrdbConnection *self, PyObject *arg, void *closure);
static PyObject *
MrdbConnection_escape_string(MrdbConnection *self, PyObject *args);
static PyObject *
MrdbConnection_server_version(MrdbConnection *self);
static PyObject *
MrdbConnection_server_info(MrdbConnection *self);
static PyObject *
MrdbConnection_warnings(MrdbConnection *self);
@ -97,9 +91,6 @@ static int
MrdbConnection_setautocommit(MrdbConnection *self, PyObject *arg,
void *closure);
static PyObject *
MrdbConnection_get_server_version(MrdbConnection *self);
static PyObject *
MrdbConnection_get_server_status(MrdbConnection *self);
@ -126,10 +117,6 @@ MrdbConnection_sets[]=
connection_warnings__doc__, NULL},
{"_server_status", (getter)MrdbConnection_get_server_status, NULL,
NULL, NULL},
{"server_version", (getter)MrdbConnection_server_version, NULL,
connection_server_version__doc__, NULL},
{"server_info", (getter)MrdbConnection_server_info, NULL,
connection_server_info__doc__, NULL},
GETTER_EXCEPTION("Error", Mariadb_Error, ""),
GETTER_EXCEPTION("Warning", Mariadb_Warning, ""),
GETTER_EXCEPTION("InterfaceError", Mariadb_InterfaceError, ""),
@ -201,11 +188,6 @@ MrdbConnection_Methods[] =
METH_VARARGS,
connection_change_user__doc__
},
{ "get_server_version",
(PyCFunction)MrdbConnection_get_server_version,
METH_NOARGS,
connection_get_server_version__doc__,
},
{ "kill",
(PyCFunction)MrdbConnection_kill,
METH_VARARGS,
@ -262,11 +244,16 @@ PyMemberDef MrdbConnection_Members[] =
offsetof(MrdbConnection, port),
READONLY,
"Database server TCP/IP port"},
{"server_version_info",
T_OBJECT,
offsetof(MrdbConnection, server_version_info),
{"server_version",
T_ULONG,
offsetof(MrdbConnection, server_version),
READONLY,
"Server version in tuple format"},
"Server version"},
{"server_info",
T_STRING,
offsetof(MrdbConnection, server_info),
READONLY,
"Server info"},
{"unix_socket",
T_STRING,
offsetof(MrdbConnection, unix_socket),
@ -310,7 +297,6 @@ static void MrdbConnection_GetCapabilities(MrdbConnection *self)
&self->client_capabilities);
}
void MrdbConnection_SetAttributes(MrdbConnection *self)
{
mariadb_get_infov(self->mysql, MARIADB_CONNECTION_HOST, &self->host);
@ -491,17 +477,8 @@ MrdbConnection_Initialize(MrdbConnection *self,
}
}
/* CONPY-129: server_version_info */
if ((self->server_version_info= PyTuple_New(3)))
{
long major= mysql_get_server_version(self->mysql) / 10000;
long minor = (mysql_get_server_version(self->mysql) % 10000) / 100;
long patch= mysql_get_server_version(self->mysql) % 100;
if (PyTuple_SetItem(self->server_version_info, 0, PyLong_FromLong(major)) ||
PyTuple_SetItem(self->server_version_info, 1, PyLong_FromLong(minor)) ||
PyTuple_SetItem(self->server_version_info, 2, PyLong_FromLong(patch)))
goto end;
}
self->server_version= mysql_get_server_version(self->mysql);
self->server_info= mysql_get_server_info(self->mysql);
/*
if (asynchronous && PyObject_IsTrue(asynchronous))
{
@ -627,8 +604,8 @@ void MrdbConnection_dealloc(MrdbConnection *self)
Py_BEGIN_ALLOW_THREADS
mysql_close(self->mysql);
Py_END_ALLOW_THREADS
self->mysql= NULL;
}
Py_XDECREF(self->server_version_info);
Py_TYPE(self)->tp_free((PyObject*)self);
}
}
@ -656,51 +633,6 @@ MrdbConnection_executecommand(MrdbConnection *self,
Py_RETURN_NONE;
}
/*
PyObject *
MrdbConnection_readresponse(MrdbConnection *self)
{
int rc;
PyObject *result, *tmp;
MARIADB_CHECK_CONNECTION(self, NULL);
Py_BEGIN_ALLOW_THREADS;
rc= self->mysql->methods->db_read_query_result(self->mysql);
Py_END_ALLOW_THREADS;
if (rc)
{
mariadb_throw_exception(self->mysql, NULL, 0, NULL);
return NULL;
}
result= PyDict_New();
tmp= PyLong_FromLong((long)mysql_field_count(self->mysql));
PyDict_SetItemString(result, "field_count", tmp);
Py_DECREF(tmp);
tmp= PyLong_FromLong((long)mysql_affected_rows(self->mysql));
PyDict_SetItemString(result, "affected_rows", tmp);
Py_DECREF(tmp);
tmp= PyLong_FromLong((long)mysql_insert_id(self->mysql));
PyDict_SetItemString(result, "insert_id", tmp);
Py_DECREF(tmp);
tmp= PyLong_FromLong((long)self->mysql->server_status);
PyDict_SetItemString(result, "server_status", tmp);
Py_DECREF(tmp);
tmp= PyLong_FromLong((long)mysql_warning_count(self->mysql));
PyDict_SetItemString(result, "warning_count", tmp);
Py_DECREF(tmp);
return result;
}
*/
PyObject *MrdbConnection_close(MrdbConnection *self)
{
MARIADB_CHECK_CONNECTION(self, NULL);
@ -714,8 +646,7 @@ PyObject *MrdbConnection_close(MrdbConnection *self)
mysql_close(self->mysql);
Py_END_ALLOW_THREADS
self->mysql= NULL;
Py_INCREF(Py_None);
return Py_None;
Py_RETURN_NONE;
}
static PyObject *MrdbConnection_cursor(MrdbConnection *self,
@ -1373,22 +1304,6 @@ static PyObject *MrdbConnection_escape_string(MrdbConnection *self,
}
/* }}} */
/* {{{ MrdbConnection_server_version */
static PyObject *MrdbConnection_server_version(MrdbConnection *self)
{
MARIADB_CHECK_CONNECTION(self, NULL);
return PyLong_FromLong((long)mysql_get_server_version(self->mysql));
}
/* }}} */
/* {{{ MrdbConnection_server_info */
static PyObject *MrdbConnection_server_info(MrdbConnection *self)
{
MARIADB_CHECK_CONNECTION(self, NULL);
return PyUnicode_FromString(mysql_get_server_info(self->mysql));
}
/* }}} */
/* {{{ MrdbConnection_setautocommit */
static int MrdbConnection_setautocommit(MrdbConnection *self, PyObject *arg,
void *closure)
@ -1429,11 +1344,6 @@ static PyObject *MrdbConnection_getautocommit(MrdbConnection *self)
}
/* }}} */
static PyObject *MrdbConnection_get_server_version(MrdbConnection *self)
{
return self->server_version_info;
}
static PyObject *MrdbConnection_get_server_status(MrdbConnection *self)
{
uint32_t server_status;

File diff suppressed because it is too large Load Diff