Fixes for class ConnectionPool()

- added mutexes for thread safety
- when calling get_connection thread pool will now not return the next
  free connection, but the connection that was not used the longest time.
This commit is contained in:
Georg Richter
2019-12-01 05:47:19 +01:00
parent d207566c68
commit 3b02464b63
9 changed files with 120 additions and 22 deletions

View File

@ -30,6 +30,25 @@
#include <malloc.h>
#include <docs/common.h>
#if defined(_WIN32)
#include <windows.h>
typedef CRITICAL_SECTION pthread_mutex_t;
#define pthread_mutex_init(A,B) InitializeCriticalSection(A)
#define pthread_mutex_lock(A) (EnterCriticalSection(A),0)
#define pthread_mutex_unlock(A) LeaveCriticalSection(A)
#define pthread_mutex_destroy(A) DeleteCriticalSection(A)
#define pthread_self() GetCurrentThreadId()
#else
#include <pthread.h>
#endif /* defined(_WIN32) */
#ifdef _WIN32
int clock_gettime(int dummy, struct timespec *ct);
#define CLOCK_MONOTONIC_RAW 1
#endif
#define REQUIRED_CC_VERSION 30103
#if MARIADB_PACKAGE_VERSION_ID < REQUIRED_CC_VERSION
@ -127,10 +146,12 @@ typedef struct {
struct mrdb_pool *pool;
uint8_t inuse;
uint8_t status;
struct timespec last_used;
} MrdbConnection;
typedef struct mrdb_pool{
PyObject_HEAD
pthread_mutex_t lock;
char *pool_name;
size_t pool_name_length;
uint32_t pool_size;
@ -313,6 +334,9 @@ uint8_t MrdbParser_parse(MrdbParser *p, uint8_t is_batch, char *errmsg, size_t e
#define MAX_POOL_SIZE 64
#define TIMEDIFF(a,b)\
((a).tv_sec * 1E09 + (a).tv_nsec) - ((b).tv_sec * 1E09 + (b).tv_nsec)
/* Helper macros */
#define MrdbIndicator_Check(a)\

View File

@ -55,14 +55,16 @@ def get_config(options):
cfg.version = cc_version[0]
libs = mariadb_config(config_prg, "libs")
extra_libs= mariadb_config(config_prg, "libs_sys")
cfg.lib_dirs = [dequote(i[2:]) for i in libs if i.startswith("-L")]
cfg.libs = [dequote(i[2:]) for i in libs if i.startswith("-l")]
includes = mariadb_config(config_prg, "include")
mariadb_includes = [dequote(i[2:]) for i in includes if i.startswith("-I")]
mariadb_includes.extend(["./include"])
if static:
if static.lower() == "on":
cfg.extra_link_args= ["-u mysql_ps_fetch_functions"]
cfg.extra_objects = ['{}/lib{}.a'.format(cfg.lib_dirs[0], l) for l in ["mariadbclient"]]
cfg.libs = []
cfg.libs = [dequote(i[2:]) for i in extra_libs if i.startswith("-l")]
cfg.includes = mariadb_includes
return cfg

View File

@ -42,7 +42,7 @@ def get_config(options):
print("MariaDB Connector/Python requires MariaDB Connector/C >= %s (found version: %s") \
% (required_version, cc_version[0])
sys.exit(2)
mariadb_dir = QueryValueEx(connector_key, "InstallDir")
mariadb_dir = QueryValueEx(connector_key, "InstallDir")[0]
except:
print("Could not find InstallationDir of MariaDB Connector/C. "
@ -50,10 +50,10 @@ def get_config(options):
"MariaDB Connector/C by setting the environment variable MARIADB_CC_INSTALL_DIR.")
sys.exit(3)
print("Found MariaDB Connector/C in '%s'" % mariadb_dir[0])
print("Found MariaDB Connector/C in '%s'" % mariadb_dir)
cfg = MariaDBConfiguration()
cfg.includes = [".\\include", mariadb_dir[0] + "\\include", mariadb_dir[0] + "\\include\\mysql"]
cfg.lib_dirs = [mariadb_dir[0] + "\\lib"]
cfg.includes = [".\\include", mariadb_dir + "\\include", mariadb_dir + "\\include\\mysql"]
cfg.lib_dirs = [mariadb_dir + "\\lib"]
cfg.libs = ["ws2_32", "advapi32", "kernel32", "shlwapi", "crypt32"]
if static.lower() == "on":
cfg.libs.append("mariadbclient")

View File

@ -54,6 +54,7 @@ setup(name='mariadb',
library_dirs=cfg.lib_dirs,
libraries=cfg.libs,
extra_compile_args = cfg.extra_compile_args,
extra_link_args = cfg.extra_link_args
extra_link_args = cfg.extra_link_args,
extra_objects= cfg.extra_objects
)],
)

View File

@ -885,4 +885,41 @@ uint8_t mariadb_param_update(void *data, MYSQL_BIND *bind, uint32_t row_nr)
}
return 0;
}
#ifdef _WIN32
/* windows equivalent for clock_gettime.
Code based on https://stackoverflow.com/questions/5404277/porting-clock-gettime-to-windows
*/
static uint8_t g_first_time = 1;
static LARGE_INTEGER g_counts_per_sec;
int clock_gettime(int dummy, struct timespec *ct)
{
LARGE_INTEGER count;
if (g_first_time)
{
g_first_time = 0;
if (0 == QueryPerformanceFrequency(&g_counts_per_sec))
{
g_counts_per_sec.QuadPart = 0;
}
}
if ((NULL == ct) || (g_counts_per_sec.QuadPart <= 0) ||
(0 == QueryPerformanceCounter(&count)))
{
return -1;
}
ct->tv_sec = count.QuadPart / g_counts_per_sec.QuadPart;
ct->tv_nsec = ((count.QuadPart % g_counts_per_sec.QuadPart) * 1E09) / g_counts_per_sec.QuadPart;
return 0;
}
#endif
/* }}} */

View File

@ -484,8 +484,16 @@ PyObject *MrdbConnection_close(MrdbConnection *self)
if (self->pool)
{
if (!mysql_reset_connection(self->mysql))
int rc= 0;
pthread_mutex_lock(&self->pool->lock);
if (self->pool->reset_session)
rc= mysql_reset_connection(self->mysql);
if (!rc)
{
self->inuse= 0;
clock_gettime(CLOCK_MONOTONIC_RAW, &self->last_used);
}
pthread_mutex_unlock(&self->pool->lock);
return Py_None;
}

View File

@ -1172,6 +1172,7 @@ PyObject *MrdbCursor_executemany(MrdbCursor *self,
Py_END_ALLOW_THREADS;
if (rc)
{
printf("Error: %s\n", mysql_stmt_error(self->stmt));
mariadb_throw_exception(self->stmt, NULL, 1, NULL);
goto error;
}
@ -1181,6 +1182,7 @@ PyObject *MrdbCursor_executemany(MrdbCursor *self,
Py_END_ALLOW_THREADS;
if (rc)
{
printf("rc=%d Error: %s\n", rc, mysql_stmt_error(self->stmt));
mariadb_throw_exception(self->stmt, NULL, 1, NULL);
goto error;
}

View File

@ -136,6 +136,7 @@ static int MrdbPool_initialize(MrdbPool *self, PyObject *args,
goto error;
}
pthread_mutex_init(&self->lock, NULL);
self->pool_name= strdup(pool_name);
self->pool_name_length= pool_name_length;
self->pool_size= pool_size;
@ -157,6 +158,7 @@ static int MrdbPool_initialize(MrdbPool *self, PyObject *args,
if (!(self->connection[i]=
(MrdbConnection *)MrdbConnection_connect(NULL, args, self->configuration)))
goto error;
clock_gettime(CLOCK_MONOTONIC_RAW, &self->connection[i]->last_used);
Py_INCREF(self->connection[i]);
self->connection[i]->pool= self;
}
@ -284,6 +286,7 @@ void MrdbPool_dealloc(MrdbPool *self)
self->pool_size= 0;
MARIADB_FREE_MEM(self->connection);
self->connection= NULL;
pthread_mutex_destroy(&self->lock);
Py_TYPE(self)->tp_free((PyObject*)self);
}
@ -311,15 +314,28 @@ MrdbPool_add(
PyObject *MrdbPool_getconnection(MrdbPool *self)
{
uint32_t i;
MrdbConnection *conn= NULL;
uint64_t tdiff= 0;
struct timespec now;
clock_gettime(CLOCK_MONOTONIC_RAW, &now);
pthread_mutex_lock(&self->lock);
for (i=0; i < self->pool_size; i++)
{
if (!self->connection[i]->inuse)
if (self->connection[i] && !self->connection[i]->inuse)
{
if (self->connection[i] && !mysql_ping(self->connection[i]->mysql))
if (self->connection[i])
{
self->connection[i]->inuse= 1;
return (PyObject *)self->connection[i];
if (!mysql_ping(self->connection[i]->mysql))
{
uint64_t t= TIMEDIFF(now, self->connection[i]->last_used);
if (t >= tdiff)
{
conn= self->connection[i];
tdiff= t;
}
} else {
self->connection[i]->pool= NULL;
MrdbConnection_close(self->connection[i]);
@ -327,6 +343,12 @@ PyObject *MrdbPool_getconnection(MrdbPool *self)
}
}
}
}
if (conn)
conn->inuse= 1;
pthread_mutex_unlock(&self->lock);
if (conn)
return (PyObject *)conn;
mariadb_throw_exception(NULL, Mariadb_PoolError, 0,
"No more connections from pool '%s' available",
self->pool_name);
@ -335,11 +357,6 @@ PyObject *MrdbPool_getconnection(MrdbPool *self)
static PyObject *MrdbPool_setconfig(MrdbPool *self, PyObject *args, PyObject *kwargs)
{
/* PyObject *conf= NULL;
if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &conf))
return NULL;
*/
self->configuration= kwargs;
Py_RETURN_NONE;
}
@ -368,6 +385,7 @@ static PyObject * MrdbPool_addconnection(MrdbPool *self, PyObject *args)
return NULL;
}
pthread_mutex_lock(&self->lock);
for (i=0; i < self->pool_size; i++)
{
if (!self->connection[i])
@ -377,11 +395,15 @@ static PyObject * MrdbPool_addconnection(MrdbPool *self, PyObject *args)
return NULL;
self->connection[i]= conn;
self->connection[i]->inuse= 0;
clock_gettime(CLOCK_MONOTONIC_RAW, &self->connection[i]->last_used);
conn->pool= self;
pthread_mutex_unlock(&self->lock);
Py_RETURN_NONE;
}
}
pthread_mutex_unlock(&self->lock);
mariadb_throw_exception(NULL, Mariadb_PoolError, 0,
"Couldn't add connection to pool '%s' (no free slot available).",
self->pool_name);

View File

@ -47,11 +47,13 @@ class TestPooling(unittest.TestCase):
connections.append(pool.get_connection())
try:
x= pool.get_connection()
print("ok")
except mariadb.PoolError:
pass
for c in connections:
c.close()
x= pool.get_connection()
print("ok")
del pool
def test_connection_pool_add(self):