First implementation for status callback

This commit is contained in:
Georg Richter
2022-08-05 14:41:43 +02:00
parent 09e5cad2c0
commit a9ad1fcaee
5 changed files with 236 additions and 66 deletions

View File

@ -31,6 +31,9 @@
#include <docs/common.h>
#include <limits.h>
#define CHECK_TYPE(obj, type) \
(Py_TYPE((obj)) == type || PyType_IsSubtype(Py_TYPE((obj)), type))
#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type)
{ ob->ob_type = type; }
@ -55,8 +58,6 @@ typedef CRITICAL_SECTION pthread_mutex_t;
#include <limits.h>
#endif /* defined(_WIN32) */
#define CHECK_TYPE(obj, type) \
(Py_TYPE((obj)) == type || PyType_IsSubtype(Py_TYPE((obj)), type))
#ifndef MIN
#define MIN(a,b) (a) < (b) ? (a) : (b)
@ -181,6 +182,7 @@ typedef struct st_parser {
/* PEP-249: Connection object */
typedef struct {
PyObject_HEAD
PyThreadState *thread_state;
MYSQL *mysql;
int open;
uint8_t is_buffered;
@ -199,10 +201,13 @@ typedef struct {
uint8_t status;
uint8_t asynchronous;
struct timespec last_used;
PyThreadState *thread_state;
unsigned long thread_id;
char *server_info;
uint8_t closed;
#if MARIADB_PACKAGE_VERSION_ID > 30301
PyObject *status_callback;
#endif
PyObject *last_executed_stmt;
} MrdbConnection;
typedef struct {
@ -275,7 +280,6 @@ typedef struct {
uint8_t fetched;
uint8_t closed;
uint8_t reprepare;
PyThreadState *thread_state;
enum enum_paramstyle paramstyle;
} MrdbCursor;
@ -736,6 +740,33 @@ MrdbParser_parse(MrdbParser *p, uint8_t is_batch, char *errmsg, size_t errmsg_le
#endif /* __i386__ OR _WIN32 */
#ifdef _WIN32
//#define alloca _malloca
#endif
/* Due to callback functions we cannot use PY_BEGIN/END_ALLOW_THREADS */
#define MARIADB_BEGIN_ALLOW_THREADS(obj)\
{\
(obj)->thread_state= PyEval_SaveThread();\
}
#define MARIADB_END_ALLOW_THREADS(obj)\
if ((obj)->thread_state)\
{\
PyEval_RestoreThread((obj)->thread_state);\
(obj)->thread_state= NULL;\
}
#define MARIADB_UNBLOCK_THREADS(obj)\
{\
if ((obj)->thread_state)\
{\
_save= (obj)->thread_state;\
PyEval_RestoreThread(_save);\
(obj)->thread_state= NULL;\
}\
}
#define MARIADB_BLOCK_THREADS(obj)\
if (_save)\
{\
(obj)->thread_state= PyEval_SaveThread();\
_save= NULL;\
}

View File

@ -48,7 +48,7 @@ class Connection(mariadb._mariadb.connection):
Establishes a connection to a database server and returns a connection
object.
"""
self._last_executed_statement= None
self._socket= None
self.__in_use= 0
self.__pool = None
@ -440,7 +440,7 @@ class Connection(mariadb._mariadb.connection):
"""Get default database for connection."""
self._check_closed()
return self._mariadb_get_info(INFO.SCHEMA, str)
return self._mariadb_get_info(INFO.SCHEMA)
@database.setter
def database(self, schema):
@ -462,7 +462,7 @@ class Connection(mariadb._mariadb.connection):
"""
self._check_closed()
return self._mariadb_get_info(INFO.USER, str)
return self._mariadb_get_info(INFO.USER)
@property
def character_set(self):
@ -479,42 +479,42 @@ class Connection(mariadb._mariadb.connection):
"""Client capability flags."""
self._check_closed()
return self._mariadb_get_info(INFO.CLIENT_CAPABILITIES, int)
return self._mariadb_get_info(INFO.CLIENT_CAPABILITIES)
@property
def server_capabilities(self):
"""Server capability flags."""
self._check_closed()
return self._mariadb_get_info(INFO.SERVER_CAPABILITIES, int)
return self._mariadb_get_info(INFO.SERVER_CAPABILITIES)
@property
def extended_server_capabilities(self):
"""Extended server capability flags (only for MariaDB database servers)."""
self._check_closed()
return self._mariadb_get_info(INFO.EXTENDED_SERVER_CAPABILITIES, int)
return self._mariadb_get_info(INFO.EXTENDED_SERVER_CAPABILITIES)
@property
def server_port(self):
"""Database server TCP/IP port. This value will be 0 in case of a unix socket connection."""
self._check_closed()
return self._mariadb_get_info(INFO.PORT, int)
return self._mariadb_get_info(INFO.PORT)
@property
def unix_socket(self):
"""Unix socket name."""
self._check_closed()
return self._mariadb_get_info(INFO.UNIX_SOCKET, str)
return self._mariadb_get_info(INFO.UNIX_SOCKET)
@property
def server_name(self):
"""Name or IP address of database server."""
self._check_closed()
return self._mariadb_get_info(INFO.HOST, str)
return self._mariadb_get_info(INFO.HOST)
@property
def collation(self):
@ -527,21 +527,21 @@ class Connection(mariadb._mariadb.connection):
"""Server version in alphanumerical format (str)"""
self._check_closed()
return self._mariadb_get_info(INFO.SERVER_VERSION, str)
return self._mariadb_get_info(INFO.SERVER_VERSION)
@property
def tls_cipher(self):
"""TLS cipher suite if a secure connection is used."""
self._check_closed()
return self._mariadb_get_info(INFO.SSL_CIPHER, str)
return self._mariadb_get_info(INFO.SSL_CIPHER)
@property
def tls_version(self):
"""TLS protocol version if a secure connection is used."""
self._check_closed()
return self._mariadb_get_info(INFO.TLS_VERSION, str)
return self._mariadb_get_info(INFO.TLS_VERSION)
@property
def server_status(self):
@ -550,7 +550,7 @@ class Connection(mariadb._mariadb.connection):
"""
self._check_closed()
return self._mariadb_get_info(INFO.SERVER_STATUS, int)
return self._mariadb_get_info(INFO.SERVER_STATUS)
@property
def server_version(self):
@ -562,7 +562,7 @@ class Connection(mariadb._mariadb.connection):
"""
self._check_closed()
return self._mariadb_get_info(INFO.SERVER_VERSION_ID, int)
return self._mariadb_get_info(INFO.SERVER_VERSION_ID)
@property
def server_version_info(self):

View File

@ -242,6 +242,8 @@ class Cursor(mariadb._mariadb.cursor):
self.check_closed()
self.connection._last_executed_statement= statement
# Parse statement
do_parse= True
self._rowcount= 0
@ -314,6 +316,8 @@ class Cursor(mariadb._mariadb.cursor):
if not parameters or not len(parameters):
raise mariadb.ProgrammingError("No data provided")
self.connection._last_executed_statement= statement
# clear pending results
if self.field_count:
self._clear_result()
@ -373,7 +377,8 @@ class Cursor(mariadb._mariadb.cursor):
The cursor will be unusable from this point forward; an Error (or subclass)
exception will be raised if any operation is attempted with the cursor."
"""
super().close()
if not self.connection.is_closed:
super().close()
def fetchone(self):
"""

View File

@ -39,6 +39,7 @@ char *dsn_keys[]= {
"client_flag", "pool_name", "pool_size",
"pool_reset_connection", "plugin_dir",
"username", "db", "passwd",
"status_callback",
NULL
};
@ -176,6 +177,94 @@ PyMemberDef MrdbConnection_Members[] =
"Indicates if connection was closed"},
{NULL} /* always last */
};
#if MARIADB_PACKAGE_VERSION_ID > 30301
void MrdbConnection_process_status_info(void *data, enum enum_mariadb_status_info type, ...)
{
va_list ap;
PyThreadState *_save= NULL;
MrdbConnection *self= (MrdbConnection *)data;
PyObject *dict= NULL;
PyObject *dict_key= NULL, *dict_val= NULL;
va_start(ap, type);
if (self->status_callback) {
if (type == STATUS_TYPE)
{
unsigned int server_status= va_arg(ap, int);
MARIADB_UNBLOCK_THREADS(self);
dict_key= PyUnicode_FromString("server_status");
dict_val= PyLong_FromLong(server_status);
dict= PyDict_New();
PyDict_SetItem(dict, dict_key, dict_val);
Py_DECREF(dict_key);
Py_DECREF(dict_val);
PyObject_CallFunction(self->status_callback, "OO", (PyObject *)data, dict);
MARIADB_BLOCK_THREADS(self);
}
}
if (type == SESSION_TRACK_TYPE)
{
enum enum_session_state_type track_type= va_arg(ap, enum enum_session_state_type);
MARIADB_UNBLOCK_THREADS(self);
if (self->status_callback) {
switch (track_type) {
case SESSION_TRACK_SCHEMA:
dict_key= PyUnicode_FromString("schema");
break;
case SESSION_TRACK_STATE_CHANGE:
dict_key= PyUnicode_FromString("state_change");
break;
default:
break;
}
}
if (dict_key)
{
MARIADB_CONST_STRING *val= va_arg(ap, MARIADB_CONST_STRING *);
dict_val= PyUnicode_FromStringAndSize(val->str, val->length);
dict= PyDict_New();
PyDict_SetItem(dict, dict_key, dict_val);
Py_DECREF(dict_key);
Py_DECREF(dict_val);
PyObject_CallFunction(self->status_callback, "OO", (PyObject *)data, dict);
}
if (track_type == SESSION_TRACK_SYSTEM_VARIABLES)
{
MARIADB_CONST_STRING *key= va_arg(ap, MARIADB_CONST_STRING *);
MARIADB_CONST_STRING *val= va_arg(ap, MARIADB_CONST_STRING *);
if (!strncmp(key->str, "character_set_client", key->length) &&
strncmp(val->str, "utf8mb4", val->length))
{
char charset[128];
memcpy(charset, val->str, val->length);
charset[val->length]= 0;
va_end(ap);
mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 1,
"Character set '%s' is not supported", charset);
}
if (self->status_callback)
{
dict_key= PyUnicode_FromStringAndSize(key->str, key->length);
dict_val= PyUnicode_FromStringAndSize(val->str, val->length);
dict= PyDict_New();
PyDict_SetItem(dict, dict_key, dict_val);
Py_DECREF(dict_key);
Py_DECREF(dict_val);
PyObject_CallFunction(self->status_callback, "OO", (PyObject *)data, dict);
}
}
MARIADB_BLOCK_THREADS(self);
}
va_end(ap);
}
#endif
static int
MrdbConnection_Initialize(MrdbConnection *self,
@ -197,9 +286,10 @@ MrdbConnection_Initialize(MrdbConnection *self,
unsigned int local_infile= 0xFF;
unsigned int connect_timeout=0, read_timeout=0, write_timeout=0,
compress= 0, ssl_verify_cert= 0;
PyObject *status_callback= NULL;
if (!PyArg_ParseTupleAndKeywords(args, dsnargs,
"|zzzzziziiibbzzzzzzzzzzibizibzzzz:connect",
"|zzzzziziiibbzzzzzzzzzzibizibzzzzO:connect",
dsn_keys,
&dsn, &host, &user, &password, &schema, &port, &socket,
&connect_timeout, &read_timeout, &write_timeout,
@ -210,7 +300,7 @@ MrdbConnection_Initialize(MrdbConnection *self,
&ssl_verify_cert, &ssl_enforce,
&client_flags, &pool_name, &pool_size,
&reset_session, &plugin_dir,
&user, &schema, &password))
&user, &schema, &password, &status_callback))
{
return -1;
}
@ -222,6 +312,16 @@ MrdbConnection_Initialize(MrdbConnection *self,
return -1;
}
#if MARIADB_PACKAGE_VERSION_ID < 30302
if (status_callback)
{
mariadb_throw_exception(NULL, Mariadb_OperationalError, 1,
"Use of status_callback requires Connector/C version 3.3.2 or higher.");
return -1;
}
#endif
self->status_callback= status_callback;
if (!(self->mysql= mysql_init(NULL)))
{
mariadb_throw_exception(self->mysql, Mariadb_OperationalError, 1,
@ -229,7 +329,13 @@ MrdbConnection_Initialize(MrdbConnection *self,
return -1;
}
Py_BEGIN_ALLOW_THREADS;
#if MARIADB_PACKAGE_VERSION_ID > 30301
if (mysql_optionsv(self->mysql, MARIADB_OPT_STATUS_CALLBACK, MrdbConnection_process_status_info, self))
goto end;
#endif
MARIADB_BEGIN_ALLOW_THREADS(self);
if (mysql_options(self->mysql, MYSQL_SET_CHARSET_NAME, mariadb_default_charset))
goto end;
@ -326,7 +432,7 @@ MrdbConnection_Initialize(MrdbConnection *self,
has_error= 0;
end:
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (has_error)
{
@ -454,9 +560,9 @@ void MrdbConnection_dealloc(MrdbConnection *self)
{
if (self->mysql)
{
Py_BEGIN_ALLOW_THREADS
MARIADB_BEGIN_ALLOW_THREADS(self)
mysql_close(self->mysql);
Py_END_ALLOW_THREADS
MARIADB_END_ALLOW_THREADS(self)
self->mysql= NULL;
}
Py_TYPE(self)->tp_free((PyObject*)self);
@ -474,9 +580,9 @@ MrdbConnection_executecommand(MrdbConnection *self,
if (!PyArg_ParseTuple(args, "s", &cmd))
return NULL;
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self);
rc= mysql_send_query(self->mysql, cmd, (long)strlen(cmd));
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (rc)
{
@ -492,9 +598,9 @@ PyObject *MrdbConnection_close(MrdbConnection *self)
/* Todo: check if all the cursor stuff is deleted (when using prepared
statements this should be handled in mysql_close) */
Py_BEGIN_ALLOW_THREADS
MARIADB_BEGIN_ALLOW_THREADS(self)
mysql_close(self->mysql);
Py_END_ALLOW_THREADS
MARIADB_END_ALLOW_THREADS(self)
self->mysql= NULL;
self->closed= 1;
Py_RETURN_NONE;
@ -516,9 +622,9 @@ PyObject *MrdbConnection_ping(MrdbConnection *self)
MARIADB_CHECK_CONNECTION(self, NULL);
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self);
rc= mysql_ping(self->mysql);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (rc) {
mariadb_throw_exception(self->mysql, Mariadb_InterfaceError, 0, NULL);
@ -545,9 +651,9 @@ PyObject *MrdbConnection_change_user(MrdbConnection *self,
if (!PyArg_ParseTuple(args, "sss", &user, &password, &database))
return NULL;
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self);
rc= mysql_change_user(self->mysql, user, password, database);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (rc)
{
@ -607,11 +713,9 @@ MrdbConnection_getinfo(MrdbConnection *self, PyObject *args)
uint8_t b;
} val;
PyObject *type;
uint32_t option;
if (!PyArg_ParseTuple(args, "iO", &option, &type))
if (!PyArg_ParseTuple(args, "i", &option))
return NULL;
memset(&val, 0, sizeof(val));
@ -623,15 +727,45 @@ MrdbConnection_getinfo(MrdbConnection *self, PyObject *args)
return NULL;
}
if ((PyTypeObject *)type == &PyUnicode_Type)
{
switch (option) {
case MARIADB_CONNECTION_UNIX_SOCKET:
case MARIADB_CONNECTION_USER:
case MARIADB_CHARSET_NAME:
case MARIADB_TLS_LIBRARY:
case MARIADB_CLIENT_VERSION:
case MARIADB_CONNECTION_HOST:
case MARIADB_CONNECTION_INFO:
case MARIADB_CONNECTION_SCHEMA:
case MARIADB_CONNECTION_SQLSTATE:
case MARIADB_CONNECTION_SOCKET:
case MARIADB_CONNECTION_SSL_CIPHER:
case MARIADB_CONNECTION_TLS_VERSION:
case MARIADB_CONNECTION_SERVER_VERSION:
return PyUnicode_FromString(val.str ? val.str : "");
}
if ((PyTypeObject *)type == &PyLong_Type)
break;
case MARIADB_CHARSET_ID:
case MARIADB_CLIENT_VERSION_ID:
case MARIADB_CONNECTION_ASYNC_TIMEOUT:
case MARIADB_CONNECTION_ASYNC_TIMEOUT_MS:
case MARIADB_CONNECTION_PORT:
case MARIADB_CONNECTION_PROTOCOL_VERSION_ID:
case MARIADB_CONNECTION_SERVER_TYPE:
case MARIADB_CONNECTION_SERVER_VERSION_ID:
case MARIADB_CONNECTION_TLS_VERSION_ID:
case MARIADB_MAX_ALLOWED_PACKET:
case MARIADB_NET_BUFFER_LENGTH:
case MARIADB_CONNECTION_SERVER_STATUS:
case MARIADB_CONNECTION_SERVER_CAPABILITIES:
case MARIADB_CONNECTION_EXTENDED_SERVER_CAPABILITIES:
case MARIADB_CONNECTION_CLIENT_CAPABILITIES:
case MARIADB_CONNECTION_BYTES_READ:
case MARIADB_CONNECTION_BYTES_SENT:
return PyLong_FromLong((long)val.num);
if ((PyTypeObject *)type == &PyBool_Type)
return val.b ? Py_True : Py_False;
Py_RETURN_NONE;
break;
default:
Py_RETURN_NONE;
}
}
/* {{{ MrdbConnection_reconnect */
@ -648,9 +782,9 @@ PyObject *MrdbConnection_reconnect(MrdbConnection *self)
if (!save_reconnect)
mysql_optionsv(self->mysql, MYSQL_OPT_RECONNECT, &reconnect);
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self);
rc= mariadb_reconnect(self->mysql);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (!save_reconnect)
mysql_optionsv(self->mysql, MYSQL_OPT_RECONNECT, &save_reconnect);
@ -672,9 +806,9 @@ PyObject *MrdbConnection_reset(MrdbConnection *self)
int rc;
MARIADB_CHECK_CONNECTION(self, NULL);
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self);
rc= mysql_reset_connection(self->mysql);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (rc)
{
@ -728,9 +862,9 @@ MrdbConnection_dump_debug_info(MrdbConnection *self)
int rc;
MARIADB_CHECK_CONNECTION(self, NULL);
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self);
rc= mysql_dump_debug_info(self->mysql);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (rc)
{
@ -744,9 +878,9 @@ static PyObject *MrdbConnection_readresponse(MrdbConnection *self)
{
int rc;
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self);
rc= self->mysql->methods->db_read_query_result(self->mysql);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self);
if (rc)
{

View File

@ -534,9 +534,9 @@ void ma_cursor_close(MrdbCursor *self)
{
/* Todo: check if all the cursor stuff is deleted (when using prepared
statements this should be handled in mysql_stmt_close) */
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self->connection);
mysql_stmt_close(self->stmt);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self->connection);
self->stmt= NULL;
}
MrdbCursor_clear(self, 0);
@ -662,7 +662,7 @@ static int Mrdb_execute_direct(MrdbCursor *self,
{
int rc;
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self->connection);
long ext_caps;
mariadb_get_infov(self->connection->mysql,
@ -692,7 +692,7 @@ static int Mrdb_execute_direct(MrdbCursor *self,
rc= mariadb_stmt_execute_direct(self->stmt, statement, statement_len);
}
end:
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self->connection);
return rc;
}
@ -855,12 +855,12 @@ static PyObject *MrdbCursor_seek(MrdbCursor *self, PyObject *args)
{
return NULL;
}
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self->connection);
if (self->parseinfo.is_text)
mysql_data_seek(self->result, new_position);
else
mysql_stmt_data_seek(self->stmt, new_position);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self->connection);
Py_RETURN_NONE;
}
@ -890,7 +890,7 @@ MrdbCursor_nextset(MrdbCursor *self)
return NULL;
}
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self->connection);
if (!self->parseinfo.is_text)
rc= mysql_stmt_next_result(self->stmt);
else
@ -902,7 +902,7 @@ MrdbCursor_nextset(MrdbCursor *self)
}
rc= mysql_next_result(self->connection->mysql);
}
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self->connection);
if (rc)
{
@ -1124,9 +1124,9 @@ MrdbCursor_execute_text(MrdbCursor *self, PyObject *args)
db= self->connection->mysql;
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self->connection);
rc= mysql_send_query(db, statement, (long)statement_len);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self->connection);
if (rc)
{
@ -1148,9 +1148,9 @@ MrdbCursor_readresponse(MrdbCursor *self)
if (self->parseinfo.is_text)
{
Py_BEGIN_ALLOW_THREADS;
MARIADB_BEGIN_ALLOW_THREADS(self->connection);
rc= db->methods->db_read_query_result(db);
Py_END_ALLOW_THREADS;
MARIADB_END_ALLOW_THREADS(self->connection);
if (rc)
{