This commit is contained in:
Georg Richter
2019-11-12 11:13:08 +01:00
41 changed files with 4834 additions and 3461 deletions

1
.gitignore vendored
View File

@ -50,3 +50,4 @@ modules.order
Module.symvers Module.symvers
Mkfile.old Mkfile.old
dkms.conf dkms.conf
client.cnf

69
.travis.yml Normal file
View File

@ -0,0 +1,69 @@
sudo: true
language: python
dist: bionic
services:
- docker
addons:
hosts:
- mariadb.example.com
before_install:
- chmod +x .travis/script.sh
- sudo apt-get install software-properties-common
- sudo apt-key adv --recv-keys --keyserver hkp://keyserver.ubuntu.com:80 0xF1656F24C74CD1D8
- sudo add-apt-repository 'deb [arch=amd64,arm64,ppc64el] http://mirrors.accretive-networks.net/mariadb/repo/10.4/ubuntu bionic main'
- sudo apt-get remove --purge mysql*
- sudo apt update
- sudo apt-get install -f libmariadb3 libmariadb-dev libssl1.1
- sudo apt-get install -f
install:
- wget -qO- 'https://github.com/tianon/pgp-happy-eyeballs/raw/master/hack-my-builds.sh' | bash
# generate SSL certificates
- mkdir tmp
- chmod +x .travis/gen-ssl.sh
- chmod +x .travis/build/build.sh
- chmod +x .travis/build/docker-entrypoint.sh
- chmod 777 .travis/build/
- .travis/gen-ssl.sh mariadb.example.com tmp
- export PROJ_PATH=`pwd`
- export SSLCERT=$PROJ_PATH/tmp
- export TEST_SSL_CA_FILE=$SSLCERT/server.crt
- export TEST_SSL_CLIENT_KEY_FILE=$SSLCERT/client.key
- export TEST_SSL_CLIENT_CERT_FILE=$SSLCERT/client.crt
- export TEST_SSL_CLIENT_KEYSTORE_FILE=$SSLCERT/client-keystore.p12
env:
global:
- TEST_PORT=3305
- TEST_HOST=mariadb.example.com
matrix:
include:
- python: "2.7"
env: DB=mariadb:10.4
- python: "3.6"
env: DB=mariadb:10.4
- python: "3.8"
env: DB=mariadb:10.4
- env: DB=mariadb:10.4 MAXSCALE_VERSION=2.2.9 TEST_PORT=4007 TEST_USER=bob TEXT_DATABASE=test2 SKIP_LEAK=1
- env: DB=mariadb:5.5
- env: DB=mariadb:10.0
- env: DB=mariadb:10.1
- env: DB=mariadb:10.2
- env: DB=mariadb:10.3
- env: DB=mysql:5.5
- env: DB=mysql:5.6
- env: DB=mysql:5.7
notifications:
email: false
script:
- python setup.py build
- python setup.py install
- npm install nyc -g
- .travis/script.sh

99
.travis/build/Dockerfile Normal file
View File

@ -0,0 +1,99 @@
# vim:set ft=dockerfile:
FROM ubuntu:bionic
# add our user and group first to make sure their IDs get assigned consistently, regardless of whatever dependencies get added
RUN groupadd -r mysql && useradd -r -g mysql mysql
# https://bugs.debian.org/830696 (apt uses gpgv by default in newer releases, rather than gpg)
RUN set -ex; \
apt-get update; \
if ! which gpg; then \
apt-get install -y --no-install-recommends gnupg; \
fi; \
# Ubuntu includes "gnupg" (not "gnupg2", but still 2.x), but not dirmngr, and gnupg 2.x requires dirmngr
# so, if we're not running gnupg 1.x, explicitly install dirmngr too
if ! gpg --version | grep -q '^gpg (GnuPG) 1\.'; then \
apt-get install -y --no-install-recommends dirmngr; \
fi; \
rm -rf /var/lib/apt/lists/*
# add gosu for easy step-down from root
ENV GOSU_VERSION 1.10
RUN set -ex; \
\
fetchDeps=' \
ca-certificates \
wget \
'; \
apt-get update; \
apt-get install -y --no-install-recommends $fetchDeps; \
rm -rf /var/lib/apt/lists/*; \
\
dpkgArch="$(dpkg --print-architecture | awk -F- '{ print $NF }')"; \
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-$dpkgArch"; \
wget -O /usr/local/bin/gosu.asc "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-$dpkgArch.asc"; \
\
# verify the signature
export GNUPGHOME="$(mktemp -d)"; \
gpg --batch --keyserver ha.pool.sks-keyservers.net --recv-keys B42F6819007F00F88E364FD4036A9C25BF357DD4; \
gpg --batch --verify /usr/local/bin/gosu.asc /usr/local/bin/gosu; \
command -v gpgconf > /dev/null && gpgconf --kill all || :; \
rm -r "$GNUPGHOME" /usr/local/bin/gosu.asc; \
\
chmod +x /usr/local/bin/gosu; \
# verify that the binary works
gosu nobody true; \
\
apt-get purge -y --auto-remove $fetchDeps
RUN mkdir /docker-entrypoint-initdb.d
# install "pwgen" for randomizing passwords
# install "apt-transport-https" for Percona's repo (switched to https-only)
RUN apt-get update && apt-get install -y --no-install-recommends \
apt-transport-https ca-certificates \
tzdata \
pwgen \
&& rm -rf /var/lib/apt/lists/*
RUN { \
echo "mariadb-server-10.4" mysql-server/root_password password 'unused'; \
echo "mariadb-server-10.4" mysql-server/root_password_again password 'unused'; \
} | debconf-set-selections
RUN apt-get update -y
RUN apt-get install -y software-properties-common wget
RUN apt-key adv --recv-keys --keyserver keyserver.ubuntu.com 0xcbcb082a1bb943db
RUN apt-key adv --recv-keys --keyserver ha.pool.sks-keyservers.net F1656F24C74CD1D8
RUN echo 'deb http://yum.mariadb.org/galera/repo/deb bionic main' > /etc/apt/sources.list.d/galera-test-repo.list
RUN apt-get update -y
RUN apt-get install -y curl libdbi-perl rsync socat galera3 libnuma1 libaio1 zlib1g-dev libreadline5 libjemalloc1 libsnappy1v5 libcrack2
COPY *.deb /root/
RUN chmod 777 /root/*
RUN dpkg --install /root/mysql-common*
RUN dpkg --install /root/mariadb-common*
RUN dpkg -R --unpack /root/
RUN apt-get install -f -y
RUN rm -rf /var/lib/apt/lists/* \
&& sed -ri 's/^user\s/#&/' /etc/mysql/my.cnf /etc/mysql/conf.d/* \
&& rm -rf /var/lib/mysql && mkdir -p /var/lib/mysql /var/run/mysqld \
&& chown -R mysql:mysql /var/lib/mysql /var/run/mysqld \
&& chmod 777 /var/run/mysqld \
&& find /etc/mysql/ -name '*.cnf' -print0 \
| xargs -0 grep -lZE '^(bind-address|log)' \
| xargs -rt -0 sed -Ei 's/^(bind-address|log)/#&/' \
&& echo '[mysqld]\nskip-host-cache\nskip-name-resolve' > /etc/mysql/conf.d/docker.cnf
VOLUME /var/lib/mysql
COPY docker-entrypoint.sh /usr/local/bin/
RUN ln -s usr/local/bin/docker-entrypoint.sh / # backwards compat
ENTRYPOINT ["docker-entrypoint.sh"]
EXPOSE 3306
CMD ["mysqld"]

33
.travis/build/build.sh Normal file
View File

@ -0,0 +1,33 @@
#!/usr/bin/env bash
echo "**************************************************************************"
echo "* searching for last complete build"
echo "**************************************************************************"
wget -q -o /dev/null index.html http://hasky.askmonty.org/archive/10.4/
grep -o ">build-[0-9]*" index.html | grep -o "[0-9]*" | tac | while read -r line ; do
curl -s --head http://hasky.askmonty.org/archive/10.4/build-$line/kvm-deb-bionic-amd64/md5sums.txt | head -n 1 | grep "HTTP/1.[01] [23].." > /dev/null
if [ $? = "0" ]; then
echo "**************************************************************************"
echo "* Processing $line"
echo "**************************************************************************"
wget -q -o /dev/null -O $line.html http://hasky.askmonty.org/archive/10.4/build-$line/kvm-deb-bionic-amd64/debs/binary/
grep -o ">[^\"]*\.deb" $line.html | grep -o "[^>]*\.deb" | while read -r file ; do
if [[ "$file" =~ ^mariadb-plugin.* ]] ;
then
echo "skipped file: $file"
else
echo "download file: $file"
wget -q -o /dev/null -O .travis/build/$file http://hasky.askmonty.org/archive/10.4/build-$line/kvm-deb-bionic-amd64/debs/binary/$file
fi
done
exit
else
echo "skip build $line"
fi
done

View File

@ -0,0 +1,196 @@
#!/bin/bash
set -eo pipefail
shopt -s nullglob
# if command starts with an option, prepend mysqld
if [ "${1:0:1}" = '-' ]; then
set -- mysqld "$@"
fi
# skip setup if they want an option that stops mysqld
wantHelp=
for arg; do
case "$arg" in
-'?'|--help|--print-defaults|-V|--version)
wantHelp=1
break
;;
esac
done
# usage: file_env VAR [DEFAULT]
# ie: file_env 'XYZ_DB_PASSWORD' 'example'
# (will allow for "$XYZ_DB_PASSWORD_FILE" to fill in the value of
# "$XYZ_DB_PASSWORD" from a file, especially for Docker's secrets feature)
file_env() {
local var="$1"
local fileVar="${var}_FILE"
local def="${2:-}"
if [ "${!var:-}" ] && [ "${!fileVar:-}" ]; then
echo >&2 "error: both $var and $fileVar are set (but are exclusive)"
exit 1
fi
local val="$def"
if [ "${!var:-}" ]; then
val="${!var}"
elif [ "${!fileVar:-}" ]; then
val="$(< "${!fileVar}")"
fi
export "$var"="$val"
unset "$fileVar"
}
_check_config() {
toRun=( "$@" --verbose --help --log-bin-index="$(mktemp -u)" )
if ! errors="$("${toRun[@]}" 2>&1 >/dev/null)"; then
cat >&2 <<-EOM
ERROR: mysqld failed while attempting to check config
command was: "${toRun[*]}"
$errors
EOM
exit 1
fi
}
# Fetch value from server config
# We use mysqld --verbose --help instead of my_print_defaults because the
# latter only show values present in config files, and not server defaults
_get_config() {
local conf="$1"; shift
"$@" --verbose --help --log-bin-index="$(mktemp -u)" 2>/dev/null \
| awk '$1 == "'"$conf"'" && /^[^ \t]/ { sub(/^[^ \t]+[ \t]+/, ""); print; exit }'
# match "datadir /some/path with/spaces in/it here" but not "--xyz=abc\n datadir (xyz)"
}
# allow the container to be started with `--user`
if [ "$1" = 'mysqld' -a -z "$wantHelp" -a "$(id -u)" = '0' ]; then
_check_config "$@"
DATADIR="$(_get_config 'datadir' "$@")"
mkdir -p "$DATADIR"
find "$DATADIR" \! -user mysql -exec chown mysql '{}' +
exec gosu mysql "$BASH_SOURCE" "$@"
fi
if [ "$1" = 'mysqld' -a -z "$wantHelp" ]; then
# still need to check config, container may have started with --user
_check_config "$@"
# Get config
DATADIR="$(_get_config 'datadir' "$@")"
if [ ! -d "$DATADIR/mysql" ]; then
file_env 'MYSQL_ROOT_PASSWORD'
if [ -z "$MYSQL_ROOT_PASSWORD" -a -z "$MYSQL_ALLOW_EMPTY_PASSWORD" -a -z "$MYSQL_RANDOM_ROOT_PASSWORD" ]; then
echo >&2 'error: database is uninitialized and password option is not specified '
echo >&2 ' You need to specify one of MYSQL_ROOT_PASSWORD, MYSQL_ALLOW_EMPTY_PASSWORD and MYSQL_RANDOM_ROOT_PASSWORD'
exit 1
fi
mkdir -p "$DATADIR"
echo 'Initializing database'
installArgs=( --datadir="$DATADIR" --rpm )
if { mysql_install_db --help || :; } | grep -q -- '--auth-root-authentication-method'; then
# beginning in 10.4.3, install_db uses "socket" which only allows system user root to connect, switch back to "normal" to allow mysql root without a password
# see https://github.com/MariaDB/server/commit/b9f3f06857ac6f9105dc65caae19782f09b47fb3
# (this flag doesn't exist in 10.0 and below)
installArgs+=( --auth-root-authentication-method=normal )
fi
# "Other options are passed to mysqld." (so we pass all "mysqld" arguments directly here)
mysql_install_db "${installArgs[@]}" "${@:2}"
echo 'Database initialized'
SOCKET="$(_get_config 'socket' "$@")"
"$@" --skip-networking --socket="${SOCKET}" &
pid="$!"
mysql=( mysql --protocol=socket -uroot -hlocalhost --socket="${SOCKET}" )
for i in {30..0}; do
if echo 'SELECT 1' | "${mysql[@]}" &> /dev/null; then
break
fi
echo 'MySQL init process in progress...'
sleep 1
done
if [ "$i" = 0 ]; then
echo >&2 'MySQL init process failed.'
exit 1
fi
if [ -z "$MYSQL_INITDB_SKIP_TZINFO" ]; then
# sed is for https://bugs.mysql.com/bug.php?id=20545
mysql_tzinfo_to_sql /usr/share/zoneinfo | sed 's/Local time zone must be set--see zic manual page/FCTY/' | "${mysql[@]}" mysql
fi
if [ ! -z "$MYSQL_RANDOM_ROOT_PASSWORD" ]; then
export MYSQL_ROOT_PASSWORD="$(pwgen -1 32)"
echo "GENERATED ROOT PASSWORD: $MYSQL_ROOT_PASSWORD"
fi
rootCreate=
# default root to listen for connections from anywhere
file_env 'MYSQL_ROOT_HOST' '%'
if [ ! -z "$MYSQL_ROOT_HOST" -a "$MYSQL_ROOT_HOST" != 'localhost' ]; then
# no, we don't care if read finds a terminating character in this heredoc
# https://unix.stackexchange.com/questions/265149/why-is-set-o-errexit-breaking-this-read-heredoc-expression/265151#265151
read -r -d '' rootCreate <<-EOSQL || true
CREATE USER 'root'@'${MYSQL_ROOT_HOST}' IDENTIFIED BY '${MYSQL_ROOT_PASSWORD}' ;
GRANT ALL ON *.* TO 'root'@'${MYSQL_ROOT_HOST}' WITH GRANT OPTION ;
EOSQL
fi
"${mysql[@]}" <<-EOSQL
-- What's done in this file shouldn't be replicated
-- or products like mysql-fabric won't work
SET @@SESSION.SQL_LOG_BIN=0;
DELETE FROM mysql.user WHERE user NOT IN ('mysql.sys', 'mysqlxsys', 'root') OR host NOT IN ('localhost') ;
SET PASSWORD FOR 'root'@'localhost'=PASSWORD('${MYSQL_ROOT_PASSWORD}') ;
GRANT ALL ON *.* TO 'root'@'localhost' WITH GRANT OPTION ;
${rootCreate}
DROP DATABASE IF EXISTS test ;
FLUSH PRIVILEGES ;
EOSQL
if [ ! -z "$MYSQL_ROOT_PASSWORD" ]; then
mysql+=( -p"${MYSQL_ROOT_PASSWORD}" )
fi
file_env 'MYSQL_DATABASE'
if [ "$MYSQL_DATABASE" ]; then
echo "CREATE DATABASE IF NOT EXISTS \`$MYSQL_DATABASE\` ;" | "${mysql[@]}"
mysql+=( "$MYSQL_DATABASE" )
fi
file_env 'MYSQL_USER'
file_env 'MYSQL_PASSWORD'
if [ "$MYSQL_USER" -a "$MYSQL_PASSWORD" ]; then
echo "CREATE USER '$MYSQL_USER'@'%' IDENTIFIED BY '$MYSQL_PASSWORD' ;" | "${mysql[@]}"
if [ "$MYSQL_DATABASE" ]; then
echo "GRANT ALL ON \`$MYSQL_DATABASE\`.* TO '$MYSQL_USER'@'%' ;" | "${mysql[@]}"
fi
fi
echo
for f in /docker-entrypoint-initdb.d/*; do
case "$f" in
*.sh) echo "$0: running $f"; . "$f" ;;
*.sql) echo "$0: running $f"; "${mysql[@]}" < "$f"; echo ;;
*.sql.gz) echo "$0: running $f"; gunzip -c "$f" | "${mysql[@]}"; echo ;;
*) echo "$0: ignoring $f" ;;
esac
echo
done
if ! kill -s TERM "$pid" || ! wait "$pid"; then
echo >&2 'MySQL init process failed.'
exit 1
fi
echo
echo 'MySQL init process done. Ready for start up.'
echo
fi
fi
exec "$@"

View File

@ -0,0 +1,17 @@
version: '2'
services:
db:
image: $DB
command: --innodb-log-file-size=400m --max-allowed-packet=40m --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --ssl-ca=/etc/sslcert/ca.crt --ssl-cert=/etc/sslcert/server.crt --ssl-key=/etc/sslcert/server.key --bind-address=0.0.0.0 $ADDITIONAL_CONF
ports:
- 3305:3306
volumes:
- $SSLCERT:/etc/sslcert
- $ENTRYPOINT:/pam
environment:
MYSQL_DATABASE: testp
MYSQL_ALLOW_EMPTY_PASSWORD: 1
MYSQL_ROOT_PASSWORD:

View File

@ -0,0 +1,11 @@
CREATE USER 'bob'@'%';
GRANT ALL ON *.* TO 'bob'@'%' with grant option;
CREATE USER 'boby'@'%' identified by 'heyPassw0@rd';
GRANT ALL ON *.* TO 'boby'@'%' with grant option;
INSTALL PLUGIN pam SONAME 'auth_pam';
FLUSH PRIVILEGES;
CREATE DATABASE test2;

16
.travis/entrypoint/pam.sh Normal file
View File

@ -0,0 +1,16 @@
#!/bin/bash
tee /etc/pam.d/mariadb << EOF
auth required pam_unix.so audit
auth required pam_unix.so audit
account required pam_unix.so audit
EOF
useradd testPam
chpasswd << EOF
testPam:myPwd
EOF
usermod -a -G shadow mysql
echo "pam configuration done"

128
.travis/gen-ssl.sh Normal file
View File

@ -0,0 +1,128 @@
#!/bin/bash
set -e
log () {
echo "$@" 1>&2
}
print_error () {
echo "$@" 1>&2
exit 1
}
print_usage () {
print_error "Usage: gen-ssl-cert-key <fqdn> <output-dir>"
}
gen_cert_subject () {
local fqdn="$1"
[[ "${fqdn}" != "" ]] || print_error "FQDN cannot be blank"
echo "/C=XX/ST=X/O=X/localityName=X/CN=${fqdn}/organizationalUnitName=X/emailAddress=X/"
}
main () {
local fqdn="$1"
local sslDir="$2"
[[ "${fqdn}" != "" ]] || print_usage
[[ -d "${sslDir}" ]] || print_error "Directory does not exist: ${sslDir}"
local caCertFile="${sslDir}/ca.crt"
local caKeyFile="${sslDir}/ca.key"
local certFile="${sslDir}/server.crt"
local keyFile="${sslDir}/server.key"
local csrFile="${sslDir}/csrFile.key"
local clientCertFile="${sslDir}/client.crt"
local clientKeyFile="${sslDir}/client.key"
local clientKeystoreFile="${sslDir}/client-keystore.p12"
local pcks12FullKeystoreFile="${sslDir}/fullclient-keystore.p12"
local clientReqFile="${sslDir}/clientReqFile.key"
log "Generating CA key"
openssl genrsa -out "${caKeyFile}" 2048
log "Generating CA certificate"
openssl req \
-sha1 \
-new \
-x509 \
-nodes \
-days 3650 \
-subj "$(gen_cert_subject ca.example.com)" \
-key "${caKeyFile}" \
-out "${caCertFile}"
log "Generating private key"
openssl genrsa -out "${keyFile}" 2048
log "Generating certificate signing request"
openssl req \
-new \
-batch \
-sha1 \
-subj "$(gen_cert_subject "$fqdn")" \
-set_serial 01 \
-key "${keyFile}" \
-out "${csrFile}" \
-nodes
log "Generating X509 certificate"
openssl x509 \
-req \
-sha1 \
-set_serial 01 \
-CA "${caCertFile}" \
-CAkey "${caKeyFile}" \
-days 3650 \
-in "${csrFile}" \
-signkey "${keyFile}" \
-out "${certFile}"
log "Generating client certificate"
openssl req \
-batch \
-newkey rsa:2048 \
-days 3600 \
-subj "$(gen_cert_subject "$fqdn")" \
-nodes \
-keyout "${clientKeyFile}" \
-out "${clientReqFile}"
openssl x509 \
-req \
-in "${clientReqFile}" \
-days 3600 \
-CA "${caCertFile}" \
-CAkey "${caKeyFile}" \
-set_serial 01 \
-out "${clientCertFile}"
# Now generate a keystore with the client cert & key
log "Generating client keystore"
openssl pkcs12 \
-export \
-in "${clientCertFile}" \
-inkey "${clientKeyFile}" \
-out "${clientKeystoreFile}" \
-name "mysqlAlias" \
-passout pass:kspass
# Now generate a full keystore with the client cert & key + trust certificates
log "Generating full client keystore"
openssl pkcs12 \
-export \
-in "${clientCertFile}" \
-inkey "${clientKeyFile}" \
-out "${pcks12FullKeystoreFile}" \
-name "mysqlAlias" \
-passout pass:kspass
# Clean up CSR file:
rm "$csrFile"
rm "$clientReqFile"
log "Generated key file and certificate in: ${sslDir}"
ls -l "${sslDir}"
}
main "$@"

View File

@ -0,0 +1,25 @@
version: '2.1'
services:
maxscale:
depends_on:
- db
ports:
- 4006:4006
- 4007:4007
- 4008:4008
build:
context: .
dockerfile: maxscale/Dockerfile
args:
MAXSCALE_VERSION: $MAXSCALE_VERSION
db:
image: $DB
command: --max-connections=500 --max-allowed-packet=40m --innodb-log-file-size=400m --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --ssl-ca=/etc/sslcert/ca.crt --ssl-cert=/etc/sslcert/server.crt --ssl-key=/etc/sslcert/server.key --bind-address=0.0.0.0
ports:
- 3305:3306
volumes:
- $SSLCERT:/etc/sslcert
- $ENTRYPOINT:/docker-entrypoint-initdb.d
environment:
MYSQL_DATABASE: testp
MYSQL_ALLOW_EMPTY_PASSWORD: 1

View File

@ -0,0 +1,24 @@
FROM centos:7
ARG MAXSCALE_VERSION
ENV MAXSCALE_VERSION ${MAXSCALE_VERSION:-2.2.9}
COPY maxscale/mariadb.repo /etc/yum.repos.d/
RUN rpm --import https://yum.mariadb.org/RPM-GPG-KEY-MariaDB \
&& yum -y install https://downloads.mariadb.com/MaxScale/${MAXSCALE_VERSION}/centos/7/x86_64/maxscale-${MAXSCALE_VERSION}-1.centos.7.x86_64.rpm \
&& yum -y update
RUN yum -y install maxscale-${MAXSCALE_VERSION} MariaDB-client \
&& yum clean all \
&& rm -rf /tmp/*
COPY maxscale/docker-entrypoint.sh /
RUN chmod 777 /etc/maxscale.cnf
COPY maxscale/maxscale.cnf /etc/
RUN chmod 777 /docker-entrypoint.sh
EXPOSE 4006 4007 4008
ENTRYPOINT ["/docker-entrypoint.sh"]

View File

@ -0,0 +1,35 @@
#!/usr/bin/env bash
set -e
echo 'creating configuration done'
sleep 15
#################################################################################################
# wait for db availability for 60s
#################################################################################################
mysql=( mysql --protocol=tcp -ubob -hdb --port=3306 )
for i in {60..0}; do
if echo 'use test2' | "${mysql[@]}" &> /dev/null; then
break
fi
echo 'DB init process in progress...'
sleep 1
done
echo 'use test2' | "${mysql[@]}"
if [ "$i" = 0 ]; then
echo 'DB init process failed.'
exit 1
fi
echo 'maxscale launching ...'
tail -n 500 /etc/maxscale.cnf
/usr/bin/maxscale --user=root --nodaemon
cd /var/log/maxscale
ls -lrt
tail -n 500 /var/log/maxscale/maxscale.log

View File

@ -0,0 +1,7 @@
# MariaDB 10.3 CentOS repository list - created 2018-11-09 14:50 UTC
# http://downloads.mariadb.org/mariadb/repositories/
[mariadb]
name = MariaDB
baseurl = http://yum.mariadb.org/10.3/centos7-amd64
gpgkey=https://yum.mariadb.org/RPM-GPG-KEY-MariaDB
gpgcheck=1

View File

@ -0,0 +1,125 @@
# MaxScale documentation on GitHub:
# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Documentation-Contents.md
# Global parameters
#
# Complete list of configuration options:
# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Getting-Started/Configuration-Guide.md
[maxscale]
threads=2
log_messages=1
log_trace=1
log_debug=1
# Server definitions
#
# Set the address of the server to the network
# address of a MySQL server.
#
[server1]
type=server
address=db
port=3306
protocol=MariaDBBackend
authenticator_options=skip_authentication=true
router_options=master
# Monitor for the servers
#
# This will keep MaxScale aware of the state of the servers.
# MySQL Monitor documentation:
# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Monitors/MySQL-Monitor.md
[MySQLMonitor]
type=monitor
module=mariadbmon
servers=server1
user=boby
passwd=heyPassw0@rd
monitor_interval=10000
# Service definitions
#
# Service Definition for a read-only service and
# a read/write splitting service.
#
# ReadConnRoute documentation:
# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Routers/ReadConnRoute.md
[Read-OnlyService]
enable_root_user=1
version_string=10.4.99-MariaDB-maxScale
type=service
router=readconnroute
servers=server1
user=boby
passwd=heyPassw0@rd
router_options=slave
localhost_match_wildcard_host=1
[Read-WriteService]
enable_root_user=1
version_string=10.4.99-MariaDB-maxScale
type=service
router=readwritesplit
servers=server1
user=boby
passwd=heyPassw0@rd
localhost_match_wildcard_host=1
[WriteService]
type=service
router=readconnroute
servers=server1
user=boby
passwd=heyPassw0@rd
router_options=master
localhost_match_wildcard_host=1
version_string=10.4.99-MariaDB-maxScale
# This service enables the use of the MaxAdmin interface
# MaxScale administration guide:
# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Reference/MaxAdmin.mda
[MaxAdminService]
enable_root_user=1
version_string=10.4.99-MariaDB-maxScale
type=service
router=cli
# Listener definitions for the services
#
# These listeners represent the ports the
# services will listen on.
#
[WriteListener]
type=listener
service=WriteService
protocol=MariaDBClient
port=4007
#socket=/var/lib/maxscale/writeconn.sock
[Read-OnlyListener]
type=listener
service=Read-OnlyService
protocol=MariaDBClient
port=4008
#socket=/var/lib/maxscale/readconn.sock
[Read-WriteListener]
type=listener
service=Read-WriteService
protocol=MariaDBClient
port=4006
#socket=/var/lib/maxscale/rwsplit.sock
[MaxAdminListener]
type=listener
service=MaxAdminService
protocol=maxscaled
socket=/tmp/maxadmin.sock

57
.travis/script.sh Normal file
View File

@ -0,0 +1,57 @@
#!/bin/bash
set -x
set -e
###################################################################################################################
# test different type of configuration
###################################################################################################################
mysql=( mysql --protocol=tcp -ubob -h127.0.0.1 --port=3305 )
if [ "$DB" = "build" ] ; then
.travis/build/build.sh
docker build -t build:latest --label build .travis/build/
fi
export ENTRYPOINT=$PROJ_PATH/.travis/entrypoint
if [ -n "$MAXSCALE_VERSION" ] ; then
###################################################################################################################
# launch Maxscale with one server
###################################################################################################################
export COMPOSE_FILE=.travis/maxscale-compose.yml
export ENTRYPOINT=$PROJ_PATH/.travis/sql
docker-compose -f ${COMPOSE_FILE} build
docker-compose -f ${COMPOSE_FILE} up -d
mysql=( mysql --protocol=tcp -ubob -h127.0.0.1 --port=4007 )
else
docker-compose -f .travis/docker-compose.yml up -d
fi
for i in {60..0}; do
if echo 'SELECT 1' | "${mysql[@]}" &> /dev/null; then
break
fi
echo 'data server still not active'
sleep 1
done
if [ -z "$MAXSCALE_VERSION" ] ; then
docker-compose -f .travis/docker-compose.yml exec -u root db bash /pam/pam.sh
sleep 1
docker-compose -f .travis/docker-compose.yml stop db
sleep 1
docker-compose -f .travis/docker-compose.yml up -d
docker-compose -f .travis/docker-compose.yml logs db
for i in {60..0}; do
if echo 'SELECT 1' | "${mysql[@]}" &> /dev/null; then
break
fi
echo 'data server still not active'
sleep 1
done
fi
python -m unittest discover -v

9
.travis/sql/dbinit.sql Normal file
View File

@ -0,0 +1,9 @@
CREATE USER 'bob'@'%';
GRANT ALL ON *.* TO 'bob'@'%' with grant option;
CREATE USER 'boby'@'%' identified by 'heyPassw0@rd';
GRANT ALL ON *.* TO 'boby'@'%' with grant option;
FLUSH PRIVILEGES;
CREATE DATABASE test2;

View File

@ -1,2 +1,23 @@
# mariadb-connector-python <p align="center">
MariaDB Connector/Python <a href="http://mariadb.com/">
<img src="https://mariadb.com/kb/static/images/logo-2018-black.png">
</a>
</p>
# MariaDB python connector
[![License (LGPL version 2.1)][licence-image]][licence-url]
[![Python 3.6][python-image]][python-url]
This package contains a MariaDB client library for connecting to MariaDB and MySQL
database servers based on PEP-249.
## Prerequisites
* Python 3.6 (or newer). Older Python 3 versions might work, but aren't tested.
* MariaDB Connector/C, minimum required version is 3.1
[licence-image]:https://img.shields.io/badge/license-GNU%20LGPL%20version%202.1-green.svg?style=flat-square
[licence-url]:http://opensource.org/licenses/LGPL-2.1
[python-image]:https://img.shields.io/badge/python-3.6-blue.svg
[python-url]:https://www.python.org/downloads/release/python-360/

File diff suppressed because it is too large Load Diff

View File

@ -1,54 +1,58 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys
import os import os
import string import sys
class MariaDBConfiguration(): class MariaDBConfiguration():
lib_dirs= "" lib_dirs = ""
libs= "" libs = ""
version= "" version = ""
includes= "" includes = ""
def mariadb_config(config, option): def mariadb_config(config, option):
from os import popen from os import popen
file= popen("%s --%s" % (config, option)) file = popen("%s --%s" % (config, option))
data= file.read().strip().split() data = file.read().strip().split()
rc= file.close() rc = file.close()
if rc: if rc:
if rc/256: if rc / 256:
data= [] data = []
if rc/256 > 1: if rc / 256 > 1:
raise EnvironmentError("mariadb_config not found.\nMake sure that MariaDB Connector/C is installed (e.g. on Debian or Ubuntu 'sudo apt-get install libmariadb-dev'\nIf mariadb_config is not installed in a default path, please set the environment variable MARIADB_CONFIG which points to the location of mariadb_config utility, e.g. MARIADB_CONFIG=/opt/mariadb/bin/mariadb_config") raise EnvironmentError(
return data "mariadb_config not found.\nMake sure that MariaDB Connector/C is installed (e.g. on Debian or Ubuntu 'sudo apt-get install libmariadb-dev'\nIf mariadb_config is not installed in a default path, please set the environment variable MARIADB_CONFIG which points to the location of mariadb_config utility, e.g. MARIADB_CONFIG=/opt/mariadb/bin/mariadb_config")
return data
def dequote(s): def dequote(s):
if s[0] in "\"'" and s[0] == s[-1]: if s[0] in "\"'" and s[0] == s[-1]:
s = s[1:-1] s = s[1:-1]
return s return s
def get_config(): def get_config():
required_version="3.1.0" required_version = "3.1.0"
no_env= 0 no_env = 0
try: try:
config_prg= os.environ["MARIADB_CONFIG"] config_prg = os.environ["MARIADB_CONFIG"]
except KeyError: except KeyError:
config_prg= 'mariadb_config' config_prg = 'mariadb_config'
cc_version= mariadb_config(config_prg, "cc_version") cc_version = mariadb_config(config_prg, "cc_version")
if cc_version[0] < required_version: if cc_version[0] < required_version:
print ('MariaDB Connector/Python requires MariaDB Connector/C >= %s, found version %s' % (required_version, cc_version[0])) print ('MariaDB Connector/Python requires MariaDB Connector/C >= %s, found version %s' % (
sys.exit(2) required_version, cc_version[0]))
cfg= MariaDBConfiguration() sys.exit(2)
cfg.version= cc_version[0] cfg = MariaDBConfiguration()
cfg.version = cc_version[0]
libs= mariadb_config(config_prg, "libs") libs = mariadb_config(config_prg, "libs")
cfg.lib_dirs = [ dequote(i[2:]) for i in libs if i.startswith("-L") ] cfg.lib_dirs = [dequote(i[2:]) for i in libs if i.startswith("-L")]
cfg.libs = [ dequote(i[2:]) for i in libs if i.startswith("-l") ] cfg.libs = [dequote(i[2:]) for i in libs if i.startswith("-l")]
includes= mariadb_config(config_prg, "include") includes = mariadb_config(config_prg, "include")
mariadb_includes = [ dequote(i[2:]) for i in includes if i.startswith("-I") ] mariadb_includes = [dequote(i[2:]) for i in includes if i.startswith("-I")]
mariadb_includes.extend(["./include"]) mariadb_includes.extend(["./include"])
cfg.includes= mariadb_includes cfg.includes = mariadb_includes
return cfg return cfg

View File

@ -1,40 +1,53 @@
import sys
import os import os
import string import platform
import sys
from winreg import * from winreg import *
class MariaDBConfiguration(): class MariaDBConfiguration():
lib_dirs= "" lib_dirs = ""
libs= "" libs = ""
version= "" version = ""
includes= "" includes = ""
def get_config(): def get_config():
required_version="3.1.0" required_version = "3.1.0"
try: try:
config_prg= os.environ["MARIADB_CC_DIR"] config_prg = os.environ["MARIADB_CC_INSTALL_DIR"]
cc_version= ["",""] cc_version = ["", ""]
cc_instdir= [config_prg, ""] cc_instdir = [config_prg, ""]
print("using environment configuration " + config_prg) print("using environment configuration " + config_prg)
except KeyError: except KeyError:
Registry= ConnectRegistry(None, HKEY_LOCAL_MACHINE)
Key= OpenKey(Registry, "SOFTWARE\MariaDB Corporation\MariaDB Connector C 64-bit")
if Key:
cc_version= QueryValueEx(Key, "Version")
if cc_version[0] < required_version:
print("MariaDB Connector/Python requires MariaDB Connector/C >= %s (found version: %s") % (required_version, cc_version[0])
sys.exit(2) try:
cc_instdir= QueryValueEx(Key, "InstallDir") local_reg = ConnectRegistry(None, HKEY_LOCAL_MACHINE)
if cc_instdir is None: if platform.architecture()[0] == '32bit':
print("Could not find InstallationDir of MariaDB Connector/C. Please make sure MariaDB Connector/C is installed or specify the InstallationDir of MariaDB Connector/C by setting the environment variable MARIADB_CC_INSTALL_DIR.") connector_key = OpenKey(local_reg,
sys.exit(3) 'SOFTWARE\\MariaDB Corporation\\MariaDB Connector C')
else:
connector_key = OpenKey(local_reg,
'SOFTWARE\\MariaDB Corporation\\MariaDB Connector C 64-bit',
access=KEY_READ | KEY_WOW64_64KEY)
cc_version = QueryValueEx(connector_key, "Version")
if cc_version[0] < required_version:
print(
"MariaDB Connector/Python requires MariaDB Connector/C >= %s (found version: %s") \
% (required_version, cc_version[0])
sys.exit(2)
cc_instdir = QueryValueEx(connector_key, "InstallDir")
except:
cfg= MariaDBConfiguration() print("Could not find InstallationDir of MariaDB Connector/C. "
cfg.version= cc_version[0] "Please make sure MariaDB Connector/C is installed or specify the InstallationDir of "
cfg.includes= [".\\include", cc_instdir[0] + "\\include", cc_instdir[0] + "\\include\\mysql"] "MariaDB Connector/C by setting the environment variable MARIADB_CC_INSTALL_DIR.")
cfg.lib_dirs= [cc_instdir[0] + "\\lib"] sys.exit(3)
cfg.libs= ["mariadbclient", "ws2_32", "advapi32", "kernel32", "shlwapi", "crypt32"]
return cfg cfg = MariaDBConfiguration()
cfg.version = cc_version[0]
cfg.includes = [".\\include", cc_instdir[0] + "\\include", cc_instdir[0] + "\\include\\mysql"]
cfg.lib_dirs = [cc_instdir[0] + "\\lib"]
cfg.libs = ["mariadbclient", "ws2_32", "advapi32", "kernel32", "shlwapi", "crypt32"]
return cfg

View File

@ -1,28 +1,49 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import sys
import subprocess from distutils.core import setup, Extension
import string
if os.name == "posix":
from distutils.core import setup, Extension from mariadb_posix import get_config
if os.name == "nt":
if os.name == "posix": from mariadb_windows import get_config
from mariadb_posix import get_config
if os.name == "nt": cfg = get_config()
from mariadb_windows import get_config
setup(name='mariadb',
cfg= get_config() version='0.9.1',
classifiers = [
setup(name='mariadb', 'Development Status :: 3 - Alpha',
version='0.9.1', 'Environment :: Console',
description='Python MariaDB extension', 'Environment :: MacOS X',
author='Georg Richter', 'Environment :: Win32 (MS Windows)',
license='LGPL 2.1', 'Environment :: Posix',
url='http://www.mariadb.com', 'License :: OSI Approved :: GNU Lesser General Public License v2 or later (LGPLv2+)',
ext_modules=[Extension('mariadb', ['src/mariadb.c', 'src/mariadb_connection.c', 'src/mariadb_exception.c', 'src/mariadb_cursor.c', 'src/mariadb_codecs.c', 'src/mariadb_field.c', 'src/mariadb_dbapitype.c', 'src/mariadb_indicator.c'], 'Programming Language :: C',
include_dirs=cfg.includes, 'Programming Language :: Python',
library_dirs= cfg.lib_dirs, 'Programming Language :: Python :: 3.6',
libraries= cfg.libs 'Programming Language :: Python :: 3.7',
)], 'Programming Language :: Python :: 3.8',
) 'Operating System :: Microsoft :: Windows',
'Operating System :: MacOS',
'Operating System :: POSIX',
'Intended Audience :: End Users/Desktop',
'Intended Audience :: Developers',
'Intended Audience :: System Administrators',
'Topic :: Database'
],
description='Python MariaDB extension',
author='Georg Richter',
license='LGPL 2.1',
url='https://www.github.com/MariaDB/mariadb-connector-python',
ext_modules=[Extension('mariadb', ['src/mariadb.c', 'src/mariadb_connection.c',
'src/mariadb_exception.c', 'src/mariadb_cursor.c',
'src/mariadb_codecs.c', 'src/mariadb_field.c',
'src/mariadb_parser.c',
'src/mariadb_dbapitype.c', 'src/mariadb_indicator.c'],
include_dirs=cfg.includes,
library_dirs=cfg.lib_dirs,
libraries=cfg.libs
)],
)

View File

@ -75,7 +75,7 @@ static PyObject *mariadb_get_pickled(unsigned char *data, size_t length)
PyObject *obj= NULL; PyObject *obj= NULL;
if (length < 3) if (length < 3)
return NULL; return NULL;
if (*data == 0x80 && *(data +1) == 0x03 && *(data + length - 1) == 0x2E) if (*data == 0x80 && *(data +1) <= 0x04 && *(data + length - 1) == 0x2E)
{ {
PyObject *byte= PyBytes_FromStringAndSize((char *)data, length); PyObject *byte= PyBytes_FromStringAndSize((char *)data, length);
obj= PyObject_CallMethod(Mrdb_Pickle, "loads", "O", byte); obj= PyObject_CallMethod(Mrdb_Pickle, "loads", "O", byte);
@ -169,7 +169,7 @@ void field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
unsigned long utf8len; unsigned long utf8len;
self->values[column]= PyUnicode_FromStringAndSize((const char *)data, (Py_ssize_t)length[column]); self->values[column]= PyUnicode_FromStringAndSize((const char *)data, (Py_ssize_t)length[column]);
utf8len= PyUnicode_GET_LENGTH(self->values[column]); utf8len= (unsigned long)PyUnicode_GET_LENGTH(self->values[column]);
if (utf8len > self->fields[column].max_length) if (utf8len > self->fields[column].max_length)
self->fields[column].max_length= utf8len; self->fields[column].max_length= utf8len;
break; break;
@ -237,7 +237,7 @@ void field_fetch_callback(void *data, unsigned int column, unsigned char **row)
long long l= sint8korr(*row); long long l= sint8korr(*row);
self->values[column]= (self->fields[column].flags & UNSIGNED_FLAG) ? self->values[column]= (self->fields[column].flags & UNSIGNED_FLAG) ?
PyLong_FromUnsignedLongLong((unsigned long long)l) : PyLong_FromUnsignedLongLong((unsigned long long)l) :
PyLong_FromLong(l); PyLong_FromLongLong(l);
*row+= 8; *row+= 8;
break; break;
} }
@ -368,7 +368,7 @@ void field_fetch_callback(void *data, unsigned int column, unsigned char **row)
length= mysql_net_field_length(row); length= mysql_net_field_length(row);
self->values[column]= PyUnicode_FromStringAndSize((const char *)*row, (Py_ssize_t)length); self->values[column]= PyUnicode_FromStringAndSize((const char *)*row, (Py_ssize_t)length);
utf8len= PyUnicode_GET_LENGTH(self->values[column]); utf8len= (unsigned long)PyUnicode_GET_LENGTH(self->values[column]);
if (utf8len > self->fields[column].max_length) if (utf8len > self->fields[column].max_length)
self->fields[column].max_length= utf8len; self->fields[column].max_length= utf8len;
*row+= length; *row+= length;
@ -509,7 +509,7 @@ static uint8_t mariadb_get_parameter(MrdbCursor *self,
"MariaDB %s doesn't support indicator variables. Required version is 10.2.6 or newer", mysql_get_server_info(self->stmt->mysql)); "MariaDB %s doesn't support indicator variables. Required version is 10.2.6 or newer", mysql_get_server_info(self->stmt->mysql));
return 1; return 1;
} }
param->indicator= MrdbIndicator_AsLong(column); param->indicator= (char)MrdbIndicator_AsLong(column);
param->value= NULL; /* you can't have both indicator and value */ param->value= NULL; /* you can't have both indicator and value */
} else if (column == Py_None) } else if (column == Py_None)
{ {
@ -559,7 +559,7 @@ static uint8_t mariadb_get_parameter_info(MrdbCursor *self,
return 1; return 1;
} }
param->buffer_type= pinfo.type; param->buffer_type= pinfo.type;
bits= pinfo.bits; bits= (uint32_t)pinfo.bits;
} }
for (i=0; i < self->array_size; i++) for (i=0; i < self->array_size; i++)
@ -580,7 +580,7 @@ static uint8_t mariadb_get_parameter_info(MrdbCursor *self,
if (pinfo.type == MYSQL_TYPE_LONGLONG) if (pinfo.type == MYSQL_TYPE_LONGLONG)
{ {
if (pinfo.bits > bits) if (pinfo.bits > bits)
bits= pinfo.bits; bits= (uint32_t)pinfo.bits;
} }
@ -631,7 +631,7 @@ uint8_t mariadb_check_bulk_parameters(MrdbCursor *self,
{ {
uint32_t i; uint32_t i;
if (!(self->array_size= PyList_Size(data))) if (!(self->array_size= (uint32_t)PyList_Size(data)))
{ {
mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1, mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1,
"Empty parameter list. At least one row must be specified"); "Empty parameter list. At least one row must be specified");
@ -649,7 +649,7 @@ uint8_t mariadb_check_bulk_parameters(MrdbCursor *self,
} }
if (!self->param_count && !self->is_prepared) if (!self->param_count && !self->is_prepared)
self->param_count= PyTuple_Size(obj); self->param_count= (uint32_t)PyTuple_Size(obj);
if (!self->param_count || if (!self->param_count ||
self->param_count != PyTuple_Size(obj)) self->param_count != PyTuple_Size(obj))
{ {
@ -695,7 +695,7 @@ uint8_t mariadb_check_execute_parameters(MrdbCursor *self,
{ {
uint32_t i; uint32_t i;
if (!self->is_prepared) if (!self->is_prepared)
self->param_count= PyTuple_Size(data); self->param_count= (uint32_t)PyTuple_Size(data);
if (!self->param_count) if (!self->param_count)
{ {

View File

@ -265,7 +265,7 @@ MrdbConnection_Initialize(MrdbConnection *self,
}; };
if (!PyArg_ParseTupleAndKeywords(args, dsnargs, if (!PyArg_ParseTupleAndKeywords(args, dsnargs,
"|sssssisiiipissssssssssipi:connect", "|sssssisiiipissssssssssipis:connect",
dsn_keys, dsn_keys,
&dsn, &host, &user, &password, &schema, &port, &socket, &dsn, &host, &user, &password, &schema, &port, &socket,
&connect_timeout, &read_timeout, &write_timeout, &connect_timeout, &read_timeout, &write_timeout,
@ -798,7 +798,7 @@ end:
/* }}} */ /* }}} */
/* {{{ MrdbConnection_ping */ /* {{{ MrdbConnection_ping */
PyObject *MrdbConnection_ping(MrdbConnection *self, PyObject *args) PyObject *MrdbConnection_ping(MrdbConnection *self)
{ {
int rc; int rc;

File diff suppressed because it is too large Load Diff

211
src/mariadb_parser.c Executable file
View File

@ -0,0 +1,211 @@
/************************************************************************************
Copyright (C) 2019 Georg Richter and MariaDB Corporation AB
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Library General Public
License as published by the Free Software Foundation; either
version 2 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Library General Public License for more details.
You should have received a copy of the GNU Library General Public
License along with this library; if not see <http://www.gnu.org/licenses>
or write to the Free Software Foundation, Inc.,
51 Franklin St., Fifth Floor, Boston, MA 02110, USA
*************************************************************************************/
#include <mariadb_python.h>
#define IS_WHITESPACE(a) (a==32 || a==9 || a==10 || a==13)
#define IN_LITERAL(p) ((p)->in_literal[0] || (p)->in_literal[1] || (p)->in_literal[2])
const char *comment_start= "/*";
const char *comment_end= "*/";
const char literals[3]= {'\'', '\"', '`'};
static uint8_t check_keyword(char* ofs, char* end, char* keyword, size_t keylen)
{
int i;
if (end - ofs < keylen + 1)
return 0;
for (i = 0; i < keylen; i++)
if (toupper(*(ofs + i)) != keyword[i])
return 0;
if (!IS_WHITESPACE(*(ofs + keylen)))
return 0;
return 1;
}
void Mrdb_Parser_end(Mrdb_Parser* p)
{
if (p)
{
MARIADB_FREE_MEM(p->statement.str);
MARIADB_FREE_MEM(p);
}
}
Mrdb_Parser *Mrdb_Parser_init(const char *statement, size_t length)
{
Mrdb_Parser *p;
if (!statement || !length)
return NULL;
if ((p= PyMem_RawCalloc(1, sizeof(Mrdb_Parser))))
{
if (!(p->statement.str = PyMem_RawCalloc(1, length + 1)))
{
MARIADB_FREE_MEM(p);
return NULL;
}
memcpy(p->statement.str, statement, length);
p->statement.length= length;
}
return p;
}
void Mrdb_Parser_parse(Mrdb_Parser *p, uint8_t is_batch)
{
char *a, *end;
char lastchar= 0;
uint8_t i;
if (!p || !p->statement.str)
return;
a= p->statement.str;
end= a + p->statement.length - 1;
while (a <= end)
{
/* check literals */
for (i=0; i < 3; i++)
{
if (*a == literals[i])
{
p->in_literal[i]= !(p->in_literal[i]);
a++;
continue;
}
}
/* nothing to do, if we are inside a comment or literal */
if (IN_LITERAL(p))
{
a++;
continue;
}
/* check comment */
if (!p->in_comment)
{
/* Style 1 */
if (*a == '/' && *(a + 1) == '*')
{
a+= 2;
p->in_comment= 1;
continue;
}
/* Style 2 */
if (*a == '#')
{
a++;
p->comment_eol= 1;
}
/* Style 3 */
if (*a == '-' && *(a+1) == '-')
{
if (((a+2) < end) && *(a+2) == ' ')
{
a+= 3;
p->comment_eol= 1;
}
}
} else
{
if (*a == '*' && *(a + 1) == '/')
{
a+= 2;
p->in_comment= 0;
continue;
} else {
a++;
continue;
}
}
if (p->comment_eol) {
if (*a == '\0' || *a == '\n')
{
a++;
p->comment_eol= 0;
continue;
}
a++;
continue;
}
/* checking for different paramstyles */
/* parmastyle = qmark */
if (*a == '?')
{
p->param_count++;
a++;
continue;
}
/* paramstype = pyformat */
if (*a == '%' && lastchar != '\\')
{
if (*(a+1) == 's' || *(a+1) == 'd')
{
*a= '?';
memmove(a+1, a+2, end - a);
end--;
a++;
p->param_count++;
continue;
}
if (*(a+1) == '(')
{
char *val_end= strstr(a+1, ")s");
if (val_end)
{
int keylen= val_end - a + 1;
*a= '?';
p->param_count++;
memmove(a+1, val_end+2, end - a - keylen);
end -= keylen;
continue;
}
}
}
if (is_batch)
{
/* Do we have an insert statement ? */
if (!p->is_insert && check_keyword(a, end, "INSERT", 6))
{
if (lastchar == 0 || (IS_WHITESPACE(lastchar)) || lastchar == '/')
{
p->is_insert = 1;
a += 7;
}
}
if (p->is_insert && check_keyword(a, end, "VALUES", 6))
{
p->value_ofs = a + 7;
a += 7;
continue;
}
}
lastchar= *a;
a++;
}
/* Update length */
p->statement.length= end - p->statement.str + 1;
}

0
test/__init__.py Normal file
View File

16
test/base_test.py Normal file
View File

@ -0,0 +1,16 @@
#!/usr/bin/env python -O
# -*- coding: utf-8 -*-
import mariadb
from .conf_test import conf
def create_connection(additional_conf=None):
default_conf = conf()
if additional_conf is None:
c = {key: value for (key, value) in (default_conf.items())}
else:
c = {key: value for (key, value) in (list(default_conf.items()) + list(
additional_conf.items()))}
return mariadb.connect(**c)

16
test/conf_test.py Normal file
View File

@ -0,0 +1,16 @@
#!/usr/bin/env python -O
# -*- coding: utf-8 -*-
import os
def conf():
d = {
"user": os.environ.get('TEST_USER', 'root'),
"host": os.environ.get('TEST_HOST', 'localhost'),
"database": os.environ.get('TEST_DATABASE', 'testp'),
"port": int(os.environ.get('TEST_PORT', '3306'))
}
if os.environ.get('TEST_PASSWORD'):
d["password"] = os.environ.get('TEST_PASSWORD')
return d

View File

@ -1,498 +0,0 @@
#!/usr/bin/env python -O
import mariadb
import datetime
import unittest
import collections
class CursorTest(unittest.TestCase):
def setUp(self):
self.connection= mariadb.connection(default_file='default.cnf')
self.connection.autocommit= False
def tearDown(self):
self.connection.rollback()
del self.connection
def test_date(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1(c1 TIMESTAMP(6), c2 TIME(6), c3 DATETIME(6), c4 DATE)")
t= datetime.datetime(2018,6,20,12,22,31,123456)
c1= t
c2= t.time()
c3= t
c4= t.date()
cursor.execute("INSERT INTO t1 VALUES (?,?,?,?)", (c1, c2, c3, c4))
cursor.execute("SELECT c1,c2,c3,c4 FROM t1")
row= cursor.fetchone()
self.assertEqual(row[0],c1)
self.assertEqual(row[1],c2)
self.assertEqual(row[2],c3)
self.assertEqual(row[3],c4)
cursor.close()
def test_numbers(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyint unsigned, b smallint unsigned, c mediumint unsigned, d int unsigned, e bigint unsigned, f double)")
c1= 4
c2= 200
c3= 167557
c4= 28688817
c5= 7330133222578
c6= 3.1415925
cursor.execute("insert into t1 values (?,?,?,?,?,?)", (c1,c2,c3,c4,c5,c6))
cursor.execute("select * from t1")
row= cursor.fetchone()
self.assertEqual(row[0],c1)
self.assertEqual(row[1],c2)
self.assertEqual(row[2],c3)
self.assertEqual(row[3],c4)
self.assertEqual(row[4],c5)
self.assertEqual(row[5],c6)
del cursor
def test_string(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a char(5), b varchar(100), c tinytext, d mediumtext, e text, f longtext)");
c1= "12345";
c2= "The length of this text is < 100 characters"
c3= "This should also fit into tinytext which has a maximum of 255 characters"
c4= 'a' * 1000;
c5= 'b' * 6000;
c6= 'c' * 67000;
cursor.execute("INSERT INTO t1 VALUES (?,?,?,?,?,?)", (c1,c2,c3,c4,c5,c6))
cursor.execute("SELECT * from t1")
row= cursor.fetchone()
self.assertEqual(row[0],c1)
self.assertEqual(row[1],c2)
self.assertEqual(row[2],c3)
self.assertEqual(row[3],c4)
self.assertEqual(row[4],c5)
self.assertEqual(row[5],c6)
del cursor
def test_blob(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyblob, b mediumblob, c blob, d longblob)")
c1= b'a' * 100;
c2= b'b' * 1000;
c3= b'c' * 10000;
c4= b'd' * 100000;
a= (None, None, None, None)
cursor.execute("INSERT INTO t1 VALUES (?,?,?,?)", (c1, c2, c3, c4))
cursor.execute("SELECT * FROM t1")
row= cursor.fetchone()
self.assertEqual(row[0],c1)
self.assertEqual(row[1],c2)
self.assertEqual(row[2],c3)
self.assertEqual(row[3],c4)
del cursor
def test_fetchmany(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t01 (id int, name varchar(64), city varchar(64))");
params= [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO t01 VALUES (?,?,?)", params);
#test Errors
# a) if no select was executed
self.assertRaises(mariadb.Error, cursor.fetchall)
#b ) if cursor was not executed
del cursor
cursor= self.connection.cursor()
self.assertRaises(mariadb.Error, cursor.fetchall)
cursor.execute("SELECT id, name, city FROM t01 ORDER BY id")
self.assertEqual(0, cursor.rowcount)
row = cursor.fetchall()
self.assertEqual(row, params)
self.assertEqual(5, cursor.rowcount)
cursor.execute("SELECT id, name, city FROM t01 ORDER BY id")
self.assertEqual(0, cursor.rowcount)
row= cursor.fetchmany(1)
self.assertEqual(row,[params[0]])
self.assertEqual(1, cursor.rowcount)
row= cursor.fetchmany(2)
self.assertEqual(row,([params[1], params[2]]))
self.assertEqual(3, cursor.rowcount)
cursor.arraysize= 1
row= cursor.fetchmany()
self.assertEqual(row,[params[3]])
self.assertEqual(4, cursor.rowcount)
cursor.arraysize= 2
row= cursor.fetchmany()
self.assertEqual(row,[params[4]])
self.assertEqual(5, cursor.rowcount)
del cursor
def test1_multi_result(self):
cursor= self.connection.cursor()
sql= """
CREATE OR REPLACE PROCEDURE p1()
BEGIN
SELECT 1 FROM DUAL;
SELECT 2 FROM DUAL;
END
"""
cursor.execute(sql)
cursor.execute("call p1()")
row= cursor.fetchone()
self.assertEqual(row[0], 1)
cursor.nextset()
row= cursor.fetchone()
self.assertEqual(row[0], 2)
del cursor
def test_buffered(self):
cursor= self.connection.cursor()
cursor.execute("SELECT 1 UNION SELECT 2 UNION SELECT 3", buffered=True)
self.assertEqual(cursor.rowcount, 3)
cursor.scroll(1)
row= cursor.fetchone()
self.assertEqual(row[0],2)
del cursor
def test_xfield_types(self):
cursor= self.connection.cursor()
fieldinfo= mariadb.fieldinfo()
cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyint not null auto_increment primary key, b smallint, c int, d bigint, e float, f decimal, g double, h char(10), i varchar(255), j blob, index(b))");
info= cursor.description
self.assertEqual(info, None)
cursor.execute("SELECT * FROM t1")
info= cursor.description
self.assertEqual(fieldinfo.type(info[0]), "TINY")
self.assertEqual(fieldinfo.type(info[1]), "SHORT")
self.assertEqual(fieldinfo.type(info[2]), "LONG")
self.assertEqual(fieldinfo.type(info[3]), "LONGLONG")
self.assertEqual(fieldinfo.type(info[4]), "FLOAT")
self.assertEqual(fieldinfo.type(info[5]), "NEWDECIMAL")
self.assertEqual(fieldinfo.type(info[6]), "DOUBLE")
self.assertEqual(fieldinfo.type(info[7]), "STRING")
self.assertEqual(fieldinfo.type(info[8]), "VAR_STRING")
self.assertEqual(fieldinfo.type(info[9]), "BLOB")
self.assertEqual(fieldinfo.flag(info[0]), "NOT_NULL | PRIMARY_KEY | AUTO_INCREMENT | NUMERIC")
self.assertEqual(fieldinfo.flag(info[1]), "PART_KEY | NUMERIC")
self.assertEqual(fieldinfo.flag(info[9]), "BLOB | BINARY")
del cursor
def test_bulk_delete(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE bulk_delete (id int, name varchar(64), city varchar(64))");
params= [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO bulk_delete VALUES (?,?,?)", params)
self.assertEqual(cursor.rowcount, 5)
params= [(1,2)]
cursor.executemany("DELETE FROM bulk_delete WHERE id=?", params)
self.assertEqual(cursor.rowcount, 2)
def test_named_tuple(self):
cursor= self.connection.cursor(named_tuple=1)
cursor.execute("CREATE OR REPLACE TABLE t1 (id int, name varchar(64), city varchar(64))");
params= [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO t1 VALUES (?,?,?)", params);
cursor.execute("SELECT * FROM t1 ORDER BY id")
row= cursor.fetchone()
self.assertEqual(cursor.statement, "SELECT * FROM t1 ORDER BY id")
self.assertEqual(row.id, 1)
self.assertEqual(row.name, "Jack")
self.assertEqual(row.city, "Boston")
del cursor
def test_laststatement(self):
cursor= self.connection.cursor(named_tuple=1)
cursor.execute("CREATE OR REPLACE TABLE t1 (id int, name varchar(64), city varchar(64))");
self.assertEqual(cursor.statement, "CREATE OR REPLACE TABLE t1 (id int, name varchar(64), city varchar(64))")
params= [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO t1 VALUES (?,?,?)", params);
cursor.execute("SELECT * FROM t1 ORDER BY id")
self.assertEqual(cursor.statement, "SELECT * FROM t1 ORDER BY id")
del cursor
def test_multi_cursor(self):
cursor= self.connection.cursor()
cursor1= self.connection.cursor(cursor_type=1)
cursor2= self.connection.cursor(cursor_type=1)
cursor.execute("CREATE OR REPLACE TABLE t1 (a int)")
cursor.execute("INSERT INTO t1 VALUES (1),(2),(3),(4),(5),(6),(7),(8)")
del cursor
cursor1.execute("SELECT a FROM t1 ORDER BY a")
cursor2.execute("SELECT a FROM t1 ORDER BY a DESC")
for i in range (0,8):
self.assertEqual(cursor1.rownumber, i)
row1= cursor1.fetchone()
row2= cursor2.fetchone()
self.assertEqual(cursor1.rownumber, cursor2.rownumber)
self.assertEqual(row1[0]+row2[0], 9)
del cursor1
del cursor2
def test_connection_attr(self):
cursor= self.connection.cursor()
self.assertEqual(cursor.connection, self.connection)
del cursor
def test_dbapi_type(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a int, b varchar(20), c blob, d datetime, e decimal)")
cursor.execute("INSERT INTO t1 VALUES (1, 'foo', 'blabla', now(), 10.2)");
cursor.execute("SELECT * FROM t1 ORDER BY a")
expected_typecodes= [
mariadb.NUMBER,
mariadb.STRING,
mariadb.BINARY,
mariadb.DATETIME,
mariadb.NUMBER
]
row= cursor.fetchone()
typecodes= [row[1] for row in cursor.description]
self.assertEqual(expected_typecodes, typecodes)
del cursor
def test_tuple(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE dyncol1 (a blob)");
tpl=(1,2,3)
cursor.execute("INSERT INTO dyncol1 VALUES (?)", tpl);
del cursor
def test_indicator(self):
if self.connection.server_version < 100206:
self.skipTest("Requires server version >= 10.2.6")
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE ind1 (a int, b int default 2,c int)");
vals= (mariadb.indicator_null, mariadb.indicator_default, 3)
cursor.executemany("INSERT INTO ind1 VALUES (?,?,?)", [vals])
cursor.execute("SELECT a, b, c FROM ind1")
row= cursor.fetchone()
self.assertEqual(row[0], None)
self.assertEqual(row[1], 2)
self.assertEqual(row[2], 3)
def test_tuple(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE dyncol1 (a blob)");
t= datetime.datetime(2018,6,20,12,22,31,123456)
val=([1,t,3,(1,2,3)],)
cursor.execute("INSERT INTO dyncol1 VALUES (?)", val);
cursor.execute("SELECT a FROM dyncol1")
row= cursor.fetchone()
self.assertEqual(row,val);
del cursor
def test_set(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE dyncol1 (a blob)");
t= datetime.datetime(2018,6,20,12,22,31,123456)
a= collections.OrderedDict([('apple', 4), ('banana', 3), ('orange', 2), ('pear', 1), ('4',3), (4,4)])
val=([1,t,3,(1,2,3), {1,2,3},a],)
cursor.execute("INSERT INTO dyncol1 VALUES (?)", val);
cursor.execute("SELECT a FROM dyncol1")
row= cursor.fetchone()
self.assertEqual(row,val);
del cursor
def test_reset(self):
cursor= self.connection.cursor()
cursor.execute("SELECT 1 UNION SELECT 2", buffered=False)
cursor.execute("SELECT 1 UNION SELECT 2")
del cursor
def test_fake_pickle(self):
cursor= self.connection.cursor()
cursor.execute("create or replace table t1 (a blob)")
k=bytes([0x80,0x03,0x00,0x2E])
cursor.execute("insert into t1 values (?)", (k,))
cursor.execute("select * from t1");
row= cursor.fetchone()
self.assertEqual(row[0],k)
del cursor
def test_no_result(self):
cursor= self.connection.cursor()
cursor.execute("set @a:=1")
try:
row= cursor.fetchone()
except mariadb.ProgrammingError:
pass
del cursor
def test_collate(self):
cursor= self.connection.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS `tt` (`test` varchar(500) COLLATE utf8mb4_unicode_ci NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci")
cursor.execute("SET NAMES utf8mb4")
cursor.execute("SELECT * FROM `tt` WHERE `test` LIKE 'jj' COLLATE utf8mb4_unicode_ci")
del cursor
def test_conpy_8(self):
cursor=self.connection.cursor()
sql= """
CREATE OR REPLACE PROCEDURE p1()
BEGIN
SELECT 1 FROM DUAL UNION SELECT 0 FROM DUAL;
SELECT 2 FROM DUAL;
END
"""
cursor.execute(sql)
cursor.execute("call p1()")
cursor.nextset()
row= cursor.fetchone()
self.assertEqual(row[0],2);
del cursor
def test_conpy_7(self):
cursor=self.connection.cursor()
stmt= "SELECT 1 UNION SELECT 2 UNION SELECT 3 UNION SELECT 4"
cursor.execute(stmt, buffered=True)
cursor.scroll(2, mode='relative')
row= cursor.fetchone()
self.assertEqual(row[0],3)
cursor.scroll(-2, mode='relative')
row= cursor.fetchone()
del cursor
def test_compy_9(self):
cursor=self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a varchar(20), b double(6,3), c double)");
cursor.execute("INSERT INTO t1 VALUES ('€uro', 123.345, 12345.678)")
cursor.execute("SELECT a,b,c FROM t1")
cursor.fetchone()
d= cursor.description;
self.assertEqual(d[0][2], 4); # 4 code points only
self.assertEqual(d[0][3], -1); # variable length
self.assertEqual(d[1][2], 7); # length=precision + 1
self.assertEqual(d[1][4], 6); # precision
self.assertEqual(d[1][5], 3); # decimals
del cursor
def test_conpy_15(self):
cursor=self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a int not null auto_increment primary key, b varchar(20))");
self.assertEqual(cursor.lastrowid, 0)
cursor.execute("INSERT INTO t1 VALUES (null, 'foo')")
self.assertEqual(cursor.lastrowid, 1)
cursor.execute("SELECT LAST_INSERT_ID()")
row= cursor.fetchone()
self.assertEqual(row[0], 1)
vals= [(3, "bar"), (4, "this")]
cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals)
self.assertEqual(cursor.lastrowid, 4)
# Bug MDEV-16847
# cursor.execute("SELECT LAST_INSERT_ID()")
# row= cursor.fetchone()
# self.assertEqual(row[0], 4)
# Bug MDEV-16593
# vals= [(None, "bar"), (None, "foo")]
# cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals)
# self.assertEqual(cursor.lastrowid, 6)
del cursor
def test_conpy_14(self):
cursor=self.connection.cursor()
self.assertEqual(cursor.rowcount, -1)
cursor.execute("CREATE OR REPLACE TABLE t1 (a int not null auto_increment primary key, b varchar(20))");
self.assertEqual(cursor.rowcount, -1)
cursor.execute("INSERT INTO t1 VALUES (null, 'foo')")
self.assertEqual(cursor.rowcount, 1)
vals= [(3, "bar"), (4, "this")]
cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals)
self.assertEqual(cursor.rowcount, 2)
del cursor
def test_closed(self):
cursor= self.connection.cursor()
cursor.close()
cursor.close()
self.assertEqual(cursor.closed, True)
try:
cursor.execute("set @a:=1")
except mariadb.ProgrammingError:
pass
del cursor
def test_emptycursor(self):
cursor= self.connection.cursor()
try:
cursor.execute("")
except mariadb.DatabaseError:
pass
del cursor
def test_iterator(self):
cursor= self.connection.cursor()
cursor.execute("select 1 union select 2 union select 3 union select 4 union select 5")
for i, row in enumerate(cursor):
self.assertEqual(i+1, cursor.rownumber)
self.assertEqual(i+1, row[0])
def test_update_bulk(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a int primary key, b int)")
vals= [(i,) for i in range(1000)]
cursor.executemany("INSERT INTO t1 VALUES (?, NULL)", vals);
self.assertEqual(cursor.rowcount, 1000)
self.connection.autocommit= False
cursor.executemany("UPDATE t1 SET b=2 WHERE a=?", vals);
self.connection.commit()
self.assertEqual(cursor.rowcount, 1000)
self.connection.autocommit= True
del cursor
def test_multi_execute(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a int auto_increment primary key, b int)")
self.connection.autocommit= False
for i in range(1,1000):
cursor.execute("INSERT INTO t1 VALUES (?,1)", (i,))
self.connection.autocommit= True
del cursor
def test_conpy21(self):
conn= mariadb.connection(default_file='default.cnf')
cursor=conn.cursor()
self.assertFalse(cursor.closed)
conn.close()
self.assertTrue(cursor.closed)
del cursor, conn

View File

@ -1,80 +0,0 @@
#!/usr/bin/env python -O
import mariadb
import datetime
import unittest
class CursorTest(unittest.TestCase):
def setUp(self):
self.connection= mariadb.connection(default_file='default.cnf')
def tearDown(self):
self.connection.close()
del self.connection
def test_insert_parameter(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)")
# cursor.execute("set @@autocommit=0");
list_in= []
for i in range(1, 300001):
row= (i,i,i,"bar", datetime.date(2019,1,1))
list_in.append(row)
cursor.executemany("INSERT INTO t1 VALUES (?,?,?,?,?)", list_in)
self.assertEqual(len(list_in), cursor.rowcount)
self.connection.commit()
cursor.execute("SELECT * FROM t1 order by a")
list_out= cursor.fetchall()
self.assertEqual(len(list_in), cursor.rowcount);
self.assertEqual(list_in,list_out)
cursor.close()
def test_update_parameter(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)")
cursor.execute("set @@autocommit=0");
list_in= []
for i in range(1, 300001):
row= (i,i,i,"bar", datetime.date(2019,1,1))
list_in.append(row)
cursor.executemany("INSERT INTO t1 VALUES (?,?,?,?,?)", list_in)
self.assertEqual(len(list_in), cursor.rowcount)
self.connection.commit()
cursor.close()
list_update= [];
cursor= self.connection.cursor()
cursor.execute("set @@autocommit=0");
for i in range(1, 300001):
row= (i+1, i);
list_update.append(row);
cursor.executemany("UPDATE t1 SET b=? WHERE a=?", list_update);
self.assertEqual(cursor.rowcount, 300000)
self.connection.commit();
cursor.close()
def test_delete_parameter(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)")
cursor.execute("set @@autocommit=0");
list_in= []
for i in range(1, 300001):
row= (i,i,i,"bar", datetime.date(2019,1,1))
list_in.append(row)
cursor.executemany("INSERT INTO t1 VALUES (?,?,?,?,?)", list_in)
self.assertEqual(len(list_in), cursor.rowcount)
self.connection.commit()
cursor.close()
list_delete= [];
cursor= self.connection.cursor()
cursor.execute("set @@autocommit=0");
for i in range(1, 300001):
list_delete.append((i,));
cursor.executemany("DELETE FROM t1 WHERE a=?", list_delete);
self.assertEqual(cursor.rowcount, 300000)
self.connection.commit();
cursor.close()

View File

@ -1,34 +0,0 @@
#!/usr/bin/env python -O
import mysql.connector
import datetime
import unittest
class CursorTest(unittest.TestCase):
def setUp(self):
self.connection= mysql.connector.connect(user='root', database='test')
def tearDown(self):
self.connection.close()
del self.connection
def test_parameter(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1(a int auto_increment primary key not null, b int, c int, d varchar(20),e date)")
cursor.execute("SET @@autocommit=0");
c = (1,2,3, "bar", datetime.date(2018,11,11))
list_in= []
for i in range(1,300001):
row= (i,i,i,"bar", datetime.date(2019,1,1))
list_in.append(row)
cursor.executemany("INSERT INTO t1 VALUES (%s,%s,%s,%s,%s)", list_in)
print("rows inserted:", len(list_in))
self.connection.commit()
cursor.execute("SELECT * FROM t1 order by a")
list_out= cursor.fetchall()
print("rows fetched: ", len(list_out))
self.assertEqual(list_in,list_out)
cursor.close()

View File

@ -1,787 +0,0 @@
#!/usr/bin/env python
''' Python DB API 2.0 driver compliance unit test suite.
This software is Public Domain and may be used without restrictions.
"Now we have booze and barflies entering the discussion, plus rumours of
DBAs on drugs... and I won't tell you what flashes through my mind each
time I read the subject line with 'Anal Compliance' in it. All around
this is turning out to be a thoroughly unwholesome unit test."
-- Ian Bicking
'''
__rcs_id__ = '$Id$'
__version__ = '$Revision$'[11:-2]
__author__ = 'Stuart Bishop <zen@shangri-la.dropbear.id.au>'
import unittest
import time
import mariadb
# $Log$
# Revision 1.1.2.1 2006/02/25 03:44:32 adustman
# Generic DB-API unit test module
#
# Revision 1.10 2003/10/09 03:14:14 zenzen
# Add test for DB API 2.0 optional extension, where database exceptions
# are exposed as attributes on the Connection object.
#
# Revision 1.9 2003/08/13 01:16:36 zenzen
# Minor tweak from Stefan Fleiter
#
# Revision 1.8 2003/04/10 00:13:25 zenzen
# Changes, as per suggestions by M.-A. Lemburg
# - Add a table prefix, to ensure namespace collisions can always be avoided
#
# Revision 1.7 2003/02/26 23:33:37 zenzen
# Break out DDL into helper functions, as per request by David Rushby
#
# Revision 1.6 2003/02/21 03:04:33 zenzen
# Stuff from Henrik Ekelund:
# added test_None
# added test_nextset & hooks
#
# Revision 1.5 2003/02/17 22:08:43 zenzen
# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
# defaults to 1 & generic cursor.callproc test added
#
# Revision 1.4 2003/02/15 00:16:33 zenzen
# Changes, as per suggestions and bug reports by M.-A. Lemburg,
# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
# - Class renamed
# - Now a subclass of TestCase, to avoid requiring the driver stub
# to use multiple inheritance
# - Reversed the polarity of buggy test in test_description
# - Test exception hierarchy correctly
# - self.populate is now self._populate(), so if a driver stub
# overrides self.ddl1 this change propagates
# - VARCHAR columns now have a width, which will hopefully make the
# DDL even more portible (this will be reversed if it causes more problems)
# - cursor.rowcount being checked after various execute and fetchXXX methods
# - Check for fetchall and fetchmany returning empty lists after results
# are exhausted (already checking for empty lists if select retrieved
# nothing
# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
#
class DatabaseAPI20Test(unittest.TestCase):
''' Test a database self.driver for DB API 2.0 compatibility.
This implementation tests Gadfly, but the TestCase
is structured so that other self.drivers can subclass this
test case to ensure compiliance with the DB-API. It is
expected that this TestCase may be expanded in the future
if ambiguities or edge conditions are discovered.
The 'Optional Extensions' are not yet being tested.
self.drivers should subclass this test, overriding setUp, tearDown,
self.driver, connect_args and connect_kw_args. Class specification
should be as follows:
import dbapi20
class mytest(dbapi20.DatabaseAPI20Test):
[...]
Don't 'import DatabaseAPI20Test from dbapi20', or you will
confuse the unit tester - just 'import dbapi20'.
'''
# The self.driver module. This should be the module where the 'connect'
# method is to be found
driver = mariadb
connect_args = () # List of arguments to pass to connect
connect_kw_args = {"user=root", "database=test"} # Keyword arguments for connect
table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
xddl1 = 'drop table %sbooze' % table_prefix
xddl2 = 'drop table %sbarflys' % table_prefix
lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
# Some drivers may need to override these helpers, for example adding
# a 'commit' after the execute.
def executeDDL1(self,cursor):
cursor.execute(self.ddl1)
def executeDDL2(self,cursor):
cursor.execute(self.ddl2)
def setUp(self):
''' self.drivers should override this method to perform required setup
if any is necessary, such as creating the database.
'''
pass
def tearDown(self):
''' self.drivers should override this method to perform required cleanup
if any is necessary, such as deleting the test database.
The default drops the tables that may be created.
'''
con = self._connect()
try:
cur = con.cursor()
for ddl in (self.xddl1,self.xddl2):
try:
cur.execute(ddl)
con.commit()
except self.driver.Error:
# Assume table didn't exist. Other tests will check if
# execute is busted.
pass
finally:
con.close()
def _connect(self):
try:
return self.driver.connect(default_file='default.cnf')
except AttributeError:
self.fail("No connect method found in self.driver module")
def test_connect(self):
con = self._connect()
con.close()
def test_apilevel(self):
try:
# Must exist
apilevel = self.driver.apilevel
# Must equal 2.0
self.assertEqual(apilevel,'2.0')
except AttributeError:
self.fail("Driver doesn't define apilevel")
def test_threadsafety(self):
try:
# Must exist
threadsafety = self.driver.threadsafety
# Must be a valid value
self.assertTrue(threadsafety in (0,1,2,3))
except AttributeError:
self.fail("Driver doesn't define threadsafety")
def test_paramstyle(self):
try:
# Must exist
paramstyle = self.driver.paramstyle
# Must be a valid value
self.assertTrue(paramstyle in (
'qmark','numeric','named','format','pyformat'
))
except AttributeError:
self.fail("Driver doesn't define paramstyle")
def test_Exceptions(self):
# Make sure required exceptions exist, and are in the
# defined hierarchy.
self.assertTrue(issubclass(self.driver.Warning,Exception))
self.assertTrue(issubclass(self.driver.Error,Exception))
self.assertTrue(
issubclass(self.driver.InterfaceError,self.driver.Error)
)
self.assertTrue(
issubclass(self.driver.DatabaseError,self.driver.Error)
)
self.assertTrue(
issubclass(self.driver.OperationalError,self.driver.Error)
)
self.assertTrue(
issubclass(self.driver.IntegrityError,self.driver.Error)
)
self.assertTrue(
issubclass(self.driver.InternalError,self.driver.Error)
)
self.assertTrue(
issubclass(self.driver.ProgrammingError,self.driver.Error)
)
self.assertTrue(
issubclass(self.driver.NotSupportedError,self.driver.Error)
)
def test_ExceptionsAsConnectionAttributes(self):
# OPTIONAL EXTENSION
# Test for the optional DB API 2.0 extension, where the exceptions
# are exposed as attributes on the Connection object
# I figure this optional extension will be implemented by any
# driver author who is using this test suite, so it is enabled
# by default.
con = self._connect()
drv = self.driver
self.assertTrue(con.Warning is drv.Warning)
self.assertTrue(con.Error is drv.Error)
self.assertTrue(con.InterfaceError is drv.InterfaceError)
self.assertTrue(con.DatabaseError is drv.DatabaseError)
self.assertTrue(con.OperationalError is drv.OperationalError)
self.assertTrue(con.IntegrityError is drv.IntegrityError)
self.assertTrue(con.InternalError is drv.InternalError)
self.assertTrue(con.ProgrammingError is drv.ProgrammingError)
self.assertTrue(con.NotSupportedError is drv.NotSupportedError)
def test_commit(self):
con = self._connect()
try:
# Commit must work, even if it doesn't do anything
con.commit()
finally:
con.close()
def test_rollback(self):
con = self._connect()
# If rollback is defined, it should either work or throw
# the documented exception
if hasattr(con,'rollback'):
try:
con.rollback()
except self.driver.NotSupportedError:
pass
def test_cursor(self):
con = self._connect()
try:
cur = con.cursor()
finally:
con.close()
def test_cursor_isolation(self):
con = self._connect()
try:
# Make sure cursors created from the same connection have
# the documented transaction isolation level
cur1 = con.cursor()
cur2 = con.cursor()
self.executeDDL1(cur1)
cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
cur2.execute("select name from %sbooze" % self.table_prefix)
booze = cur2.fetchall()
self.assertEqual(len(booze),1)
self.assertEqual(len(booze[0]),1)
self.assertEqual(booze[0][0],'Victoria Bitter')
finally:
con.close()
def test_description(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
self.assertEqual(cur.description,None,
'cursor.description should be none after executing a '
'statement that can return no rows (such as DDL)'
)
cur.execute('select name from %sbooze' % self.table_prefix)
self.assertEqual(len(cur.description),1,
'cursor.description describes too many columns'
)
self.assertEqual(len(cur.description[0]),8,
'cursor.description[x] tuples must have 8 elements'
)
self.assertEqual(cur.description[0][0].lower(),'name',
'cursor.description[x][0] must return column name'
)
self.assertEqual(cur.description[0][1],self.driver.STRING,
'cursor.description[x][1] must return column type. Got %r'
% cur.description[0][1]
)
# Make sure self.description gets reset
self.executeDDL2(cur)
self.assertEqual(cur.description,None,
'cursor.description not being set to None when executing '
'no-result statements (eg. DDL)'
)
finally:
con.close()
def test_rowcount(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
self.assertEqual(cur.rowcount,-1,
'cursor.rowcount should be -1 after executing no-result '
'statements'
)
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
self.assertTrue(cur.rowcount in (-1,1),
'cursor.rowcount should == number or rows inserted, or '
'set to -1 after executing an insert statement'
)
cur.execute("select name from %sbooze" % self.table_prefix, buffered=True)
self.assertTrue(cur.rowcount in (-1,1),
'cursor.rowcount should == number of rows returned, or '
'set to -1 after executing a select statement'
)
self.executeDDL2(cur)
self.assertEqual(cur.rowcount,-1,
'cursor.rowcount not being reset to -1 after executing '
'no-result statements'
)
finally:
con.close()
lower_func = 'lower'
def test_close(self):
con = self._connect()
try:
cur = con.cursor()
finally:
con.close()
# cursor.execute should raise an Error if called after connection
# closed
self.assertRaises(self.driver.Error,self.executeDDL1,cur)
# connection.commit should raise an Error if called after connection'
# closed.'
self.assertRaises(self.driver.Error,con.commit)
# connection.close should raise an Error if called more than once
self.assertRaises(self.driver.Error,con.close)
def test_execute(self):
con = self._connect()
try:
cur = con.cursor()
self._paraminsert(cur)
finally:
con.close()
def _paraminsert(self,cur):
self.executeDDL1(cur)
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
self.assertTrue(cur.rowcount in (-1,1))
if self.driver.paramstyle == 'qmark':
cur.execute(
'insert into %sbooze values (?)' % self.table_prefix,
("Cooper's",)
)
elif self.driver.paramstyle == 'numeric':
cur.execute(
'insert into %sbooze values (:1)' % self.table_prefix,
("Cooper's",)
)
elif self.driver.paramstyle == 'named':
cur.execute(
'insert into %sbooze values (:beer)' % self.table_prefix,
{'beer':"Cooper's"}
)
elif self.driver.paramstyle == 'format':
cur.execute(
'insert into %sbooze values (%%s)' % self.table_prefix,
("Cooper's",)
)
elif self.driver.paramstyle == 'pyformat':
cur.execute(
'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
{'beer':"Cooper's"}
)
else:
self.fail('Invalid paramstyle')
self.assertTrue(cur.rowcount in (-1,1))
cur.execute('select name from %sbooze' % self.table_prefix)
res = cur.fetchall()
self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
beers = [res[0][0],res[1][0]]
beers.sort()
self.assertEqual(beers[0],"Cooper's",
'cursor.fetchall retrieved incorrect data, or data inserted '
'incorrectly'
)
self.assertEqual(beers[1],"Victoria Bitter",
'cursor.fetchall retrieved incorrect data, or data inserted '
'incorrectly'
)
def test_executemany(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
largs = [ ("Cooper's",) , ("Boag's",) ]
margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
if self.driver.paramstyle == 'qmark':
cur.executemany(
'insert into %sbooze values (?)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'numeric':
cur.executemany(
'insert into %sbooze values (:1)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'named':
cur.executemany(
'insert into %sbooze values (:beer)' % self.table_prefix,
margs
)
elif self.driver.paramstyle == 'format':
cur.executemany(
'insert into %sbooze values (%%s)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'pyformat':
cur.executemany(
'insert into %sbooze values (%%(beer)s)' % (
self.table_prefix
),
margs
)
else:
self.fail('Unknown paramstyle')
self.assertTrue(cur.rowcount in (-1,2),
'insert using cursor.executemany set cursor.rowcount to '
'incorrect value %r' % cur.rowcount
)
cur.execute('select name from %sbooze' % self.table_prefix)
res = cur.fetchall()
self.assertEqual(len(res),2,
'cursor.fetchall retrieved incorrect number of rows'
)
beers = [res[0][0],res[1][0]]
beers.sort()
self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
finally:
con.close()
def test_fetchone(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchone should raise an Error if called before
# executing a select-type query
self.assertRaises(self.driver.Error,cur.fetchone)
# cursor.fetchone should raise an Error if called after
# executing a query that cannot return rows
self.executeDDL1(cur)
self.assertRaises(self.driver.Error,cur.fetchone)
cur.execute('select name from %sbooze' % self.table_prefix)
self.assertEqual(cur.fetchone(),None,
'cursor.fetchone should return None if a query retrieves '
'no rows'
)
self.assertTrue(cur.rowcount in (-1,0))
# cursor.fetchone should raise an Error if called after
# executing a query that cannot return rows
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
self.assertRaises(self.driver.Error,cur.fetchone)
cur.execute('select name from %sbooze' % self.table_prefix, buffered=True)
r = cur.fetchone()
self.assertEqual(len(r),1,
'cursor.fetchone should have retrieved a single row'
)
self.assertEqual(r[0],'Victoria Bitter',
'cursor.fetchone retrieved incorrect data'
)
self.assertEqual(cur.fetchone(),None,
'cursor.fetchone should return None if no more rows available'
)
self.assertTrue(cur.rowcount in (-1,1))
finally:
con.close()
samples = [
'Carlton Cold',
'Carlton Draft',
'Mountain Goat',
'Redback',
'Victoria Bitter',
'XXXX'
]
def _populate(self):
''' Return a list of sql commands to setup the DB for the fetch
tests.
'''
populate = [
"insert into %sbooze values ('%s')" % (self.table_prefix,s)
for s in self.samples
]
return populate
def test_fetchmany(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchmany should raise an Error if called without
#issuing a query
self.assertRaises(self.driver.Error,cur.fetchmany,4)
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
cur.execute('select name from %sbooze' % self.table_prefix)
r = cur.fetchmany()
self.assertEqual(len(r),1,
'cursor.fetchmany retrieved incorrect number of rows, '
'default of arraysize is one.'
)
cur.arraysize=10
r = cur.fetchmany(3) # Should get 3 rows
self.assertEqual(len(r),3,
'cursor.fetchmany retrieved incorrect number of rows'
)
r = cur.fetchmany(4) # Should get 2 more
self.assertEqual(len(r),2,
'cursor.fetchmany retrieved incorrect number of rows'
)
r = cur.fetchmany(4) # Should be an empty sequence
self.assertEqual(len(r),0,
'cursor.fetchmany should return an empty sequence after '
'results are exhausted'
)
self.assertTrue(cur.rowcount in (-1,6))
# Same as above, using cursor.arraysize
cur.arraysize=4
cur.execute('select name from %sbooze' % self.table_prefix)
r = cur.fetchmany() # Should get 4 rows
self.assertEqual(len(r),4,
'cursor.arraysize not being honoured by fetchmany'
)
r = cur.fetchmany() # Should get 2 more
self.assertEqual(len(r),2)
r = cur.fetchmany() # Should be an empty sequence
self.assertEqual(len(r),0)
self.assertTrue(cur.rowcount in (-1,6))
cur.arraysize=6
cur.execute('select name from %sbooze' % self.table_prefix)
rows = cur.fetchmany() # Should get all rows
self.assertTrue(cur.rowcount in (-1,6))
self.assertEqual(len(rows),6)
self.assertEqual(len(rows),6)
rows = [r[0] for r in rows]
rows.sort()
# Make sure we get the right data back out
for i in range(0,6):
self.assertEqual(rows[i],self.samples[i],
'incorrect data retrieved by cursor.fetchmany'
)
rows = cur.fetchmany() # Should return an empty list
self.assertEqual(len(rows),0,
'cursor.fetchmany should return an empty sequence if '
'called after the whole result set has been fetched'
)
self.assertTrue(cur.rowcount in (-1,6))
self.executeDDL2(cur)
cur.execute('select name from %sbarflys' % self.table_prefix)
r = cur.fetchmany() # Should get empty sequence
self.assertEqual(len(r),0,
'cursor.fetchmany should return an empty sequence if '
'query retrieved no rows'
)
self.assertTrue(cur.rowcount in (-1,0))
finally:
con.close()
def test_fetchall(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchall should raise an Error if called
# without executing a query that may return rows (such
# as a select)
self.assertRaises(self.driver.Error, cur.fetchall)
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
# cursor.fetchall should raise an Error if called
# after executing a a statement that cannot return rows
self.assertRaises(self.driver.Error,cur.fetchall)
cur.execute('select name from %sbooze' % self.table_prefix)
rows = cur.fetchall()
self.assertTrue(cur.rowcount in (-1,len(self.samples)))
self.assertEqual(len(rows),len(self.samples),
'cursor.fetchall did not retrieve all rows'
)
rows = [r[0] for r in rows]
rows.sort()
for i in range(0,len(self.samples)):
self.assertEqual(rows[i],self.samples[i],
'cursor.fetchall retrieved incorrect rows'
)
rows = cur.fetchall()
self.assertEqual(
len(rows),0,
'cursor.fetchall should return an empty list if called '
'after the whole result set has been fetched'
)
self.assertTrue(cur.rowcount in (-1,len(self.samples)))
self.executeDDL2(cur)
cur.execute('select name from %sbarflys' % self.table_prefix)
rows = cur.fetchall()
self.assertTrue(cur.rowcount in (-1,0))
self.assertEqual(len(rows),0,
'cursor.fetchall should return an empty list if '
'a select query returns no rows'
)
finally:
con.close()
def test_mixedfetch(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
cur.execute('select name from %sbooze' % self.table_prefix)
rows1 = cur.fetchone()
rows23 = cur.fetchmany(2)
rows4 = cur.fetchone()
rows56 = cur.fetchall()
self.assertTrue(cur.rowcount in (-1,6))
self.assertEqual(len(rows23),2,
'fetchmany returned incorrect number of rows'
)
self.assertEqual(len(rows56),2,
'fetchall returned incorrect number of rows'
)
rows = [rows1[0]]
rows.extend([rows23[0][0],rows23[1][0]])
rows.append(rows4[0])
rows.extend([rows56[0][0],rows56[1][0]])
rows.sort()
for i in range(0,len(self.samples)):
self.assertEqual(rows[i],self.samples[i],
'incorrect data retrieved or inserted'
)
finally:
con.close()
def help_nextset_setUp(self,cur):
''' Should create a procedure called deleteme
that returns two result sets, first the
number of rows in booze then "name from booze"
'''
raise NotImplementedError('Helper not implemented')
#sql="""
# create procedure deleteme as
# begin
# select count(*) from booze
# select name from booze
# end
#"""
#cur.execute(sql)
def help_nextset_tearDown(self,cur):
'If cleaning up is needed after nextSetTest'
raise NotImplementedError('Helper not implemented')
#cur.execute("drop procedure deleteme")
def test_arraysize(self):
# Not much here - rest of the tests for this are in test_fetchmany
con = self._connect()
try:
cur = con.cursor()
self.assertTrue(hasattr(cur,'arraysize'),
'cursor.arraysize must be defined'
)
finally:
con.close()
def test_setinputsizes(self):
con = self._connect()
try:
cur = con.cursor()
cur.setinputsizes( (25,) )
self._paraminsert(cur) # Make sure cursor still works
finally:
con.close()
def test_None(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
cur.execute('select name from %sbooze' % self.table_prefix)
r = cur.fetchall()
self.assertEqual(len(r),1)
self.assertEqual(len(r[0]),1)
self.assertEqual(r[0][0],None,'NULL value not returned as None')
finally:
con.close()
def test_Date(self):
d1 = self.driver.Date(2002,12,25)
d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(d1),str(d2))
def test_Time(self):
t1 = self.driver.Time(13,45,30)
t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Timestamp(self):
t1 = self.driver.Timestamp(2002,12,25,13,45,30)
t2 = self.driver.TimestampFromTicks(
time.mktime((2002,12,25,13,45,30,0,0,0))
)
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Binary(self):
b = self.driver.Binary(b'Something')
b = self.driver.Binary(b'')
def test_STRING(self):
self.assertTrue(hasattr(self.driver,'STRING'),
'module.STRING must be defined'
)
def test_BINARY(self):
self.assertTrue(hasattr(self.driver,'BINARY'),
'module.BINARY must be defined.'
)
def test_NUMBER(self):
self.assertTrue(hasattr(self.driver,'NUMBER'),
'module.NUMBER must be defined.'
)
def test_DATETIME(self):
self.assertTrue(hasattr(self.driver,'DATETIME'),
'module.DATETIME must be defined.'
)
def test_ROWID(self):
self.assertTrue(hasattr(self.driver,'ROWID'),
'module.ROWID must be defined.'
)

View File

@ -1,5 +0,0 @@
[client]
host=127.0.0.1
port=3306
user=root
database=test

View File

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python -O
# -*- coding: utf-8 -*-
import os
import unittest
import mariadb
from test.base_test import create_connection
from test.conf_test import conf
class TestConnection(unittest.TestCase):
def setUp(self):
self.connection = create_connection()
def tearDown(self):
del self.connection
def test_connection_default_file(self):
if os.path.exists("client.cnf"):
os.remove("client.cnf")
default_conf = conf()
f = open("client.cnf", "w+")
f.write("[client]\n")
f.write("host=%s\n" % default_conf["host"])
f.write("port=%i\n" % default_conf["port"])
f.write("database=%s\n" % default_conf["database"])
f.close()
new_conn = mariadb.connect(default_file="./client.cnf")
self.assertEqual(new_conn.database, default_conf["database"])
del new_conn
def test_autocommit(self):
conn = self.connection
cursor = conn.cursor()
self.assertEqual(conn.autocommit, True)
conn.autocommit = False
self.assertEqual(conn.autocommit, False)
conn.reset()
def test_schema(self):
default_conf = conf()
conn = self.connection
self.assertEqual(conn.database, default_conf["database"])
cursor = conn.cursor()
cursor.execute("CREATE OR REPLACE SCHEMA test1")
cursor.execute("USE test1")
self.assertEqual(conn.database, "test1")
conn.database = default_conf["database"]
self.assertEqual(conn.database, default_conf["database"])
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,558 @@
#!/usr/bin/env python -O
# -*- coding: utf-8 -*-
import collections
import datetime
import unittest
import mariadb
from test.base_test import create_connection
class TestCursor(unittest.TestCase):
def setUp(self):
self.connection = create_connection()
self.connection.autocommit = False
def tearDown(self):
del self.connection
def test_date(self):
if self.connection.server_version < 50500:
self.skipTest("microsecond not supported")
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_date(c1 TIMESTAMP(6), c2 TIME(6), c3 DATETIME(6), c4 DATE)")
t = datetime.datetime(2018, 6, 20, 12, 22, 31, 123456)
c1 = t
c2 = t.time()
c3 = t
c4 = t.date()
cursor.execute("INSERT INTO test_date VALUES (?,?,?,?)", (c1, c2, c3, c4))
cursor.execute("SELECT c1,c2,c3,c4 FROM test_date")
row = cursor.fetchone()
self.assertEqual(row[0], c1)
self.assertEqual(row[1], c2)
self.assertEqual(row[2], c3)
self.assertEqual(row[3], c4)
cursor.close()
def test_numbers(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_numbers (a tinyint unsigned, b smallint unsigned, c mediumint "
"unsigned, d int unsigned, e bigint unsigned, f double)")
c1 = 4
c2 = 200
c3 = 167557
c4 = 28688817
c5 = 7330133222578
c6 = 3.1415925
cursor.execute("insert into test_numbers values (?,?,?,?,?,?)", (c1, c2, c3, c4, c5, c6))
cursor.execute("select * from test_numbers")
row = cursor.fetchone()
self.assertEqual(row[0], c1)
self.assertEqual(row[1], c2)
self.assertEqual(row[2], c3)
self.assertEqual(row[3], c4)
self.assertEqual(row[4], c5)
self.assertEqual(row[5], c6)
del cursor
def test_string(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_string (a char(5), b varchar(100), c tinytext, "
"d mediumtext, e text, f longtext)");
c1 = "12345";
c2 = "The length of this text is < 100 characters"
c3 = "This should also fit into tinytext which has a maximum of 255 characters"
c4 = 'a' * 1000;
c5 = 'b' * 6000;
c6 = 'c' * 67000;
cursor.execute("INSERT INTO test_string VALUES (?,?,?,?,?,?)", (c1, c2, c3, c4, c5, c6))
cursor.execute("SELECT * from test_string")
row = cursor.fetchone()
self.assertEqual(row[0], c1)
self.assertEqual(row[1], c2)
self.assertEqual(row[2], c3)
self.assertEqual(row[3], c4)
self.assertEqual(row[4], c5)
self.assertEqual(row[5], c6)
del cursor
def test_blob(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE test_blob (a tinyblob, b mediumblob, c blob, "
"d longblob)")
c1 = b'a' * 100;
c2 = b'b' * 1000;
c3 = b'c' * 10000;
c4 = b'd' * 100000;
a = (None, None, None, None)
cursor.execute("INSERT INTO test_blob VALUES (?,?,?,?)", (c1, c2, c3, c4))
cursor.execute("SELECT * FROM test_blob")
row = cursor.fetchone()
self.assertEqual(row[0], c1)
self.assertEqual(row[1], c2)
self.assertEqual(row[2], c3)
self.assertEqual(row[3], c4)
del cursor
def test_fetchmany(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE test_fetchmany (id int, name varchar(64), "
"city varchar(64))");
params = [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO test_fetchmany VALUES (?,?,?)", params);
# test Errors
# a) if no select was executed
self.assertRaises(mariadb.Error, cursor.fetchall)
# b ) if cursor was not executed
del cursor
cursor = self.connection.cursor()
self.assertRaises(mariadb.Error, cursor.fetchall)
cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id")
self.assertEqual(0, cursor.rowcount)
row = cursor.fetchall()
self.assertEqual(row, params)
self.assertEqual(5, cursor.rowcount)
cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id")
self.assertEqual(0, cursor.rowcount)
row = cursor.fetchmany(1)
self.assertEqual(row, [params[0]])
self.assertEqual(1, cursor.rowcount)
row = cursor.fetchmany(2)
self.assertEqual(row, ([params[1], params[2]]))
self.assertEqual(3, cursor.rowcount)
cursor.arraysize = 1
row = cursor.fetchmany()
self.assertEqual(row, [params[3]])
self.assertEqual(4, cursor.rowcount)
cursor.arraysize = 2
row = cursor.fetchmany()
self.assertEqual(row, [params[4]])
self.assertEqual(5, cursor.rowcount)
del cursor
def test1_multi_result(self):
cursor = self.connection.cursor()
sql = """
CREATE OR REPLACE PROCEDURE p1()
BEGIN
SELECT 1 FROM DUAL;
SELECT 2 FROM DUAL;
END
"""
cursor.execute(sql)
cursor.execute("call p1()")
row = cursor.fetchone()
self.assertEqual(row[0], 1)
cursor.nextset()
row = cursor.fetchone()
self.assertEqual(row[0], 2)
del cursor
def test_buffered(self):
cursor = self.connection.cursor()
cursor.execute("SELECT 1 UNION SELECT 2 UNION SELECT 3", buffered=True)
self.assertEqual(cursor.rowcount, 3)
cursor.scroll(1)
row = cursor.fetchone()
self.assertEqual(row[0], 2)
del cursor
def test_xfield_types(self):
cursor = self.connection.cursor()
fieldinfo = mariadb.fieldinfo()
cursor.execute(
"CREATE TEMPORARY TABLE test_xfield_types (a tinyint not null auto_increment primary "
"key, b smallint, c int, d bigint, e float, f decimal, g double, h char(10), i varchar(255), j blob, index(b))");
info = cursor.description
self.assertEqual(info, None)
cursor.execute("SELECT * FROM test_xfield_types")
info = cursor.description
self.assertEqual(fieldinfo.type(info[0]), "TINY")
self.assertEqual(fieldinfo.type(info[1]), "SHORT")
self.assertEqual(fieldinfo.type(info[2]), "LONG")
self.assertEqual(fieldinfo.type(info[3]), "LONGLONG")
self.assertEqual(fieldinfo.type(info[4]), "FLOAT")
self.assertEqual(fieldinfo.type(info[5]), "NEWDECIMAL")
self.assertEqual(fieldinfo.type(info[6]), "DOUBLE")
self.assertEqual(fieldinfo.type(info[7]), "STRING")
self.assertEqual(fieldinfo.type(info[8]), "VAR_STRING")
self.assertEqual(fieldinfo.type(info[9]), "BLOB")
self.assertEqual(fieldinfo.flag(info[0]),
"NOT_NULL | PRIMARY_KEY | AUTO_INCREMENT | NUMERIC")
self.assertEqual(fieldinfo.flag(info[1]), "PART_KEY | NUMERIC")
self.assertEqual(fieldinfo.flag(info[9]), "BLOB | BINARY")
del cursor
def test_bulk_delete(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE bulk_delete (id int, name varchar(64), city varchar(64))");
params = [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO bulk_delete VALUES (?,?,?)", params)
self.assertEqual(cursor.rowcount, 5)
params = [(1, 2)]
cursor.executemany("DELETE FROM bulk_delete WHERE id=?", params)
self.assertEqual(cursor.rowcount, 2)
def test_named_tuple(self):
cursor = self.connection.cursor(named_tuple=1)
cursor.execute(
"CREATE TEMPORARY TABLE test_named_tuple (id int, name varchar(64), city varchar(64))");
params = [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO test_named_tuple VALUES (?,?,?)", params);
cursor.execute("SELECT * FROM test_named_tuple ORDER BY id")
row = cursor.fetchone()
self.assertEqual(cursor.statement, "SELECT * FROM test_named_tuple ORDER BY id")
self.assertEqual(row.id, 1)
self.assertEqual(row.name, "Jack")
self.assertEqual(row.city, "Boston")
del cursor
def test_laststatement(self):
cursor = self.connection.cursor(named_tuple=1)
cursor.execute("CREATE TEMPORARY TABLE test_laststatement (id int, name varchar(64), "
"city varchar(64))");
self.assertEqual(cursor.statement,
"CREATE TEMPORARY TABLE test_laststatement (id int, name varchar(64), city varchar(64))")
params = [(1, u"Jack", u"Boston"),
(2, u"Martin", u"Ohio"),
(3, u"James", u"Washington"),
(4, u"Rasmus", u"Helsinki"),
(5, u"Andrey", u"Sofia")]
cursor.executemany("INSERT INTO test_laststatement VALUES (?,?,?)", params);
cursor.execute("SELECT * FROM test_laststatement ORDER BY id")
self.assertEqual(cursor.statement, "SELECT * FROM test_laststatement ORDER BY id")
del cursor
def test_multi_cursor(self):
cursor = self.connection.cursor()
cursor1 = self.connection.cursor(cursor_type=1)
cursor2 = self.connection.cursor(cursor_type=1)
cursor.execute("CREATE TEMPORARY TABLE test_multi_cursor (a int)")
cursor.execute("INSERT INTO test_multi_cursor VALUES (1),(2),(3),(4),(5),(6),(7),(8)")
del cursor
cursor1.execute("SELECT a FROM test_multi_cursor ORDER BY a")
cursor2.execute("SELECT a FROM test_multi_cursor ORDER BY a DESC")
for i in range(0, 8):
self.assertEqual(cursor1.rownumber, i)
row1 = cursor1.fetchone()
row2 = cursor2.fetchone()
self.assertEqual(cursor1.rownumber, cursor2.rownumber)
self.assertEqual(row1[0] + row2[0], 9)
del cursor1
del cursor2
def test_connection_attr(self):
cursor = self.connection.cursor()
self.assertEqual(cursor.connection, self.connection)
del cursor
def test_dbapi_type(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_dbapi_type (a int, b varchar(20), c blob, d datetime, e decimal)")
cursor.execute("INSERT INTO test_dbapi_type VALUES (1, 'foo', 'blabla', now(), 10.2)");
cursor.execute("SELECT * FROM test_dbapi_type ORDER BY a")
expected_typecodes = [
mariadb.NUMBER,
mariadb.STRING,
mariadb.BINARY,
mariadb.DATETIME,
mariadb.NUMBER
]
row = cursor.fetchone()
typecodes = [row[1] for row in cursor.description]
self.assertEqual(expected_typecodes, typecodes)
del cursor
def test_tuple(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE dyncol1 (a blob)")
tpl = (1, 2, 3)
cursor.execute("INSERT INTO dyncol1 VALUES (?)", tpl)
del cursor
def test_indicator(self):
if self.connection.server_version < 100206:
self.skipTest("Requires server version >= 10.2.6")
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE ind1 (a int, b int default 2,c int)")
vals = (mariadb.indicator_null, mariadb.indicator_default, 3)
cursor.executemany("INSERT INTO ind1 VALUES (?,?,?)", [vals])
cursor.execute("SELECT a, b, c FROM ind1")
row = cursor.fetchone()
self.assertEqual(row[0], None)
self.assertEqual(row[1], 2)
self.assertEqual(row[2], 3)
def test_tuple2(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE dyncol1 (a blob)");
t = datetime.datetime(2018, 6, 20, 12, 22, 31, 123456)
val = ([1, t, 3, (1, 2, 3)],)
cursor.execute("INSERT INTO dyncol1 VALUES (?)", val);
cursor.execute("SELECT a FROM dyncol1")
row = cursor.fetchone()
self.assertEqual(row, val);
del cursor
def test_set(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE dyncol1 (a blob)")
t = datetime.datetime(2018, 6, 20, 12, 22, 31, 123456)
a = collections.OrderedDict(
[('apple', 4), ('banana', 3), ('orange', 2), ('pear', 1), ('4', 3), (4, 4)])
val = ([1, t, 3, (1, 2, 3), {1, 2, 3}, a],)
cursor.execute("INSERT INTO dyncol1 VALUES (?)", val)
cursor.execute("SELECT a FROM dyncol1")
row = cursor.fetchone()
self.assertEqual(row, val)
del cursor
def test_reset(self):
cursor = self.connection.cursor()
cursor.execute("SELECT 1 UNION SELECT 2", buffered=False)
cursor.execute("SELECT 1 UNION SELECT 2")
del cursor
def test_fake_pickle(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE test_fake_pickle (a blob)")
k = bytes([0x80, 0x03, 0x00, 0x2E])
cursor.execute("insert into test_fake_pickle values (?)", (k,))
cursor.execute("select * from test_fake_pickle");
row = cursor.fetchone()
self.assertEqual(row[0], k)
del cursor
def test_no_result(self):
cursor = self.connection.cursor()
cursor.execute("set @a:=1")
try:
row = cursor.fetchone()
except mariadb.ProgrammingError:
pass
del cursor
def test_collate(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE `test_collate` (`test` varchar(500) COLLATE "
"utf8mb4_unicode_ci NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci")
cursor.execute("SET NAMES utf8mb4")
cursor.execute(
"SELECT * FROM `test_collate` WHERE `test` LIKE 'jj' COLLATE utf8mb4_unicode_ci")
del cursor
def test_conpy_8(self):
cursor = self.connection.cursor()
sql = """
CREATE OR REPLACE PROCEDURE p1()
BEGIN
SELECT 1 FROM DUAL UNION SELECT 0 FROM DUAL;
SELECT 2 FROM DUAL;
END
"""
cursor.execute(sql)
cursor.execute("call p1()")
cursor.nextset()
row = cursor.fetchone()
self.assertEqual(row[0], 2);
del cursor
def test_conpy_7(self):
cursor = self.connection.cursor()
stmt = "SELECT 1 UNION SELECT 2 UNION SELECT 3 UNION SELECT 4"
cursor.execute(stmt, buffered=True)
cursor.scroll(2, mode='relative')
row = cursor.fetchone()
self.assertEqual(row[0], 3)
cursor.scroll(-2, mode='relative')
row = cursor.fetchone()
del cursor
def test_compy_9(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_compy_9 (a varchar(20), b double(6,3), c double)");
cursor.execute("INSERT INTO test_compy_9 VALUES ('€uro', 123.345, 12345.678)")
cursor.execute("SELECT a,b,c FROM test_compy_9")
cursor.fetchone()
d = cursor.description;
self.assertEqual(d[0][2], 4); # 4 code points only
self.assertEqual(d[0][3], -1); # variable length
self.assertEqual(d[1][2], 7); # length=precision + 1
self.assertEqual(d[1][4], 6); # precision
self.assertEqual(d[1][5], 3); # decimals
del cursor
def test_conpy_15(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_conpy_15 (a int not null auto_increment primary key, b varchar(20))");
self.assertEqual(cursor.lastrowid, 0)
cursor.execute("INSERT INTO test_conpy_15 VALUES (null, 'foo')")
self.assertEqual(cursor.lastrowid, 1)
cursor.execute("SELECT LAST_INSERT_ID()")
row = cursor.fetchone()
self.assertEqual(row[0], 1)
# del cursor
# cursor= self.connection.cursor()
vals = [(3, "bar"), (4, "this")]
cursor.executemany("INSERT INTO test_conpy_15 VALUES (?,?)", vals)
self.assertEqual(cursor.lastrowid, 4)
# Bug MDEV-16847
# cursor.execute("SELECT LAST_INSERT_ID()")
# row= cursor.fetchone()
# self.assertEqual(row[0], 4)
# Bug MDEV-16593
# vals= [(None, "bar"), (None, "foo")]
# cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals)
# self.assertEqual(cursor.lastrowid, 6)
del cursor
def test_conpy_14(self):
cursor = self.connection.cursor()
self.assertEqual(cursor.rowcount, -1)
cursor.execute(
"CREATE TEMPORARY TABLE test_conpy_14 (a int not null auto_increment primary key, b varchar(20))");
self.assertEqual(cursor.rowcount, -1)
cursor.execute("INSERT INTO test_conpy_14 VALUES (null, 'foo')")
self.assertEqual(cursor.rowcount, 1)
vals = [(3, "bar"), (4, "this")]
cursor.executemany("INSERT INTO test_conpy_14 VALUES (?,?)", vals)
self.assertEqual(cursor.rowcount, 2)
del cursor
def test_closed(self):
cursor = self.connection.cursor()
cursor.close()
cursor.close()
self.assertEqual(cursor.closed, True)
try:
cursor.execute("set @a:=1")
except mariadb.ProgrammingError:
pass
del cursor
def test_emptycursor(self):
cursor = self.connection.cursor()
try:
cursor.execute("")
except mariadb.DatabaseError:
pass
del cursor
def test_iterator(self):
cursor = self.connection.cursor()
cursor.execute("select 1 union select 2 union select 3 union select 4 union select 5")
for i, row in enumerate(cursor):
self.assertEqual(i + 1, cursor.rownumber)
self.assertEqual(i + 1, row[0])
def test_update_bulk(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE test_update_bulk (a int primary key, b int)")
vals = [(i,) for i in range(1000)]
cursor.executemany("INSERT INTO test_update_bulk VALUES (?, NULL)", vals);
self.assertEqual(cursor.rowcount, 1000)
self.connection.autocommit = False
cursor.executemany("UPDATE test_update_bulk SET b=2 WHERE a=?", vals);
self.connection.commit()
self.assertEqual(cursor.rowcount, 1000)
self.connection.autocommit = True
del cursor
def test_multi_execute(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_multi_execute (a int auto_increment primary key, b int)")
self.connection.autocommit = False
for i in range(1, 1000):
cursor.execute("INSERT INTO test_multi_execute VALUES (?,1)", (i,))
self.connection.autocommit = True
del cursor
def test_conpy21(self):
conn = self.connection
cursor = conn.cursor()
self.assertFalse(cursor.closed)
conn.close()
self.assertTrue(cursor.closed)
del cursor, conn
def test_utf8(self):
# F0 9F 98 8E 😎 unicode 6 smiling face with sunglasses
# F0 9F 8C B6 🌶 unicode 7 hot pepper
# F0 9F 8E A4 🎤 unicode 8 no microphones
# F0 9F A5 82 🥂 unicode 9 champagne glass
con = create_connection({"charset": "utf8mb4"})
cursor = con.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE `test_utf8` (`test` blob)")
cursor.execute("INSERT INTO test_utf8 VALUES (?)", ("😎🌶🎤🥂",))
cursor.execute("SELECT * FROM test_utf8")
row = cursor.fetchone()
self.assertEqual(row[0], b"\xf0\x9f\x98\x8e\xf0\x9f\x8c\xb6\xf0\x9f\x8e\xa4\xf0\x9f\xa5\x82")
del cursor, con
def test_latin2(self):
con = create_connection({"charset": "cp1251"})
print(con.character_set)
cursor = con.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE `test_latin2` (`test` blob)")
cursor.execute("INSERT INTO test_latin2 VALUES (?)", (b"\xA9\xB0",))
cursor.execute("SELECT * FROM test_latin2")
row = cursor.fetchone()
self.assertEqual(row[0], b"\xA9\xB0")
del cursor, con
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python -O
# -*- coding: utf-8 -*-
import datetime
import unittest
from test.base_test import create_connection
class CursorMariaDBTest(unittest.TestCase):
def setUp(self):
self.connection = create_connection()
def tearDown(self):
del self.connection
def test_insert_parameter(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_insert_parameter(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)")
# cursor.execute("set @@autocommit=0");
list_in = []
for i in range(1, 300001):
row = (i, i, i, "bar", datetime.date(2019, 1, 1))
list_in.append(row)
cursor.executemany("INSERT INTO test_insert_parameter VALUES (?,?,?,?,?)", list_in)
self.assertEqual(len(list_in), cursor.rowcount)
self.connection.commit()
cursor.execute("SELECT * FROM test_insert_parameter order by a")
list_out = cursor.fetchall()
self.assertEqual(len(list_in), cursor.rowcount);
self.assertEqual(list_in, list_out)
cursor.close()
def test_update_parameter(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_update_parameter(a int not null auto_increment "
"primary key, b int, c int, d varchar(20),e date)")
cursor.execute("set @@autocommit=0");
list_in = []
for i in range(1, 300001):
row = (i, i, i, "bar", datetime.date(2019, 1, 1))
list_in.append(row)
cursor.executemany("INSERT INTO test_update_parameter VALUES (?,?,?,?,?)", list_in)
self.assertEqual(len(list_in), cursor.rowcount)
self.connection.commit()
cursor.close()
list_update = [];
cursor = self.connection.cursor()
cursor.execute("set @@autocommit=0");
for i in range(1, 300001):
row = (i + 1, i);
list_update.append(row);
cursor.executemany("UPDATE test_update_parameter SET b=? WHERE a=?", list_update);
self.assertEqual(cursor.rowcount, 300000)
self.connection.commit();
cursor.close()
def test_delete_parameter(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_delete_parameter(a int not null auto_increment "
"primary key, b int, c int, d varchar(20),e date)")
cursor.execute("set @@autocommit=0");
list_in = []
for i in range(1, 300001):
row = (i, i, i, "bar", datetime.date(2019, 1, 1))
list_in.append(row)
cursor.executemany("INSERT INTO test_delete_parameter VALUES (?,?,?,?,?)", list_in)
self.assertEqual(len(list_in), cursor.rowcount)
self.connection.commit()
cursor.close()
list_delete = [];
cursor = self.connection.cursor()
cursor.execute("set @@autocommit=0");
for i in range(1, 300001):
list_delete.append((i,));
cursor.executemany("DELETE FROM test_delete_parameter WHERE a=?", list_delete);
self.assertEqual(cursor.rowcount, 300000)
self.connection.commit();
cursor.close()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python -O
# -*- coding: utf-8 -*-
import datetime
import unittest
from test.base_test import create_connection
class CursorMySQLTest(unittest.TestCase):
def setUp(self):
self.connection = create_connection()
def tearDown(self):
del self.connection
def test_parameter(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE test_parameter(a int auto_increment primary key not "
"null, b int, c int, d varchar(20),e date)")
cursor.execute("SET @@autocommit=0")
c = (1, 2, 3, "bar", datetime.date(2018, 11, 11))
list_in = []
for i in range(1, 30000):
row = (i, i, i, "bar", datetime.date(2019, 1, 1))
list_in.append(row)
cursor.executemany("INSERT INTO test_parameter VALUES (%s,%s,%s,%s,%s)", list_in)
print("rows inserted:", len(list_in))
self.connection.commit()
cursor.execute("SELECT * FROM test_parameter order by a")
list_out = cursor.fetchall()
print("rows fetched: ", len(list_out))
self.assertEqual(list_in, list_out)
cursor.close()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,794 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
''' Python DB API 2.0 driver compliance unit test suite.
This software is Public Domain and may be used without restrictions.
"Now we have booze and barflies entering the discussion, plus rumours of
DBAs on drugs... and I won't tell you what flashes through my mind each
time I read the subject line with 'Anal Compliance' in it. All around
this is turning out to be a thoroughly unwholesome unit test."
-- Ian Bicking
'''
__rcs_id__ = '$Id$'
__version__ = '$Revision$'[11:-2]
__author__ = 'Stuart Bishop <zen@shangri-la.dropbear.id.au>'
import time
import unittest
import mariadb as mariadb
# $Log$
# Revision 1.1.2.1 2006/02/25 03:44:32 adustman
# Generic DB-API unit test module
#
# Revision 1.10 2003/10/09 03:14:14 zenzen
# Add test for DB API 2.0 optional extension, where database exceptions
# are exposed as attributes on the Connection object.
#
# Revision 1.9 2003/08/13 01:16:36 zenzen
# Minor tweak from Stefan Fleiter
#
# Revision 1.8 2003/04/10 00:13:25 zenzen
# Changes, as per suggestions by M.-A. Lemburg
# - Add a table prefix, to ensure namespace collisions can always be avoided
#
# Revision 1.7 2003/02/26 23:33:37 zenzen
# Break out DDL into helper functions, as per request by David Rushby
#
# Revision 1.6 2003/02/21 03:04:33 zenzen
# Stuff from Henrik Ekelund:
# added test_None
# added test_nextset & hooks
#
# Revision 1.5 2003/02/17 22:08:43 zenzen
# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
# defaults to 1 & generic cursor.callproc test added
#
# Revision 1.4 2003/02/15 00:16:33 zenzen
# Changes, as per suggestions and bug reports by M.-A. Lemburg,
# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
# - Class renamed
# - Now a subclass of TestCase, to avoid requiring the driver stub
# to use multiple inheritance
# - Reversed the polarity of buggy test in test_description
# - Test exception hierarchy correctly
# - self.populate is now self._populate(), so if a driver stub
# overrides self.ddl1 this change propagates
# - VARCHAR columns now have a width, which will hopefully make the
# DDL even more portible (this will be reversed if it causes more problems)
# - cursor.rowcount being checked after various execute and fetchXXX methods
# - Check for fetchall and fetchmany returning empty lists after results
# are exhausted (already checking for empty lists if select retrieved
# nothing
# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
#
from test.conf_test import conf
class DatabaseAPI20Test(unittest.TestCase):
''' Test a database self.driver for DB API 2.0 compatibility.
This implementation tests Gadfly, but the TestCase
is structured so that other self.drivers can subclass this
test case to ensure compiliance with the DB-API. It is
expected that this TestCase may be expanded in the future
if ambiguities or edge conditions are discovered.
The 'Optional Extensions' are not yet being tested.
self.drivers should subclass this test, overriding setUp, tearDown,
self.driver, connect_args and connect_kw_args. Class specification
should be as follows:
import dbapi20
class mytest(dbapi20.DatabaseAPI20Test):
[...]
Don't 'import DatabaseAPI20Test from dbapi20', or you will
confuse the unit tester - just 'import dbapi20'.
'''
# The self.driver module. This should be the module where the 'connect'
# method is to be found
driver = mariadb
connect_args = () # List of arguments to pass to connect
connect_kw_args = conf() # Keyword arguments for connect
table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
xddl1 = 'drop table %sbooze' % table_prefix
xddl2 = 'drop table %sbarflys' % table_prefix
lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
# Some drivers may need to override these helpers, for example adding
# a 'commit' after the execute.
def executeDDL1(self, cursor):
cursor.execute(self.ddl1)
def executeDDL2(self, cursor):
cursor.execute(self.ddl2)
def setUp(self):
''' self.drivers should override this method to perform required setup
if any is necessary, such as creating the database.
'''
pass
def tearDown(self):
''' self.drivers should override this method to perform required cleanup
if any is necessary, such as deleting the test database.
The default drops the tables that may be created.
'''
con = self._connect()
try:
cur = con.cursor()
for ddl in (self.xddl1, self.xddl2):
try:
cur.execute(ddl)
con.commit()
except self.driver.Error:
# Assume table didn't exist. Other tests will check if
# execute is busted.
pass
finally:
con.close()
def _connect(self):
try:
return self.driver.connect(**self.connect_kw_args)
except AttributeError:
self.fail("No connect method found in self.driver module")
def test_connect(self):
con = self._connect()
con.close()
def test_apilevel(self):
try:
# Must exist
apilevel = self.driver.apilevel
# Must equal 2.0
self.assertEqual(apilevel, '2.0')
except AttributeError:
self.fail("Driver doesn't define apilevel")
# def test_threadsafety(self):
# try:
# # Must exist
# threadsafety = self.driver.threadsafety
# # Must be a valid value
# self.assertTrue(threadsafety in (0, 1, 2, 3))
# except AttributeError:
# self.fail("Driver doesn't define threadsafety")
# def test_paramstyle(self):
# try:
# # Must exist
# paramstyle = self.driver.paramstyle
# # Must be a valid value
# self.assertTrue(paramstyle in (
# 'qmark', 'numeric', 'named', 'format', 'pyformat'
# ))
# except AttributeError:
# self.fail("Driver doesn't define paramstyle")
# def test_Exceptions(self):
# # Make sure required exceptions exist, and are in the
# # defined hierarchy.
# self.assertTrue(issubclass(self.driver.Warning, Exception))
# self.assertTrue(issubclass(self.driver.Error, Exception))
# self.assertTrue(
# issubclass(self.driver.InterfaceError, self.driver.Error)
# )
# self.assertTrue(
# issubclass(self.driver.DatabaseError, self.driver.Error)
# )
# self.assertTrue(
# issubclass(self.driver.OperationalError, self.driver.Error)
# )
# self.assertTrue(
# issubclass(self.driver.IntegrityError, self.driver.Error)
# )
# self.assertTrue(
# issubclass(self.driver.InternalError, self.driver.Error)
# )
# self.assertTrue(
# issubclass(self.driver.ProgrammingError, self.driver.Error)
# )
# self.assertTrue(
# issubclass(self.driver.NotSupportedError, self.driver.Error)
# )
# def test_ExceptionsAsConnectionAttributes(self):
# # OPTIONAL EXTENSION
# # Test for the optional DB API 2.0 extension, where the exceptions
# # are exposed as attributes on the Connection object
# # I figure this optional extension will be implemented by any
# # driver author who is using this test suite, so it is enabled
# # by default.
# con = self._connect()
# drv = self.driver
# self.assertTrue(con.Warning is drv.Warning)
# self.assertTrue(con.Error is drv.Error)
# self.assertTrue(con.InterfaceError is drv.InterfaceError)
# self.assertTrue(con.DatabaseError is drv.DatabaseError)
# self.assertTrue(con.OperationalError is drv.OperationalError)
# self.assertTrue(con.IntegrityError is drv.IntegrityError)
# self.assertTrue(con.InternalError is drv.InternalError)
# self.assertTrue(con.ProgrammingError is drv.ProgrammingError)
# self.assertTrue(con.NotSupportedError is drv.NotSupportedError)
# def test_commit(self):
# con = self._connect()
# try:
# # Commit must work, even if it doesn't do anything
# con.commit()
# finally:
# con.close()
# def test_rollback(self):
# con = self._connect()
# # If rollback is defined, it should either work or throw
# # the documented exception
# if hasattr(con, 'rollback'):
# try:
# con.rollback()
# except self.driver.NotSupportedError:
# pass
# def test_cursor(self):
# con = self._connect()
# try:
# cur = con.cursor()
# finally:
# con.close()
# def test_cursor_isolation(self):
# con = self._connect()
# try:
# # Make sure cursors created from the same connection have
# # the documented transaction isolation level
# cur1 = con.cursor()
# cur2 = con.cursor()
# self.executeDDL1(cur1)
# cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
# self.table_prefix
# ))
# cur2.execute("select name from %sbooze" % self.table_prefix)
# booze = cur2.fetchall()
# self.assertEqual(len(booze), 1)
# self.assertEqual(len(booze[0]), 1)
# self.assertEqual(booze[0][0], 'Victoria Bitter')
# finally:
# con.close()
# def test_description(self):
# con = self._connect()
# try:
# cur = con.cursor()
# self.executeDDL1(cur)
# self.assertEqual(cur.description, None,
# 'cursor.description should be none after executing a '
# 'statement that can return no rows (such as DDL)'
# )
# cur.execute('select name from %sbooze' % self.table_prefix)
# self.assertEqual(len(cur.description), 1,
# 'cursor.description describes too many columns'
# )
# self.assertEqual(len(cur.description[0]), 8,
# 'cursor.description[x] tuples must have 8 elements'
# )
# self.assertEqual(cur.description[0][0].lower(), 'name',
# 'cursor.description[x][0] must return column name'
# )
# self.assertEqual(cur.description[0][1], self.driver.STRING,
# 'cursor.description[x][1] must return column type. Got %r'
# % cur.description[0][1]
# )
# # Make sure self.description gets reset
# self.executeDDL2(cur)
# self.assertEqual(cur.description, None,
# 'cursor.description not being set to None when executing '
# 'no-result statements (eg. DDL)'
# )
# finally:
# con.close()
# def test_rowcount(self):
# con = self._connect()
# try:
# cur = con.cursor()
# self.executeDDL1(cur)
# self.assertEqual(cur.rowcount, -1,
# 'cursor.rowcount should be -1 after executing no-result '
# 'statements'
# )
# cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
# self.table_prefix
# ))
# self.assertTrue(cur.rowcount in (-1, 1),
# 'cursor.rowcount should == number or rows inserted, or '
# 'set to -1 after executing an insert statement'
# )
# cur.execute("select name from %sbooze" % self.table_prefix, buffered=True)
# self.assertTrue(cur.rowcount in (-1, 1),
# 'cursor.rowcount should == number of rows returned, or '
# 'set to -1 after executing a select statement'
# )
# self.executeDDL2(cur)
# self.assertEqual(cur.rowcount, -1,
# 'cursor.rowcount not being reset to -1 after executing '
# 'no-result statements'
# )
# finally:
# con.close()
# lower_func = 'lower'
# def test_close(self):
# con = self._connect()
# try:
# cur = con.cursor()
# finally:
# con.close()
# # cursor.execute should raise an Error if called after connection
# # closed
# self.assertRaises(self.driver.Error, self.executeDDL1, cur)
# # connection.commit should raise an Error if called after connection'
# # closed.'
# self.assertRaises(self.driver.Error, con.commit)
# # connection.close should raise an Error if called more than once
# self.assertRaises(self.driver.Error, con.close)
# def test_execute(self):
# con = self._connect()
# try:
# cur = con.cursor()
# self._paraminsert(cur)
# finally:
# con.close()
# def _paraminsert(self, cur):
# self.executeDDL1(cur)
# cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
# self.table_prefix
# ))
# self.assertTrue(cur.rowcount in (-1, 1))
# if self.driver.paramstyle == 'qmark':
# cur.execute(
# 'insert into %sbooze values (?)' % self.table_prefix,
# ("Cooper's",)
# )
# elif self.driver.paramstyle == 'numeric':
# cur.execute(
# 'insert into %sbooze values (:1)' % self.table_prefix,
# ("Cooper's",)
# )
# elif self.driver.paramstyle == 'named':
# cur.execute(
# 'insert into %sbooze values (:beer)' % self.table_prefix,
# {'beer': "Cooper's"}
# )
# elif self.driver.paramstyle == 'format':
# cur.execute(
# 'insert into %sbooze values (%%s)' % self.table_prefix,
# ("Cooper's",)
# )
# elif self.driver.paramstyle == 'pyformat':
# cur.execute(
# 'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
# {'beer': "Cooper's"}
# )
# else:
# self.fail('Invalid paramstyle')
# self.assertTrue(cur.rowcount in (-1, 1))
# cur.execute('select name from %sbooze' % self.table_prefix)
# res = cur.fetchall()
# self.assertEqual(len(res), 2, 'cursor.fetchall returned too few rows')
# beers = [res[0][0], res[1][0]]
# beers.sort()
# self.assertEqual(beers[0], "Cooper's",
# 'cursor.fetchall retrieved incorrect data, or data inserted '
# 'incorrectly'
# )
# self.assertEqual(beers[1], "Victoria Bitter",
# 'cursor.fetchall retrieved incorrect data, or data inserted '
# 'incorrectly'
# )
def test_executemany(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
largs = [("Cooper's",), ("Boag's",)]
margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}]
if self.driver.paramstyle == 'qmark':
cur.executemany(
'insert into %sbooze values (?)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'numeric':
cur.executemany(
'insert into %sbooze values (:1)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'named':
cur.executemany(
'insert into %sbooze values (:beer)' % self.table_prefix,
margs
)
elif self.driver.paramstyle == 'format':
cur.executemany(
'insert into %sbooze values (%%s)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'pyformat':
cur.executemany(
'insert into %sbooze values (%%(beer)s)' % (
self.table_prefix
),
margs
)
else:
self.fail('Unknown paramstyle')
self.assertTrue(cur.rowcount in (-1, 2),
'insert using cursor.executemany set cursor.rowcount to '
'incorrect value %r' % cur.rowcount
)
cur.execute('select name from %sbooze' % self.table_prefix)
res = cur.fetchall()
self.assertEqual(len(res), 2,
'cursor.fetchall retrieved incorrect number of rows'
)
beers = [res[0][0], res[1][0]]
beers.sort()
self.assertEqual(beers[0], "Boag's", 'incorrect data retrieved')
self.assertEqual(beers[1], "Cooper's", 'incorrect data retrieved')
finally:
con.close()
# def test_fetchone(self):
# con = self._connect()
# try:
# cur = con.cursor()
# # cursor.fetchone should raise an Error if called before
# # executing a select-type query
# self.assertRaises(self.driver.Error, cur.fetchone)
# # cursor.fetchone should raise an Error if called after
# # executing a query that cannot return rows
# self.executeDDL1(cur)
# self.assertRaises(self.driver.Error, cur.fetchone)
# cur.execute('select name from %sbooze' % self.table_prefix)
# self.assertEqual(cur.fetchone(), None,
# 'cursor.fetchone should return None if a query retrieves '
# 'no rows'
# )
# self.assertTrue(cur.rowcount in (-1, 0))
# # cursor.fetchone should raise an Error if called after
# # executing a query that cannot return rows
# cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
# self.table_prefix
# ))
# self.assertRaises(self.driver.Error, cur.fetchone)
# cur.execute('select name from %sbooze' % self.table_prefix, buffered=True)
# r = cur.fetchone()
# self.assertEqual(len(r), 1,
# 'cursor.fetchone should have retrieved a single row'
# )
# self.assertEqual(r[0], 'Victoria Bitter',
# 'cursor.fetchone retrieved incorrect data'
# )
# self.assertEqual(cur.fetchone(), None,
# 'cursor.fetchone should return None if no more rows available'
# )
# self.assertTrue(cur.rowcount in (-1, 1))
# finally:
# con.close()
# samples = [
# 'Carlton Cold',
# 'Carlton Draft',
# 'Mountain Goat',
# 'Redback',
# 'Victoria Bitter',
# 'XXXX'
# ]
# def _populate(self):
# ''' Return a list of sql commands to setup the DB for the fetch
# tests.
# '''
# populate = [
# "insert into %sbooze values ('%s')" % (self.table_prefix, s)
# for s in self.samples
# ]
# return populate
# def test_fetchmany(self):
# con = self._connect()
# try:
# cur = con.cursor()
# # cursor.fetchmany should raise an Error if called without
# # issuing a query
# self.assertRaises(self.driver.Error, cur.fetchmany, 4)
# self.executeDDL1(cur)
# for sql in self._populate():
# cur.execute(sql)
# cur.execute('select name from %sbooze' % self.table_prefix)
# r = cur.fetchmany()
# self.assertEqual(len(r), 1,
# 'cursor.fetchmany retrieved incorrect number of rows, '
# 'default of arraysize is one.'
# )
# cur.arraysize = 10
# r = cur.fetchmany(3) # Should get 3 rows
# self.assertEqual(len(r), 3,
# 'cursor.fetchmany retrieved incorrect number of rows'
# )
# r = cur.fetchmany(4) # Should get 2 more
# self.assertEqual(len(r), 2,
# 'cursor.fetchmany retrieved incorrect number of rows'
# )
# r = cur.fetchmany(4) # Should be an empty sequence
# self.assertEqual(len(r), 0,
# 'cursor.fetchmany should return an empty sequence after '
# 'results are exhausted'
# )
# self.assertTrue(cur.rowcount in (-1, 6))
# # Same as above, using cursor.arraysize
# cur.arraysize = 4
# cur.execute('select name from %sbooze' % self.table_prefix)
# r = cur.fetchmany() # Should get 4 rows
# self.assertEqual(len(r), 4,
# 'cursor.arraysize not being honoured by fetchmany'
# )
# r = cur.fetchmany() # Should get 2 more
# self.assertEqual(len(r), 2)
# r = cur.fetchmany() # Should be an empty sequence
# self.assertEqual(len(r), 0)
# self.assertTrue(cur.rowcount in (-1, 6))
# cur.arraysize = 6
# cur.execute('select name from %sbooze' % self.table_prefix)
# rows = cur.fetchmany() # Should get all rows
# self.assertTrue(cur.rowcount in (-1, 6))
# self.assertEqual(len(rows), 6)
# self.assertEqual(len(rows), 6)
# rows = [r[0] for r in rows]
# rows.sort()
# # Make sure we get the right data back out
# for i in range(0, 6):
# self.assertEqual(rows[i], self.samples[i],
# 'incorrect data retrieved by cursor.fetchmany'
# )
# rows = cur.fetchmany() # Should return an empty list
# self.assertEqual(len(rows), 0,
# 'cursor.fetchmany should return an empty sequence if '
# 'called after the whole result set has been fetched'
# )
# self.assertTrue(cur.rowcount in (-1, 6))
# self.executeDDL2(cur)
# cur.execute('select name from %sbarflys' % self.table_prefix)
# r = cur.fetchmany() # Should get empty sequence
# self.assertEqual(len(r), 0,
# 'cursor.fetchmany should return an empty sequence if '
# 'query retrieved no rows'
# )
# self.assertTrue(cur.rowcount in (-1, 0))
# finally:
# con.close()
# def test_fetchall(self):
# con = self._connect()
# try:
# cur = con.cursor()
# # cursor.fetchall should raise an Error if called
# # without executing a query that may return rows (such
# # as a select)
# self.assertRaises(self.driver.Error, cur.fetchall)
# self.executeDDL1(cur)
# for sql in self._populate():
# cur.execute(sql)
# # cursor.fetchall should raise an Error if called
# # after executing a a statement that cannot return rows
# self.assertRaises(self.driver.Error, cur.fetchall)
# cur.execute('select name from %sbooze' % self.table_prefix)
# rows = cur.fetchall()
# self.assertTrue(cur.rowcount in (-1, len(self.samples)))
# self.assertEqual(len(rows), len(self.samples),
# 'cursor.fetchall did not retrieve all rows'
# )
# rows = [r[0] for r in rows]
# rows.sort()
# for i in range(0, len(self.samples)):
# self.assertEqual(rows[i], self.samples[i],
# 'cursor.fetchall retrieved incorrect rows'
# )
# rows = cur.fetchall()
# self.assertEqual(
# len(rows), 0,
# 'cursor.fetchall should return an empty list if called '
# 'after the whole result set has been fetched'
# )
# self.assertTrue(cur.rowcount in (-1, len(self.samples)))
# self.executeDDL2(cur)
# cur.execute('select name from %sbarflys' % self.table_prefix)
# rows = cur.fetchall()
# self.assertTrue(cur.rowcount in (-1, 0))
# self.assertEqual(len(rows), 0,
# 'cursor.fetchall should return an empty list if '
# 'a select query returns no rows'
# )
# finally:
# con.close()
# def test_mixedfetch(self):
# con = self._connect()
# try:
# cur = con.cursor()
# self.executeDDL1(cur)
# for sql in self._populate():
# cur.execute(sql)
# cur.execute('select name from %sbooze' % self.table_prefix)
# rows1 = cur.fetchone()
# rows23 = cur.fetchmany(2)
# rows4 = cur.fetchone()
# rows56 = cur.fetchall()
# self.assertTrue(cur.rowcount in (-1, 6))
# self.assertEqual(len(rows23), 2,
# 'fetchmany returned incorrect number of rows'
# )
# self.assertEqual(len(rows56), 2,
# 'fetchall returned incorrect number of rows'
# )
# rows = [rows1[0]]
# rows.extend([rows23[0][0], rows23[1][0]])
# rows.append(rows4[0])
# rows.extend([rows56[0][0], rows56[1][0]])
# rows.sort()
# for i in range(0, len(self.samples)):
# self.assertEqual(rows[i], self.samples[i],
# 'incorrect data retrieved or inserted'
# )
# finally:
# con.close()
# def help_nextset_setUp(self, cur):
# ''' Should create a procedure called deleteme
# that returns two result sets, first the
# number of rows in booze then "name from booze"
# '''
# raise NotImplementedError('Helper not implemented')
# # sql="""
# # create procedure deleteme as
# # begin
# # select count(*) from booze
# # select name from booze
# # end
# # """
# # cur.execute(sql)
# def help_nextset_tearDown(self, cur):
# 'If cleaning up is needed after nextSetTest'
# raise NotImplementedError('Helper not implemented')
# # cur.execute("drop procedure deleteme")
# def test_arraysize(self):
# # Not much here - rest of the tests for this are in test_fetchmany
# con = self._connect()
# try:
# cur = con.cursor()
# self.assertTrue(hasattr(cur, 'arraysize'),
# 'cursor.arraysize must be defined'
# )
# finally:
# con.close()
# def test_setinputsizes(self):
# con = self._connect()
# try:
# cur = con.cursor()
# cur.setinputsizes((25,))
# self._paraminsert(cur) # Make sure cursor still works
# finally:
# con.close()
# def test_None(self):
# con = self._connect()
# try:
# cur = con.cursor()
# self.executeDDL1(cur)
# cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
# cur.execute('select name from %sbooze' % self.table_prefix)
# r = cur.fetchall()
# self.assertEqual(len(r), 1)
# self.assertEqual(len(r[0]), 1)
# self.assertEqual(r[0][0], None, 'NULL value not returned as None')
# finally:
# con.close()
# def test_Date(self):
# d1 = self.driver.Date(2002, 12, 25)
# d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
# # Can we assume this? API doesn't specify, but it seems implied
# # self.assertEqual(str(d1),str(d2))
# def test_Time(self):
# t1 = self.driver.Time(13, 45, 30)
# t2 = self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
# # Can we assume this? API doesn't specify, but it seems implied
# # self.assertEqual(str(t1),str(t2))
# def test_Timestamp(self):
# t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30)
# t2 = self.driver.TimestampFromTicks(
# time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0))
# )
# # Can we assume this? API doesn't specify, but it seems implied
# # self.assertEqual(str(t1),str(t2))
# def test_Binary(self):
# b = self.driver.Binary(b'Something')
# b = self.driver.Binary(b'')
# def test_STRING(self):
# self.assertTrue(hasattr(self.driver, 'STRING'),
# 'module.STRING must be defined'
# )
# def test_BINARY(self):
# self.assertTrue(hasattr(self.driver, 'BINARY'),
# 'module.BINARY must be defined.'
# )
# def test_NUMBER(self):
# self.assertTrue(hasattr(self.driver, 'NUMBER'),
# 'module.NUMBER must be defined.'
# )
# def test_DATETIME(self):
# self.assertTrue(hasattr(self.driver, 'DATETIME'),
# 'module.DATETIME must be defined.'
# )
# def test_ROWID(self):
# self.assertTrue(hasattr(self.driver, 'ROWID'),
# 'module.ROWID must be defined.'
# )
#f __name__ == '__main__':
# unittest.main()

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python -O
# -*- coding: utf-8 -*-
import unittest
import mariadb
from test.base_test import create_connection
from test.conf_test import conf
class CursorTest(unittest.TestCase):
def setUp(self):
self.connection = create_connection()
def tearDown(self):
del self.connection
def test_ping(self):
new_conn = create_connection()
id = new_conn.connection_id;
self.connection.kill(id)
try:
new_conn.ping()
except mariadb.DatabaseError:
pass
del new_conn
new_conn = create_connection()
new_conn.auto_reconnect = True
id = new_conn.connection_id;
self.connection.kill(id)
new_conn.ping()
new_id = new_conn.connection_id
self.assertTrue(id != new_id)
del new_conn
def test_change_user(self):
cursor = self.connection.cursor()
cursor.execute("create or replace user foo@localhost")
cursor.execute("GRANT ALL on testp.* TO foo@localhost")
new_conn = create_connection()
default_conf = conf()
new_conn.change_user("foo", "", default_conf["database"])
self.assertEqual("foo", new_conn.user)
del new_conn
del cursor
def test_reconnect(self):
new_conn = create_connection()
conn1_id = new_conn.connection_id
self.connection.kill(conn1_id)
new_conn.reconnect()
conn2_id = new_conn.connection_id
self.assertFalse(conn1_id == conn2_id)
del new_conn
def test_reset(self):
cursor = self.connection.cursor()
cursor.execute("SELECT 1 UNION SELECT 2")
try:
self.connection.ping()
except mariadb.DatabaseError:
pass
self.connection.reset()
self.connection.ping()
del cursor
def test_warnings(self):
conn = self.connection
cursor = conn.cursor()
cursor.execute("SET session sql_mode=''")
cursor.execute("CREATE TEMPORARY TABLE test_warnings (a tinyint)")
cursor.execute("INSERT INTO test_warnings VALUES (300)")
self.assertEqual(conn.warnings, 1)
self.assertEqual(conn.warnings, cursor.warnings)
del cursor
def test_server_infos(self):
self.assertTrue(self.connection.server_info)
self.assertTrue(self.connection.server_version > 0);
def test_escape(self):
cursor = self.connection.cursor()
cursor.execute("CREATE TEMPORARY TABLE test_escape (a varchar(100))")
str = 'This is a \ and a "'
cmd = "INSERT INTO test_escape VALUES('%s')" % str
try:
cursor.execute(cmd)
except mariadb.DatabaseError:
pass
str = self.connection.escape_string(str)
cmd = "INSERT INTO test_escape VALUES('%s')" % str
cursor.execute(cmd)
del cursor
if __name__ == '__main__':
unittest.main()

View File

@ -1,110 +0,0 @@
import mariadb
import datetime
import unittest
import collections
import time
class CursorTest(unittest.TestCase):
def setUp(self):
self.connection= mariadb.connection(default_file='default.cnf')
def tearDown(self):
del self.connection
def test_ping(self):
new_conn= mariadb.connection(default_file='default.cnf')
id= new_conn.connection_id;
self.connection.kill(id)
try:
new_conn.ping()
except mariadb.DatabaseError:
pass
del new_conn
new_conn= mariadb.connection(default_file='default.cnf')
new_conn.auto_reconnect= True
id= new_conn.connection_id;
self.connection.kill(id)
new_conn.ping()
new_id= new_conn.connection_id
self.assertTrue(id != new_id)
del new_conn
def test_change_user(self):
cursor= self.connection.cursor()
cursor.execute("create or replace user foo@localhost")
cursor.execute("GRANT ALL on test.* TO 'foo'@'localhost'")
new_conn= mariadb.connection(default_file='default.cnf')
new_conn.change_user("foo", "", "test")
self.assertEqual("foo", new_conn.user)
del new_conn
cursor.execute("drop user foo@localhost")
del cursor
def test_db(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE schema test1")
self.assertEqual(self.connection.database, "test")
self.connection.database= "test1"
self.assertEqual(self.connection.database, "test1")
self.connection.database= "test"
self.assertEqual(self.connection.database, "test")
cursor.execute("USE test1")
self.assertEqual(self.connection.database, "test1")
cursor.execute("USE test")
cursor.execute("DROP SCHEMA test1")
del cursor
def test_reconnect(self):
new_conn= mariadb.connection(default_file='default.cnf')
conn1_id= new_conn.connection_id
self.connection.kill(conn1_id)
new_conn.reconnect()
conn2_id= new_conn.connection_id
self.assertFalse(conn1_id == conn2_id)
del new_conn
def test_reset(self):
cursor= self.connection.cursor()
cursor.execute("SELECT 1 UNION SELECT 2")
try:
self.connection.ping()
except mariadb.DatabaseError:
pass
self.connection.reset()
self.connection.ping()
del cursor
def test_warnings(self):
conn= self.connection
cursor= conn.cursor()
cursor.execute("SET session sql_mode=''")
cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyint)")
cursor.execute("INSERT INTO t1 VALUES (300)")
self.assertEqual(conn.warnings,1)
self.assertEqual(conn.warnings, cursor.warnings)
del cursor
def test_server_infos(self):
self.assertTrue(self.connection.server_info)
self.assertTrue(self.connection.server_version > 0);
def test_escape(self):
cursor= self.connection.cursor()
cursor.execute("CREATE OR REPLACE TABLE t1 (a varchar(100))")
str= 'This is a \ and a "'
cmd= "INSERT INTO t1 VALUES('%s')" % str
try:
cursor.execute(cmd)
except mariadb.DatabaseError:
pass
str= self.connection.escape_string(str)
cmd= "INSERT INTO t1 VALUES('%s')" % str
cursor.execute(cmd)
del cursor