Fixed various memory leaks and address sanitizer related problems:

- When catching errors from cursor the try/except construct must be either
  within a context or we need to explicitly call traceback.clear_frames().
  Otherwise traceback will hold a reference to the cursor which
  generates a msan error.
- Before executing a cursor command via execute, callproc or executemany we
  now reset the cursor.
- Fixed various tests, which didn't close cursor or connection properly.
This commit is contained in:
Georg Richter
2025-06-16 12:44:14 +02:00
parent 7b0fbd85c5
commit 8090efa833
9 changed files with 236 additions and 139 deletions

View File

@ -17,6 +17,7 @@ from ._mariadb import (
ProgrammingError,
Warning,
mariadbapi_version,
_have_asan,
)
from .field import fieldinfo
@ -36,7 +37,7 @@ __all__ = ["DataError", "DatabaseError", "Error", "IntegrityError",
"InterfaceError", "InternalError", "NotSupportedError",
"OperationalError", "PoolError", "ProgrammingError",
"Warning", "Connection", "__version__", "__version_info__",
"__author__", "Cursor", "fieldinfo"]
"__author__", "Cursor", "fieldinfo", "_have_asan"]
def connect(*args, connectionclass=mariadb.connections.Connection, **kwargs):

View File

@ -190,6 +190,7 @@ class Cursor(mariadb._mariadb.cursor):
if self.connection.auto_reconnect:
self._thread_id= self.connection.thread_id
self.check_closed()
self._reset()
# create statement
params = ""
@ -260,6 +261,8 @@ class Cursor(mariadb._mariadb.cursor):
self._thread_id= self.connection.thread_id
self.check_closed()
if not self._prepared:
self._reset()
self.connection._last_executed_statement = statement
@ -341,6 +344,7 @@ class Cursor(mariadb._mariadb.cursor):
self._thread_id= self.connection.thread_id
self.check_closed()
self._reset()
if not parameters or not len(parameters):
raise mariadb.ProgrammingError("No data provided")

View File

@ -24,6 +24,22 @@
#include <structmember.h>
#include <datetime.h>
#ifdef __clang__
# if defined(__has_feature) && __has_feature(address_sanitizer)
# define HAVE_ASAN Py_True
# else
# define HAVE_ASAN Py_False
# endif
#elif defined(__GNUC__)
# ifdef __SANITIZE_ADDRESS__
# define HAVE_ASAN Py_True
# else
# define HAVE_ASAN Py_False
# endif
#else
# define HAVE_ASAN Py_False
#endif
extern int codecs_datetime_init(void);
extern int connection_datetime_init(void);
@ -182,6 +198,8 @@ PyMODINIT_FUNC PyInit__mariadb(void)
Py_INCREF(&MrdbConnection_Type);
PyModule_AddObject(module, "connection", (PyObject *)&MrdbConnection_Type);
PyModule_AddObject(module, "_have_asan", HAVE_ASAN);
Py_INCREF(HAVE_ASAN);
return module;
error:

View File

@ -509,6 +509,11 @@ static int MrdbConnection_traverse(
visitproc visit,
void *arg)
{
Py_VISIT(self->last_executed_stmt);
Py_VISIT(self->converter);
#if MARIADB_PACKAGE_VERSION_ID > 30301
Py_VISIT(self->status_callback);
#endif
return 0;
}
@ -525,6 +530,30 @@ static PyObject *MrdbConnection_repr(MrdbConnection *self)
return PyUnicode_FromString(cobj_repr);
}
static void ma_connection_close(MrdbConnection *conn)
{
if (conn)
{
if (conn->mysql)
{
Py_BEGIN_ALLOW_THREADS
mysql_close(conn->mysql);
Py_END_ALLOW_THREADS
conn->mysql= NULL;
}
}
}
static void MrdbConnection_dealloc(PyObject *obj)
{
MrdbConnection *self = (MrdbConnection *)obj;
if (self && self->mysql)
ma_connection_close(self);
Py_TYPE(self)->tp_free((PyObject *)self);
}
PyTypeObject MrdbConnection_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "mariadb.connection",
@ -539,6 +568,7 @@ PyTypeObject MrdbConnection_Type = {
.tp_getset = MrdbConnection_sets,
.tp_init = (initproc)MrdbConnection_Initialize,
.tp_alloc = PyType_GenericAlloc,
.tp_dealloc = MrdbConnection_dealloc,
.tp_finalize = (destructor)MrdbConnection_finalize
};
@ -564,16 +594,7 @@ MrdbConnection_connect(
static
void MrdbConnection_finalize(MrdbConnection *self)
{
if (self)
{
if (self->mysql)
{
Py_BEGIN_ALLOW_THREADS
mysql_close(self->mysql);
Py_END_ALLOW_THREADS
self->mysql= NULL;
}
}
ma_connection_close(self);
}
static PyObject *
@ -603,10 +624,7 @@ PyObject *MrdbConnection_close(MrdbConnection *self)
{
MARIADB_CHECK_CONNECTION(self, NULL);
Py_BEGIN_ALLOW_THREADS
mysql_close(self->mysql);
Py_END_ALLOW_THREADS
self->mysql= NULL;
ma_connection_close(self);
self->closed= 1;
Py_RETURN_NONE;
}

View File

@ -27,6 +27,8 @@ MrdbCursor_finalize(MrdbCursor *self);
static PyObject *
MrdbCursor_close(MrdbCursor *self);
static PyObject *
MrdbCursor_reset(MrdbCursor *self);
static PyObject *
MrdbCursor_nextset(MrdbCursor *self);
@ -149,6 +151,9 @@ static PyMethodDef MrdbCursor_Methods[] =
{"_check_text_types", (PyCFunction) MrdbCursor_check_text_types,
METH_NOARGS,
NULL},
{"_reset", (PyCFunction)MrdbCursor_reset,
METH_NOARGS,
NULL},
{"_seek", (PyCFunction)MrdbCursor_seek,
METH_O,
NULL},
@ -332,6 +337,15 @@ static int MrdbCursor_traverse(
visitproc visit,
void *arg)
{
Py_VISIT(self->connection);
Py_VISIT(self->data);
return 0;
}
static int MrdbCursor_tpclear(MrdbCursor *self)
{
Py_CLEAR(self->connection);
Py_CLEAR(self->data);
return 0;
}
@ -369,6 +383,7 @@ PyTypeObject MrdbCursor_Type =
.tp_init= (initproc)MrdbCursor_initialize,
.tp_new= PyType_GenericNew,
.tp_dealloc= MrdbCursor_dealloc,
.tp_clear = (inquiry)MrdbCursor_tpclear,
.tp_finalize= (destructor)MrdbCursor_finalize
};
@ -496,13 +511,8 @@ static void ma_set_result_column_value(MrdbCursor *self, PyObject *row, uint32_t
}
}
/* {{{ ma_cursor_close
closes the statement handle of current cursor. After call to
cursor_close the cursor can't be reused anymore
*/
static
void ma_cursor_close(MrdbCursor *self)
void ma_cursor_reset(MrdbCursor *self)
{
if (!self->closed)
{
@ -519,8 +529,24 @@ void ma_cursor_close(MrdbCursor *self)
MrdbCursor_clear(self, 0);
MrdbCursor_clearparseinfo(&self->parseinfo);
}
}
/* {{{ ma_cursor_close
closes the statement handle of current cursor. After call to
cursor_close the cursor can't be reused anymore
*/
static
void ma_cursor_close(MrdbCursor *self)
{
ma_cursor_reset(self);
self->closed= 1;
}
static PyObject * MrdbCursor_reset(MrdbCursor *self)
{
ma_cursor_reset(self);
Py_RETURN_NONE;
}
static

View File

@ -12,6 +12,7 @@ from mariadb.constants import STATUS
import platform
from packaging.version import parse as parse_version
from packaging import version
import traceback, sys
class TestConnection(unittest.TestCase):
@ -32,6 +33,9 @@ class TestConnection(unittest.TestCase):
port=default_conf["port"],
host=default_conf["host"])
except (mariadb.OperationalError,):
# make asan happy
tb = sys.exc_info()[2]
traceback.clear_frames(tb)
pass
def test_connection_default_file(self):
@ -61,6 +65,7 @@ class TestConnection(unittest.TestCase):
# revert
conn.autocommit = True
self.assertEqual(conn.autocommit, True)
conn.close()
def test_local_infile(self):
default_conf = conf()
@ -70,6 +75,10 @@ class TestConnection(unittest.TestCase):
try:
cursor.execute("LOAD DATA LOCAL INFILE 'x.x' INTO TABLE t1")
except (mariadb.OperationalError,):
# make asan happy
if mariadb._have_asan:
tb = sys.exc_info()[2]
traceback.clear_frames(tb)
pass
del cursor
del new_conn
@ -116,7 +125,7 @@ class TestConnection(unittest.TestCase):
self.skipTest("MAXSCALE doesn't tell schema change for now")
default_conf = conf()
conn = self.connection
conn = create_connection()
self.assertEqual(conn.database, default_conf["database"])
cursor = conn.cursor()
cursor.execute("DROP SCHEMA IF EXISTS test1")
@ -125,12 +134,14 @@ class TestConnection(unittest.TestCase):
self.assertEqual(conn.database, "test1")
conn.database = default_conf["database"]
self.assertEqual(conn.database, default_conf["database"])
cursor.close()
conn.close()
def test_ping(self):
if is_maxscale():
self.skipTest("MAXSCALE wrong thread id")
conn = self.connection
cursor = conn.cursor()
with create_connection() as conn:
with conn.cursor() as cursor:
oldid = conn.connection_id
try:
@ -152,7 +163,7 @@ class TestConnection(unittest.TestCase):
if self.connection.server_version < 100122:
self.skipTest("ed25519 not supported")
conn = self.connection
conn = create_connection()
curs = conn.cursor(buffered=True)
if self.connection.server_name == "localhost":
@ -190,7 +201,7 @@ class TestConnection(unittest.TestCase):
except (mariadb.OperationalError):
pass
cursor.execute("DROP USER IF EXISTS eduser")
del cursor, conn2
del cursor, conn2, conn
def test_conpy46(self):
with create_connection() as con:
@ -211,12 +222,14 @@ class TestConnection(unittest.TestCase):
default_conf = conf()
c1 = mariadb.connect(**default_conf)
self.assertEqual(c1.autocommit, False)
c1.close()
c1 = mariadb.connect(**default_conf, autocommit=True)
self.assertEqual(c1.autocommit, True)
c1.close()
def test_db_attribute(self):
con = create_connection()
cursor = con.cursor()
with create_connection() as con:
with con.cursor() as cursor:
db = con.database
try:
cursor.execute("create schema test123")
@ -232,7 +245,6 @@ class TestConnection(unittest.TestCase):
self.assertEqual(row[0], db)
self.assertEqual(row[0], con.database)
cursor.execute("drop schema test123")
del cursor
def test_server_status(self):
con = create_connection()
@ -319,31 +331,27 @@ class TestConnection(unittest.TestCase):
def test_conpy278(self):
if is_maxscale():
self.skipTest("MAXSCALE bug MXS-4961")
test_conf= conf()
test_conf["reconnect"]= True
conn= mariadb.connect(**test_conf)
with create_connection({"reconnect" : True}) as conn:
old_id= conn.connection_id
try:
conn.kill(conn.connection_id)
except mariadb.OperationalError:
conn.ping()
self.assertNotEqual(old_id, conn.connection_id)
conn.close()
conn= mariadb.connect(**test_conf)
with create_connection({"reconnect" : True}) as conn:
old_id= conn.connection_id
try:
conn.kill(conn.connection_id)
except mariadb.OperationalError:
conn.ping()
self.assertNotEqual(old_id, conn.connection_id)
conn.close()
conn= mariadb.connect(**test_conf)
with create_connection({"reconnect" : True}) as conn:
old_id= conn.connection_id
try:
conn.kill(conn.connection_id)
except mariadb.OperationalError:
pass
cursor= conn.cursor()
with conn.cursor() as cursor:
try:
cursor.execute("set @a:=1")
except mariadb.InterfaceError:
@ -354,7 +362,6 @@ class TestConnection(unittest.TestCase):
old_id= conn.connection_id
conn.reconnect()
self.assertNotEqual(old_id, conn.connection_id)
conn.close()
if __name__ == '__main__':

View File

@ -729,7 +729,8 @@ class TestCursor(unittest.TestCase):
def test_conpy34(self):
cursor = self.connection.cursor()
with create_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("CREATE TEMPORARY TABLE t1 (a varchar(20),"
"b varchar(20))")
try:
@ -737,7 +738,6 @@ class TestCursor(unittest.TestCase):
(("Walker", "Percy"), ("Flannery", "O'Connor")))
except (mariadb.ProgrammingError, mariadb.NotSupportedError):
pass
del cursor
def test_scroll(self):
cursor = self.connection.cursor(buffered=True)
@ -987,6 +987,7 @@ class TestCursor(unittest.TestCase):
row = cursor.fetchone()
self.assertEqual(row[0], 1)
cursor.execute("DROP PROCEDURE IF EXISTS p1")
con.close()
def test_sp2(self):
con = create_connection()
@ -1052,6 +1053,7 @@ class TestCursor(unittest.TestCase):
b'\x00\xf0?'])
self.assertEqual(row[0], expected)
del cursor
con.close()
def test_conpy35(self):
con = create_connection()
@ -1071,6 +1073,7 @@ class TestCursor(unittest.TestCase):
i = i + 1
self.assertEqual(row[0], i)
del cursor
con.close()
def test_conpy45(self):
con = create_connection()
@ -1108,6 +1111,7 @@ class TestCursor(unittest.TestCase):
cursor.execute("SELECT ?", (False,))
row = cursor.fetchone()
self.assertEqual(row[0], 0)
cursor.close()
del con
def test_conpy48(self):
@ -1123,6 +1127,7 @@ class TestCursor(unittest.TestCase):
self.assertEqual(row[0], 1)
row = cur.fetchone()
self.assertEqual(row[0], 2)
cur.close()
del con
def test_conpy51(self):
@ -1136,6 +1141,7 @@ class TestCursor(unittest.TestCase):
self.assertEqual(row[0][0], 1)
self.assertEqual(row[1][0], 2)
self.assertEqual(row[2][0], 3)
cur.close()
del con
def test_conpy52(self):
@ -1156,6 +1162,7 @@ class TestCursor(unittest.TestCase):
self.assertEqual(row[1][0], 2)
self.assertEqual(row[2][0], 3)
cur.execute("drop table if exists temp")
cur.close()
del con
def test_conpy49(self):
@ -1166,6 +1173,7 @@ class TestCursor(unittest.TestCase):
cur.execute("select a from t1")
row = cur.fetchone()
self.assertEqual(row[0], Decimal('10.20'))
cur.close()
del con
def test_conpy56(self):
@ -1175,6 +1183,7 @@ class TestCursor(unittest.TestCase):
row = cur.fetchone()
self.assertEqual(row["foo"], "bar")
self.assertEqual(row["bar"], "foo")
cur.close()
del con
def test_conpy53(self):
@ -1186,6 +1195,7 @@ class TestCursor(unittest.TestCase):
cur.execute("select 1", [])
row = cur.fetchone()
self.assertEqual(row[0], 1)
cur.close()
del con
def test_conpy58(self):
@ -1201,6 +1211,7 @@ class TestCursor(unittest.TestCase):
row = cursor.fetchall()
self.assertEqual(row[0][0], 1)
self.assertEqual(row[1][0], 2)
cursor.close()
del con
def test_conpy59(self):
@ -1211,6 +1222,7 @@ class TestCursor(unittest.TestCase):
cursor.execute("SELECT a FROM t1")
row = cursor.fetchone()
self.assertEqual(row[0], None)
cursor.close()
del con
def test_conpy61(self):
@ -1243,6 +1255,7 @@ class TestCursor(unittest.TestCase):
self.assertEqual(row[2], None)
del cursor
con.close()
def test_conpy62(self):
con = create_connection()
@ -1392,7 +1405,7 @@ class TestCursor(unittest.TestCase):
def test_conpy133(self):
if is_mysql():
self.skipTest("Skip (MySQL)")
conn = create_connection()
with create_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT /*! ? */", (1,))
@ -1418,19 +1431,17 @@ class TestCursor(unittest.TestCase):
self.assertEqual(row[0], 1)
del cursor
cursor = conn.cursor()
with conn.cursor() as cursor:
try:
cursor.execute("SELECT /*!50701 ? */", (1,))
except mariadb.ProgrammingError:
pass
del cursor
cursor = conn.cursor()
with conn.cursor() as cursor:
try:
cursor.execute("SELECT /*!250701 ? */", (1,))
except mariadb.ProgrammingError:
pass
del cursor
def check_closed(self):
conn = create_connection()
@ -1552,11 +1563,11 @@ class TestCursor(unittest.TestCase):
pass
cursor.close()
conn.close()
def test_conpy203(self):
conn = create_connection()
cursor = conn.cursor()
with create_connection() as conn:
with conn.cursor() as cursor:
try:
cursor.execute("SELECT")
except mariadb.ProgrammingError as err:

View File

@ -4,6 +4,7 @@
import unittest
from datetime import datetime
import mariadb
import sys, traceback
from test.base_test import create_connection
@ -25,6 +26,9 @@ class TestException(unittest.TestCase):
self.assertEqual(err.errno, 1064)
self.assertTrue(err.errmsg.find("You have an error "
"in your SQL syntax") > -1)
if mariadb._have_asan:
tb = sys.exc_info()[2]
traceback.clear_frames(tb)
pass
del cursor
@ -36,6 +40,9 @@ class TestException(unittest.TestCase):
self.assertEqual(err.sqlstate, "42000")
self.assertEqual(err.errno, 1049)
self.assertTrue(err.errmsg.find("Unknown database 'unknown'") > -1)
if mariadb._have_asan:
tb = sys.exc_info()[2]
traceback.clear_frames(tb)
pass
def test_conn_timeout_exception(self):
@ -50,6 +57,9 @@ class TestException(unittest.TestCase):
difference = end - start
self.assertEqual(difference.days, 0)
self.assertTrue(difference.seconds, 1)
if mariadb._have_asan:
tb = sys.exc_info()[2]
traceback.clear_frames(tb)
pass

View File

@ -87,6 +87,7 @@ class TestPooling(unittest.TestCase):
pconn = pool.get_connection()
old_id = pconn.connection_id
cursor.execute("KILL %s" % (old_id,))
cursor.close()
pconn.close()
pconn = pool.get_connection()
@ -110,6 +111,7 @@ class TestPooling(unittest.TestCase):
pconn = pool.get_connection()
old_id = pconn.connection_id
cursor.execute("KILL %s" % (old_id,))
cursor.close()
pconn.close()
pconn = pool.get_connection()