mirror of
https://github.com/mariadb-corporation/mariadb-connector-python.git
synced 2025-07-30 12:57:47 +00:00
368 lines
13 KiB
Python
368 lines
13 KiB
Python
#!/usr/bin/env python -O
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import unittest
|
|
|
|
import mariadb
|
|
|
|
from test.base_test import create_connection, is_skysql, is_maxscale, get_host_suffix
|
|
from test.conf_test import conf
|
|
from mariadb.constants import STATUS
|
|
import platform
|
|
from packaging.version import parse as parse_version
|
|
from packaging import version
|
|
import traceback, sys
|
|
|
|
|
|
class TestConnection(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.connection = create_connection()
|
|
|
|
def tearDown(self):
|
|
del self.connection
|
|
|
|
def test_conpy36(self):
|
|
if platform.system() == "Windows":
|
|
self.skipTest("unix_socket not supported on Windows")
|
|
default_conf = conf()
|
|
try:
|
|
mariadb.connect(user=default_conf["user"],
|
|
unix_socket="/does_not_exist/x.sock",
|
|
port=default_conf["port"],
|
|
host=default_conf["host"])
|
|
except (mariadb.OperationalError,):
|
|
# make asan happy
|
|
tb = sys.exc_info()[2]
|
|
traceback.clear_frames(tb)
|
|
pass
|
|
|
|
def test_connection_default_file(self):
|
|
if os.path.exists("client.cnf"):
|
|
os.remove("client.cnf")
|
|
default_conf = conf()
|
|
f = open("client.cnf", "w+")
|
|
f.write("[client]\n")
|
|
f.write("host =%s\n" % default_conf["host"])
|
|
f.write("port =%i\n" % default_conf["port"])
|
|
f.write("user =%s\n" % default_conf["user"])
|
|
if "password" in default_conf:
|
|
f.write("password =%s\n" % default_conf["password"])
|
|
f.write("database =%s\n" % default_conf["database"])
|
|
f.close()
|
|
|
|
new_conn = mariadb.connect(user=default_conf["user"], ssl=True,
|
|
default_file="./client.cnf")
|
|
self.assertEqual(new_conn.database, default_conf["database"])
|
|
del new_conn
|
|
os.remove("client.cnf")
|
|
|
|
def test_autocommit(self):
|
|
conn = self.connection
|
|
conn.autocommit = False
|
|
self.assertEqual(conn.autocommit, False)
|
|
# revert
|
|
conn.autocommit = True
|
|
self.assertEqual(conn.autocommit, True)
|
|
|
|
def test_local_infile(self):
|
|
default_conf = conf()
|
|
new_conn = mariadb.connect(**default_conf, local_infile=False)
|
|
cursor = new_conn.cursor()
|
|
cursor.execute("CREATE TEMPORARY TABLE t1 (a int)")
|
|
try:
|
|
cursor.execute("LOAD DATA LOCAL INFILE 'x.x' INTO TABLE t1")
|
|
except (mariadb.OperationalError,):
|
|
# make asan happy
|
|
if mariadb._have_asan:
|
|
tb = sys.exc_info()[2]
|
|
traceback.clear_frames(tb)
|
|
pass
|
|
del cursor
|
|
del new_conn
|
|
|
|
def test_tls_version(self):
|
|
if is_maxscale():
|
|
self.skipTest("MAXSCALE test has no SSL on port by default")
|
|
default_conf = conf()
|
|
conn = mariadb.connect(**default_conf, tls_version="TLSv1.2")
|
|
cursor = conn.cursor()
|
|
cursor.execute("SHOW STATUS LIKE 'ssl_version'")
|
|
row = cursor.fetchone()
|
|
self.assertEqual(row[1], "TLSv1.2")
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def test_init_command(self):
|
|
default_conf = conf()
|
|
new_conn = mariadb.connect(**default_conf, init_command="SET @a:=1")
|
|
cursor = new_conn.cursor()
|
|
cursor.execute("SELECT @a")
|
|
row = cursor.fetchone()
|
|
self.assertEqual(row[0], 1)
|
|
del cursor
|
|
del new_conn
|
|
|
|
def test_compress(self):
|
|
default_conf = conf()
|
|
new_conn = mariadb.connect(**default_conf, compress=True)
|
|
cursor = new_conn.cursor()
|
|
cursor.execute("SHOW SESSION STATUS LIKE 'compression'")
|
|
row = cursor.fetchone()
|
|
if is_maxscale():
|
|
self.assertEqual(row[1], "OFF")
|
|
else:
|
|
self.assertEqual(row[1], "ON")
|
|
del cursor
|
|
del new_conn
|
|
|
|
def test_schema(self):
|
|
if self.connection.server_version < 100202:
|
|
self.skipTest("session tracking not supported")
|
|
if is_maxscale():
|
|
self.skipTest("MAXSCALE doesn't tell schema change for now")
|
|
|
|
default_conf = conf()
|
|
conn = create_connection()
|
|
self.assertEqual(conn.database, default_conf["database"])
|
|
cursor = conn.cursor()
|
|
cursor.execute("DROP SCHEMA IF EXISTS test1")
|
|
cursor.execute("CREATE SCHEMA test1")
|
|
cursor.execute("USE test1")
|
|
self.assertEqual(conn.database, "test1")
|
|
conn.database = default_conf["database"]
|
|
self.assertEqual(conn.database, default_conf["database"])
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def test_ping(self):
|
|
if is_maxscale():
|
|
self.skipTest("MAXSCALE wrong thread id")
|
|
with create_connection() as conn:
|
|
with conn.cursor() as cursor:
|
|
oldid = conn.connection_id
|
|
|
|
try:
|
|
cursor.execute("KILL {id}" . format(id=oldid))
|
|
except mariadb.Error:
|
|
pass
|
|
|
|
conn.auto_reconnect = True
|
|
conn.ping()
|
|
self.assertNotEqual(oldid, conn.connection_id)
|
|
self.assertNotEqual(0, conn.connection_id)
|
|
|
|
def test_ed25519(self):
|
|
if is_skysql():
|
|
self.skipTest("Test fail on SkySQL")
|
|
default_conf = conf()
|
|
if is_maxscale():
|
|
self.skipTest("MAXSCALE doesn't support ed25519 for now")
|
|
if self.connection.server_version < 100122:
|
|
self.skipTest("ed25519 not supported")
|
|
|
|
conn = create_connection()
|
|
curs = conn.cursor(buffered=True)
|
|
|
|
if self.connection.server_name == "localhost":
|
|
curs.execute("select * from information_schema.plugins where "
|
|
"plugin_name ='unix_socket' and "
|
|
"plugin_status ='ACTIVE'")
|
|
if curs.rowcount > 0:
|
|
del curs
|
|
self.skipTest("unix_socket is active")
|
|
|
|
cursor = conn.cursor()
|
|
try:
|
|
cursor.execute("INSTALL SONAME 'auth_ed25519'")
|
|
except (mariadb.DatabaseError, mariadb.OperationalError):
|
|
self.skipTest("Server couldn't load auth_ed25519")
|
|
cursor.execute("DROP USER IF EXISTS eduser")
|
|
if self.connection.server_version < 100400:
|
|
cursor.execute("CREATE USER eduser"+get_host_suffix()+" IDENTIFIED VIA ed25519 "
|
|
"USING "
|
|
"'6aW9C7ENlasUfymtfMvMZZtnkCVlcb1ssxOLJ0kj/AA'")
|
|
else:
|
|
cursor.execute("CREATE USER eduser"+get_host_suffix()+" IDENTIFIED VIA ed25519 "
|
|
"USING PASSWORD('MySup8%rPassw@ord')")
|
|
cursor.execute("GRANT ALL on " + default_conf["database"] +
|
|
".* to eduser"+get_host_suffix())
|
|
conn2 = create_connection({"user": "eduser",
|
|
"password": "MySup8%rPassw@ord"})
|
|
# disabling this test part for now
|
|
# try:
|
|
# create_connection({"user": "eduser",
|
|
# "password": "MySup8%rPassw@ord",
|
|
# "plugin_dir": "wrong_plugin_dir"})
|
|
# self.fail("wrong plugin directory, must not have found "
|
|
# "authentication plugin")
|
|
# except (mariadb.OperationalError):
|
|
# pass
|
|
cursor.execute("DROP USER IF EXISTS eduser"+get_host_suffix())
|
|
del cursor, conn2, conn
|
|
|
|
def test_conpy46(self):
|
|
with create_connection() as con:
|
|
with con.cursor() as cursor:
|
|
cursor.execute("SELECT 'foo'")
|
|
row = cursor.fetchone()
|
|
self.assertEqual(row[0], "foo")
|
|
try:
|
|
cursor.execute("SELECT 'bar'")
|
|
except mariadb.ProgrammingError:
|
|
pass
|
|
try:
|
|
cursor = con.cursor()
|
|
except mariadb.ProgrammingError:
|
|
pass
|
|
|
|
def test_conpy101(self):
|
|
default_conf = conf()
|
|
c1 = mariadb.connect(**default_conf)
|
|
self.assertEqual(c1.autocommit, False)
|
|
c1.close()
|
|
c1 = mariadb.connect(**default_conf, autocommit=True)
|
|
self.assertEqual(c1.autocommit, True)
|
|
c1.close()
|
|
|
|
def test_db_attribute(self):
|
|
with create_connection() as con:
|
|
with con.cursor() as cursor:
|
|
db = con.database
|
|
try:
|
|
cursor.execute("create schema test123")
|
|
except mariadb.Error:
|
|
pass
|
|
con.database = "test123"
|
|
cursor.execute("select database()", buffered=True)
|
|
row = cursor.fetchone()
|
|
self.assertEqual(row[0], "test123")
|
|
con.database = db
|
|
cursor.execute("select database()", buffered=True)
|
|
row = cursor.fetchone()
|
|
self.assertEqual(row[0], db)
|
|
self.assertEqual(row[0], con.database)
|
|
cursor.execute("drop schema test123")
|
|
|
|
def test_server_status(self):
|
|
con = create_connection()
|
|
self.assertTrue(not con.server_status & STATUS.AUTOCOMMIT)
|
|
con.autocommit = True
|
|
self.assertTrue(con.server_status & STATUS.AUTOCOMMIT)
|
|
con.autocommit = False
|
|
self.assertTrue(not con.server_status & STATUS.AUTOCOMMIT)
|
|
|
|
def test_conpy175(self):
|
|
default_conf = conf()
|
|
conn = mariadb.connect(**default_conf)
|
|
str = "Bob's"
|
|
cursor= conn.cursor()
|
|
cursor.execute("SET session sql_mode='NO_BACKSLASH_ESCAPES'")
|
|
newstr = conn.escape_string(str)
|
|
self.assertEqual(newstr, "Bob''s")
|
|
cursor.execute("SET session sql_mode=''")
|
|
newstr = conn.escape_string(str)
|
|
self.assertEqual(newstr, "Bob\\'s")
|
|
conn.close()
|
|
|
|
def test_closed(self):
|
|
default_conf = conf()
|
|
conn = mariadb.connect(**default_conf)
|
|
conn.close()
|
|
try:
|
|
conn.cursor()
|
|
except (mariadb.ProgrammingError):
|
|
pass
|
|
|
|
def test_multi_host(self):
|
|
default_conf = conf()
|
|
default_conf["host"] = "non_existant," + default_conf["host"]
|
|
try:
|
|
mariadb.connect(**default_conf)
|
|
except mariadb.ProgrammingError:
|
|
self.assertLess(parse_version(mariadb.mariadbapi_version),
|
|
parse_version('3.3.0'))
|
|
pass
|
|
|
|
def test_tls_verification(self):
|
|
if is_maxscale():
|
|
self.skipTest("MAXSCALE test has no SSL on port by default")
|
|
if version.Version(mariadb.mariadbapi_version) <\
|
|
version.Version('3.4.2'):
|
|
self.skipTest("Requires C/C 3.4.2 or newer")
|
|
default_conf= conf()
|
|
default_conf["ssl"] = False
|
|
conn= mariadb.connect(**default_conf)
|
|
self.assertEqual(conn._tls_verify_status, None)
|
|
conn.close()
|
|
default_conf= conf()
|
|
default_conf["ssl"] = True
|
|
conn= mariadb.connect(**default_conf)
|
|
self.assertNotEqual(conn._tls_verify_status, None)
|
|
conn.close()
|
|
|
|
def test_tls_fp(self):
|
|
if is_maxscale():
|
|
self.skipTest("MAXSCALE test has no SSL on port by default")
|
|
if version.Version(mariadb.mariadbapi_version) <\
|
|
version.Version('3.4.2'):
|
|
self.skipTest("Requires C/C 3.4.2 or newer")
|
|
default_conf= conf()
|
|
default_conf["ssl"] = True
|
|
conn= mariadb.connect(**default_conf)
|
|
self.assertEqual(conn._tls, True)
|
|
x509_info= conn.tls_peer_cert_info
|
|
if not x509_info:
|
|
conn.close()
|
|
self.skipTest("Peer certificate information not supported")
|
|
fp= x509_info["fingerprint"]
|
|
self.assertEqual(len(fp), 64)
|
|
conn.close()
|
|
default_conf= conf()
|
|
default_conf["tls_fp"] = fp
|
|
conn= mariadb.connect(**default_conf)
|
|
self.assertEqual(conn._tls, True)
|
|
x509_info= conn.tls_peer_cert_info
|
|
self.assertEqual(fp, x509_info["fingerprint"])
|
|
conn.close()
|
|
|
|
def test_conpy278(self):
|
|
if is_maxscale():
|
|
self.skipTest("MAXSCALE bug MXS-4961")
|
|
with create_connection({"reconnect" : True}) as conn:
|
|
old_id= conn.connection_id
|
|
try:
|
|
conn.kill(conn.connection_id)
|
|
except mariadb.OperationalError:
|
|
conn.ping()
|
|
self.assertNotEqual(old_id, conn.connection_id)
|
|
with create_connection({"reconnect" : True}) as conn:
|
|
old_id= conn.connection_id
|
|
try:
|
|
conn.kill(conn.connection_id)
|
|
except mariadb.OperationalError:
|
|
conn.ping()
|
|
self.assertNotEqual(old_id, conn.connection_id)
|
|
with create_connection({"reconnect" : True}) as conn:
|
|
old_id= conn.connection_id
|
|
try:
|
|
conn.kill(conn.connection_id)
|
|
except mariadb.OperationalError:
|
|
pass
|
|
with conn.cursor() as cursor:
|
|
try:
|
|
cursor.execute("set @a:=1")
|
|
except mariadb.InterfaceError:
|
|
pass
|
|
cursor.execute("set @a:=1")
|
|
self.assertNotEqual(old_id, conn.connection_id)
|
|
|
|
old_id= conn.connection_id
|
|
conn.reconnect()
|
|
self.assertNotEqual(old_id, conn.connection_id)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|