#!/usr/bin/env python -O # -*- coding: utf-8 -*- import os import unittest import mariadb from test.base_test import create_connection, is_skysql, is_maxscale, get_host_suffix from test.conf_test import conf from mariadb.constants import STATUS import platform from packaging.version import parse as parse_version from packaging import version import traceback, sys class TestConnection(unittest.TestCase): def setUp(self): self.connection = create_connection() def tearDown(self): del self.connection def test_conpy36(self): if platform.system() == "Windows": self.skipTest("unix_socket not supported on Windows") default_conf = conf() try: mariadb.connect(user=default_conf["user"], unix_socket="/does_not_exist/x.sock", port=default_conf["port"], host=default_conf["host"]) except (mariadb.OperationalError,): # make asan happy tb = sys.exc_info()[2] traceback.clear_frames(tb) pass def test_connection_default_file(self): if os.path.exists("client.cnf"): os.remove("client.cnf") default_conf = conf() f = open("client.cnf", "w+") f.write("[client]\n") f.write("host =%s\n" % default_conf["host"]) f.write("port =%i\n" % default_conf["port"]) f.write("user =%s\n" % default_conf["user"]) if "password" in default_conf: f.write("password =%s\n" % default_conf["password"]) f.write("database =%s\n" % default_conf["database"]) f.close() new_conn = mariadb.connect(user=default_conf["user"], ssl=True, default_file="./client.cnf") self.assertEqual(new_conn.database, default_conf["database"]) del new_conn os.remove("client.cnf") def test_autocommit(self): conn = self.connection conn.autocommit = False self.assertEqual(conn.autocommit, False) # revert conn.autocommit = True self.assertEqual(conn.autocommit, True) def test_local_infile(self): default_conf = conf() new_conn = mariadb.connect(**default_conf, local_infile=False) cursor = new_conn.cursor() cursor.execute("CREATE TEMPORARY TABLE t1 (a int)") try: cursor.execute("LOAD DATA LOCAL INFILE 'x.x' INTO TABLE t1") except (mariadb.OperationalError,): # make asan happy if mariadb._have_asan: tb = sys.exc_info()[2] traceback.clear_frames(tb) pass del cursor del new_conn def test_tls_version(self): if is_maxscale(): self.skipTest("MAXSCALE test has no SSL on port by default") default_conf = conf() conn = mariadb.connect(**default_conf, tls_version="TLSv1.2") cursor = conn.cursor() cursor.execute("SHOW STATUS LIKE 'ssl_version'") row = cursor.fetchone() self.assertEqual(row[1], "TLSv1.2") cursor.close() conn.close() def test_init_command(self): default_conf = conf() new_conn = mariadb.connect(**default_conf, init_command="SET @a:=1") cursor = new_conn.cursor() cursor.execute("SELECT @a") row = cursor.fetchone() self.assertEqual(row[0], 1) del cursor del new_conn def test_compress(self): default_conf = conf() new_conn = mariadb.connect(**default_conf, compress=True) cursor = new_conn.cursor() cursor.execute("SHOW SESSION STATUS LIKE 'compression'") row = cursor.fetchone() if is_maxscale(): self.assertEqual(row[1], "OFF") else: self.assertEqual(row[1], "ON") del cursor del new_conn def test_schema(self): if self.connection.server_version < 100202: self.skipTest("session tracking not supported") if is_maxscale(): self.skipTest("MAXSCALE doesn't tell schema change for now") default_conf = conf() conn = create_connection() self.assertEqual(conn.database, default_conf["database"]) cursor = conn.cursor() cursor.execute("DROP SCHEMA IF EXISTS test1") cursor.execute("CREATE SCHEMA test1") cursor.execute("USE test1") self.assertEqual(conn.database, "test1") conn.database = default_conf["database"] self.assertEqual(conn.database, default_conf["database"]) cursor.close() conn.close() def test_ping(self): if is_maxscale(): self.skipTest("MAXSCALE wrong thread id") with create_connection() as conn: with conn.cursor() as cursor: oldid = conn.connection_id try: cursor.execute("KILL {id}" . format(id=oldid)) except mariadb.Error: pass conn.auto_reconnect = True conn.ping() self.assertNotEqual(oldid, conn.connection_id) self.assertNotEqual(0, conn.connection_id) def test_ed25519(self): if is_skysql(): self.skipTest("Test fail on SkySQL") default_conf = conf() if is_maxscale(): self.skipTest("MAXSCALE doesn't support ed25519 for now") if self.connection.server_version < 100122: self.skipTest("ed25519 not supported") conn = create_connection() curs = conn.cursor(buffered=True) if self.connection.server_name == "localhost": curs.execute("select * from information_schema.plugins where " "plugin_name ='unix_socket' and " "plugin_status ='ACTIVE'") if curs.rowcount > 0: del curs self.skipTest("unix_socket is active") cursor = conn.cursor() try: cursor.execute("INSTALL SONAME 'auth_ed25519'") except (mariadb.DatabaseError, mariadb.OperationalError): self.skipTest("Server couldn't load auth_ed25519") cursor.execute("DROP USER IF EXISTS eduser") if self.connection.server_version < 100400: cursor.execute("CREATE USER eduser"+get_host_suffix()+" IDENTIFIED VIA ed25519 " "USING " "'6aW9C7ENlasUfymtfMvMZZtnkCVlcb1ssxOLJ0kj/AA'") else: cursor.execute("CREATE USER eduser"+get_host_suffix()+" IDENTIFIED VIA ed25519 " "USING PASSWORD('MySup8%rPassw@ord')") cursor.execute("GRANT ALL on " + default_conf["database"] + ".* to eduser"+get_host_suffix()) conn2 = create_connection({"user": "eduser", "password": "MySup8%rPassw@ord"}) # disabling this test part for now # try: # create_connection({"user": "eduser", # "password": "MySup8%rPassw@ord", # "plugin_dir": "wrong_plugin_dir"}) # self.fail("wrong plugin directory, must not have found " # "authentication plugin") # except (mariadb.OperationalError): # pass cursor.execute("DROP USER IF EXISTS eduser"+get_host_suffix()) del cursor, conn2, conn def test_conpy46(self): with create_connection() as con: with con.cursor() as cursor: cursor.execute("SELECT 'foo'") row = cursor.fetchone() self.assertEqual(row[0], "foo") try: cursor.execute("SELECT 'bar'") except mariadb.ProgrammingError: pass try: cursor = con.cursor() except mariadb.ProgrammingError: pass def test_conpy101(self): default_conf = conf() c1 = mariadb.connect(**default_conf) self.assertEqual(c1.autocommit, False) c1.close() c1 = mariadb.connect(**default_conf, autocommit=True) self.assertEqual(c1.autocommit, True) c1.close() def test_db_attribute(self): with create_connection() as con: with con.cursor() as cursor: db = con.database try: cursor.execute("create schema test123") except mariadb.Error: pass con.database = "test123" cursor.execute("select database()", buffered=True) row = cursor.fetchone() self.assertEqual(row[0], "test123") con.database = db cursor.execute("select database()", buffered=True) row = cursor.fetchone() self.assertEqual(row[0], db) self.assertEqual(row[0], con.database) cursor.execute("drop schema test123") def test_server_status(self): con = create_connection() self.assertTrue(not con.server_status & STATUS.AUTOCOMMIT) con.autocommit = True self.assertTrue(con.server_status & STATUS.AUTOCOMMIT) con.autocommit = False self.assertTrue(not con.server_status & STATUS.AUTOCOMMIT) def test_conpy175(self): default_conf = conf() conn = mariadb.connect(**default_conf) str = "Bob's" cursor= conn.cursor() cursor.execute("SET session sql_mode='NO_BACKSLASH_ESCAPES'") newstr = conn.escape_string(str) self.assertEqual(newstr, "Bob''s") cursor.execute("SET session sql_mode=''") newstr = conn.escape_string(str) self.assertEqual(newstr, "Bob\\'s") conn.close() def test_closed(self): default_conf = conf() conn = mariadb.connect(**default_conf) conn.close() try: conn.cursor() except (mariadb.ProgrammingError): pass def test_multi_host(self): default_conf = conf() default_conf["host"] = "non_existant," + default_conf["host"] try: mariadb.connect(**default_conf) except mariadb.ProgrammingError: self.assertLess(parse_version(mariadb.mariadbapi_version), parse_version('3.3.0')) pass def test_tls_verification(self): if is_maxscale(): self.skipTest("MAXSCALE test has no SSL on port by default") if version.Version(mariadb.mariadbapi_version) <\ version.Version('3.4.2'): self.skipTest("Requires C/C 3.4.2 or newer") default_conf= conf() default_conf["ssl"] = False conn= mariadb.connect(**default_conf) self.assertEqual(conn._tls_verify_status, None) conn.close() default_conf= conf() default_conf["ssl"] = True conn= mariadb.connect(**default_conf) self.assertNotEqual(conn._tls_verify_status, None) conn.close() def test_tls_fp(self): if is_maxscale(): self.skipTest("MAXSCALE test has no SSL on port by default") if version.Version(mariadb.mariadbapi_version) <\ version.Version('3.4.2'): self.skipTest("Requires C/C 3.4.2 or newer") default_conf= conf() default_conf["ssl"] = True conn= mariadb.connect(**default_conf) self.assertEqual(conn._tls, True) x509_info= conn.tls_peer_cert_info if not x509_info: conn.close() self.skipTest("Peer certificate information not supported") fp= x509_info["fingerprint"] self.assertEqual(len(fp), 64) conn.close() default_conf= conf() default_conf["tls_fp"] = fp conn= mariadb.connect(**default_conf) self.assertEqual(conn._tls, True) x509_info= conn.tls_peer_cert_info self.assertEqual(fp, x509_info["fingerprint"]) conn.close() def test_conpy278(self): if is_maxscale(): self.skipTest("MAXSCALE bug MXS-4961") with create_connection({"reconnect" : True}) as conn: old_id= conn.connection_id try: conn.kill(conn.connection_id) except mariadb.OperationalError: conn.ping() self.assertNotEqual(old_id, conn.connection_id) with create_connection({"reconnect" : True}) as conn: old_id= conn.connection_id try: conn.kill(conn.connection_id) except mariadb.OperationalError: conn.ping() self.assertNotEqual(old_id, conn.connection_id) with create_connection({"reconnect" : True}) as conn: old_id= conn.connection_id try: conn.kill(conn.connection_id) except mariadb.OperationalError: pass with conn.cursor() as cursor: try: cursor.execute("set @a:=1") except mariadb.InterfaceError: pass cursor.execute("set @a:=1") self.assertNotEqual(old_id, conn.connection_id) old_id= conn.connection_id conn.reconnect() self.assertNotEqual(old_id, conn.connection_id) if __name__ == '__main__': unittest.main()