diff --git a/mariadb/connectionpool.py b/mariadb/connectionpool.py index a9a2e1b..8743023 100644 --- a/mariadb/connectionpool.py +++ b/mariadb/connectionpool.py @@ -134,6 +134,23 @@ class ConnectionPool(object): # store connection pool in _CONNECTION_POOLS mariadb._CONNECTION_POOLS[self._pool_args["name"]] = self + def _replace_connection(self, connection): + """ + Removes the given connection and adds a new connection. + """ + + if connection: + if connection in self._connections_free: + x = self._connections_free.index(connection) + del self._connections_free[x] + elif connection in self._connections_used: + x = self._connections_used.index(connection) + del self._connections_used[x] + + connection._Connection__pool = None + connection.close() + return self.add_connection() + def __repr__(self): if (self.__closed): return "= self._pool_args["size"]: raise mariadb.PoolError("Can't add connection to pool %s: " "No free slot available (%s)." % @@ -177,6 +194,7 @@ class ConnectionPool(object): connection._Connection__pool = self connection.__last_used = time.perf_counter_ns() self._connections_free.append(connection) + return connection def get_connection(self): """ @@ -188,15 +206,16 @@ class ConnectionPool(object): with self._lock_pool: for i in range(0, len(self._connections_free)): - dt = (time.perf_counter_ns() - - self._connections_free[i].__last_used) / 1000000 + conn = self._connections_free[i] + dt = (time.perf_counter_ns() - conn.__last_used) / 1000000 if dt > self._pool_args["validation_interval"]: try: - self._connections_free[i].ping() + conn.ping() except mariadb.Error: - continue + conn = self._replace_connection(conn) + if not conn: + continue - conn = self._connections_free[i] conn._used += 1 self._connections_used.append(conn) del self._connections_free[i] @@ -209,20 +228,22 @@ class ConnectionPool(object): Returns connection to the pool. Internally used by connection object. """ - - if self._pool_args["reset_connection"]: - connection.reset() - elif connection.server_status & STATUS.IN_TRANS: - connection.rollback() - with self._lock_pool: - for i in range(0, len(self._connections_used)): - if self._connections_used[i] == connection: - del self._connections_used[i] + try: + if self._pool_args["reset_connection"]: + connection.reset() + elif connection.server_status & STATUS.IN_TRANS: + connection.rollback() + except mariadb.Error: + self._replace_connection(connection) + + if connection: + if connection in self._connections_used: + x = self._connections_used.index(connection) + del self._connections_used[x] connection.__last_used = time.perf_counter_ns() self._connections_free.append(connection) - return def set_config(self, **kwargs): """ @@ -238,10 +259,7 @@ class ConnectionPool(object): def close(self): """Closes connection pool and all connections.""" try: - for c in self._connections_free: - c._Connection__pool = None - c.close() - for c in self._connections_used: + for c in (self._connections_free + self._connections_used): c._Connection__pool = None c.close() finally: diff --git a/testing/test/integration/test_pooling.py b/testing/test/integration/test_pooling.py index 32ee79e..47518a3 100644 --- a/testing/test/integration/test_pooling.py +++ b/testing/test/integration/test_pooling.py @@ -60,6 +60,100 @@ class TestPooling(unittest.TestCase): conn.close() pool.close() + def test_conpy247_1(self): + default_conf = conf() + pool = mariadb.ConnectionPool(pool_name="CONPY247_1", + pool_size=1, + pool_reset_connection=False, + pool_validation_interval=0, + **default_conf) + + # service connection + conn = create_connection() + cursor = conn.cursor() + + pconn = pool.get_connection() + old_id = pconn.connection_id + cursor.execute("KILL %s" % (old_id,)) + pconn.close() + + pconn = pool.get_connection() + self.assertNotEqual(old_id, pconn.connection_id) + + conn.close() + pool.close() + + def test_conpy247_2(self): + default_conf = conf() + pool = mariadb.ConnectionPool(pool_name="CONPY247_2", + pool_size=1, + pool_reset_connection=True, + pool_validation_interval=0, + **default_conf) + + # service connection + conn = create_connection() + cursor = conn.cursor() + + pconn = pool.get_connection() + old_id = pconn.connection_id + cursor.execute("KILL %s" % (old_id,)) + pconn.close() + + pconn = pool.get_connection() + self.assertNotEqual(old_id, pconn.connection_id) + + conn.close() + pool.close() + + def test_conpy247_3(self): + default_conf = conf() + pool = mariadb.ConnectionPool(pool_name="CONPY247_3", + pool_size=10, + pool_reset_connection=True, + pool_validation_interval=0, + **default_conf) + + # service connection + conn = create_connection() + cursor = conn.cursor() + ids = [] + + sql = """CREATE OR REPLACE PROCEDURE p1() + BEGIN + SELECT 1; + SELECT 2; + END""" + + cursor.execute(sql) + + for i in range(0, 10): + pconn = pool.get_connection() + ids.append(pconn.connection_id) + cursor.execute("KILL %s" % (pconn.connection_id,)) + pconn.close() + + new_ids = [] + + for i in range(0, 10): + pconn = pool.get_connection() + new_ids.append(pconn.connection_id) + self.assertEqual(pconn.connection_id in ids, False) + cursor = pconn.cursor(buffered=False, binary=False) + cursor.callproc("P1") + pconn.close() + + print("new_ids", new_ids) + + for i in range(0, 10): + pconn = pool.get_connection() + print("new_id: ", pconn.connection_id) + self.assertEqual(pconn.connection_id in new_ids, True) + pconn.close() + + conn.close() + pool.close() + def test_conpy245(self): # we can't test performance here, but we can check if LRU works. # All connections must have been used the same number of times.