diff --git a/.gitignore b/.gitignore index c6127b3..aa29d58 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ modules.order Module.symvers Mkfile.old dkms.conf +client.cnf \ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..fd0c5ad --- /dev/null +++ b/.travis.yml @@ -0,0 +1,69 @@ +sudo: true +language: python +dist: bionic + +services: +- docker + +addons: + hosts: + - mariadb.example.com + +before_install: +- chmod +x .travis/script.sh +- sudo apt-get install software-properties-common +- sudo apt-key adv --recv-keys --keyserver hkp://keyserver.ubuntu.com:80 0xF1656F24C74CD1D8 +- sudo add-apt-repository 'deb [arch=amd64,arm64,ppc64el] http://mirrors.accretive-networks.net/mariadb/repo/10.4/ubuntu bionic main' +- sudo apt-get remove --purge mysql* +- sudo apt update +- sudo apt-get install -f libmariadb3 libmariadb-dev libssl1.1 +- sudo apt-get install -f + +install: + - wget -qO- 'https://github.com/tianon/pgp-happy-eyeballs/raw/master/hack-my-builds.sh' | bash + # generate SSL certificates + - mkdir tmp + - chmod +x .travis/gen-ssl.sh + - chmod +x .travis/build/build.sh + - chmod +x .travis/build/docker-entrypoint.sh + - chmod 777 .travis/build/ + - .travis/gen-ssl.sh mariadb.example.com tmp + - export PROJ_PATH=`pwd` + - export SSLCERT=$PROJ_PATH/tmp + - export TEST_SSL_CA_FILE=$SSLCERT/server.crt + - export TEST_SSL_CLIENT_KEY_FILE=$SSLCERT/client.key + - export TEST_SSL_CLIENT_CERT_FILE=$SSLCERT/client.crt + - export TEST_SSL_CLIENT_KEYSTORE_FILE=$SSLCERT/client-keystore.p12 + +env: + global: + - TEST_PORT=3305 + - TEST_HOST=mariadb.example.com + + +matrix: + include: + - python: "2.7" + env: DB=mariadb:10.4 + - python: "3.6" + env: DB=mariadb:10.4 + - python: "3.8" + env: DB=mariadb:10.4 + - env: DB=mariadb:10.4 MAXSCALE_VERSION=2.2.9 TEST_PORT=4007 TEST_USER=bob TEXT_DATABASE=test2 SKIP_LEAK=1 + - env: DB=mariadb:5.5 + - env: DB=mariadb:10.0 + - env: DB=mariadb:10.1 + - env: DB=mariadb:10.2 + - env: DB=mariadb:10.3 + - env: DB=mysql:5.5 + - env: DB=mysql:5.6 + - env: DB=mysql:5.7 + +notifications: + email: false + +script: +- python setup.py build +- python setup.py install +- npm install nyc -g +- .travis/script.sh diff --git a/.travis/build/Dockerfile b/.travis/build/Dockerfile new file mode 100644 index 0000000..d5b0296 --- /dev/null +++ b/.travis/build/Dockerfile @@ -0,0 +1,99 @@ +# vim:set ft=dockerfile: +FROM ubuntu:bionic + +# add our user and group first to make sure their IDs get assigned consistently, regardless of whatever dependencies get added +RUN groupadd -r mysql && useradd -r -g mysql mysql + +# https://bugs.debian.org/830696 (apt uses gpgv by default in newer releases, rather than gpg) +RUN set -ex; \ + apt-get update; \ + if ! which gpg; then \ + apt-get install -y --no-install-recommends gnupg; \ + fi; \ +# Ubuntu includes "gnupg" (not "gnupg2", but still 2.x), but not dirmngr, and gnupg 2.x requires dirmngr +# so, if we're not running gnupg 1.x, explicitly install dirmngr too + if ! gpg --version | grep -q '^gpg (GnuPG) 1\.'; then \ + apt-get install -y --no-install-recommends dirmngr; \ + fi; \ + rm -rf /var/lib/apt/lists/* + +# add gosu for easy step-down from root +ENV GOSU_VERSION 1.10 +RUN set -ex; \ + \ + fetchDeps=' \ + ca-certificates \ + wget \ + '; \ + apt-get update; \ + apt-get install -y --no-install-recommends $fetchDeps; \ + rm -rf /var/lib/apt/lists/*; \ + \ + dpkgArch="$(dpkg --print-architecture | awk -F- '{ print $NF }')"; \ + wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-$dpkgArch"; \ + wget -O /usr/local/bin/gosu.asc "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-$dpkgArch.asc"; \ + \ +# verify the signature + export GNUPGHOME="$(mktemp -d)"; \ + gpg --batch --keyserver ha.pool.sks-keyservers.net --recv-keys B42F6819007F00F88E364FD4036A9C25BF357DD4; \ + gpg --batch --verify /usr/local/bin/gosu.asc /usr/local/bin/gosu; \ + command -v gpgconf > /dev/null && gpgconf --kill all || :; \ + rm -r "$GNUPGHOME" /usr/local/bin/gosu.asc; \ + \ + chmod +x /usr/local/bin/gosu; \ +# verify that the binary works + gosu nobody true; \ + \ + apt-get purge -y --auto-remove $fetchDeps + +RUN mkdir /docker-entrypoint-initdb.d + +# install "pwgen" for randomizing passwords +# install "apt-transport-https" for Percona's repo (switched to https-only) +RUN apt-get update && apt-get install -y --no-install-recommends \ + apt-transport-https ca-certificates \ + tzdata \ + pwgen \ + && rm -rf /var/lib/apt/lists/* + +RUN { \ + echo "mariadb-server-10.4" mysql-server/root_password password 'unused'; \ + echo "mariadb-server-10.4" mysql-server/root_password_again password 'unused'; \ + } | debconf-set-selections + +RUN apt-get update -y +RUN apt-get install -y software-properties-common wget +RUN apt-key adv --recv-keys --keyserver keyserver.ubuntu.com 0xcbcb082a1bb943db +RUN apt-key adv --recv-keys --keyserver ha.pool.sks-keyservers.net F1656F24C74CD1D8 +RUN echo 'deb http://yum.mariadb.org/galera/repo/deb bionic main' > /etc/apt/sources.list.d/galera-test-repo.list +RUN apt-get update -y + +RUN apt-get install -y curl libdbi-perl rsync socat galera3 libnuma1 libaio1 zlib1g-dev libreadline5 libjemalloc1 libsnappy1v5 libcrack2 + +COPY *.deb /root/ +RUN chmod 777 /root/* + +RUN dpkg --install /root/mysql-common* +RUN dpkg --install /root/mariadb-common* +RUN dpkg -R --unpack /root/ +RUN apt-get install -f -y + +RUN rm -rf /var/lib/apt/lists/* \ + && sed -ri 's/^user\s/#&/' /etc/mysql/my.cnf /etc/mysql/conf.d/* \ + && rm -rf /var/lib/mysql && mkdir -p /var/lib/mysql /var/run/mysqld \ + && chown -R mysql:mysql /var/lib/mysql /var/run/mysqld \ + && chmod 777 /var/run/mysqld \ + && find /etc/mysql/ -name '*.cnf' -print0 \ + | xargs -0 grep -lZE '^(bind-address|log)' \ + | xargs -rt -0 sed -Ei 's/^(bind-address|log)/#&/' \ + && echo '[mysqld]\nskip-host-cache\nskip-name-resolve' > /etc/mysql/conf.d/docker.cnf + +VOLUME /var/lib/mysql + +COPY docker-entrypoint.sh /usr/local/bin/ +RUN ln -s usr/local/bin/docker-entrypoint.sh / # backwards compat +ENTRYPOINT ["docker-entrypoint.sh"] + +EXPOSE 3306 +CMD ["mysqld"] + diff --git a/.travis/build/build.sh b/.travis/build/build.sh new file mode 100644 index 0000000..2975752 --- /dev/null +++ b/.travis/build/build.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +echo "**************************************************************************" +echo "* searching for last complete build" +echo "**************************************************************************" + +wget -q -o /dev/null index.html http://hasky.askmonty.org/archive/10.4/ +grep -o ">build-[0-9]*" index.html | grep -o "[0-9]*" | tac | while read -r line ; do + + curl -s --head http://hasky.askmonty.org/archive/10.4/build-$line/kvm-deb-bionic-amd64/md5sums.txt | head -n 1 | grep "HTTP/1.[01] [23].." > /dev/null + if [ $? = "0" ]; then + echo "**************************************************************************" + echo "* Processing $line" + echo "**************************************************************************" + wget -q -o /dev/null -O $line.html http://hasky.askmonty.org/archive/10.4/build-$line/kvm-deb-bionic-amd64/debs/binary/ + grep -o ">[^\"]*\.deb" $line.html | grep -o "[^>]*\.deb" | while read -r file ; do + if [[ "$file" =~ ^mariadb-plugin.* ]] ; + then + echo "skipped file: $file" + else + echo "download file: $file" + wget -q -o /dev/null -O .travis/build/$file http://hasky.askmonty.org/archive/10.4/build-$line/kvm-deb-bionic-amd64/debs/binary/$file + fi + done + + exit + else + echo "skip build $line" + fi +done + + + diff --git a/.travis/build/docker-entrypoint.sh b/.travis/build/docker-entrypoint.sh new file mode 100644 index 0000000..721cc7d --- /dev/null +++ b/.travis/build/docker-entrypoint.sh @@ -0,0 +1,196 @@ +#!/bin/bash +set -eo pipefail +shopt -s nullglob + +# if command starts with an option, prepend mysqld +if [ "${1:0:1}" = '-' ]; then + set -- mysqld "$@" +fi + +# skip setup if they want an option that stops mysqld +wantHelp= +for arg; do + case "$arg" in + -'?'|--help|--print-defaults|-V|--version) + wantHelp=1 + break + ;; + esac +done + +# usage: file_env VAR [DEFAULT] +# ie: file_env 'XYZ_DB_PASSWORD' 'example' +# (will allow for "$XYZ_DB_PASSWORD_FILE" to fill in the value of +# "$XYZ_DB_PASSWORD" from a file, especially for Docker's secrets feature) +file_env() { + local var="$1" + local fileVar="${var}_FILE" + local def="${2:-}" + if [ "${!var:-}" ] && [ "${!fileVar:-}" ]; then + echo >&2 "error: both $var and $fileVar are set (but are exclusive)" + exit 1 + fi + local val="$def" + if [ "${!var:-}" ]; then + val="${!var}" + elif [ "${!fileVar:-}" ]; then + val="$(< "${!fileVar}")" + fi + export "$var"="$val" + unset "$fileVar" +} + +_check_config() { + toRun=( "$@" --verbose --help --log-bin-index="$(mktemp -u)" ) + if ! errors="$("${toRun[@]}" 2>&1 >/dev/null)"; then + cat >&2 <<-EOM + ERROR: mysqld failed while attempting to check config + command was: "${toRun[*]}" + $errors + EOM + exit 1 + fi +} + +# Fetch value from server config +# We use mysqld --verbose --help instead of my_print_defaults because the +# latter only show values present in config files, and not server defaults +_get_config() { + local conf="$1"; shift + "$@" --verbose --help --log-bin-index="$(mktemp -u)" 2>/dev/null \ + | awk '$1 == "'"$conf"'" && /^[^ \t]/ { sub(/^[^ \t]+[ \t]+/, ""); print; exit }' + # match "datadir /some/path with/spaces in/it here" but not "--xyz=abc\n datadir (xyz)" +} + +# allow the container to be started with `--user` +if [ "$1" = 'mysqld' -a -z "$wantHelp" -a "$(id -u)" = '0' ]; then + _check_config "$@" + DATADIR="$(_get_config 'datadir' "$@")" + mkdir -p "$DATADIR" + find "$DATADIR" \! -user mysql -exec chown mysql '{}' + + exec gosu mysql "$BASH_SOURCE" "$@" +fi + +if [ "$1" = 'mysqld' -a -z "$wantHelp" ]; then + # still need to check config, container may have started with --user + _check_config "$@" + # Get config + DATADIR="$(_get_config 'datadir' "$@")" + + if [ ! -d "$DATADIR/mysql" ]; then + file_env 'MYSQL_ROOT_PASSWORD' + if [ -z "$MYSQL_ROOT_PASSWORD" -a -z "$MYSQL_ALLOW_EMPTY_PASSWORD" -a -z "$MYSQL_RANDOM_ROOT_PASSWORD" ]; then + echo >&2 'error: database is uninitialized and password option is not specified ' + echo >&2 ' You need to specify one of MYSQL_ROOT_PASSWORD, MYSQL_ALLOW_EMPTY_PASSWORD and MYSQL_RANDOM_ROOT_PASSWORD' + exit 1 + fi + + mkdir -p "$DATADIR" + + echo 'Initializing database' + installArgs=( --datadir="$DATADIR" --rpm ) + if { mysql_install_db --help || :; } | grep -q -- '--auth-root-authentication-method'; then + # beginning in 10.4.3, install_db uses "socket" which only allows system user root to connect, switch back to "normal" to allow mysql root without a password + # see https://github.com/MariaDB/server/commit/b9f3f06857ac6f9105dc65caae19782f09b47fb3 + # (this flag doesn't exist in 10.0 and below) + installArgs+=( --auth-root-authentication-method=normal ) + fi + # "Other options are passed to mysqld." (so we pass all "mysqld" arguments directly here) + mysql_install_db "${installArgs[@]}" "${@:2}" + echo 'Database initialized' + + SOCKET="$(_get_config 'socket' "$@")" + "$@" --skip-networking --socket="${SOCKET}" & + pid="$!" + + mysql=( mysql --protocol=socket -uroot -hlocalhost --socket="${SOCKET}" ) + + for i in {30..0}; do + if echo 'SELECT 1' | "${mysql[@]}" &> /dev/null; then + break + fi + echo 'MySQL init process in progress...' + sleep 1 + done + if [ "$i" = 0 ]; then + echo >&2 'MySQL init process failed.' + exit 1 + fi + + if [ -z "$MYSQL_INITDB_SKIP_TZINFO" ]; then + # sed is for https://bugs.mysql.com/bug.php?id=20545 + mysql_tzinfo_to_sql /usr/share/zoneinfo | sed 's/Local time zone must be set--see zic manual page/FCTY/' | "${mysql[@]}" mysql + fi + + if [ ! -z "$MYSQL_RANDOM_ROOT_PASSWORD" ]; then + export MYSQL_ROOT_PASSWORD="$(pwgen -1 32)" + echo "GENERATED ROOT PASSWORD: $MYSQL_ROOT_PASSWORD" + fi + + rootCreate= + # default root to listen for connections from anywhere + file_env 'MYSQL_ROOT_HOST' '%' + if [ ! -z "$MYSQL_ROOT_HOST" -a "$MYSQL_ROOT_HOST" != 'localhost' ]; then + # no, we don't care if read finds a terminating character in this heredoc + # https://unix.stackexchange.com/questions/265149/why-is-set-o-errexit-breaking-this-read-heredoc-expression/265151#265151 + read -r -d '' rootCreate <<-EOSQL || true + CREATE USER 'root'@'${MYSQL_ROOT_HOST}' IDENTIFIED BY '${MYSQL_ROOT_PASSWORD}' ; + GRANT ALL ON *.* TO 'root'@'${MYSQL_ROOT_HOST}' WITH GRANT OPTION ; + EOSQL + fi + + "${mysql[@]}" <<-EOSQL + -- What's done in this file shouldn't be replicated + -- or products like mysql-fabric won't work + SET @@SESSION.SQL_LOG_BIN=0; + DELETE FROM mysql.user WHERE user NOT IN ('mysql.sys', 'mysqlxsys', 'root') OR host NOT IN ('localhost') ; + SET PASSWORD FOR 'root'@'localhost'=PASSWORD('${MYSQL_ROOT_PASSWORD}') ; + GRANT ALL ON *.* TO 'root'@'localhost' WITH GRANT OPTION ; + ${rootCreate} + DROP DATABASE IF EXISTS test ; + FLUSH PRIVILEGES ; + EOSQL + + if [ ! -z "$MYSQL_ROOT_PASSWORD" ]; then + mysql+=( -p"${MYSQL_ROOT_PASSWORD}" ) + fi + + file_env 'MYSQL_DATABASE' + if [ "$MYSQL_DATABASE" ]; then + echo "CREATE DATABASE IF NOT EXISTS \`$MYSQL_DATABASE\` ;" | "${mysql[@]}" + mysql+=( "$MYSQL_DATABASE" ) + fi + + file_env 'MYSQL_USER' + file_env 'MYSQL_PASSWORD' + if [ "$MYSQL_USER" -a "$MYSQL_PASSWORD" ]; then + echo "CREATE USER '$MYSQL_USER'@'%' IDENTIFIED BY '$MYSQL_PASSWORD' ;" | "${mysql[@]}" + + if [ "$MYSQL_DATABASE" ]; then + echo "GRANT ALL ON \`$MYSQL_DATABASE\`.* TO '$MYSQL_USER'@'%' ;" | "${mysql[@]}" + fi + fi + + echo + for f in /docker-entrypoint-initdb.d/*; do + case "$f" in + *.sh) echo "$0: running $f"; . "$f" ;; + *.sql) echo "$0: running $f"; "${mysql[@]}" < "$f"; echo ;; + *.sql.gz) echo "$0: running $f"; gunzip -c "$f" | "${mysql[@]}"; echo ;; + *) echo "$0: ignoring $f" ;; + esac + echo + done + + if ! kill -s TERM "$pid" || ! wait "$pid"; then + echo >&2 'MySQL init process failed.' + exit 1 + fi + + echo + echo 'MySQL init process done. Ready for start up.' + echo + fi +fi + +exec "$@" \ No newline at end of file diff --git a/.travis/docker-compose.yml b/.travis/docker-compose.yml new file mode 100644 index 0000000..0792a73 --- /dev/null +++ b/.travis/docker-compose.yml @@ -0,0 +1,17 @@ +version: '2' +services: + db: + image: $DB + command: --innodb-log-file-size=400m --max-allowed-packet=40m --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --ssl-ca=/etc/sslcert/ca.crt --ssl-cert=/etc/sslcert/server.crt --ssl-key=/etc/sslcert/server.key --bind-address=0.0.0.0 $ADDITIONAL_CONF + ports: + - 3305:3306 + volumes: + - $SSLCERT:/etc/sslcert + - $ENTRYPOINT:/pam + environment: + MYSQL_DATABASE: testp + MYSQL_ALLOW_EMPTY_PASSWORD: 1 + MYSQL_ROOT_PASSWORD: + + + diff --git a/.travis/entrypoint/dbinit.sql b/.travis/entrypoint/dbinit.sql new file mode 100644 index 0000000..665bca2 --- /dev/null +++ b/.travis/entrypoint/dbinit.sql @@ -0,0 +1,11 @@ +CREATE USER 'bob'@'%'; +GRANT ALL ON *.* TO 'bob'@'%' with grant option; + +CREATE USER 'boby'@'%' identified by 'heyPassw0@rd'; +GRANT ALL ON *.* TO 'boby'@'%' with grant option; + +INSTALL PLUGIN pam SONAME 'auth_pam'; + +FLUSH PRIVILEGES; + +CREATE DATABASE test2; \ No newline at end of file diff --git a/.travis/entrypoint/pam.sh b/.travis/entrypoint/pam.sh new file mode 100644 index 0000000..4a879a0 --- /dev/null +++ b/.travis/entrypoint/pam.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +tee /etc/pam.d/mariadb << EOF +auth required pam_unix.so audit +auth required pam_unix.so audit +account required pam_unix.so audit +EOF + +useradd testPam +chpasswd << EOF +testPam:myPwd +EOF + +usermod -a -G shadow mysql + +echo "pam configuration done" \ No newline at end of file diff --git a/.travis/gen-ssl.sh b/.travis/gen-ssl.sh new file mode 100644 index 0000000..01c8fce --- /dev/null +++ b/.travis/gen-ssl.sh @@ -0,0 +1,128 @@ +#!/bin/bash +set -e + +log () { + echo "$@" 1>&2 +} + +print_error () { + echo "$@" 1>&2 + exit 1 +} + +print_usage () { + print_error "Usage: gen-ssl-cert-key " +} + +gen_cert_subject () { + local fqdn="$1" + [[ "${fqdn}" != "" ]] || print_error "FQDN cannot be blank" + echo "/C=XX/ST=X/O=X/localityName=X/CN=${fqdn}/organizationalUnitName=X/emailAddress=X/" +} + +main () { + local fqdn="$1" + local sslDir="$2" + [[ "${fqdn}" != "" ]] || print_usage + [[ -d "${sslDir}" ]] || print_error "Directory does not exist: ${sslDir}" + + local caCertFile="${sslDir}/ca.crt" + local caKeyFile="${sslDir}/ca.key" + local certFile="${sslDir}/server.crt" + local keyFile="${sslDir}/server.key" + local csrFile="${sslDir}/csrFile.key" + local clientCertFile="${sslDir}/client.crt" + local clientKeyFile="${sslDir}/client.key" + local clientKeystoreFile="${sslDir}/client-keystore.p12" + local pcks12FullKeystoreFile="${sslDir}/fullclient-keystore.p12" + local clientReqFile="${sslDir}/clientReqFile.key" + + log "Generating CA key" + openssl genrsa -out "${caKeyFile}" 2048 + + log "Generating CA certificate" + openssl req \ + -sha1 \ + -new \ + -x509 \ + -nodes \ + -days 3650 \ + -subj "$(gen_cert_subject ca.example.com)" \ + -key "${caKeyFile}" \ + -out "${caCertFile}" + + log "Generating private key" + openssl genrsa -out "${keyFile}" 2048 + + log "Generating certificate signing request" + openssl req \ + -new \ + -batch \ + -sha1 \ + -subj "$(gen_cert_subject "$fqdn")" \ + -set_serial 01 \ + -key "${keyFile}" \ + -out "${csrFile}" \ + -nodes + + log "Generating X509 certificate" + openssl x509 \ + -req \ + -sha1 \ + -set_serial 01 \ + -CA "${caCertFile}" \ + -CAkey "${caKeyFile}" \ + -days 3650 \ + -in "${csrFile}" \ + -signkey "${keyFile}" \ + -out "${certFile}" + + log "Generating client certificate" + openssl req \ + -batch \ + -newkey rsa:2048 \ + -days 3600 \ + -subj "$(gen_cert_subject "$fqdn")" \ + -nodes \ + -keyout "${clientKeyFile}" \ + -out "${clientReqFile}" + + openssl x509 \ + -req \ + -in "${clientReqFile}" \ + -days 3600 \ + -CA "${caCertFile}" \ + -CAkey "${caKeyFile}" \ + -set_serial 01 \ + -out "${clientCertFile}" + + # Now generate a keystore with the client cert & key + log "Generating client keystore" + openssl pkcs12 \ + -export \ + -in "${clientCertFile}" \ + -inkey "${clientKeyFile}" \ + -out "${clientKeystoreFile}" \ + -name "mysqlAlias" \ + -passout pass:kspass + + # Now generate a full keystore with the client cert & key + trust certificates + log "Generating full client keystore" + openssl pkcs12 \ + -export \ + -in "${clientCertFile}" \ + -inkey "${clientKeyFile}" \ + -out "${pcks12FullKeystoreFile}" \ + -name "mysqlAlias" \ + -passout pass:kspass + + # Clean up CSR file: + rm "$csrFile" + rm "$clientReqFile" + + log "Generated key file and certificate in: ${sslDir}" + ls -l "${sslDir}" +} + +main "$@" + diff --git a/.travis/maxscale-compose.yml b/.travis/maxscale-compose.yml new file mode 100644 index 0000000..b5b5095 --- /dev/null +++ b/.travis/maxscale-compose.yml @@ -0,0 +1,25 @@ +version: '2.1' +services: + maxscale: + depends_on: + - db + ports: + - 4006:4006 + - 4007:4007 + - 4008:4008 + build: + context: . + dockerfile: maxscale/Dockerfile + args: + MAXSCALE_VERSION: $MAXSCALE_VERSION + db: + image: $DB + command: --max-connections=500 --max-allowed-packet=40m --innodb-log-file-size=400m --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --ssl-ca=/etc/sslcert/ca.crt --ssl-cert=/etc/sslcert/server.crt --ssl-key=/etc/sslcert/server.key --bind-address=0.0.0.0 + ports: + - 3305:3306 + volumes: + - $SSLCERT:/etc/sslcert + - $ENTRYPOINT:/docker-entrypoint-initdb.d + environment: + MYSQL_DATABASE: testp + MYSQL_ALLOW_EMPTY_PASSWORD: 1 diff --git a/.travis/maxscale/Dockerfile b/.travis/maxscale/Dockerfile new file mode 100644 index 0000000..61e4181 --- /dev/null +++ b/.travis/maxscale/Dockerfile @@ -0,0 +1,24 @@ +FROM centos:7 + +ARG MAXSCALE_VERSION +ENV MAXSCALE_VERSION ${MAXSCALE_VERSION:-2.2.9} + +COPY maxscale/mariadb.repo /etc/yum.repos.d/ + +RUN rpm --import https://yum.mariadb.org/RPM-GPG-KEY-MariaDB \ + && yum -y install https://downloads.mariadb.com/MaxScale/${MAXSCALE_VERSION}/centos/7/x86_64/maxscale-${MAXSCALE_VERSION}-1.centos.7.x86_64.rpm \ + && yum -y update + +RUN yum -y install maxscale-${MAXSCALE_VERSION} MariaDB-client \ + && yum clean all \ + && rm -rf /tmp/* + +COPY maxscale/docker-entrypoint.sh / +RUN chmod 777 /etc/maxscale.cnf +COPY maxscale/maxscale.cnf /etc/ +RUN chmod 777 /docker-entrypoint.sh + + +EXPOSE 4006 4007 4008 + +ENTRYPOINT ["/docker-entrypoint.sh"] \ No newline at end of file diff --git a/.travis/maxscale/docker-entrypoint.sh b/.travis/maxscale/docker-entrypoint.sh new file mode 100644 index 0000000..1f2d02c --- /dev/null +++ b/.travis/maxscale/docker-entrypoint.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +set -e + +echo 'creating configuration done' + +sleep 15 + +################################################################################################# +# wait for db availability for 60s +################################################################################################# +mysql=( mysql --protocol=tcp -ubob -hdb --port=3306 ) +for i in {60..0}; do + if echo 'use test2' | "${mysql[@]}" &> /dev/null; then + break + fi + echo 'DB init process in progress...' + sleep 1 +done + +echo 'use test2' | "${mysql[@]}" +if [ "$i" = 0 ]; then + echo 'DB init process failed.' + exit 1 +fi + +echo 'maxscale launching ...' + +tail -n 500 /etc/maxscale.cnf + +/usr/bin/maxscale --user=root --nodaemon + +cd /var/log/maxscale +ls -lrt +tail -n 500 /var/log/maxscale/maxscale.log diff --git a/.travis/maxscale/mariadb.repo b/.travis/maxscale/mariadb.repo new file mode 100644 index 0000000..d15c559 --- /dev/null +++ b/.travis/maxscale/mariadb.repo @@ -0,0 +1,7 @@ +# MariaDB 10.3 CentOS repository list - created 2018-11-09 14:50 UTC +# http://downloads.mariadb.org/mariadb/repositories/ +[mariadb] +name = MariaDB +baseurl = http://yum.mariadb.org/10.3/centos7-amd64 +gpgkey=https://yum.mariadb.org/RPM-GPG-KEY-MariaDB +gpgcheck=1 \ No newline at end of file diff --git a/.travis/maxscale/maxscale.cnf b/.travis/maxscale/maxscale.cnf new file mode 100644 index 0000000..59788bd --- /dev/null +++ b/.travis/maxscale/maxscale.cnf @@ -0,0 +1,125 @@ +# MaxScale documentation on GitHub: +# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Documentation-Contents.md + +# Global parameters +# +# Complete list of configuration options: +# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Getting-Started/Configuration-Guide.md + + +[maxscale] +threads=2 +log_messages=1 +log_trace=1 +log_debug=1 + +# Server definitions +# +# Set the address of the server to the network +# address of a MySQL server. +# + +[server1] +type=server +address=db +port=3306 +protocol=MariaDBBackend +authenticator_options=skip_authentication=true +router_options=master + +# Monitor for the servers +# +# This will keep MaxScale aware of the state of the servers. +# MySQL Monitor documentation: +# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Monitors/MySQL-Monitor.md + +[MySQLMonitor] +type=monitor +module=mariadbmon +servers=server1 +user=boby +passwd=heyPassw0@rd +monitor_interval=10000 + +# Service definitions +# +# Service Definition for a read-only service and +# a read/write splitting service. +# + +# ReadConnRoute documentation: +# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Routers/ReadConnRoute.md + +[Read-OnlyService] +enable_root_user=1 +version_string=10.4.99-MariaDB-maxScale +type=service +router=readconnroute +servers=server1 +user=boby +passwd=heyPassw0@rd +router_options=slave +localhost_match_wildcard_host=1 + +[Read-WriteService] +enable_root_user=1 +version_string=10.4.99-MariaDB-maxScale +type=service +router=readwritesplit +servers=server1 +user=boby +passwd=heyPassw0@rd +localhost_match_wildcard_host=1 + +[WriteService] +type=service +router=readconnroute +servers=server1 +user=boby +passwd=heyPassw0@rd +router_options=master +localhost_match_wildcard_host=1 +version_string=10.4.99-MariaDB-maxScale + + +# This service enables the use of the MaxAdmin interface +# MaxScale administration guide: +# https://github.com/mariadb-corporation/MaxScale/blob/2.1/Documentation/Reference/MaxAdmin.mda + +[MaxAdminService] +enable_root_user=1 +version_string=10.4.99-MariaDB-maxScale +type=service +router=cli + +# Listener definitions for the services +# +# These listeners represent the ports the +# services will listen on. +# +[WriteListener] +type=listener +service=WriteService +protocol=MariaDBClient +port=4007 +#socket=/var/lib/maxscale/writeconn.sock + +[Read-OnlyListener] +type=listener +service=Read-OnlyService +protocol=MariaDBClient +port=4008 +#socket=/var/lib/maxscale/readconn.sock + +[Read-WriteListener] +type=listener +service=Read-WriteService +protocol=MariaDBClient +port=4006 +#socket=/var/lib/maxscale/rwsplit.sock + +[MaxAdminListener] +type=listener +service=MaxAdminService +protocol=maxscaled +socket=/tmp/maxadmin.sock diff --git a/.travis/script.sh b/.travis/script.sh new file mode 100644 index 0000000..5f5a313 --- /dev/null +++ b/.travis/script.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +set -x +set -e + +################################################################################################################### +# test different type of configuration +################################################################################################################### +mysql=( mysql --protocol=tcp -ubob -h127.0.0.1 --port=3305 ) + +if [ "$DB" = "build" ] ; then + .travis/build/build.sh + docker build -t build:latest --label build .travis/build/ +fi + +export ENTRYPOINT=$PROJ_PATH/.travis/entrypoint +if [ -n "$MAXSCALE_VERSION" ] ; then + ################################################################################################################### + # launch Maxscale with one server + ################################################################################################################### + export COMPOSE_FILE=.travis/maxscale-compose.yml + export ENTRYPOINT=$PROJ_PATH/.travis/sql + docker-compose -f ${COMPOSE_FILE} build + docker-compose -f ${COMPOSE_FILE} up -d + mysql=( mysql --protocol=tcp -ubob -h127.0.0.1 --port=4007 ) +else + docker-compose -f .travis/docker-compose.yml up -d +fi + +for i in {60..0}; do + if echo 'SELECT 1' | "${mysql[@]}" &> /dev/null; then + break + fi + echo 'data server still not active' + sleep 1 +done + +if [ -z "$MAXSCALE_VERSION" ] ; then + docker-compose -f .travis/docker-compose.yml exec -u root db bash /pam/pam.sh + sleep 1 + docker-compose -f .travis/docker-compose.yml stop db + sleep 1 + docker-compose -f .travis/docker-compose.yml up -d + docker-compose -f .travis/docker-compose.yml logs db + + for i in {60..0}; do + if echo 'SELECT 1' | "${mysql[@]}" &> /dev/null; then + break + fi + echo 'data server still not active' + sleep 1 + done + +fi + +python -m unittest discover -v + diff --git a/.travis/sql/dbinit.sql b/.travis/sql/dbinit.sql new file mode 100644 index 0000000..8108417 --- /dev/null +++ b/.travis/sql/dbinit.sql @@ -0,0 +1,9 @@ +CREATE USER 'bob'@'%'; +GRANT ALL ON *.* TO 'bob'@'%' with grant option; + +CREATE USER 'boby'@'%' identified by 'heyPassw0@rd'; +GRANT ALL ON *.* TO 'boby'@'%' with grant option; + +FLUSH PRIVILEGES; + +CREATE DATABASE test2; \ No newline at end of file diff --git a/README.md b/README.md index d7446c6..c50d403 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,23 @@ -# mariadb-connector-python -MariaDB Connector/Python +

+ + + +

+ +# MariaDB python connector + +[![License (LGPL version 2.1)][licence-image]][licence-url] +[![Python 3.6][python-image]][python-url] + +This package contains a MariaDB client library for connecting to MariaDB and MySQL +database servers based on PEP-249. + +## Prerequisites + +* Python 3.6 (or newer). Older Python 3 versions might work, but aren't tested. +* MariaDB Connector/C, minimum required version is 3.1 + +[licence-image]:https://img.shields.io/badge/license-GNU%20LGPL%20version%202.1-green.svg?style=flat-square +[licence-url]:http://opensource.org/licenses/LGPL-2.1 +[python-image]:https://img.shields.io/badge/python-3.6-blue.svg +[python-url]:https://www.python.org/downloads/release/python-360/ \ No newline at end of file diff --git a/include/mariadb_python.h b/include/mariadb_python.h index a123bb2..55c8bb2 100644 --- a/include/mariadb_python.h +++ b/include/mariadb_python.h @@ -1,559 +1,587 @@ -/************************************************************************************ - Copyright (C) 2018 Georg Richter and MariaDB Corporation AB - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Library General Public - License as published by the Free Software Foundation; either - version 2 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Library General Public License for more details. - - You should have received a copy of the GNU Library General Public - License along with this library; if not see - or write to the Free Software Foundation, Inc., - 51 Franklin St., Fifth Floor, Boston, MA 02110, USA -*************************************************************************************/ -#include "Python.h" -#include "bytesobject.h" -#include "structmember.h" -#include "structseq.h" -#include -#include -#include -#include -#include -#include -#include - -#if defined(_WIN32) && defined(_MSVC) -#ifndef L64 -#define L64(x) x##i64 -#endif -#else -#ifndef L64 -#define L64(x) x##LL -#endif /* L64 */ -#endif /* _WIN32 */ - -#define MAX_TPC_XID_SIZE 65 - -/* Magic constant for checking dynamic columns */ -#define PYTHON_DYNCOL_VALUE 0xA378BD8E - -enum enum_dataapi_groups -{ - DBAPI_NUMBER= 1, - DBAPI_STRING, - DBAPI_DATETIME, - DBAPI_BINARY, - DBAPI_ROWID -}; - -enum enum_dyncol_type -{ - DYNCOL_LIST= 1, - DYNCOL_TUPLE, - DYNCOL_SET, - DYNCOL_DICT, - DYNCOL_ODICT, - DYNCOL_LAST -}; - -enum enum_tpc_state -{ - TPC_STATE_NONE= 0, - TPC_STATE_XID, - TPC_STATE_PREPARE -}; - -/* PEP-249: Connection object */ -typedef struct { - PyObject_HEAD - MYSQL *mysql; - int open; - uint8_t is_buffered; - uint8_t is_closed; - enum enum_tpc_state tpc_state; - char xid[MAX_TPC_XID_SIZE]; - PyObject *dsn; /* always null */ - PyObject *tls_cipher; - PyObject *tls_version; - PyObject *host; - PyObject *unix_socket; - int port; - PyObject *charset; - PyObject *collation; -} MrdbConnection; - -typedef struct { - enum enum_field_types type; - PyObject *Value; - char indicator; -} Mariadb_Value; - -/* Parameter info for cursor.executemany() - operations */ -typedef struct { - enum enum_field_types type; - size_t bits; /* for PyLong Object */ - PyTypeObject *ob_type; - uint8_t is_negative; - uint8_t has_indicator; -} MrdbParamInfo; - -typedef struct { - PyObject *value; - char indicator; - enum enum_field_types type; - size_t length; - uint8_t free_me; - void *buffer; - unsigned char num[8]; - MYSQL_TIME tm; -} MrdbParamValue; - -typedef struct { - PyObject_HEAD - enum enum_indicator_type indicator; -} MrdbIndicator; - -/* PEP-249: Cursor object */ -typedef struct { - PyObject_HEAD - MrdbConnection *connection; - MYSQL_STMT *stmt; - MYSQL_RES *result; - PyObject *data; - uint32_t array_size; - uint32_t param_count; - uint32_t row_array_size; /* for fetch many */ - MrdbParamInfo *paraminfo; - MrdbParamValue *value; - MYSQL_BIND *params; - MYSQL_BIND *bind; - MYSQL_FIELD *fields; - char *statement; - unsigned long statement_len; - PyObject **values; - PyStructSequence_Desc sequence_desc; - PyStructSequence_Field *sequence_fields; - PyTypeObject *sequence_type; - unsigned long prefetch_rows; - unsigned long cursor_type; - int64_t affected_rows; - int64_t row_count; - unsigned long row_number; - uint8_t is_prepared; - uint8_t is_buffered; - uint8_t is_named_tuple; - uint8_t is_closed; - uint8_t is_text; -} MrdbCursor; - -typedef struct -{ - PyObject_HEAD -} Mariadb_Fieldinfo; - -typedef struct -{ - PyObject_HEAD - int32_t *types; -} Mariadb_DBAPIType; - -typedef struct { - ps_field_fetch_func func; - int pack_len; - unsigned long max_len; -} Mariadb_Conversion; - - -/* Exceptions */ -PyObject *Mariadb_InterfaceError; -PyObject *Mariadb_Error; -PyObject *Mariadb_DatabaseError; -PyObject *Mariadb_DataError; -PyObject *Mariadb_OperationalError; -PyObject *Mariadb_IntegrityError; -PyObject *Mariadb_InternalError; -PyObject *Mariadb_ProgrammingError; -PyObject *Mariadb_NotSupportedError; -PyObject *Mariadb_Warning; - -PyObject *Mrdb_Pickle; - -/* Object types */ -PyTypeObject Mariadb_Fieldinfo_Type; -PyTypeObject MrdbIndicator_Type; -PyTypeObject MrdbConnection_Type; -PyTypeObject MrdbCursor_Type; -PyTypeObject Mariadb_DBAPIType_Type; - -int Mariadb_traverse(PyObject *self, - visitproc visit, - void *arg); - -/* Function prototypes */ -void mariadb_throw_exception(void *handle, - PyObject *execption_type, - unsigned char is_statement, - const char *message, - ...); - -PyObject *MrdbIndicator_Object(uint32_t type); -long MrdbIndicator_AsLong(PyObject *v); -PyObject *Mariadb_DBAPIType_Object(uint32_t type); -PyObject *MrdbConnection_affected_rows(MrdbConnection *self); -PyObject *MrdbConnection_ping(MrdbConnection *self, PyObject *args); -PyObject *MrdbConnection_kill(MrdbConnection *self, PyObject *args); -PyObject *MrdbConnection_reconnect(MrdbConnection *self); -PyObject *MrdbConnection_reset(MrdbConnection *self); -PyObject *MrdbConnection_autocommit(MrdbConnection *self, - PyObject *args); -PyObject *MrdbConnection_change_user(MrdbConnection *self, - PyObject *args); -PyObject *MrdbConnection_rollback(MrdbConnection *self); -PyObject *MrdbConnection_commit(MrdbConnection *self); -PyObject *MrdbConnection_close(MrdbConnection *self); -PyObject *MrdbConnection_connect( PyObject *self,PyObject *args, PyObject *kwargs); -void MrdbConnection_SetAttributes(MrdbConnection *self); - -/* TPC methods */ -PyObject *MrdbConnection_xid(MrdbConnection *self, PyObject *args); -PyObject *MrdbConnection_tpc_begin(MrdbConnection *self, PyObject *args); -PyObject *MrdbConnection_tpc_commit(MrdbConnection *self, PyObject *args); -PyObject *MrdbConnection_tpc_rollback(MrdbConnection *self, PyObject *args); -PyObject *MrdbConnection_tpc_prepare(MrdbConnection *self); -PyObject *MrdbConnection_tpc_recover(MrdbConnection *self); - -/* codecs prototypes */ -uint8_t mariadb_check_bulk_parameters(MrdbCursor *self, - PyObject *data); -uint8_t mariadb_check_execute_parameters(MrdbCursor *self, - PyObject *data); -uint8_t mariadb_param_update(void *data, MYSQL_BIND *bind, uint32_t row_nr); -/* Global defines */ - - -#define MARIADB_PY_APILEVEL "2.0" -#define MARIADB_PY_PARAMSTYLE "qmark" -#define MARIADB_PY_THREADSAFETY 1 - - -/* Helper macros */ - -#define MrdbIndicator_Check(a)\ -(Py_TYPE((a)) == &MrdbIndicator_Type) - -#define MARIADB_FEATURE_SUPPORTED(mysql,version)\ -(mysql_get_server_version((mysql)) >= (version)) - -#define MARIADB_CHECK_CONNECTION(connection, ret)\ -if (!connection || !connection->mysql) {\ - mariadb_throw_exception(connection->mysql, Mariadb_Error, 0,\ - "Invalid connection or not connected");\ - return (ret);\ -} - -#define MARIADB_CHECK_TPC(connection)\ -if (connection->tpc_state == TPC_STATE_NONE) {\ - mariadb_throw_exception(connection->mysql, Mariadb_ProgrammingError, 0,\ - "Transaction not started");\ - return NULL;\ -} - -#define MARIADB_FREE_MEM(a)\ -if (a) {\ - PyMem_RawFree((a));\ - (a)= NULL;\ -} - -#define MARIADB_CHECK_STMT(cursor)\ -if (!cursor->stmt || !cursor->stmt->mysql || cursor->is_closed)\ -{\ - (cursor)->is_closed= 1;\ - mariadb_throw_exception(cursor->stmt, Mariadb_ProgrammingError, 1,\ - "Invalid cursor or not connected");\ -} - -/* MariaDB protocol macros */ -#define int1store(T,A) *((int8_t*) (T)) = (A) -#define uint1korr(A) (*(((uint8_t*)(A)))) -#if defined(__i386__) || defined(_WIN32) -#define sint2korr(A) (*((int16_t *) (A))) -#define sint3korr(A) ((int32_t) ((((unsigned char) (A)[2]) & 128) ? \ - (((uint32_t) 255L << 24) | \ - (((uint32_t) (unsigned char) (A)[2]) << 16) |\ - (((uint32_t) (unsigned char) (A)[1]) << 8) | \ - ((uint32_t) (unsigned char) (A)[0])) : \ - (((uint32_t) (unsigned char) (A)[2]) << 16) |\ - (((uint32_t) (unsigned char) (A)[1]) << 8) | \ - ((uint32_t) (unsigned char) (A)[0]))) -#define sint4korr(A) (*((long *) (A))) -#define uint2korr(A) (*((uint16_t *) (A))) -#if defined(HAVE_purify) && !defined(_WIN32) -#define uint3korr(A) (uint32_t) (((uint32_t) ((unsigned char) (A)[0])) +\ - (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ - (((uint32_t) ((unsigned char) (A)[2])) << 16)) -#else -/* - ATTENTION ! - - Please, note, uint3korr reads 4 bytes (not 3) ! - It means, that you have to provide enough allocated space ! -*/ -#define uint3korr(A) (long) (*((unsigned int *) (A)) & 0xFFFFFF) -#endif /* HAVE_purify && !_WIN32 */ -#define uint4korr(A) (*((uint32_t *) (A))) -#define uint5korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) +\ - (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ - (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ - (((uint32_t) ((unsigned char) (A)[3])) << 24)) +\ - (((unsigned long long) ((unsigned char) (A)[4])) << 32)) -#define uint6korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) + \ - (((uint32_t) ((unsigned char) (A)[1])) << 8) + \ - (((uint32_t) ((unsigned char) (A)[2])) << 16) + \ - (((uint32_t) ((unsigned char) (A)[3])) << 24)) + \ - (((unsigned long long) ((unsigned char) (A)[4])) << 32) + \ - (((unsigned long long) ((unsigned char) (A)[5])) << 40)) -#define uint8_tkorr(A) (*((unsigned long long *) (A))) -#define sint8korr(A) (*((long long *) (A))) -#define int2store(T,A) *((uint16_t*) (T))= (uint16_t) (A) -#define int3store(T,A) do { *(T)= (unsigned char) ((A));\ - *(T+1)=(unsigned char) (((uint) (A) >> 8));\ - *(T+2)=(unsigned char) (((A) >> 16)); } while (0) -#define int4store(T,A) *((long *) (T))= (long) (A) -#define int5store(T,A) do { *(T)= (unsigned char)((A));\ - *((T)+1)=(unsigned char) (((A) >> 8));\ - *((T)+2)=(unsigned char) (((A) >> 16));\ - *((T)+3)=(unsigned char) (((A) >> 24)); \ - *((T)+4)=(unsigned char) (((A) >> 32)); } while(0) -#define int6store(T,A) do { *(T)= (unsigned char)((A)); \ - *((T)+1)=(unsigned char) (((A) >> 8)); \ - *((T)+2)=(unsigned char) (((A) >> 16)); \ - *((T)+3)=(unsigned char) (((A) >> 24)); \ - *((T)+4)=(unsigned char) (((A) >> 32)); \ - *((T)+5)=(unsigned char) (((A) >> 40)); } while(0) -#define int8store(T,A) *((unsigned long long *) (T))= (unsigned long long) (A) - -typedef union { - double v; - long m[2]; -} doubleget_union; -#define doubleget(V,M) \ -do { doubleget_union _tmp; \ - _tmp.m[0] = *((long*)(M)); \ - _tmp.m[1] = *(((long*) (M))+1); \ - (V) = _tmp.v; } while(0) -#define doublestore(T,V) do { *((long *) T) = ((doubleget_union *)&V)->m[0]; \ - *(((long *) T)+1) = ((doubleget_union *)&V)->m[1]; \ - } while (0) -#define float4get(V,M) do { *((float *) &(V)) = *((float*) (M)); } while(0) -#define float8get(V,M) doubleget((V),(M)) -#define float4store(V,M) memcpy((unsigned char*) V,(unsigned char*) (&M),sizeof(float)) -#define floatstore(T,V) memcpy((unsigned char*)(T), (unsigned char*)(&V),sizeof(float)) -#define floatget(V,M) memcpy((unsigned char*) &V,(unsigned char*) (M),sizeof(float)) -#define float8store(V,M) doublestore((V),(M)) -#else - -/* - We're here if it's not a IA-32 architecture (Win32 and UNIX IA-32 defines - were done before) -*/ -#define sint2korr(A) (int16_t) (((int16_t) ((unsigned char) (A)[0])) +\ - ((int16_t) ((int16_t) (A)[1]) << 8)) -#define sint3korr(A) ((int32_t) ((((unsigned char) (A)[2]) & 128) ? \ - (((uint32_t) 255L << 24) | \ - (((uint32_t) (unsigned char) (A)[2]) << 16) |\ - (((uint32_t) (unsigned char) (A)[1]) << 8) | \ - ((uint32_t) (unsigned char) (A)[0])) : \ - (((uint32_t) (unsigned char) (A)[2]) << 16) |\ - (((uint32_t) (unsigned char) (A)[1]) << 8) | \ - ((uint32_t) (unsigned char) (A)[0]))) -#define sint4korr(A) (int32_t) (((int32_t) ((unsigned char) (A)[0])) +\ - (((int32_t) ((unsigned char) (A)[1]) << 8)) +\ - (((int32_t) ((unsigned char) (A)[2]) << 16)) +\ - (((int32_t) ((int16_t) (A)[3]) << 24))) -#define sint8korr(A) (long long) uint8korr(A) -#define uint2korr(A) (uint16_t) (((uint16_t) ((unsigned char) (A)[0])) +\ - ((uint16_t) ((unsigned char) (A)[1]) << 8)) -#define uint3korr(A) (uint32_t) (((uint32_t) ((unsigned char) (A)[0])) +\ - (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ - (((uint32_t) ((unsigned char) (A)[2])) << 16)) -#define uint4korr(A) (uint32_t) (((uint32_t) ((unsigned char) (A)[0])) +\ - (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ - (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ - (((uint32_t) ((unsigned char) (A)[3])) << 24)) -#define uint5korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) +\ - (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ - (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ - (((uint32_t) ((unsigned char) (A)[3])) << 24)) +\ - (((unsigned long long) ((unsigned char) (A)[4])) << 32)) -#define uint6korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) + \ - (((uint32_t) ((unsigned char) (A)[1])) << 8) + \ - (((uint32_t) ((unsigned char) (A)[2])) << 16) + \ - (((uint32_t) ((unsigned char) (A)[3])) << 24)) + \ - (((unsigned long long) ((unsigned char) (A)[4])) << 32) + \ - (((unsigned long long) ((unsigned char) (A)[5])) << 40)) -#define uint8korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) +\ - (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ - (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ - (((uint32_t) ((unsigned char) (A)[3])) << 24)) +\ - (((unsigned long long) (((uint32_t) ((unsigned char) (A)[4])) +\ - (((uint32_t) ((unsigned char) (A)[5])) << 8) +\ - (((uint32_t) ((unsigned char) (A)[6])) << 16) +\ - (((uint32_t) ((unsigned char) (A)[7])) << 24))) <<\ - 32)) -#define int2store(T,A) do { uint def_temp= (uint) (A) ;\ - *((unsigned char*) (T))= (unsigned char)(def_temp); \ - *((unsigned char*) (T)+1)=(unsigned char)((def_temp >> 8)); \ - } while(0) -#define int3store(T,A) do { /*lint -save -e734 */\ - *((unsigned char*)(T))=(unsigned char) ((A));\ - *((unsigned char*) (T)+1)=(unsigned char) (((A) >> 8));\ - *((unsigned char*)(T)+2)=(unsigned char) (((A) >> 16)); \ - /*lint -restore */} while(0) -#define int4store(T,A) do { *((char *)(T))=(char) ((A));\ - *(((char *)(T))+1)=(char) (((A) >> 8));\ - *(((char *)(T))+2)=(char) (((A) >> 16));\ - *(((char *)(T))+3)=(char) (((A) >> 24)); } while(0) -#define int5store(T,A) do { *((char *)(T))= (char)((A)); \ - *(((char *)(T))+1)= (char)(((A) >> 8)); \ - *(((char *)(T))+2)= (char)(((A) >> 16)); \ - *(((char *)(T))+3)= (char)(((A) >> 24)); \ - *(((char *)(T))+4)= (char)(((A) >> 32)); \ - } while(0) -#define int6store(T,A) do { *((char *)(T))= (char)((A)); \ - *(((char *)(T))+1)= (char)(((A) >> 8)); \ - *(((char *)(T))+2)= (char)(((A) >> 16)); \ - *(((char *)(T))+3)= (char)(((A) >> 24)); \ - *(((char *)(T))+4)= (char)(((A) >> 32)); \ - *(((char *)(T))+5)= (char)(((A) >> 40)); \ - } while(0) -#define int8store(T,A) do { uint def_temp= (uint) (A), def_temp2= (uint) ((A) >> 32); \ - int4store((T),def_temp); \ - int4store((T+4),def_temp2); } while(0) -#ifdef WORDS_BIGENDIAN -#define float4store(T,A) do { *(T)= ((unsigned char *) &A)[3];\ - *((T)+1)=(char) ((unsigned char *) &A)[2];\ - *((T)+2)=(char) ((unsigned char *) &A)[1];\ - *((T)+3)=(char) ((unsigned char *) &A)[0]; } while(0) - -#define float4get(V,M) do { float def_temp;\ - ((unsigned char*) &def_temp)[0]=(M)[3];\ - ((unsigned char*) &def_temp)[1]=(M)[2];\ - ((unsigned char*) &def_temp)[2]=(M)[1];\ - ((unsigned char*) &def_temp)[3]=(M)[0];\ - (V)=def_temp; } while(0) -#define float8store(T,V) do { *(T)= ((unsigned char *) &V)[7];\ - *((T)+1)=(char) ((unsigned char *) &V)[6];\ - *((T)+2)=(char) ((unsigned char *) &V)[5];\ - *((T)+3)=(char) ((unsigned char *) &V)[4];\ - *((T)+4)=(char) ((unsigned char *) &V)[3];\ - *((T)+5)=(char) ((unsigned char *) &V)[2];\ - *((T)+6)=(char) ((unsigned char *) &V)[1];\ - *((T)+7)=(char) ((unsigned char *) &V)[0]; } while(0) - -#define float8get(V,M) do { double def_temp;\ - ((unsigned char*) &def_temp)[0]=(M)[7];\ - ((unsigned char*) &def_temp)[1]=(M)[6];\ - ((unsigned char*) &def_temp)[2]=(M)[5];\ - ((unsigned char*) &def_temp)[3]=(M)[4];\ - ((unsigned char*) &def_temp)[4]=(M)[3];\ - ((unsigned char*) &def_temp)[5]=(M)[2];\ - ((unsigned char*) &def_temp)[6]=(M)[1];\ - ((unsigned char*) &def_temp)[7]=(M)[0];\ - (V) = def_temp; } while(0) -#else -#define float4get(V,M) memcpy(&V, (M), sizeof(float)) -#define float4store(V,M) memcpy(V, (&M), sizeof(float)) - -#if defined(__FLOAT_WORD_ORDER) && (__FLOAT_WORD_ORDER == __BIG_ENDIAN) -#define doublestore(T,V) do { *(((char*)T)+0)=(char) ((unsigned char *) &V)[4];\ - *(((char*)T)+1)=(char) ((unsigned char *) &V)[5];\ - *(((char*)T)+2)=(char) ((unsigned char *) &V)[6];\ - *(((char*)T)+3)=(char) ((unsigned char *) &V)[7];\ - *(((char*)T)+4)=(char) ((unsigned char *) &V)[0];\ - *(((char*)T)+5)=(char) ((unsigned char *) &V)[1];\ - *(((char*)T)+6)=(char) ((unsigned char *) &V)[2];\ - *(((char*)T)+7)=(char) ((unsigned char *) &V)[3]; }\ - while(0) -#define doubleget(V,M) do { double def_temp;\ - ((unsigned char*) &def_temp)[0]=(M)[4];\ - ((unsigned char*) &def_temp)[1]=(M)[5];\ - ((unsigned char*) &def_temp)[2]=(M)[6];\ - ((unsigned char*) &def_temp)[3]=(M)[7];\ - ((unsigned char*) &def_temp)[4]=(M)[0];\ - ((unsigned char*) &def_temp)[5]=(M)[1];\ - ((unsigned char*) &def_temp)[6]=(M)[2];\ - ((unsigned char*) &def_temp)[7]=(M)[3];\ - (V) = def_temp; } while(0) -#endif /* __FLOAT_WORD_ORDER */ - -#define float8get(V,M) doubleget((V),(M)) -#define float8store(V,M) doublestore((V),(M)) -#endif /* WORDS_BIGENDIAN */ - -#ifdef HAVE_BIGENDIAN - -#define ushortget(V,M) do { V = (uint16_t) (((uint16_t) ((unsigned char) (M)[1]))+\ - ((uint16_t) ((uint16_t) (M)[0]) << 8)); } while(0) -#define shortget(V,M) do { V = (short) (((short) ((unsigned char) (M)[1]))+\ - ((short) ((short) (M)[0]) << 8)); } while(0) -#define longget(V,M) do { int32 def_temp;\ - ((unsigned char*) &def_temp)[0]=(M)[0];\ - ((unsigned char*) &def_temp)[1]=(M)[1];\ - ((unsigned char*) &def_temp)[2]=(M)[2];\ - ((unsigned char*) &def_temp)[3]=(M)[3];\ - (V)=def_temp; } while(0) -#define ulongget(V,M) do { uint32 def_temp;\ - ((unsigned char*) &def_temp)[0]=(M)[0];\ - ((unsigned char*) &def_temp)[1]=(M)[1];\ - ((unsigned char*) &def_temp)[2]=(M)[2];\ - ((unsigned char*) &def_temp)[3]=(M)[3];\ - (V)=def_temp; } while(0) -#define shortstore(T,A) do { uint def_temp=(uint) (A) ;\ - *(((char*)T)+1)=(char)(def_temp); \ - *(((char*)T)+0)=(char)(def_temp >> 8); } while(0) -#define longstore(T,A) do { *(((char*)T)+3)=((A));\ - *(((char*)T)+2)=(((A) >> 8));\ - *(((char*)T)+1)=(((A) >> 16));\ - *(((char*)T)+0)=(((A) >> 24)); } while(0) - -#define floatget(V,M) memcpy(&V, (M), sizeof(float)) -#define floatstore(T,V) memcpy((T), (void*) (&V), sizeof(float)) -#define doubleget(V,M) memcpy(&V, (M), sizeof(double)) -#define doublestore(T,V) memcpy((T), (void *) &V, sizeof(double)) -#define longlongget(V,M) memcpy(&V, (M), sizeof(unsigned long long)) -#define longlongstore(T,V) memcpy((T), &V, sizeof(unsigned long long)) - -#else - -#define ushortget(V,M) do { V = uint2korr(M); } while(0) -#define shortget(V,M) do { V = sint2korr(M); } while(0) -#define longget(V,M) do { V = sint4korr(M); } while(0) -#define ulongget(V,M) do { V = uint4korr(M); } while(0) -#define shortstore(T,V) int2store(T,V) -#define longstore(T,V) int4store(T,V) -#ifndef floatstore -#define floatstore(T,V) memcpy((T), (void *) (&V), sizeof(float)) -#define floatget(V,M) memcpy(&V, (M), sizeof(float)) -#endif -#ifndef doubleget -#define doubleget(V,M) memcpy(&V, (M), sizeof(double)) -#define doublestore(T,V) memcpy((T), (void *) &V, sizeof(double)) -#endif /* doubleget */ -#define longlongget(V,M) memcpy(&V, (M), sizeof(unsigned long long)) -#define longlongstore(T,V) memcpy((T), &V, sizeof(unsigned long long)) - -#endif /* WORDS_BIGENDIAN */ - - -#endif /* __i386__ OR _WIN32 */ - -#ifdef _WIN32 -//#define alloca _malloca -#endif +/************************************************************************************ + Copyright (C) 2018 Georg Richter and MariaDB Corporation AB + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Library General Public + License as published by the Free Software Foundation; either + version 2 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Library General Public License for more details. + + You should have received a copy of the GNU Library General Public + License along with this library; if not see + or write to the Free Software Foundation, Inc., + 51 Franklin St., Fifth Floor, Boston, MA 02110, USA +*************************************************************************************/ +#define PY_SSIZE_T_CLEAN +#include "Python.h" +#include "bytesobject.h" +#include "structmember.h" +#include "structseq.h" +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) && defined(_MSVC) +#ifndef L64 +#define L64(x) x##i64 +#endif +#else +#ifndef L64 +#define L64(x) x##LL +#endif /* L64 */ +#endif /* _WIN32 */ + +#define MAX_TPC_XID_SIZE 65 + +/* Magic constant for checking dynamic columns */ +#define PYTHON_DYNCOL_VALUE 0xA378BD8E + +enum enum_dataapi_groups +{ + DBAPI_NUMBER= 1, + DBAPI_STRING, + DBAPI_DATETIME, + DBAPI_BINARY, + DBAPI_ROWID +}; + +enum enum_dyncol_type +{ + DYNCOL_LIST= 1, + DYNCOL_TUPLE, + DYNCOL_SET, + DYNCOL_DICT, + DYNCOL_ODICT, + DYNCOL_LAST +}; + +enum enum_tpc_state +{ + TPC_STATE_NONE= 0, + TPC_STATE_XID, + TPC_STATE_PREPARE +}; + + +typedef struct st_lex_str { + char *str; + size_t length; +} MrdbString; + +typedef struct st_parser { + MrdbString statement; + uint8_t in_literal[3]; + uint8_t in_comment; + uint8_t in_values; + uint8_t is_insert; + uint8_t comment_eol; + uint32_t param_count; + uint32_t key_count; + char* value_ofs; + MrdbString *keys; +} Mrdb_Parser; + +/* PEP-249: Connection object */ +typedef struct { + PyObject_HEAD + MYSQL *mysql; + int open; + uint8_t is_buffered; + uint8_t is_closed; + enum enum_tpc_state tpc_state; + char xid[MAX_TPC_XID_SIZE]; + PyObject *dsn; /* always null */ + PyObject *tls_cipher; + PyObject *tls_version; + PyObject *host; + PyObject *unix_socket; + int port; + PyObject *charset; + PyObject *collation; +} MrdbConnection; + +typedef struct { + enum enum_field_types type; + PyObject *Value; + char indicator; +} Mariadb_Value; + +/* Parameter info for cursor.executemany() + operations */ +typedef struct { + enum enum_field_types type; + size_t bits; /* for PyLong Object */ + PyTypeObject *ob_type; + uint8_t is_negative; + uint8_t has_indicator; +} MrdbParamInfo; + +typedef struct { + PyObject *value; + char indicator; + enum enum_field_types type; + size_t length; + uint8_t free_me; + void *buffer; + unsigned char num[8]; + MYSQL_TIME tm; +} MrdbParamValue; + +typedef struct { + PyObject_HEAD + enum enum_indicator_type indicator; +} MrdbIndicator; + +/* PEP-249: Cursor object */ +typedef struct { + PyObject_HEAD + MrdbConnection *connection; + MYSQL_STMT *stmt; + MYSQL_RES *result; + PyObject *data; + uint32_t array_size; + uint32_t param_count; + uint32_t row_array_size; /* for fetch many */ + MrdbParamInfo *paraminfo; + MrdbParamValue *value; + MYSQL_BIND *params; + MYSQL_BIND *bind; + MYSQL_FIELD *fields; + char *statement; + unsigned long statement_len; + PyObject **values; + PyStructSequence_Desc sequence_desc; + PyStructSequence_Field *sequence_fields; + PyTypeObject *sequence_type; + unsigned long prefetch_rows; + unsigned long cursor_type; + int64_t affected_rows; + int64_t row_count; + unsigned long row_number; + uint8_t is_prepared; + uint8_t is_buffered; + uint8_t is_named_tuple; + uint8_t is_closed; + uint8_t is_text; + Mrdb_Parser *parser; +} MrdbCursor; + +typedef struct +{ + PyObject_HEAD +} Mariadb_Fieldinfo; + +typedef struct +{ + PyObject_HEAD + int32_t *types; +} Mariadb_DBAPIType; + +typedef struct { + ps_field_fetch_func func; + int pack_len; + unsigned long max_len; +} Mariadb_Conversion; + + +/* Exceptions */ +PyObject *Mariadb_InterfaceError; +PyObject *Mariadb_Error; +PyObject *Mariadb_DatabaseError; +PyObject *Mariadb_DataError; +PyObject *Mariadb_OperationalError; +PyObject *Mariadb_IntegrityError; +PyObject *Mariadb_InternalError; +PyObject *Mariadb_ProgrammingError; +PyObject *Mariadb_NotSupportedError; +PyObject *Mariadb_Warning; + +PyObject *Mrdb_Pickle; + +/* Object types */ +PyTypeObject Mariadb_Fieldinfo_Type; +PyTypeObject MrdbIndicator_Type; +PyTypeObject MrdbConnection_Type; +PyTypeObject MrdbCursor_Type; +PyTypeObject Mariadb_DBAPIType_Type; + +int Mariadb_traverse(PyObject *self, + visitproc visit, + void *arg); + +/* Function prototypes */ +void mariadb_throw_exception(void *handle, + PyObject *execption_type, + unsigned char is_statement, + const char *message, + ...); + +PyObject *MrdbIndicator_Object(uint32_t type); +long MrdbIndicator_AsLong(PyObject *v); +PyObject *Mariadb_DBAPIType_Object(uint32_t type); +PyObject *MrdbConnection_affected_rows(MrdbConnection *self); +PyObject *MrdbConnection_ping(MrdbConnection *self); +PyObject *MrdbConnection_kill(MrdbConnection *self, PyObject *args); +PyObject *MrdbConnection_reconnect(MrdbConnection *self); +PyObject *MrdbConnection_reset(MrdbConnection *self); +PyObject *MrdbConnection_autocommit(MrdbConnection *self, + PyObject *args); +PyObject *MrdbConnection_change_user(MrdbConnection *self, + PyObject *args); +PyObject *MrdbConnection_rollback(MrdbConnection *self); +PyObject *MrdbConnection_commit(MrdbConnection *self); +PyObject *MrdbConnection_close(MrdbConnection *self); +PyObject *MrdbConnection_connect( PyObject *self,PyObject *args, PyObject *kwargs); +void MrdbConnection_SetAttributes(MrdbConnection *self); + +/* TPC methods */ +PyObject *MrdbConnection_xid(MrdbConnection *self, PyObject *args); +PyObject *MrdbConnection_tpc_begin(MrdbConnection *self, PyObject *args); +PyObject *MrdbConnection_tpc_commit(MrdbConnection *self, PyObject *args); +PyObject *MrdbConnection_tpc_rollback(MrdbConnection *self, PyObject *args); +PyObject *MrdbConnection_tpc_prepare(MrdbConnection *self); +PyObject *MrdbConnection_tpc_recover(MrdbConnection *self); + +/* codecs prototypes */ +uint8_t mariadb_check_bulk_parameters(MrdbCursor *self, + PyObject *data); +uint8_t mariadb_check_execute_parameters(MrdbCursor *self, + PyObject *data); +uint8_t mariadb_param_update(void *data, MYSQL_BIND *bind, uint32_t row_nr); + +/* parser prototypes */ +Mrdb_Parser *Mrdb_Parser_init(const char *statement, size_t length); +void Mrdb_Parser_end(Mrdb_Parser *p); +void Mrdb_Parser_parse(Mrdb_Parser *p, uint8_t is_batch); + + +/* Global defines */ + + +#define MARIADB_PY_APILEVEL "2.0" +#define MARIADB_PY_PARAMSTYLE "qmark" +#define MARIADB_PY_THREADSAFETY 1 + + +/* Helper macros */ + +#define MrdbIndicator_Check(a)\ +(Py_TYPE((a)) == &MrdbIndicator_Type) + +#define MARIADB_FEATURE_SUPPORTED(mysql,version)\ +(mysql_get_server_version((mysql)) >= (version)) + +#define MARIADB_CHECK_CONNECTION(connection, ret)\ +if (!connection || !connection->mysql) {\ + mariadb_throw_exception(connection->mysql, Mariadb_Error, 0,\ + "Invalid connection or not connected");\ + return (ret);\ +} + +#define MARIADB_CHECK_TPC(connection)\ +if (connection->tpc_state == TPC_STATE_NONE) {\ + mariadb_throw_exception(connection->mysql, Mariadb_ProgrammingError, 0,\ + "Transaction not started");\ + return NULL;\ +} + +#define MARIADB_FREE_MEM(a)\ +if (a) {\ + PyMem_RawFree((a));\ + (a)= NULL;\ +} + +#define MARIADB_CHECK_STMT(cursor)\ +if (!cursor->stmt || !cursor->stmt->mysql || cursor->is_closed)\ +{\ + (cursor)->is_closed= 1;\ + mariadb_throw_exception(cursor->stmt, Mariadb_ProgrammingError, 1,\ + "Invalid cursor or not connected");\ +} + +/* MariaDB protocol macros */ +#define int1store(T,A) *((int8_t*) (T)) = (A) +#define uint1korr(A) (*(((uint8_t*)(A)))) +#if defined(__i386__) || defined(_WIN32) +#define sint2korr(A) (*((int16_t *) (A))) +#define sint3korr(A) ((int32_t) ((((unsigned char) (A)[2]) & 128) ? \ + (((uint32_t) 255L << 24) | \ + (((uint32_t) (unsigned char) (A)[2]) << 16) |\ + (((uint32_t) (unsigned char) (A)[1]) << 8) | \ + ((uint32_t) (unsigned char) (A)[0])) : \ + (((uint32_t) (unsigned char) (A)[2]) << 16) |\ + (((uint32_t) (unsigned char) (A)[1]) << 8) | \ + ((uint32_t) (unsigned char) (A)[0]))) +#define sint4korr(A) (*((long *) (A))) +#define uint2korr(A) (*((uint16_t *) (A))) +#if defined(HAVE_purify) && !defined(_WIN32) +#define uint3korr(A) (uint32_t) (((uint32_t) ((unsigned char) (A)[0])) +\ + (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ + (((uint32_t) ((unsigned char) (A)[2])) << 16)) +#else +/* + ATTENTION ! + + Please, note, uint3korr reads 4 bytes (not 3) ! + It means, that you have to provide enough allocated space ! +*/ +#define uint3korr(A) (long) (*((unsigned int *) (A)) & 0xFFFFFF) +#endif /* HAVE_purify && !_WIN32 */ +#define uint4korr(A) (*((uint32_t *) (A))) +#define uint5korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) +\ + (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ + (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ + (((uint32_t) ((unsigned char) (A)[3])) << 24)) +\ + (((unsigned long long) ((unsigned char) (A)[4])) << 32)) +#define uint6korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) + \ + (((uint32_t) ((unsigned char) (A)[1])) << 8) + \ + (((uint32_t) ((unsigned char) (A)[2])) << 16) + \ + (((uint32_t) ((unsigned char) (A)[3])) << 24)) + \ + (((unsigned long long) ((unsigned char) (A)[4])) << 32) + \ + (((unsigned long long) ((unsigned char) (A)[5])) << 40)) +#define uint8_tkorr(A) (*((unsigned long long *) (A))) +#define sint8korr(A) (*((long long *) (A))) +#define int2store(T,A) *((uint16_t*) (T))= (uint16_t) (A) +#define int3store(T,A) do { *(T)= (unsigned char) ((A));\ + *(T+1)=(unsigned char) (((uint) (A) >> 8));\ + *(T+2)=(unsigned char) (((A) >> 16)); } while (0) +#define int4store(T,A) *((long *) (T))= (long) (A) +#define int5store(T,A) do { *(T)= (unsigned char)((A));\ + *((T)+1)=(unsigned char) (((A) >> 8));\ + *((T)+2)=(unsigned char) (((A) >> 16));\ + *((T)+3)=(unsigned char) (((A) >> 24)); \ + *((T)+4)=(unsigned char) (((A) >> 32)); } while(0) +#define int6store(T,A) do { *(T)= (unsigned char)((A)); \ + *((T)+1)=(unsigned char) (((A) >> 8)); \ + *((T)+2)=(unsigned char) (((A) >> 16)); \ + *((T)+3)=(unsigned char) (((A) >> 24)); \ + *((T)+4)=(unsigned char) (((A) >> 32)); \ + *((T)+5)=(unsigned char) (((A) >> 40)); } while(0) +#define int8store(T,A) *((unsigned long long *) (T))= (unsigned long long) (A) + +typedef union { + double v; + long m[2]; +} doubleget_union; +#define doubleget(V,M) \ +do { doubleget_union _tmp; \ + _tmp.m[0] = *((long*)(M)); \ + _tmp.m[1] = *(((long*) (M))+1); \ + (V) = _tmp.v; } while(0) +#define doublestore(T,V) do { *((long *) T) = ((doubleget_union *)&V)->m[0]; \ + *(((long *) T)+1) = ((doubleget_union *)&V)->m[1]; \ + } while (0) +#define float4get(V,M) do { *((float *) &(V)) = *((float*) (M)); } while(0) +#define float8get(V,M) doubleget((V),(M)) +#define float4store(V,M) memcpy((unsigned char*) V,(unsigned char*) (&M),sizeof(float)) +#define floatstore(T,V) memcpy((unsigned char*)(T), (unsigned char*)(&V),sizeof(float)) +#define floatget(V,M) memcpy((unsigned char*) &V,(unsigned char*) (M),sizeof(float)) +#define float8store(V,M) doublestore((V),(M)) +#else + +/* + We're here if it's not a IA-32 architecture (Win32 and UNIX IA-32 defines + were done before) +*/ +#define sint2korr(A) (int16_t) (((int16_t) ((unsigned char) (A)[0])) +\ + ((int16_t) ((int16_t) (A)[1]) << 8)) +#define sint3korr(A) ((int32_t) ((((unsigned char) (A)[2]) & 128) ? \ + (((uint32_t) 255L << 24) | \ + (((uint32_t) (unsigned char) (A)[2]) << 16) |\ + (((uint32_t) (unsigned char) (A)[1]) << 8) | \ + ((uint32_t) (unsigned char) (A)[0])) : \ + (((uint32_t) (unsigned char) (A)[2]) << 16) |\ + (((uint32_t) (unsigned char) (A)[1]) << 8) | \ + ((uint32_t) (unsigned char) (A)[0]))) +#define sint4korr(A) (int32_t) (((int32_t) ((unsigned char) (A)[0])) +\ + (((int32_t) ((unsigned char) (A)[1]) << 8)) +\ + (((int32_t) ((unsigned char) (A)[2]) << 16)) +\ + (((int32_t) ((int16_t) (A)[3]) << 24))) +#define sint8korr(A) (long long) uint8korr(A) +#define uint2korr(A) (uint16_t) (((uint16_t) ((unsigned char) (A)[0])) +\ + ((uint16_t) ((unsigned char) (A)[1]) << 8)) +#define uint3korr(A) (uint32_t) (((uint32_t) ((unsigned char) (A)[0])) +\ + (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ + (((uint32_t) ((unsigned char) (A)[2])) << 16)) +#define uint4korr(A) (uint32_t) (((uint32_t) ((unsigned char) (A)[0])) +\ + (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ + (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ + (((uint32_t) ((unsigned char) (A)[3])) << 24)) +#define uint5korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) +\ + (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ + (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ + (((uint32_t) ((unsigned char) (A)[3])) << 24)) +\ + (((unsigned long long) ((unsigned char) (A)[4])) << 32)) +#define uint6korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) + \ + (((uint32_t) ((unsigned char) (A)[1])) << 8) + \ + (((uint32_t) ((unsigned char) (A)[2])) << 16) + \ + (((uint32_t) ((unsigned char) (A)[3])) << 24)) + \ + (((unsigned long long) ((unsigned char) (A)[4])) << 32) + \ + (((unsigned long long) ((unsigned char) (A)[5])) << 40)) +#define uint8korr(A) ((unsigned long long)(((uint32_t) ((unsigned char) (A)[0])) +\ + (((uint32_t) ((unsigned char) (A)[1])) << 8) +\ + (((uint32_t) ((unsigned char) (A)[2])) << 16) +\ + (((uint32_t) ((unsigned char) (A)[3])) << 24)) +\ + (((unsigned long long) (((uint32_t) ((unsigned char) (A)[4])) +\ + (((uint32_t) ((unsigned char) (A)[5])) << 8) +\ + (((uint32_t) ((unsigned char) (A)[6])) << 16) +\ + (((uint32_t) ((unsigned char) (A)[7])) << 24))) <<\ + 32)) +#define int2store(T,A) do { uint def_temp= (uint) (A) ;\ + *((unsigned char*) (T))= (unsigned char)(def_temp); \ + *((unsigned char*) (T)+1)=(unsigned char)((def_temp >> 8)); \ + } while(0) +#define int3store(T,A) do { /*lint -save -e734 */\ + *((unsigned char*)(T))=(unsigned char) ((A));\ + *((unsigned char*) (T)+1)=(unsigned char) (((A) >> 8));\ + *((unsigned char*)(T)+2)=(unsigned char) (((A) >> 16)); \ + /*lint -restore */} while(0) +#define int4store(T,A) do { *((char *)(T))=(char) ((A));\ + *(((char *)(T))+1)=(char) (((A) >> 8));\ + *(((char *)(T))+2)=(char) (((A) >> 16));\ + *(((char *)(T))+3)=(char) (((A) >> 24)); } while(0) +#define int5store(T,A) do { *((char *)(T))= (char)((A)); \ + *(((char *)(T))+1)= (char)(((A) >> 8)); \ + *(((char *)(T))+2)= (char)(((A) >> 16)); \ + *(((char *)(T))+3)= (char)(((A) >> 24)); \ + *(((char *)(T))+4)= (char)(((A) >> 32)); \ + } while(0) +#define int6store(T,A) do { *((char *)(T))= (char)((A)); \ + *(((char *)(T))+1)= (char)(((A) >> 8)); \ + *(((char *)(T))+2)= (char)(((A) >> 16)); \ + *(((char *)(T))+3)= (char)(((A) >> 24)); \ + *(((char *)(T))+4)= (char)(((A) >> 32)); \ + *(((char *)(T))+5)= (char)(((A) >> 40)); \ + } while(0) +#define int8store(T,A) do { uint def_temp= (uint) (A), def_temp2= (uint) ((A) >> 32); \ + int4store((T),def_temp); \ + int4store((T+4),def_temp2); } while(0) +#ifdef WORDS_BIGENDIAN +#define float4store(T,A) do { *(T)= ((unsigned char *) &A)[3];\ + *((T)+1)=(char) ((unsigned char *) &A)[2];\ + *((T)+2)=(char) ((unsigned char *) &A)[1];\ + *((T)+3)=(char) ((unsigned char *) &A)[0]; } while(0) + +#define float4get(V,M) do { float def_temp;\ + ((unsigned char*) &def_temp)[0]=(M)[3];\ + ((unsigned char*) &def_temp)[1]=(M)[2];\ + ((unsigned char*) &def_temp)[2]=(M)[1];\ + ((unsigned char*) &def_temp)[3]=(M)[0];\ + (V)=def_temp; } while(0) +#define float8store(T,V) do { *(T)= ((unsigned char *) &V)[7];\ + *((T)+1)=(char) ((unsigned char *) &V)[6];\ + *((T)+2)=(char) ((unsigned char *) &V)[5];\ + *((T)+3)=(char) ((unsigned char *) &V)[4];\ + *((T)+4)=(char) ((unsigned char *) &V)[3];\ + *((T)+5)=(char) ((unsigned char *) &V)[2];\ + *((T)+6)=(char) ((unsigned char *) &V)[1];\ + *((T)+7)=(char) ((unsigned char *) &V)[0]; } while(0) + +#define float8get(V,M) do { double def_temp;\ + ((unsigned char*) &def_temp)[0]=(M)[7];\ + ((unsigned char*) &def_temp)[1]=(M)[6];\ + ((unsigned char*) &def_temp)[2]=(M)[5];\ + ((unsigned char*) &def_temp)[3]=(M)[4];\ + ((unsigned char*) &def_temp)[4]=(M)[3];\ + ((unsigned char*) &def_temp)[5]=(M)[2];\ + ((unsigned char*) &def_temp)[6]=(M)[1];\ + ((unsigned char*) &def_temp)[7]=(M)[0];\ + (V) = def_temp; } while(0) +#else +#define float4get(V,M) memcpy(&V, (M), sizeof(float)) +#define float4store(V,M) memcpy(V, (&M), sizeof(float)) + +#if defined(__FLOAT_WORD_ORDER) && (__FLOAT_WORD_ORDER == __BIG_ENDIAN) +#define doublestore(T,V) do { *(((char*)T)+0)=(char) ((unsigned char *) &V)[4];\ + *(((char*)T)+1)=(char) ((unsigned char *) &V)[5];\ + *(((char*)T)+2)=(char) ((unsigned char *) &V)[6];\ + *(((char*)T)+3)=(char) ((unsigned char *) &V)[7];\ + *(((char*)T)+4)=(char) ((unsigned char *) &V)[0];\ + *(((char*)T)+5)=(char) ((unsigned char *) &V)[1];\ + *(((char*)T)+6)=(char) ((unsigned char *) &V)[2];\ + *(((char*)T)+7)=(char) ((unsigned char *) &V)[3]; }\ + while(0) +#define doubleget(V,M) do { double def_temp;\ + ((unsigned char*) &def_temp)[0]=(M)[4];\ + ((unsigned char*) &def_temp)[1]=(M)[5];\ + ((unsigned char*) &def_temp)[2]=(M)[6];\ + ((unsigned char*) &def_temp)[3]=(M)[7];\ + ((unsigned char*) &def_temp)[4]=(M)[0];\ + ((unsigned char*) &def_temp)[5]=(M)[1];\ + ((unsigned char*) &def_temp)[6]=(M)[2];\ + ((unsigned char*) &def_temp)[7]=(M)[3];\ + (V) = def_temp; } while(0) +#endif /* __FLOAT_WORD_ORDER */ + +#define float8get(V,M) doubleget((V),(M)) +#define float8store(V,M) doublestore((V),(M)) +#endif /* WORDS_BIGENDIAN */ + +#ifdef HAVE_BIGENDIAN + +#define ushortget(V,M) do { V = (uint16_t) (((uint16_t) ((unsigned char) (M)[1]))+\ + ((uint16_t) ((uint16_t) (M)[0]) << 8)); } while(0) +#define shortget(V,M) do { V = (short) (((short) ((unsigned char) (M)[1]))+\ + ((short) ((short) (M)[0]) << 8)); } while(0) +#define longget(V,M) do { int32 def_temp;\ + ((unsigned char*) &def_temp)[0]=(M)[0];\ + ((unsigned char*) &def_temp)[1]=(M)[1];\ + ((unsigned char*) &def_temp)[2]=(M)[2];\ + ((unsigned char*) &def_temp)[3]=(M)[3];\ + (V)=def_temp; } while(0) +#define ulongget(V,M) do { uint32 def_temp;\ + ((unsigned char*) &def_temp)[0]=(M)[0];\ + ((unsigned char*) &def_temp)[1]=(M)[1];\ + ((unsigned char*) &def_temp)[2]=(M)[2];\ + ((unsigned char*) &def_temp)[3]=(M)[3];\ + (V)=def_temp; } while(0) +#define shortstore(T,A) do { uint def_temp=(uint) (A) ;\ + *(((char*)T)+1)=(char)(def_temp); \ + *(((char*)T)+0)=(char)(def_temp >> 8); } while(0) +#define longstore(T,A) do { *(((char*)T)+3)=((A));\ + *(((char*)T)+2)=(((A) >> 8));\ + *(((char*)T)+1)=(((A) >> 16));\ + *(((char*)T)+0)=(((A) >> 24)); } while(0) + +#define floatget(V,M) memcpy(&V, (M), sizeof(float)) +#define floatstore(T,V) memcpy((T), (void*) (&V), sizeof(float)) +#define doubleget(V,M) memcpy(&V, (M), sizeof(double)) +#define doublestore(T,V) memcpy((T), (void *) &V, sizeof(double)) +#define longlongget(V,M) memcpy(&V, (M), sizeof(unsigned long long)) +#define longlongstore(T,V) memcpy((T), &V, sizeof(unsigned long long)) + +#else + +#define ushortget(V,M) do { V = uint2korr(M); } while(0) +#define shortget(V,M) do { V = sint2korr(M); } while(0) +#define longget(V,M) do { V = sint4korr(M); } while(0) +#define ulongget(V,M) do { V = uint4korr(M); } while(0) +#define shortstore(T,V) int2store(T,V) +#define longstore(T,V) int4store(T,V) +#ifndef floatstore +#define floatstore(T,V) memcpy((T), (void *) (&V), sizeof(float)) +#define floatget(V,M) memcpy(&V, (M), sizeof(float)) +#endif +#ifndef doubleget +#define doubleget(V,M) memcpy(&V, (M), sizeof(double)) +#define doublestore(T,V) memcpy((T), (void *) &V, sizeof(double)) +#endif /* doubleget */ +#define longlongget(V,M) memcpy(&V, (M), sizeof(unsigned long long)) +#define longlongstore(T,V) memcpy((T), &V, sizeof(unsigned long long)) + +#endif /* WORDS_BIGENDIAN */ + + +#endif /* __i386__ OR _WIN32 */ + +#ifdef _WIN32 +//#define alloca _malloca +#endif diff --git a/mariadb_posix.py b/mariadb_posix.py index 8a94d14..71aa787 100644 --- a/mariadb_posix.py +++ b/mariadb_posix.py @@ -1,54 +1,58 @@ #!/usr/bin/env python -import sys import os -import string +import sys + class MariaDBConfiguration(): - lib_dirs= "" - libs= "" - version= "" - includes= "" + lib_dirs = "" + libs = "" + version = "" + includes = "" + def mariadb_config(config, option): - from os import popen - file= popen("%s --%s" % (config, option)) - data= file.read().strip().split() - rc= file.close() - if rc: - if rc/256: - data= [] - if rc/256 > 1: - raise EnvironmentError("mariadb_config not found.\nMake sure that MariaDB Connector/C is installed (e.g. on Debian or Ubuntu 'sudo apt-get install libmariadb-dev'\nIf mariadb_config is not installed in a default path, please set the environment variable MARIADB_CONFIG which points to the location of mariadb_config utility, e.g. MARIADB_CONFIG=/opt/mariadb/bin/mariadb_config") - return data + from os import popen + file = popen("%s --%s" % (config, option)) + data = file.read().strip().split() + rc = file.close() + if rc: + if rc / 256: + data = [] + if rc / 256 > 1: + raise EnvironmentError( + "mariadb_config not found.\nMake sure that MariaDB Connector/C is installed (e.g. on Debian or Ubuntu 'sudo apt-get install libmariadb-dev'\nIf mariadb_config is not installed in a default path, please set the environment variable MARIADB_CONFIG which points to the location of mariadb_config utility, e.g. MARIADB_CONFIG=/opt/mariadb/bin/mariadb_config") + return data def dequote(s): - if s[0] in "\"'" and s[0] == s[-1]: - s = s[1:-1] - return s + if s[0] in "\"'" and s[0] == s[-1]: + s = s[1:-1] + return s + def get_config(): - required_version="3.1.0" - no_env= 0 + required_version = "3.1.0" + no_env = 0 - try: - config_prg= os.environ["MARIADB_CONFIG"] - except KeyError: - config_prg= 'mariadb_config' + try: + config_prg = os.environ["MARIADB_CONFIG"] + except KeyError: + config_prg = 'mariadb_config' - cc_version= mariadb_config(config_prg, "cc_version") - if cc_version[0] < required_version: - print ('MariaDB Connector/Python requires MariaDB Connector/C >= %s, found version %s' % (required_version, cc_version[0])) - sys.exit(2) - cfg= MariaDBConfiguration() - cfg.version= cc_version[0] + cc_version = mariadb_config(config_prg, "cc_version") + if cc_version[0] < required_version: + print ('MariaDB Connector/Python requires MariaDB Connector/C >= %s, found version %s' % ( + required_version, cc_version[0])) + sys.exit(2) + cfg = MariaDBConfiguration() + cfg.version = cc_version[0] - libs= mariadb_config(config_prg, "libs") - cfg.lib_dirs = [ dequote(i[2:]) for i in libs if i.startswith("-L") ] - cfg.libs = [ dequote(i[2:]) for i in libs if i.startswith("-l") ] - includes= mariadb_config(config_prg, "include") - mariadb_includes = [ dequote(i[2:]) for i in includes if i.startswith("-I") ] - mariadb_includes.extend(["./include"]) - cfg.includes= mariadb_includes - return cfg + libs = mariadb_config(config_prg, "libs") + cfg.lib_dirs = [dequote(i[2:]) for i in libs if i.startswith("-L")] + cfg.libs = [dequote(i[2:]) for i in libs if i.startswith("-l")] + includes = mariadb_config(config_prg, "include") + mariadb_includes = [dequote(i[2:]) for i in includes if i.startswith("-I")] + mariadb_includes.extend(["./include"]) + cfg.includes = mariadb_includes + return cfg diff --git a/mariadb_windows.py b/mariadb_windows.py index 9d433fc..37417f4 100644 --- a/mariadb_windows.py +++ b/mariadb_windows.py @@ -1,40 +1,53 @@ -import sys import os -import string +import platform +import sys + from winreg import * + class MariaDBConfiguration(): - lib_dirs= "" - libs= "" - version= "" - includes= "" + lib_dirs = "" + libs = "" + version = "" + includes = "" + def get_config(): - required_version="3.1.0" + required_version = "3.1.0" - try: - config_prg= os.environ["MARIADB_CC_DIR"] - cc_version= ["",""] - cc_instdir= [config_prg, ""] - print("using environment configuration " + config_prg) - except KeyError: - Registry= ConnectRegistry(None, HKEY_LOCAL_MACHINE) - Key= OpenKey(Registry, "SOFTWARE\MariaDB Corporation\MariaDB Connector C 64-bit") - if Key: - cc_version= QueryValueEx(Key, "Version") - if cc_version[0] < required_version: - print("MariaDB Connector/Python requires MariaDB Connector/C >= %s (found version: %s") % (required_version, cc_version[0]) + try: + config_prg = os.environ["MARIADB_CC_INSTALL_DIR"] + cc_version = ["", ""] + cc_instdir = [config_prg, ""] + print("using environment configuration " + config_prg) + except KeyError: - sys.exit(2) - cc_instdir= QueryValueEx(Key, "InstallDir") - if cc_instdir is None: - print("Could not find InstallationDir of MariaDB Connector/C. Please make sure MariaDB Connector/C is installed or specify the InstallationDir of MariaDB Connector/C by setting the environment variable MARIADB_CC_INSTALL_DIR.") - sys.exit(3) + try: + local_reg = ConnectRegistry(None, HKEY_LOCAL_MACHINE) + if platform.architecture()[0] == '32bit': + connector_key = OpenKey(local_reg, + 'SOFTWARE\\MariaDB Corporation\\MariaDB Connector C') + else: + connector_key = OpenKey(local_reg, + 'SOFTWARE\\MariaDB Corporation\\MariaDB Connector C 64-bit', + access=KEY_READ | KEY_WOW64_64KEY) + cc_version = QueryValueEx(connector_key, "Version") + if cc_version[0] < required_version: + print( + "MariaDB Connector/Python requires MariaDB Connector/C >= %s (found version: %s") \ + % (required_version, cc_version[0]) + sys.exit(2) + cc_instdir = QueryValueEx(connector_key, "InstallDir") - - cfg= MariaDBConfiguration() - cfg.version= cc_version[0] - cfg.includes= [".\\include", cc_instdir[0] + "\\include", cc_instdir[0] + "\\include\\mysql"] - cfg.lib_dirs= [cc_instdir[0] + "\\lib"] - cfg.libs= ["mariadbclient", "ws2_32", "advapi32", "kernel32", "shlwapi", "crypt32"] - return cfg + except: + print("Could not find InstallationDir of MariaDB Connector/C. " + "Please make sure MariaDB Connector/C is installed or specify the InstallationDir of " + "MariaDB Connector/C by setting the environment variable MARIADB_CC_INSTALL_DIR.") + sys.exit(3) + + cfg = MariaDBConfiguration() + cfg.version = cc_version[0] + cfg.includes = [".\\include", cc_instdir[0] + "\\include", cc_instdir[0] + "\\include\\mysql"] + cfg.lib_dirs = [cc_instdir[0] + "\\lib"] + cfg.libs = ["mariadbclient", "ws2_32", "advapi32", "kernel32", "shlwapi", "crypt32"] + return cfg diff --git a/setup.py b/setup.py index 2013fbd..430623d 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,49 @@ -#!/usr/bin/env python - -import os -import sys -import subprocess -import string - -from distutils.core import setup, Extension - -if os.name == "posix": - from mariadb_posix import get_config -if os.name == "nt": - from mariadb_windows import get_config - -cfg= get_config() - -setup(name='mariadb', - version='0.9.1', - description='Python MariaDB extension', - author='Georg Richter', - license='LGPL 2.1', - url='http://www.mariadb.com', - ext_modules=[Extension('mariadb', ['src/mariadb.c', 'src/mariadb_connection.c', 'src/mariadb_exception.c', 'src/mariadb_cursor.c', 'src/mariadb_codecs.c', 'src/mariadb_field.c', 'src/mariadb_dbapitype.c', 'src/mariadb_indicator.c'], - include_dirs=cfg.includes, - library_dirs= cfg.lib_dirs, - libraries= cfg.libs - )], - ) +#!/usr/bin/env python + +import os + +from distutils.core import setup, Extension + +if os.name == "posix": + from mariadb_posix import get_config +if os.name == "nt": + from mariadb_windows import get_config + +cfg = get_config() + +setup(name='mariadb', + version='0.9.1', + classifiers = [ + 'Development Status :: 3 - Alpha', + 'Environment :: Console', + 'Environment :: MacOS X', + 'Environment :: Win32 (MS Windows)', + 'Environment :: Posix', + 'License :: OSI Approved :: GNU Lesser General Public License v2 or later (LGPLv2+)', + 'Programming Language :: C', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: MacOS', + 'Operating System :: POSIX', + 'Intended Audience :: End Users/Desktop', + 'Intended Audience :: Developers', + 'Intended Audience :: System Administrators', + 'Topic :: Database' + ], + description='Python MariaDB extension', + author='Georg Richter', + license='LGPL 2.1', + url='https://www.github.com/MariaDB/mariadb-connector-python', + ext_modules=[Extension('mariadb', ['src/mariadb.c', 'src/mariadb_connection.c', + 'src/mariadb_exception.c', 'src/mariadb_cursor.c', + 'src/mariadb_codecs.c', 'src/mariadb_field.c', + 'src/mariadb_parser.c', + 'src/mariadb_dbapitype.c', 'src/mariadb_indicator.c'], + include_dirs=cfg.includes, + library_dirs=cfg.lib_dirs, + libraries=cfg.libs + )], + ) diff --git a/src/mariadb_codecs.c b/src/mariadb_codecs.c index d7e9ac2..a3a97c9 100644 --- a/src/mariadb_codecs.c +++ b/src/mariadb_codecs.c @@ -75,7 +75,7 @@ static PyObject *mariadb_get_pickled(unsigned char *data, size_t length) PyObject *obj= NULL; if (length < 3) return NULL; - if (*data == 0x80 && *(data +1) == 0x03 && *(data + length - 1) == 0x2E) + if (*data == 0x80 && *(data +1) <= 0x04 && *(data + length - 1) == 0x2E) { PyObject *byte= PyBytes_FromStringAndSize((char *)data, length); obj= PyObject_CallMethod(Mrdb_Pickle, "loads", "O", byte); @@ -169,7 +169,7 @@ void field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column) unsigned long utf8len; self->values[column]= PyUnicode_FromStringAndSize((const char *)data, (Py_ssize_t)length[column]); - utf8len= PyUnicode_GET_LENGTH(self->values[column]); + utf8len= (unsigned long)PyUnicode_GET_LENGTH(self->values[column]); if (utf8len > self->fields[column].max_length) self->fields[column].max_length= utf8len; break; @@ -237,7 +237,7 @@ void field_fetch_callback(void *data, unsigned int column, unsigned char **row) long long l= sint8korr(*row); self->values[column]= (self->fields[column].flags & UNSIGNED_FLAG) ? PyLong_FromUnsignedLongLong((unsigned long long)l) : - PyLong_FromLong(l); + PyLong_FromLongLong(l); *row+= 8; break; } @@ -368,7 +368,7 @@ void field_fetch_callback(void *data, unsigned int column, unsigned char **row) length= mysql_net_field_length(row); self->values[column]= PyUnicode_FromStringAndSize((const char *)*row, (Py_ssize_t)length); - utf8len= PyUnicode_GET_LENGTH(self->values[column]); + utf8len= (unsigned long)PyUnicode_GET_LENGTH(self->values[column]); if (utf8len > self->fields[column].max_length) self->fields[column].max_length= utf8len; *row+= length; @@ -509,7 +509,7 @@ static uint8_t mariadb_get_parameter(MrdbCursor *self, "MariaDB %s doesn't support indicator variables. Required version is 10.2.6 or newer", mysql_get_server_info(self->stmt->mysql)); return 1; } - param->indicator= MrdbIndicator_AsLong(column); + param->indicator= (char)MrdbIndicator_AsLong(column); param->value= NULL; /* you can't have both indicator and value */ } else if (column == Py_None) { @@ -559,7 +559,7 @@ static uint8_t mariadb_get_parameter_info(MrdbCursor *self, return 1; } param->buffer_type= pinfo.type; - bits= pinfo.bits; + bits= (uint32_t)pinfo.bits; } for (i=0; i < self->array_size; i++) @@ -580,7 +580,7 @@ static uint8_t mariadb_get_parameter_info(MrdbCursor *self, if (pinfo.type == MYSQL_TYPE_LONGLONG) { if (pinfo.bits > bits) - bits= pinfo.bits; + bits= (uint32_t)pinfo.bits; } @@ -631,7 +631,7 @@ uint8_t mariadb_check_bulk_parameters(MrdbCursor *self, { uint32_t i; - if (!(self->array_size= PyList_Size(data))) + if (!(self->array_size= (uint32_t)PyList_Size(data))) { mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1, "Empty parameter list. At least one row must be specified"); @@ -649,7 +649,7 @@ uint8_t mariadb_check_bulk_parameters(MrdbCursor *self, } if (!self->param_count && !self->is_prepared) - self->param_count= PyTuple_Size(obj); + self->param_count= (uint32_t)PyTuple_Size(obj); if (!self->param_count || self->param_count != PyTuple_Size(obj)) { @@ -695,7 +695,7 @@ uint8_t mariadb_check_execute_parameters(MrdbCursor *self, { uint32_t i; if (!self->is_prepared) - self->param_count= PyTuple_Size(data); + self->param_count= (uint32_t)PyTuple_Size(data); if (!self->param_count) { diff --git a/src/mariadb_connection.c b/src/mariadb_connection.c index 4856213..9f4e913 100644 --- a/src/mariadb_connection.c +++ b/src/mariadb_connection.c @@ -265,7 +265,7 @@ MrdbConnection_Initialize(MrdbConnection *self, }; if (!PyArg_ParseTupleAndKeywords(args, dsnargs, - "|sssssisiiipissssssssssipi:connect", + "|sssssisiiipissssssssssipis:connect", dsn_keys, &dsn, &host, &user, &password, &schema, &port, &socket, &connect_timeout, &read_timeout, &write_timeout, @@ -798,7 +798,7 @@ end: /* }}} */ /* {{{ MrdbConnection_ping */ -PyObject *MrdbConnection_ping(MrdbConnection *self, PyObject *args) +PyObject *MrdbConnection_ping(MrdbConnection *self) { int rc; diff --git a/src/mariadb_cursor.c b/src/mariadb_cursor.c index 1af1d87..dde4ed9 100644 --- a/src/mariadb_cursor.c +++ b/src/mariadb_cursor.c @@ -1,1276 +1,1337 @@ -/************************************************************************************ - Copyright (C) 2018 Georg Richter and MariaDB Corporation AB - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Library General Public - License as published by the Free Software Foundation; either - version 2 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Library General Public License for more details. - - You should have received a copy of the GNU Library General Public - License along with this library; if not see - or write to the Free Software Foundation, Inc., - 51 Franklin St., Fifth Floor, Boston, MA 02110, USA -*************************************************************************************/ - -#include - -static void MrdbCursor_dealloc(MrdbCursor *self); -static PyObject *MrdbCursor_close(MrdbCursor *self); -static PyObject *MrdbCursor_execute(MrdbCursor *self, - PyObject *args, PyObject *kwargs); -static PyObject *MrdbCursor_nextset(MrdbCursor *self); -static PyObject *MrdbCursor_executemany(MrdbCursor *self, - PyObject *args); -static PyObject *MrdbCursor_description(MrdbCursor *self); -static PyObject *MrdbCursor_fetchall(MrdbCursor *self); -static PyObject *MrdbCursor_fetchone(MrdbCursor *self); -static PyObject *MrdbCursor_fetchmany(MrdbCursor *self, - PyObject *args, - PyObject *kwargs); -static PyObject *MrdbCursor_scroll(MrdbCursor *self, - PyObject *args, - PyObject *kwargs); -static PyObject *MrdbCursor_fieldcount(MrdbCursor *self); -void field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column); -void field_fetch_callback(void *data, unsigned int column, unsigned char **row); -static PyObject *mariadb_get_sequence_or_tuple(MrdbCursor *self); -static PyObject * MrdbCursor_iter(PyObject *self); -static PyObject * MrdbCursor_iternext(PyObject *self); - -/* todo: write more documentation, this is just a placeholder */ -static char mariadb_cursor_documentation[] = -"Returns a MariaDB cursor object"; - -#define CURSOR_SET_STATEMENT(a,s,l)\ -MARIADB_FREE_MEM((a)->statement);\ -(a)->statement= PyMem_RawMalloc((l)+ 1);\ -strncpy((a)->statement, (s), (l));\ -(a)->statement_len= (unsigned long)(l);\ -(a)->statement[(l)]= 0; - -#define CURSOR_FIELD_COUNT(a)\ -((a)->is_text ? mysql_field_count((a)->connection->mysql) : (a)->stmt ? mysql_stmt_field_count((a)->stmt) : 0) - -#define CURSOR_AFFECTED_ROWS(a)\ -((a)->is_text ? mysql_affected_rows((a)->connection->mysql) : (a)->stmt ? mysql_stmt_affected_rows((a)->stmt) : 0) - -#define CURSOR_INSERT_ID(a)\ -((a)->is_text ? mysql_insert_id((a)->connection->mysql) : (a)->stmt ? mysql_stmt_insert_id((a)->stmt) : 0) - -#define CURSOR_NUM_ROWS(a)\ -((a)->is_text ? mysql_num_rows((a)->result) : (a)->stmt ? mysql_stmt_num_rows((a)->stmt) : 0) - -#define MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, row, column)\ -if ((self)->is_named_tuple)\ - PyStructSequence_SET_ITEM((row), (column), (self)->values[(column)]);\ -else\ - PyTuple_SET_ITEM((row), (column), (self)->values[(column)]);\ - - -static char *mariadb_named_tuple_name= "Row"; -static char *mariadb_named_tuple_desc= "Named tupled row"; -static PyObject *Mariadb_no_operation(MrdbCursor *, - PyObject *); -static PyObject *Mariadb_row_count(MrdbCursor *self); -static PyObject *MrdbCursor_warnings(MrdbCursor *self); -static PyObject *MrdbCursor_getbuffered(MrdbCursor *self); -static int MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg); -static PyObject *MrdbCursor_lastrowid(MrdbCursor *self); -static PyObject *MrdbCursor_closed(MrdbCursor *self); - - -static PyGetSetDef MrdbCursor_sets[]= -{ - {"lastrowid", (getter)MrdbCursor_lastrowid, NULL, - "row id of the last modified (inserted) row"}, - {"description", (getter)MrdbCursor_description, NULL, - "This read-only attribute is a sequence of 8-item sequences. " - "Each of these sequences contains information describing one result column", - NULL}, - {"rowcount", (getter)Mariadb_row_count, NULL, "doc", NULL}, - {"warnings", (getter)MrdbCursor_warnings, NULL, - "Number of warnings which were produced from last execute() call", NULL}, - {"closed", (getter)MrdbCursor_closed, NULL, - "Indicates if the cursor is closed and can't be reused", NULL}, - {"buffered", (getter)MrdbCursor_getbuffered, (setter)MrdbCursor_setbuffered, - "When True all result sets are immediately transferred and the connection " - "between client and server is no longer blocked. Default value is False."}, - {NULL} -}; - -static PyMethodDef MrdbCursor_Methods[] = -{ - /* PEP-249 methods */ - {"close", (PyCFunction)MrdbCursor_close, - METH_NOARGS, - "Closes an open Cursor"}, - {"execute", (PyCFunction)MrdbCursor_execute, - METH_VARARGS | METH_KEYWORDS, - "Executes a SQL statement"}, - {"executemany", (PyCFunction)MrdbCursor_executemany, - METH_VARARGS, - "Executes a SQL statement by passing a list of values"}, - {"fetchall", (PyCFunction)MrdbCursor_fetchall, - METH_NOARGS, - "Fetches all rows of a result set"}, - {"fetchone", (PyCFunction)MrdbCursor_fetchone, - METH_NOARGS, - "Fetches the next row of a result set"}, - {"fetchmany", (PyCFunction)MrdbCursor_fetchmany, - METH_VARARGS | METH_KEYWORDS, - "Fetches multiple rows of a result set"}, - {"fieldcount", (PyCFunction)MrdbCursor_fieldcount, - METH_NOARGS, - "Returns number of columns in current result set"}, - {"nextset", (PyCFunction)MrdbCursor_nextset, - METH_NOARGS, - "Will make the cursor skip to the next available result set, discarding any remaining rows from the current set."}, - {"setinputsizes", (PyCFunction)Mariadb_no_operation, - METH_VARARGS, - "Required by PEP-249. Does nothing in MariaDB Connector/Python"}, - {"setoutputsize", (PyCFunction)Mariadb_no_operation, - METH_VARARGS, - "Required by PEP-249. Does nothing in MariaDB Connector/Python"}, - {"callproc", (PyCFunction)Mariadb_no_operation, - METH_VARARGS, - "Required by PEP-249. Does nothing in MariaDB Connector/Python, use the execute method with syntax 'CALL {procedurename}' instead"}, - {"next", (PyCFunction)MrdbCursor_fetchone, - METH_NOARGS, - "Return the next row from the currently executing SQL statement using the same semantics as .fetchone()."}, - {"scroll", (PyCFunction)MrdbCursor_scroll, - METH_VARARGS | METH_KEYWORDS, - "Scroll the cursor in the result set to a new position according to mode"}, - {NULL} /* always last */ -}; - -static struct PyMemberDef MrdbCursor_Members[] = -{ - {"connection", - T_OBJECT, - offsetof(MrdbCursor, connection), - READONLY, - "Reference to the connection object on which the cursor was created"}, - {"statement", - T_STRING, - offsetof(MrdbCursor, statement), - READONLY, - "The last executed statement"}, - {"buffered", - T_BYTE, - offsetof(MrdbCursor, is_buffered), - 0, - "Stores the entire result set in memory"}, - {"rownumber", - T_LONG, - offsetof(MrdbCursor, row_number), - READONLY, - "Current row number in result set"}, - {"arraysize", - T_LONG, - offsetof(MrdbCursor, row_array_size), - 0, - "the number of rows to fetch"}, - {NULL} -}; - -/* {{{ MrdbCursor_initialize - Cursor initialization - - Optional keywprds: - named_tuple (Boolean): return rows as named tuple instead of tuple - prefetch_size: Prefetch size for readonly cursors - cursor_type: Type of cursor: CURSOR_TYPE_READONLY or CURSOR_TYPE_NONE (default) - buffered: buffered or unbuffered result sets -*/ -static int MrdbCursor_initialize(MrdbCursor *self, PyObject *args, - PyObject *kwargs) -{ - char *key_words[]= {"", "named_tuple", "prefetch_size", "cursor_type", - "buffered", NULL}; - PyObject *connection; - uint8_t is_named_tuple= 0; - unsigned long cursor_type= 0, - prefetch_rows= 0; - uint8_t is_buffered= 0; - - if (!self) - return -1; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, - "O!|bkkb", key_words, &MrdbConnection_Type, &connection, - &is_named_tuple, &prefetch_rows, &cursor_type, &is_buffered)) - - if (cursor_type != CURSOR_TYPE_READ_ONLY && - cursor_type != CURSOR_TYPE_NO_CURSOR) - { - mariadb_throw_exception(NULL, Mariadb_DataError, 0, - "Invalid value %ld for cursor_type", cursor_type); - return -1; - } - - Py_INCREF(connection); - self->connection= (MrdbConnection *)connection; - self->is_buffered= self->connection->is_buffered; - - if (!(self->stmt= mysql_stmt_init(self->connection->mysql))) - { - mariadb_throw_exception(self->connection->mysql, NULL, 0, NULL); - return -1; - } - - self->cursor_type= cursor_type; - self->prefetch_rows= prefetch_rows; - self->is_named_tuple= is_named_tuple; - self->row_array_size= 1; - - if (self->cursor_type || self->prefetch_rows) - { - if (!(self->stmt = mysql_stmt_init(self->connection->mysql))) - { - mariadb_throw_exception(self->connection->mysql, Mariadb_OperationalError, 0, NULL); - return -1; - } - } - else - return 0; - - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CURSOR_TYPE, &self->cursor_type); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREFETCH_ROWS, &self->prefetch_rows); - return 0; -} -/* }}} */ - -static int MrdbCursor_traverse( - MrdbCursor *self, - visitproc visit, - void *arg) -{ - return 0; -} - -PyTypeObject MrdbCursor_Type = -{ - PyVarObject_HEAD_INIT(NULL, 0) - "mariadb.cursor", - sizeof(MrdbCursor), - 0, - (destructor)MrdbCursor_dealloc, /* tp_dealloc */ - 0, /*tp_print*/ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* PyAsyncMethods * */ - 0, /* tp_repr */ - - /* Method suites for standard classes */ - - 0, /* (PyNumberMethods *) tp_as_number */ - 0, /* (PySequenceMethods *) tp_as_sequence */ - 0, /* (PyMappingMethods *) tp_as_mapping */ - - /* More standard operations (here for binary compatibility) */ - - 0, /* (hashfunc) tp_hash */ - 0, /* (ternaryfunc) tp_call */ - 0, /* (reprfunc) tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - - /* Functions to access object as input/output buffer */ - 0, /* (PyBufferProcs *) tp_as_buffer */ - - /* (tp_flags) Flags to define presence of optional/expanded features */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, - mariadb_cursor_documentation, /* tp_doc Documentation string */ - - /* call function for all accessible objects */ - (traverseproc)MrdbCursor_traverse,/* tp_traverse */ - - /* delete references to contained objects */ - 0, /* tp_clear */ - - /* rich comparisons */ - 0, /* (richcmpfunc) tp_richcompare */ - - /* weak reference enabler */ - 0, /* (long) tp_weaklistoffset */ - - /* Iterators */ - (getiterfunc)MrdbCursor_iter, - (iternextfunc)MrdbCursor_iternext, - - /* Attribute descriptor and subclassing stuff */ - (struct PyMethodDef *)MrdbCursor_Methods, /* tp_methods */ - (struct PyMemberDef *)MrdbCursor_Members, /* tp_members */ - MrdbCursor_sets, - 0, /* (struct _typeobject *) tp_base; */ - 0, /* (PyObject *) tp_dict */ - 0, /* (descrgetfunc) tp_descr_get */ - 0, /* (descrsetfunc) tp_descr_set */ - 0, /* (long) tp_dictoffset */ - (initproc)MrdbCursor_initialize, /* tp_init */ - PyType_GenericAlloc, //NULL, /* tp_alloc */ - PyType_GenericNew, //NULL, /* tp_new */ - NULL, /* tp_free Low-level free-memory routine */ - 0, /* (PyObject *) tp_bases */ - 0, /* (PyObject *) tp_mro method resolution order */ - 0, /* (PyObject *) tp_defined */ -}; - -/* {{{ Mariadb_no_operation - This function is a stub and just returns Py_None -*/ -static PyObject *Mariadb_no_operation(MrdbCursor *self, - PyObject *args) -{ - Py_INCREF(Py_None); - return Py_None; -} -/* }}} */ - -/* {{{ MrdbCursor_isprepared - If the same statement was executed before, we don't need to - reprepare it and can just execute it. -*/ -static uint8_t MrdbCursor_isprepared(MrdbCursor *self, - const char *statement, - size_t statement_len) -{ - if (self->statement) - { - if (self->statement_len == statement_len && - !memcmp(statement, self->statement, statement_len)) - { - enum mysql_stmt_state state; - mysql_stmt_attr_get(self->stmt, STMT_ATTR_STATE, &state); - if (state >= MYSQL_STMT_PREPARED) - return 1; - } - } - return 0; -} -/* }}} */ - -/* {{{ MrdbCursor_clear - Resets statement attributes and frees - associated memory -*/ -static -void MrdbCursor_clear(MrdbCursor *self) -{ - if (!self->is_text && self->stmt) { - uint32_t val= 0; - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_USER_DATA, 0); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_PARAM, 0); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_RESULT, 0); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_ARRAY_SIZE, &val); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &val); - } - - if (self->is_text) - { - if (self->result) - { - mysql_free_result(self->result); - self->result= 0; - self->is_text= 0; - } - /* clear also pending result sets */ - if (self->connection->mysql) - while (!mysql_next_result(self->connection->mysql)); - } - - MARIADB_FREE_MEM(self->sequence_fields); - self->fields= NULL; - self->row_count= 0; - self->affected_rows= 0; - self->param_count= 0; - MARIADB_FREE_MEM(self->values); - MARIADB_FREE_MEM(self->bind); - MARIADB_FREE_MEM(self->statement); - MARIADB_FREE_MEM(self->value); - MARIADB_FREE_MEM(self->params); - -} -/* }}} */ - -/* {{{ ma_cursor_close - closes the statement handle of current cursor. After call to - cursor_close the cursor can't be reused anymore -*/ -static -void ma_cursor_close(MrdbCursor *self) -{ - if (!self->is_text && self->stmt) - { - /* Todo: check if all the cursor stuff is deleted (when using prepared - statemnts this should be handled in mysql_stmt_close) */ - Py_BEGIN_ALLOW_THREADS - mysql_stmt_close(self->stmt); - Py_END_ALLOW_THREADS - self->stmt= NULL; - } - MrdbCursor_clear(self); - self->is_closed= 1; -} - -static -PyObject * MrdbCursor_close(MrdbCursor *self) -{ - ma_cursor_close(self); - self->is_closed= 1; - Py_INCREF(Py_None); - return Py_None; -} -/* }}} */ - -/*{{{ MrDBCursor_dealloc */ -void MrdbCursor_dealloc(MrdbCursor *self) -{ - ma_cursor_close(self); - Py_TYPE(self)->tp_free((PyObject*)self); -} -/* }}} */ - -static int Mrdb_GetFieldInfo(MrdbCursor *self) -{ - unsigned int field_count= CURSOR_FIELD_COUNT(self); - - self->row_number= 0; - - self->row_count= CURSOR_AFFECTED_ROWS(self); - - if (field_count) - { - if (self->is_text) - { - self->result= (self->is_buffered) ? mysql_store_result(self->connection->mysql) : - mysql_use_result(self->connection->mysql); - if (!self->result) - { - mariadb_throw_exception(self->connection->mysql, NULL, 0, NULL); - return 1; - } - } - else if (self->is_buffered) - { - if (mysql_stmt_store_result(self->stmt)) - { - mariadb_throw_exception(self->stmt, NULL, 1, NULL); - return 1; - } - } - - self->affected_rows= CURSOR_AFFECTED_ROWS(self); - - self->fields= (self->is_text) ? mysql_fetch_fields(self->result) : - mariadb_stmt_fetch_fields(self->stmt); - - if (self->is_named_tuple) { - int i; - if (!(self->sequence_fields= (PyStructSequence_Field *) - PyMem_RawCalloc(field_count + 1, - sizeof(PyStructSequence_Field)))) - return 1; - self->sequence_desc.name= mariadb_named_tuple_name; - self->sequence_desc.doc= mariadb_named_tuple_desc; - self->sequence_desc.fields= self->sequence_fields; - self->sequence_desc.n_in_sequence= field_count; - - - for (i=0; i < field_count; i++) - { - self->sequence_fields[i].name= self->fields[i].name; - } - self->sequence_type= PyMem_RawCalloc(1,sizeof(PyTypeObject)); - PyStructSequence_InitType(self->sequence_type, &self->sequence_desc); - } - } - return 0; -} - -static int MrdbCursor_InitResultSet(MrdbCursor *self) -{ - unsigned int field_count= CURSOR_FIELD_COUNT(self); - - MARIADB_FREE_MEM(self->sequence_fields); - MARIADB_FREE_MEM(self->values); - - if (self->result) - mysql_free_result(self->result); - - if (Mrdb_GetFieldInfo(self)) - return 1; - - if (!(self->values= (PyObject**)PyMem_RawCalloc(field_count, sizeof(PyObject *)))) - return 1; - if (!self->is_text) - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_RESULT, field_fetch_callback); - return 0; -} - -/* {{{ MrdbCursor_execute - PEP-249 execute() method -*/ -static -PyObject *MrdbCursor_execute(MrdbCursor *self, - PyObject *args, - PyObject *kwargs) -{ - PyObject *Data= NULL; - const char *statement= NULL; - int statement_len= 0; - int rc= 0; - uint8_t is_buffered= 0; - static char *key_words[]= {"", "", "buffered", NULL}; - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, - "s#|O!$b", key_words, &statement, &statement_len, &PyTuple_Type, &Data, - &is_buffered)) - return NULL; - - /* defaukt was set to 0 before */ - self->is_buffered= is_buffered; - - /* if there are no parameters specified, we execute the statement in text protocol */ - if (!Data && !self->cursor_type) - { - /* in case statement was executed before, we need to clear, since we don't use - binary protocol */ - MrdbCursor_clear(self); - Py_BEGIN_ALLOW_THREADS; - rc= mysql_real_query(self->connection->mysql, statement, statement_len); - Py_END_ALLOW_THREADS; - if (rc) - { - mariadb_throw_exception(self->connection->mysql, NULL, 0, NULL); - goto error; - } - self->is_text= 1; - CURSOR_SET_STATEMENT(self, statement, statement_len); - } - else - { - self->is_text= 0; - if (!(self->is_prepared= MrdbCursor_isprepared(self, statement, statement_len))) - { - MrdbCursor_clear(self); - CURSOR_SET_STATEMENT(self, statement, statement_len); - } - - if (Data) - { - self->array_size= 0; - self->data= Data; - if (mariadb_check_execute_parameters(self, Data)) - goto error; - - self->data= Data; - - /* Load values */ - if (mariadb_param_update(self, self->params, 0)) - goto error; - } - if (!self->is_prepared) - { - mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &self->param_count); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_USER_DATA, (void *)self); - mysql_stmt_bind_param(self->stmt, self->params); - - Py_BEGIN_ALLOW_THREADS; - if (!MARIADB_FEATURE_SUPPORTED(self->stmt->mysql, 100206)) - { - rc= mysql_stmt_prepare(self->stmt, statement, statement_len); - if (!rc) - rc= mysql_stmt_execute(self->stmt); - } - else - rc= mariadb_stmt_execute_direct(self->stmt, statement, statement_len); - Py_END_ALLOW_THREADS; - - if (rc) - { - /* in case statement is not supported via binary protocol, we try - to run the statement with text protocol */ - if (mysql_stmt_errno(self->stmt) == ER_UNSUPPORTED_PS) - { - Py_BEGIN_ALLOW_THREADS; - self->is_text= 0; - rc= mysql_real_query(self->connection->mysql, statement, statement_len); - Py_END_ALLOW_THREADS; - - if (rc) - { - mariadb_throw_exception(self->stmt->mysql, NULL, 0, NULL); - goto error; - } - /* if we have a result set, we can't process it - so we will return - an error. (XA RECOVER is the only command which returns a result set - and can't be prepared) */ - if (mysql_field_count(self->stmt->mysql)) - { - MYSQL_RES *result; - - /* we need to clear the result first, otherwise the cursor remains - in usuable state (query out of order) */ - if ((result= mysql_store_result(self->stmt->mysql))) - mysql_free_result(result); - - mariadb_throw_exception(NULL, Mariadb_NotSupportedError, 0, "This command is not supported by MariaDB Connector/Python"); - goto error; - } - goto end; - } - /* throw exception from statement handle */ - mariadb_throw_exception(self->stmt, NULL, 1, NULL); - goto error; - } - } else { - /* We are already prepared, so just reexecute statement */ - mysql_stmt_bind_param(self->stmt, self->params); - Py_BEGIN_ALLOW_THREADS; - rc= mysql_stmt_execute(self->stmt); - Py_END_ALLOW_THREADS; - if (rc) - { - mariadb_throw_exception(self->stmt, NULL, 1, NULL); - goto error; - } - } - } - - if (MrdbCursor_InitResultSet(self)) - goto error; -end: - self->is_prepared= 1; - MARIADB_FREE_MEM(self->value); - Py_RETURN_NONE; -error: - MrdbCursor_clear(self); - return NULL; -} -/* }}} */ - -/* {{{ MrdbCursor_fieldcount() */ -PyObject *MrdbCursor_fieldcount(MrdbCursor *self) -{ - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - return PyLong_FromLong((long)CURSOR_FIELD_COUNT(self)); -} -/* }}} */ - -/* {{{ MrdbCursor_description - PEP-249 description method() - - Please note that the returned tuple contains eight (instead of - seven items, since we need the field flag -*/ -static -PyObject *MrdbCursor_description(MrdbCursor *self) -{ - PyObject *obj= NULL; - unsigned int field_count= CURSOR_FIELD_COUNT(self); - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - - if (self->fields && field_count) - { - uint32_t i; - - if (!(obj= PyTuple_New(field_count))) - return NULL; - - for (i=0; i < field_count; i++) - { - uint32_t precision= 0; - uint32_t decimals= 0; - unsigned long display_length= self->fields[i].max_length; - long packed_len= mysql_ps_fetch_functions[self->fields[i].type].pack_len; - - if (self->fields[i].decimals) - { - if (self->fields[i].decimals < 31) - { - decimals= self->fields[i].decimals; - precision= self->fields[i].length; - display_length= precision + 1; - } - } - - PyObject *desc; - if (!(desc= Py_BuildValue("(sIIiIIOI)", - self->fields[i].name, - self->fields[i].type, - display_length, - packed_len >= 0 ? packed_len : -1, - precision, - decimals, - PyBool_FromLong(!IS_NOT_NULL(self->fields[i].flags)), - self->fields[i].flags))) - { - Py_XDECREF(obj); - mariadb_throw_exception(NULL, Mariadb_OperationalError, 0, - "Can't build descriptor record"); - return NULL; - } - PyTuple_SetItem(obj, i, desc); - } - Py_INCREF(obj); - return obj; - } - Py_INCREF(Py_None); - return Py_None; -} -/* }}} */ - -static int MrdbCursor_fetchinternal(MrdbCursor *self) -{ - unsigned int field_count= CURSOR_FIELD_COUNT(self); - MYSQL_ROW row; - int rc; - unsigned int i; - - if (!self->is_text) - { - rc= mysql_stmt_fetch(self->stmt); - if (rc == MYSQL_NO_DATA) - return 1; - return 0; - } - - if (!(row= mysql_fetch_row(self->result))) - return 1; - - for (i= 0; i < field_count; i++) - { - field_fetch_fromtext(self, row[i], i); - } - return 0; -} - -/* {{{ MrdbCursor_fetchone - PEP-249 fetchone() method -*/ -static -PyObject *MrdbCursor_fetchone(MrdbCursor *self) -{ - PyObject *row; - uint32_t i; - unsigned int field_count= CURSOR_FIELD_COUNT(self); - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - if (!field_count) - { - mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, - "Cursor doesn't have a result set"); - return NULL; - } - - if (MrdbCursor_fetchinternal(self)) - { - Py_INCREF(Py_None); - return Py_None; - } - - self->row_number++; - if (!(row= mariadb_get_sequence_or_tuple(self))) - return NULL; - for (i= 0; i < field_count; i++) - { - MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, row, i); - } - return row; -} -/* }}} */ - -/* {{{ MrdbCursor_scroll - PEP-249: (optional) scroll() method - - Parameter: value - mode=[relative(default),absolute] - - Todo: support for forward only cursor -*/ -static -PyObject *MrdbCursor_scroll(MrdbCursor *self, PyObject *args, - PyObject *kwargs) -{ - char *modestr= NULL; - PyObject *Pos; - long position= 0; - unsigned long long new_position= 0; - uint8_t mode= 0; /* default: relative */ - char *kw_list[]= {"", "mode", NULL}; - const char *scroll_modes[]= {"relative", "absolute", NULL}; - - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - if (!CURSOR_FIELD_COUNT(self)) - { - mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, - "Cursor doesn't have a result set"); - return NULL; - } - - if (!self->is_buffered) - { - mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, - "This method is available only for cursors with buffered result set " - "or a read only cursor type"); - return NULL; - } - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, - "O!|s", kw_list, &PyLong_Type, &Pos, &modestr)) - return NULL; - - if (!(position= PyLong_AsLong(Pos))) - { - mariadb_throw_exception(NULL, Mariadb_DataError, 0, - "Invalid position value 0"); - return NULL; - } - - if (modestr != NULL) - { - while (scroll_modes[mode]) { - if (!strcmp(scroll_modes[mode], modestr)) - break; - mode++; - }; - } else - mode= 0; - - if (!scroll_modes[mode]) { - mariadb_throw_exception(NULL, Mariadb_DataError, 0, - "Invalid mode '%s'", modestr); - return NULL; - } - - if (!mode) { - new_position= self->row_number + position; - if (new_position < 0 || new_position > CURSOR_NUM_ROWS(self)) - { - mariadb_throw_exception(NULL, Mariadb_DataError, 0, - "Position value is out of range"); - return NULL; - } - } else - new_position= position; /* absolute */ - - if (!self->is_text) - mysql_stmt_data_seek(self->stmt, new_position); - else - mysql_data_seek(self->result, new_position); - self->row_number= new_position; - Py_INCREF(Py_None); - return Py_None; -} -/*}}}*/ - -/* {{{ MrdbCursor_fetchmany - PEP-249 fetchmany() method - - Optional parameters: size -*/ -static -PyObject *MrdbCursor_fetchmany(MrdbCursor *self, PyObject *args, - PyObject *kwargs) -{ - PyObject *List= NULL; - uint32_t i; - unsigned long rows= 0; - static char *kw_list[]= {"size", NULL}; - unsigned int field_count= CURSOR_FIELD_COUNT(self); - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - if (!field_count) - { - mariadb_throw_exception(0, Mariadb_ProgrammingError, 0, - "Cursor doesn't have a result set"); - return NULL; - } - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, - "|l:fetchmany", kw_list, &rows)) - return NULL; - - if (!rows) - rows= self->row_array_size; - if (!(List= PyList_New(0))) - return NULL; - - /* if rows=0, return an empty list */ - if (!rows) - return List; - - for (i=0; i < rows; i++) - { - uint32_t j; - PyObject *Row; - if (MrdbCursor_fetchinternal(self)) - goto end; - self->affected_rows= CURSOR_NUM_ROWS(self); - if (!(Row= mariadb_get_sequence_or_tuple(self))) - return NULL; - for (j=0; j < field_count; j++) - MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, Row, j); - PyList_Append(List, Row); - } -end: - return List; -} - -static PyObject *mariadb_get_sequence_or_tuple(MrdbCursor *self) -{ - unsigned int field_count= CURSOR_FIELD_COUNT(self); - if (self->is_named_tuple) - return PyStructSequence_New(self->sequence_type); - else - return PyTuple_New(field_count); -} -/* }}} */ - -/* {{{ MrdbCursor_fetchall() - PEP-249 fetchall() method */ -static -PyObject *MrdbCursor_fetchall(MrdbCursor *self) -{ - PyObject *List; - unsigned int field_count= CURSOR_FIELD_COUNT(self); - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - if (!field_count) - { - mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, - "Cursor doesn't have a result set"); - return NULL; - } - - if (!(List= PyList_New(0))) - return NULL; - - while (!MrdbCursor_fetchinternal(self)) - { - uint32_t j; - PyObject *Row; - - self->row_number++; - - if (!(Row= mariadb_get_sequence_or_tuple(self))) - return NULL; - - for (j=0; j < field_count; j++) - { - MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, Row, j) - } - PyList_Append(List, Row); - } - self->row_count= (self->is_text) ? mysql_num_rows(self->result) : - mysql_stmt_num_rows(self->stmt); - return List; -} -/* }}} */ - -/* {{{ MrdbCursor_executemany_fallback - bulk execution for server < 10.2.6 -*/ -static -uint8_t MrdbCursor_executemany_fallback(MrdbCursor *self, - const char *statement, - size_t len) -{ - uint32_t i; - - if (mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &self->param_count)) - goto error; - - self->row_count= 0; - - for (i=0; i < self->array_size; i++) - { - int rc= 0; - /* Load values */ - if (mariadb_param_update(self, self->params, i)) - return 1; - if (mysql_stmt_bind_param(self->stmt, self->params)) - goto error; - Py_BEGIN_ALLOW_THREADS; - if (i==0) - rc= mysql_stmt_prepare(self->stmt, statement, len); - if (!rc) - rc= mysql_stmt_execute(self->stmt); - Py_END_ALLOW_THREADS; - if (rc) - goto error; - self->row_count+= mysql_stmt_affected_rows(self->stmt); - } - return 0; -error: - mariadb_throw_exception(self->stmt, NULL, 1, NULL); - return 1; -} -/* }}} */ - -/* {{{ MrdbCursor_executemany - PEP-249 executemany() method - - Paramter: A List of one or more tuples - - Note: When conecting to a server < 10.2.6 this command will be emulated - by executing preparing and executing statement n times (where n is - the number of tuples in list) -*/ -PyObject *MrdbCursor_executemany(MrdbCursor *self, - PyObject *Args) -{ - const char *statement= NULL; - int statement_len= 0; - int rc; - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - self->data= NULL; - - if (!PyArg_ParseTuple(Args, "s#O!", &statement, &statement_len, - &PyList_Type, &self->data)) - return NULL; - - if (!self->data) - { - PyErr_SetString(PyExc_TypeError, "No data provided"); - return NULL; - } - - if (!(self->is_prepared= MrdbCursor_isprepared(self, statement, statement_len))) - { - MrdbCursor_clear(self); - CURSOR_SET_STATEMENT(self, statement, statement_len); - } - self->is_text= 0; - - if (mariadb_check_bulk_parameters(self, self->data)) - goto error; - - - /* If the server doesn't support bulk execution (< 10.2.6), - we need to call a fallback routine */ - if (!MARIADB_FEATURE_SUPPORTED(self->stmt->mysql, 100206)) - { - if (MrdbCursor_executemany_fallback(self, statement, statement_len)) - goto error; - goto end; - } - - mysql_stmt_attr_set(self->stmt, STMT_ATTR_ARRAY_SIZE, &self->array_size); - if (!self->is_prepared) - { - mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &self->param_count); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_USER_DATA, (void *)self); - mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_PARAM, mariadb_param_update); - - mysql_stmt_bind_param(self->stmt, self->params); - - Py_BEGIN_ALLOW_THREADS; - rc= mariadb_stmt_execute_direct(self->stmt, statement, statement_len); - Py_END_ALLOW_THREADS; - if (rc) - { - mariadb_throw_exception(self->stmt, NULL, 1, NULL); - goto error; - } - } else { - Py_BEGIN_ALLOW_THREADS; - rc= mysql_stmt_execute(self->stmt); - Py_END_ALLOW_THREADS; - if (rc) - { - mariadb_throw_exception(self->stmt, NULL, 1, NULL); - goto error; - } - } -end: - MARIADB_FREE_MEM(self->values); - Py_RETURN_NONE; -error: - MrdbCursor_clear(self); - return NULL; -} -/* }}} */ - -/* {{{ MrdbCursor_nextset - PEP-249: Optional nextset() method -*/ -PyObject *MrdbCursor_nextset(MrdbCursor *self) -{ - MARIADB_CHECK_STMT(self); - int rc; - if (PyErr_Occurred()) - return NULL; - - if (!CURSOR_FIELD_COUNT(self)) - { - mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, - "Cursor doesn't have a result set"); - return NULL; - } - - Py_BEGIN_ALLOW_THREADS; - if (!self->is_text) - rc= mysql_stmt_next_result(self->stmt); - else - { - if (self->result) - { - mysql_free_result(self->result); - self->result= NULL; - } - rc= mysql_next_result(self->connection->mysql); - } - Py_END_ALLOW_THREADS; - if (rc) - { - Py_INCREF(Py_None); - return Py_None; - } - if (CURSOR_FIELD_COUNT(self)) - { - if (MrdbCursor_InitResultSet(self)) - return NULL; - } - else - self->fields= 0; - Py_RETURN_TRUE; -} -/* }}} */ - -/* {{{ Mariadb_row_count - PEP-249: rowcount attribute -*/ -static PyObject *Mariadb_row_count(MrdbCursor *self) -{ - int64_t row_count= 0; - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - return NULL; - - /* PEP-249 requires to return -1 if the cursor was not executed before */ - if (!self->statement) - return PyLong_FromLongLong(-1); - - if (CURSOR_FIELD_COUNT(self)) - row_count= CURSOR_NUM_ROWS(self); - else - { - row_count= CURSOR_AFFECTED_ROWS(self); - if (!row_count) - row_count= -1; - } - return PyLong_FromLongLong(row_count); -} -/* }}} */ - -static PyObject *MrdbCursor_warnings(MrdbCursor *self) -{ - MARIADB_CHECK_STMT(self); - - return PyLong_FromLong((long)mysql_stmt_warning_count(self->stmt)); -} - -/* {{{ MrdbCursor_getbuffered */ -static PyObject *MrdbCursor_getbuffered(MrdbCursor *self) -{ - if (self->is_buffered) - Py_RETURN_TRUE; - Py_RETURN_FALSE; -} -/* }}} */ - -/* {{{ MrdbCursor_setbuffered */ -static int MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg) -{ - if (!arg || Py_TYPE(arg) != &PyBool_Type) - { - PyErr_SetString(PyExc_TypeError, "Argument must be boolean"); - return -1; - } - - self->is_buffered= PyObject_IsTrue(arg); - return 0; -} -/* }}} */ - -/* {{{ MrdbCursor_lastrowid */ -static PyObject *MrdbCursor_lastrowid(MrdbCursor *self) -{ - MARIADB_CHECK_STMT(self); - return PyLong_FromUnsignedLongLong(CURSOR_INSERT_ID(self)); -} -/* }}} */ - -/* iterator protocol */ - -/* {{{ MrdbCursor_iter */ -static PyObject * -MrdbCursor_iter(PyObject *self) -{ - MARIADB_CHECK_STMT(((MrdbCursor *)self)); - Py_INCREF(self); - return self; -} -/* }}} */ - -/* {{{ MrdbCursor_iternext */ -static PyObject * -MrdbCursor_iternext(PyObject *self) -{ - PyObject *res; - - res= MrdbCursor_fetchone((MrdbCursor *)self); - - if (res && res == Py_None) - { - Py_DECREF(res); - res= NULL; - } - return res; -} -/* }}} */ - -/* {{{ MrdbCursor_closed */ -static PyObject *MrdbCursor_closed(MrdbCursor *self) -{ - if (self->is_closed || !self->stmt || self->stmt->mysql == NULL) - Py_RETURN_TRUE; - Py_RETURN_FALSE; -} -/* }}} */ - +/************************************************************************************ + Copyright (C) 2018 Georg Richter and MariaDB Corporation AB + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Library General Public + License as published by the Free Software Foundation; either + version 2 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Library General Public License for more details. + + You should have received a copy of the GNU Library General Public + License along with this library; if not see + or write to the Free Software Foundation, Inc., + 51 Franklin St., Fifth Floor, Boston, MA 02110, USA +*************************************************************************************/ + +#include + +static void MrdbCursor_dealloc(MrdbCursor *self); +static PyObject *MrdbCursor_close(MrdbCursor *self); +static PyObject *MrdbCursor_execute(MrdbCursor *self, + PyObject *args, PyObject *kwargs); +static PyObject *MrdbCursor_nextset(MrdbCursor *self); +static PyObject *MrdbCursor_executemany(MrdbCursor *self, + PyObject *args); +static PyObject *MrdbCursor_description(MrdbCursor *self); +static PyObject *MrdbCursor_fetchall(MrdbCursor *self); +static PyObject *MrdbCursor_fetchone(MrdbCursor *self); +static PyObject *MrdbCursor_fetchmany(MrdbCursor *self, + PyObject *args, + PyObject *kwargs); +static PyObject *MrdbCursor_scroll(MrdbCursor *self, + PyObject *args, + PyObject *kwargs); +static PyObject *MrdbCursor_fieldcount(MrdbCursor *self); +void field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column); +void field_fetch_callback(void *data, unsigned int column, unsigned char **row); +static PyObject *mariadb_get_sequence_or_tuple(MrdbCursor *self); +static PyObject * MrdbCursor_iter(PyObject *self); +static PyObject * MrdbCursor_iternext(PyObject *self); + +/* todo: write more documentation, this is just a placeholder */ +static char mariadb_cursor_documentation[] = +"Returns a MariaDB cursor object"; + +#define CURSOR_SET_STATEMENT(a,s,l)\ +MARIADB_FREE_MEM((a)->statement);\ +(a)->statement= PyMem_RawMalloc((l)+ 1);\ +strncpy((a)->statement, (s), (l));\ +(a)->statement_len= (unsigned long)(l);\ +(a)->statement[(l)]= 0; + +#define CURSOR_FIELD_COUNT(a)\ +((a)->is_text ? mysql_field_count((a)->connection->mysql) : (a)->stmt ? mysql_stmt_field_count((a)->stmt) : 0) + +#define CURSOR_WARNING_COUNT(a)\ +((a)->is_text ? mysql_warning_count((a)->connection->mysql) : (a)->stmt ? mysql_stmt_warning_count((a)->stmt) : 0) + +#define CURSOR_AFFECTED_ROWS(a)\ +((a)->is_text ? mysql_affected_rows((a)->connection->mysql) : (a)->stmt ? mysql_stmt_affected_rows((a)->stmt) : 0) + +#define CURSOR_INSERT_ID(a)\ +((a)->is_text ? mysql_insert_id((a)->connection->mysql) : (a)->stmt ? mysql_stmt_insert_id((a)->stmt) : 0) + +#define CURSOR_NUM_ROWS(a)\ +((a)->is_text ? mysql_num_rows((a)->result) : (a)->stmt ? mysql_stmt_num_rows((a)->stmt) : 0) + +#define MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, row, column)\ +if ((self)->is_named_tuple)\ + PyStructSequence_SET_ITEM((row), (column), (self)->values[(column)]);\ +else\ + PyTuple_SET_ITEM((row), (column), (self)->values[(column)]);\ + + +static char *mariadb_named_tuple_name= "Row"; +static char *mariadb_named_tuple_desc= "Named tupled row"; +static PyObject *Mariadb_no_operation(MrdbCursor *, + PyObject *); +static PyObject *Mariadb_row_count(MrdbCursor *self); +static PyObject *MrdbCursor_warnings(MrdbCursor *self); +static PyObject *MrdbCursor_getbuffered(MrdbCursor *self); +static int MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg); +static PyObject *MrdbCursor_lastrowid(MrdbCursor *self); +static PyObject *MrdbCursor_closed(MrdbCursor *self); + + +static PyGetSetDef MrdbCursor_sets[]= +{ + {"lastrowid", (getter)MrdbCursor_lastrowid, NULL, + "row id of the last modified (inserted) row"}, + {"description", (getter)MrdbCursor_description, NULL, + "This read-only attribute is a sequence of 8-item sequences. " + "Each of these sequences contains information describing one result column", + NULL}, + {"rowcount", (getter)Mariadb_row_count, NULL, "doc", NULL}, + {"warnings", (getter)MrdbCursor_warnings, NULL, + "Number of warnings which were produced from last execute() call", NULL}, + {"closed", (getter)MrdbCursor_closed, NULL, + "Indicates if the cursor is closed and can't be reused", NULL}, + {"buffered", (getter)MrdbCursor_getbuffered, (setter)MrdbCursor_setbuffered, + "When True all result sets are immediately transferred and the connection " + "between client and server is no longer blocked. Default value is False."}, + {NULL} +}; + +static PyMethodDef MrdbCursor_Methods[] = +{ + /* PEP-249 methods */ + {"close", (PyCFunction)MrdbCursor_close, + METH_NOARGS, + "Closes an open Cursor"}, + {"execute", (PyCFunction)MrdbCursor_execute, + METH_VARARGS | METH_KEYWORDS, + "Executes a SQL statement"}, + {"executemany", (PyCFunction)MrdbCursor_executemany, + METH_VARARGS, + "Executes a SQL statement by passing a list of values"}, + {"fetchall", (PyCFunction)MrdbCursor_fetchall, + METH_NOARGS, + "Fetches all rows of a result set"}, + {"fetchone", (PyCFunction)MrdbCursor_fetchone, + METH_NOARGS, + "Fetches the next row of a result set"}, + {"fetchmany", (PyCFunction)MrdbCursor_fetchmany, + METH_VARARGS | METH_KEYWORDS, + "Fetches multiple rows of a result set"}, + {"fieldcount", (PyCFunction)MrdbCursor_fieldcount, + METH_NOARGS, + "Returns number of columns in current result set"}, + {"nextset", (PyCFunction)MrdbCursor_nextset, + METH_NOARGS, + "Will make the cursor skip to the next available result set, discarding any remaining rows from the current set."}, + {"setinputsizes", (PyCFunction)Mariadb_no_operation, + METH_VARARGS, + "Required by PEP-249. Does nothing in MariaDB Connector/Python"}, + {"setoutputsize", (PyCFunction)Mariadb_no_operation, + METH_VARARGS, + "Required by PEP-249. Does nothing in MariaDB Connector/Python"}, + {"callproc", (PyCFunction)Mariadb_no_operation, + METH_VARARGS, + "Required by PEP-249. Does nothing in MariaDB Connector/Python, use the execute method with syntax 'CALL {procedurename}' instead"}, + {"next", (PyCFunction)MrdbCursor_fetchone, + METH_NOARGS, + "Return the next row from the currently executing SQL statement using the same semantics as .fetchone()."}, + {"scroll", (PyCFunction)MrdbCursor_scroll, + METH_VARARGS | METH_KEYWORDS, + "Scroll the cursor in the result set to a new position according to mode"}, + {NULL} /* always last */ +}; + +static struct PyMemberDef MrdbCursor_Members[] = +{ + {"connection", + T_OBJECT, + offsetof(MrdbCursor, connection), + READONLY, + "Reference to the connection object on which the cursor was created"}, + {"statement", + T_STRING, + offsetof(MrdbCursor, statement), + READONLY, + "The last executed statement"}, + {"buffered", + T_BYTE, + offsetof(MrdbCursor, is_buffered), + 0, + "Stores the entire result set in memory"}, + {"rownumber", + T_LONG, + offsetof(MrdbCursor, row_number), + READONLY, + "Current row number in result set"}, + {"arraysize", + T_LONG, + offsetof(MrdbCursor, row_array_size), + 0, + "the number of rows to fetch"}, + {NULL} +}; + +/* {{{ MrdbCursor_initialize + Cursor initialization + + Optional keywprds: + named_tuple (Boolean): return rows as named tuple instead of tuple + prefetch_size: Prefetch size for readonly cursors + cursor_type: Type of cursor: CURSOR_TYPE_READONLY or CURSOR_TYPE_NONE (default) + buffered: buffered or unbuffered result sets +*/ +static int MrdbCursor_initialize(MrdbCursor *self, PyObject *args, + PyObject *kwargs) +{ + char *key_words[]= {"", "named_tuple", "prefetch_size", "cursor_type", + "buffered", "prepared", NULL}; + PyObject *connection; + uint8_t is_named_tuple= 0; + unsigned long cursor_type= 0, + prefetch_rows= 0; + uint8_t is_buffered= 0; + uint8_t is_prepared= 0; + + if (!self) + return -1; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, + "O!|bkkbb", key_words, &MrdbConnection_Type, &connection, + &is_named_tuple, &prefetch_rows, &cursor_type, &is_buffered, + &is_prepared)) + + if (cursor_type != CURSOR_TYPE_READ_ONLY && + cursor_type != CURSOR_TYPE_NO_CURSOR) + { + mariadb_throw_exception(NULL, Mariadb_DataError, 0, + "Invalid value %ld for cursor_type", cursor_type); + return -1; + } + + Py_INCREF(connection); + self->connection= (MrdbConnection *)connection; + self->is_buffered= is_buffered ? is_buffered : self->connection->is_buffered; + + self->is_prepared= is_prepared; + + if (!(self->stmt= mysql_stmt_init(self->connection->mysql))) + { + mariadb_throw_exception(self->connection->mysql, NULL, 0, NULL); + return -1; + } + + self->cursor_type= cursor_type; + self->prefetch_rows= prefetch_rows; + self->is_named_tuple= is_named_tuple; + self->row_array_size= 1; + + if (self->cursor_type || self->prefetch_rows) + { + if (!(self->stmt = mysql_stmt_init(self->connection->mysql))) + { + mariadb_throw_exception(self->connection->mysql, Mariadb_OperationalError, 0, NULL); + return -1; + } + } + else + return 0; + + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CURSOR_TYPE, &self->cursor_type); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREFETCH_ROWS, &self->prefetch_rows); + return 0; +} +/* }}} */ + +static int MrdbCursor_traverse( + MrdbCursor *self, + visitproc visit, + void *arg) +{ + return 0; +} + +PyTypeObject MrdbCursor_Type = +{ + PyVarObject_HEAD_INIT(NULL, 0) + "mariadb.cursor", + sizeof(MrdbCursor), + 0, + (destructor)MrdbCursor_dealloc, /* tp_dealloc */ + 0, /*tp_print*/ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* PyAsyncMethods * */ + 0, /* tp_repr */ + + /* Method suites for standard classes */ + + 0, /* (PyNumberMethods *) tp_as_number */ + 0, /* (PySequenceMethods *) tp_as_sequence */ + 0, /* (PyMappingMethods *) tp_as_mapping */ + + /* More standard operations (here for binary compatibility) */ + + 0, /* (hashfunc) tp_hash */ + 0, /* (ternaryfunc) tp_call */ + 0, /* (reprfunc) tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + + /* Functions to access object as input/output buffer */ + 0, /* (PyBufferProcs *) tp_as_buffer */ + + /* (tp_flags) Flags to define presence of optional/expanded features */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, + mariadb_cursor_documentation, /* tp_doc Documentation string */ + + /* call function for all accessible objects */ + (traverseproc)MrdbCursor_traverse,/* tp_traverse */ + + /* delete references to contained objects */ + 0, /* tp_clear */ + + /* rich comparisons */ + 0, /* (richcmpfunc) tp_richcompare */ + + /* weak reference enabler */ + 0, /* (long) tp_weaklistoffset */ + + /* Iterators */ + (getiterfunc)MrdbCursor_iter, + (iternextfunc)MrdbCursor_iternext, + + /* Attribute descriptor and subclassing stuff */ + (struct PyMethodDef *)MrdbCursor_Methods, /* tp_methods */ + (struct PyMemberDef *)MrdbCursor_Members, /* tp_members */ + MrdbCursor_sets, + 0, /* (struct _typeobject *) tp_base; */ + 0, /* (PyObject *) tp_dict */ + 0, /* (descrgetfunc) tp_descr_get */ + 0, /* (descrsetfunc) tp_descr_set */ + 0, /* (long) tp_dictoffset */ + (initproc)MrdbCursor_initialize, /* tp_init */ + PyType_GenericAlloc, //NULL, /* tp_alloc */ + PyType_GenericNew, //NULL, /* tp_new */ + NULL, /* tp_free Low-level free-memory routine */ + 0, /* (PyObject *) tp_bases */ + 0, /* (PyObject *) tp_mro method resolution order */ + 0, /* (PyObject *) tp_defined */ +}; + +/* {{{ Mariadb_no_operation + This function is a stub and just returns Py_None +*/ +static PyObject *Mariadb_no_operation(MrdbCursor *self, + PyObject *args) +{ + Py_INCREF(Py_None); + return Py_None; +} +/* }}} */ + +/* {{{ MrdbCursor_isprepared + If the same statement was executed before, we don't need to + reprepare it and can just execute it. +*/ +static uint8_t MrdbCursor_isprepared(MrdbCursor *self, + const char *statement, + size_t statement_len) +{ + if (self->statement) + { + if (self->statement_len == statement_len && + !memcmp(statement, self->statement, statement_len)) + { + enum mysql_stmt_state state; + mysql_stmt_attr_get(self->stmt, STMT_ATTR_STATE, &state); + if (state >= MYSQL_STMT_PREPARED) + return 1; + } + } + return 0; +} +/* }}} */ + +/* {{{ MrdbCursor_clear + Resets statement attributes and frees + associated memory +*/ +static +void MrdbCursor_clear(MrdbCursor *self) +{ + if (!self->is_text && self->stmt) { + uint32_t val= 0; + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_USER_DATA, 0); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_PARAM, 0); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_RESULT, 0); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_ARRAY_SIZE, &val); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &val); + + mysql_stmt_free_result(self->stmt); + while (!mysql_stmt_next_result(self->stmt)) + mysql_stmt_free_result(self->stmt); + + } + + if (self->is_text) + { + if (self->result) + { + mysql_free_result(self->result); + self->result= 0; + self->is_text= 0; + } + /* clear also pending result sets */ + if (self->connection->mysql) + while (!mysql_next_result(self->connection->mysql)); + } + + MARIADB_FREE_MEM(self->sequence_fields); + self->fields= NULL; + self->row_count= 0; + self->affected_rows= 0; + self->param_count= 0; + MARIADB_FREE_MEM(self->values); + MARIADB_FREE_MEM(self->bind); + MARIADB_FREE_MEM(self->statement); + MARIADB_FREE_MEM(self->value); + MARIADB_FREE_MEM(self->params); + +} +/* }}} */ + +/* {{{ ma_cursor_close + closes the statement handle of current cursor. After call to + cursor_close the cursor can't be reused anymore +*/ +static +void ma_cursor_close(MrdbCursor *self) +{ + if (!self->is_text && self->stmt) + { + /* Todo: check if all the cursor stuff is deleted (when using prepared + statemnts this should be handled in mysql_stmt_close) */ + Py_BEGIN_ALLOW_THREADS + mysql_stmt_close(self->stmt); + Py_END_ALLOW_THREADS + self->stmt= NULL; + } + MrdbCursor_clear(self); + Mrdb_Parser_end(self->parser); + self->is_closed= 1; +} + +static +PyObject * MrdbCursor_close(MrdbCursor *self) +{ + ma_cursor_close(self); + self->is_closed= 1; + Py_INCREF(Py_None); + return Py_None; +} +/* }}} */ + +/*{{{ MrDBCursor_dealloc */ +void MrdbCursor_dealloc(MrdbCursor *self) +{ + ma_cursor_close(self); + Py_TYPE(self)->tp_free((PyObject*)self); +} +/* }}} */ + +static int Mrdb_GetFieldInfo(MrdbCursor *self) +{ + unsigned int field_count= CURSOR_FIELD_COUNT(self); + + self->row_number= 0; + + self->row_count= CURSOR_AFFECTED_ROWS(self); + + if (field_count) + { + if (self->is_text) + { + self->result= (self->is_buffered) ? mysql_store_result(self->connection->mysql) : + mysql_use_result(self->connection->mysql); + if (!self->result) + { + mariadb_throw_exception(self->connection->mysql, NULL, 0, NULL); + return 1; + } + } + else if (self->is_buffered) + { + if (mysql_stmt_store_result(self->stmt)) + { + mariadb_throw_exception(self->stmt, NULL, 1, NULL); + return 1; + } + } + + self->affected_rows= CURSOR_AFFECTED_ROWS(self); + + self->fields= (self->is_text) ? mysql_fetch_fields(self->result) : + mariadb_stmt_fetch_fields(self->stmt); + + if (self->is_named_tuple) { + unsigned int i; + if (!(self->sequence_fields= (PyStructSequence_Field *) + PyMem_RawCalloc(field_count + 1, + sizeof(PyStructSequence_Field)))) + return 1; + self->sequence_desc.name= mariadb_named_tuple_name; + self->sequence_desc.doc= mariadb_named_tuple_desc; + self->sequence_desc.fields= self->sequence_fields; + self->sequence_desc.n_in_sequence= field_count; + + + for (i=0; i < field_count; i++) + { + self->sequence_fields[i].name= self->fields[i].name; + } + self->sequence_type= PyMem_RawCalloc(1,sizeof(PyTypeObject)); + PyStructSequence_InitType(self->sequence_type, &self->sequence_desc); + } + } + return 0; +} + +static int MrdbCursor_InitResultSet(MrdbCursor *self) +{ + unsigned int field_count= CURSOR_FIELD_COUNT(self); + + MARIADB_FREE_MEM(self->sequence_fields); + MARIADB_FREE_MEM(self->values); + + if (self->result) + mysql_free_result(self->result); + + if (Mrdb_GetFieldInfo(self)) + return 1; + + if (!(self->values= (PyObject**)PyMem_RawCalloc(field_count, sizeof(PyObject *)))) + return 1; + if (!self->is_text) + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_RESULT, field_fetch_callback); + return 0; +} + +/* {{{ MrdbCursor_execute + PEP-249 execute() method +*/ +static +PyObject *MrdbCursor_execute(MrdbCursor *self, + PyObject *args, + PyObject *kwargs) +{ + PyObject *Data= NULL; + const char *statement= NULL; + int statement_len= 0; + int rc= 0; + uint8_t is_buffered= 0; + static char *key_words[]= {"", "", "buffered", NULL}; + + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, + "s#|O!$b", key_words, &statement, &statement_len, &PyTuple_Type, &Data, + &is_buffered)) + return NULL; + + /* defaukt was set to 0 before */ + self->is_buffered= is_buffered; + + /* If we don't have a prepared cursor, we need to end/free parser */ + if (!self->is_prepared && self->parser) + { + Mrdb_Parser_end(self->parser); + self->parser= NULL; + } + + /* if there are no parameters specified, we execute the statement in text protocol */ + if (!Data && !self->cursor_type) + { + /* in case statement was executed before, we need to clear, since we don't use + binary protocol */ + Mrdb_Parser_end(self->parser); + self->parser= NULL; + MrdbCursor_clear(self); + Py_BEGIN_ALLOW_THREADS; + rc= mysql_real_query(self->connection->mysql, statement, statement_len); + Py_END_ALLOW_THREADS; + if (rc) + { + mariadb_throw_exception(self->connection->mysql, NULL, 0, NULL); + goto error; + } + self->is_text= 1; + CURSOR_SET_STATEMENT(self, statement, statement_len); + } + else + { + uint8_t do_prepare= 1; + + self->is_text= 0; + + if (self->is_prepared && self->statement) + do_prepare= 0; + + /* if cursor type is not prepared, we need to clear the cursor first */ + if (!self->is_prepared && self->statement) + { + MrdbCursor_clear(self); + Mrdb_Parser_end(self->parser); + self->parser= NULL; + } + + if (!self->parser) + { + self->parser= Mrdb_Parser_init(statement, statement_len); + Mrdb_Parser_parse(self->parser, 0); + CURSOR_SET_STATEMENT(self, statement, statement_len); + } + + if (Data) + { + self->array_size= 0; + self->data= Data; + if (mariadb_check_execute_parameters(self, Data)) + goto error; + + self->data= Data; + + /* Load values */ + if (mariadb_param_update(self, self->params, 0)) + goto error; + } + if (do_prepare) + { + mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &self->param_count); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_USER_DATA, (void *)self); + mysql_stmt_bind_param(self->stmt, self->params); + + Py_BEGIN_ALLOW_THREADS; + if (!MARIADB_FEATURE_SUPPORTED(self->stmt->mysql, 100206)) + { + rc= mysql_stmt_prepare(self->stmt, self->parser->statement.str, + (unsigned long)self->parser->statement.length); + if (!rc) + rc= mysql_stmt_execute(self->stmt); + } + else + rc= mariadb_stmt_execute_direct(self->stmt, self->parser->statement.str, + self->parser->statement.length); + Py_END_ALLOW_THREADS; + + if (rc) + { + /* in case statement is not supported via binary protocol, we try + to run the statement with text protocol */ + if (mysql_stmt_errno(self->stmt) == ER_UNSUPPORTED_PS) + { + Py_BEGIN_ALLOW_THREADS; + self->is_text= 0; + rc= mysql_real_query(self->connection->mysql, statement, statement_len); + Py_END_ALLOW_THREADS; + + if (rc) + { + mariadb_throw_exception(self->stmt->mysql, NULL, 0, NULL); + goto error; + } + /* if we have a result set, we can't process it - so we will return + an error. (XA RECOVER is the only command which returns a result set + and can't be prepared) */ + if (mysql_field_count(self->stmt->mysql)) + { + MYSQL_RES *result; + + /* we need to clear the result first, otherwise the cursor remains + in usuable state (query out of order) */ + if ((result= mysql_store_result(self->stmt->mysql))) + mysql_free_result(result); + + mariadb_throw_exception(NULL, Mariadb_NotSupportedError, 0, "This command is not supported by MariaDB Connector/Python"); + goto error; + } + goto end; + } + /* throw exception from statement handle */ + mariadb_throw_exception(self->stmt, NULL, 1, NULL); + goto error; + } + } else { + /* We are already prepared, so just reexecute statement */ + mysql_stmt_bind_param(self->stmt, self->params); + Py_BEGIN_ALLOW_THREADS; + rc= mysql_stmt_execute(self->stmt); + Py_END_ALLOW_THREADS; + if (rc) + { + mariadb_throw_exception(self->stmt, NULL, 1, NULL); + goto error; + } + } + } + + if (MrdbCursor_InitResultSet(self)) + goto error; +end: + MARIADB_FREE_MEM(self->value); + Py_RETURN_NONE; +error: + Mrdb_Parser_end(self->parser); + self->parser= NULL; + MrdbCursor_clear(self); + return NULL; +} +/* }}} */ + +/* {{{ MrdbCursor_fieldcount() */ +PyObject *MrdbCursor_fieldcount(MrdbCursor *self) +{ + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + return PyLong_FromLong((long)CURSOR_FIELD_COUNT(self)); +} +/* }}} */ + +/* {{{ MrdbCursor_description + PEP-249 description method() + + Please note that the returned tuple contains eight (instead of + seven items, since we need the field flag +*/ +static +PyObject *MrdbCursor_description(MrdbCursor *self) +{ + PyObject *obj= NULL; + unsigned int field_count= CURSOR_FIELD_COUNT(self); + + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + + if (self->fields && field_count) + { + uint32_t i; + + if (!(obj= PyTuple_New(field_count))) + return NULL; + + for (i=0; i < field_count; i++) + { + uint32_t precision= 0; + uint32_t decimals= 0; + unsigned long display_length= self->fields[i].max_length; + long packed_len= mysql_ps_fetch_functions[self->fields[i].type].pack_len; + + if (self->fields[i].decimals) + { + if (self->fields[i].decimals < 31) + { + decimals= self->fields[i].decimals; + precision= self->fields[i].length; + display_length= precision + 1; + } + } + + PyObject *desc; + if (!(desc= Py_BuildValue("(sIIiIIOI)", + self->fields[i].name, + self->fields[i].type, + display_length, + packed_len >= 0 ? packed_len : -1, + precision, + decimals, + PyBool_FromLong(!IS_NOT_NULL(self->fields[i].flags)), + self->fields[i].flags))) + { + Py_XDECREF(obj); + mariadb_throw_exception(NULL, Mariadb_OperationalError, 0, + "Can't build descriptor record"); + return NULL; + } + PyTuple_SetItem(obj, i, desc); + } + Py_INCREF(obj); + return obj; + } + Py_INCREF(Py_None); + return Py_None; +} +/* }}} */ + +static int MrdbCursor_fetchinternal(MrdbCursor *self) +{ + unsigned int field_count= CURSOR_FIELD_COUNT(self); + MYSQL_ROW row; + int rc; + unsigned int i; + + if (!self->is_text) + { + rc= mysql_stmt_fetch(self->stmt); + if (rc == MYSQL_NO_DATA) + return 1; + return 0; + } + + if (!(row= mysql_fetch_row(self->result))) + return 1; + + for (i= 0; i < field_count; i++) + { + field_fetch_fromtext(self, row[i], i); + } + return 0; +} + +/* {{{ MrdbCursor_fetchone + PEP-249 fetchone() method +*/ +static +PyObject *MrdbCursor_fetchone(MrdbCursor *self) +{ + PyObject *row; + uint32_t i; + unsigned int field_count= CURSOR_FIELD_COUNT(self); + + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + if (!field_count) + { + mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, + "Cursor doesn't have a result set"); + return NULL; + } + + if (MrdbCursor_fetchinternal(self)) + { + Py_INCREF(Py_None); + return Py_None; + } + + self->row_number++; + if (!(row= mariadb_get_sequence_or_tuple(self))) + return NULL; + for (i= 0; i < field_count; i++) + { + MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, row, i); + } + return row; +} +/* }}} */ + +/* {{{ MrdbCursor_scroll + PEP-249: (optional) scroll() method + + Parameter: value + mode=[relative(default),absolute] + + Todo: support for forward only cursor +*/ +static +PyObject *MrdbCursor_scroll(MrdbCursor *self, PyObject *args, + PyObject *kwargs) +{ + char *modestr= NULL; + PyObject *Pos; + long position= 0; + unsigned long long new_position= 0; + uint8_t mode= 0; /* default: relative */ + char *kw_list[]= {"", "mode", NULL}; + const char *scroll_modes[]= {"relative", "absolute", NULL}; + + + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + if (!CURSOR_FIELD_COUNT(self)) + { + mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, + "Cursor doesn't have a result set"); + return NULL; + } + + if (!self->is_buffered) + { + mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, + "This method is available only for cursors with buffered result set " + "or a read only cursor type"); + return NULL; + } + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, + "O!|s", kw_list, &PyLong_Type, &Pos, &modestr)) + return NULL; + + if (!(position= PyLong_AsLong(Pos))) + { + mariadb_throw_exception(NULL, Mariadb_DataError, 0, + "Invalid position value 0"); + return NULL; + } + + if (modestr != NULL) + { + while (scroll_modes[mode]) { + if (!strcmp(scroll_modes[mode], modestr)) + break; + mode++; + }; + } else + mode= 0; + + if (!scroll_modes[mode]) { + mariadb_throw_exception(NULL, Mariadb_DataError, 0, + "Invalid mode '%s'", modestr); + return NULL; + } + + if (!mode) { + new_position= self->row_number + position; + if (new_position < 0 || new_position > CURSOR_NUM_ROWS(self)) + { + mariadb_throw_exception(NULL, Mariadb_DataError, 0, + "Position value is out of range"); + return NULL; + } + } else + new_position= position; /* absolute */ + + if (!self->is_text) + mysql_stmt_data_seek(self->stmt, new_position); + else + mysql_data_seek(self->result, new_position); + self->row_number= (unsigned long)new_position; + Py_INCREF(Py_None); + return Py_None; +} +/*}}}*/ + +/* {{{ MrdbCursor_fetchmany + PEP-249 fetchmany() method + + Optional parameters: size +*/ +static +PyObject *MrdbCursor_fetchmany(MrdbCursor *self, PyObject *args, + PyObject *kwargs) +{ + PyObject *List= NULL; + uint32_t i; + unsigned long rows= 0; + static char *kw_list[]= {"size", NULL}; + unsigned int field_count= CURSOR_FIELD_COUNT(self); + + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + if (!field_count) + { + mariadb_throw_exception(0, Mariadb_ProgrammingError, 0, + "Cursor doesn't have a result set"); + return NULL; + } + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, + "|l:fetchmany", kw_list, &rows)) + return NULL; + + if (!rows) + rows= self->row_array_size; + if (!(List= PyList_New(0))) + return NULL; + + /* if rows=0, return an empty list */ + if (!rows) + return List; + + for (i=0; i < rows; i++) + { + uint32_t j; + PyObject *Row; + if (MrdbCursor_fetchinternal(self)) + goto end; + self->affected_rows= CURSOR_NUM_ROWS(self); + if (!(Row= mariadb_get_sequence_or_tuple(self))) + return NULL; + for (j=0; j < field_count; j++) + MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, Row, j); + PyList_Append(List, Row); + } +end: + return List; +} + +static PyObject *mariadb_get_sequence_or_tuple(MrdbCursor *self) +{ + unsigned int field_count= CURSOR_FIELD_COUNT(self); + if (self->is_named_tuple) + return PyStructSequence_New(self->sequence_type); + else + return PyTuple_New(field_count); +} +/* }}} */ + +/* {{{ MrdbCursor_fetchall() + PEP-249 fetchall() method */ +static +PyObject *MrdbCursor_fetchall(MrdbCursor *self) +{ + PyObject *List; + unsigned int field_count= CURSOR_FIELD_COUNT(self); + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + if (!field_count) + { + mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, + "Cursor doesn't have a result set"); + return NULL; + } + + if (!(List= PyList_New(0))) + return NULL; + + while (!MrdbCursor_fetchinternal(self)) + { + uint32_t j; + PyObject *Row; + + self->row_number++; + + if (!(Row= mariadb_get_sequence_or_tuple(self))) + return NULL; + + for (j=0; j < field_count; j++) + { + MARIADB_SET_SEQUENCE_OR_TUPLE_ITEM(self, Row, j) + } + PyList_Append(List, Row); + } + self->row_count= (self->is_text) ? mysql_num_rows(self->result) : + mysql_stmt_num_rows(self->stmt); + return List; +} +/* }}} */ + +/* {{{ MrdbCursor_executemany_fallback + bulk execution for server < 10.2.6 +*/ +static +uint8_t MrdbCursor_executemany_fallback(MrdbCursor *self, + const char *statement, + size_t len) +{ + uint32_t i; + + if (mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &self->param_count)) + goto error; + + self->row_count= 0; + + for (i=0; i < self->array_size; i++) + { + int rc= 0; + /* Load values */ + if (mariadb_param_update(self, self->params, i)) + return 1; + if (mysql_stmt_bind_param(self->stmt, self->params)) + goto error; + Py_BEGIN_ALLOW_THREADS; + if (i==0) + rc= mysql_stmt_prepare(self->stmt, statement, (unsigned long)len); + if (!rc) + rc= mysql_stmt_execute(self->stmt); + Py_END_ALLOW_THREADS; + if (rc) + goto error; + self->row_count+= mysql_stmt_affected_rows(self->stmt); + } + return 0; +error: + mariadb_throw_exception(self->stmt, NULL, 1, NULL); + return 1; +} +/* }}} */ + +/* {{{ MrdbCursor_executemany + PEP-249 executemany() method + + Paramter: A List of one or more tuples + + Note: When conecting to a server < 10.2.6 this command will be emulated + by executing preparing and executing statement n times (where n is + the number of tuples in list) +*/ +PyObject *MrdbCursor_executemany(MrdbCursor *self, + PyObject *Args) +{ + char *statement= NULL; + Py_ssize_t statement_len= 0; + int rc; + uint8_t do_prepare= 1; + + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + self->data= NULL; + + if (!PyArg_ParseTuple(Args, "s#O!", &statement, &statement_len, + &PyList_Type, &self->data)) + return NULL; + + + if (!self->data) + { + PyErr_SetString(PyExc_TypeError, "No data provided"); + return NULL; + } + + if (self->is_prepared && self->statement) + do_prepare= 0; + + /* if cursor type is not prepared, we need to clear the cursor first */ + if (!self->is_prepared && self->statement) + { + MrdbCursor_clear(self); + Mrdb_Parser_end(self->parser); + self->parser= NULL; + } + self->is_text= 0; + + if (!self->parser) + { + if (!(self->parser= Mrdb_Parser_init(statement, (size_t)statement_len))) + { + exit(-1); + } + Mrdb_Parser_parse(self->parser, 0); + CURSOR_SET_STATEMENT(self, statement, statement_len); + } + + if (mariadb_check_bulk_parameters(self, self->data)) + goto error; + + + /* If the server doesn't support bulk execution (< 10.2.6), + we need to call a fallback routine */ + if (!MARIADB_FEATURE_SUPPORTED(self->stmt->mysql, 100206)) + { + if (MrdbCursor_executemany_fallback(self, self->parser->statement.str, + self->parser->statement.length)) + goto error; + goto end; + } + + mysql_stmt_attr_set(self->stmt, STMT_ATTR_ARRAY_SIZE, &self->array_size); + if (do_prepare) + { + mysql_stmt_attr_set(self->stmt, STMT_ATTR_PREBIND_PARAMS, &self->param_count); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_USER_DATA, (void *)self); + mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_PARAM, mariadb_param_update); + + mysql_stmt_bind_param(self->stmt, self->params); + + Py_BEGIN_ALLOW_THREADS; + rc= mariadb_stmt_execute_direct(self->stmt, self->parser->statement.str, + (unsigned long)self->parser->statement.length); + Py_END_ALLOW_THREADS; + if (rc) + { + mariadb_throw_exception(self->stmt, NULL, 1, NULL); + goto error; + } + } else { + Py_BEGIN_ALLOW_THREADS; + rc= mysql_stmt_execute(self->stmt); + Py_END_ALLOW_THREADS; + if (rc) + { + mariadb_throw_exception(self->stmt, NULL, 1, NULL); + goto error; + } + } +end: + MARIADB_FREE_MEM(self->values); + Py_RETURN_NONE; +error: + MrdbCursor_clear(self); + Mrdb_Parser_end(self->parser); + self->parser= NULL; + return NULL; +} +/* }}} */ + +/* {{{ MrdbCursor_nextset + PEP-249: Optional nextset() method +*/ +PyObject *MrdbCursor_nextset(MrdbCursor *self) +{ + MARIADB_CHECK_STMT(self); + int rc; + if (PyErr_Occurred()) + return NULL; + + if (!CURSOR_FIELD_COUNT(self)) + { + mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, + "Cursor doesn't have a result set"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + if (!self->is_text) + rc= mysql_stmt_next_result(self->stmt); + else + { + if (self->result) + { + mysql_free_result(self->result); + self->result= NULL; + } + rc= mysql_next_result(self->connection->mysql); + } + Py_END_ALLOW_THREADS; + if (rc) + { + Py_INCREF(Py_None); + return Py_None; + } + if (CURSOR_FIELD_COUNT(self)) + { + if (MrdbCursor_InitResultSet(self)) + return NULL; + } + else + self->fields= 0; + Py_RETURN_TRUE; +} +/* }}} */ + +/* {{{ Mariadb_row_count + PEP-249: rowcount attribute +*/ +static PyObject *Mariadb_row_count(MrdbCursor *self) +{ + int64_t row_count= 0; + + MARIADB_CHECK_STMT(self); + if (PyErr_Occurred()) + return NULL; + + /* PEP-249 requires to return -1 if the cursor was not executed before */ + if (!self->statement) + return PyLong_FromLongLong(-1); + + if (CURSOR_FIELD_COUNT(self)) + row_count= CURSOR_NUM_ROWS(self); + else + { + row_count= CURSOR_AFFECTED_ROWS(self); + if (!row_count) + row_count= -1; + } + return PyLong_FromLongLong(row_count); +} +/* }}} */ + +static PyObject *MrdbCursor_warnings(MrdbCursor *self) +{ + MARIADB_CHECK_STMT(self); + + return PyLong_FromLong((long)CURSOR_WARNING_COUNT(self)); +} + +/* {{{ MrdbCursor_getbuffered */ +static PyObject *MrdbCursor_getbuffered(MrdbCursor *self) +{ + if (self->is_buffered) + Py_RETURN_TRUE; + Py_RETURN_FALSE; +} +/* }}} */ + +/* {{{ MrdbCursor_setbuffered */ +static int MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg) +{ + if (!arg || Py_TYPE(arg) != &PyBool_Type) + { + PyErr_SetString(PyExc_TypeError, "Argument must be boolean"); + return -1; + } + + self->is_buffered= PyObject_IsTrue(arg); + return 0; +} +/* }}} */ + +/* {{{ MrdbCursor_lastrowid */ +static PyObject *MrdbCursor_lastrowid(MrdbCursor *self) +{ + MARIADB_CHECK_STMT(self); + return PyLong_FromUnsignedLongLong(CURSOR_INSERT_ID(self)); +} +/* }}} */ + +/* iterator protocol */ + +/* {{{ MrdbCursor_iter */ +static PyObject * +MrdbCursor_iter(PyObject *self) +{ + MARIADB_CHECK_STMT(((MrdbCursor *)self)); + Py_INCREF(self); + return self; +} +/* }}} */ + +/* {{{ MrdbCursor_iternext */ +static PyObject * +MrdbCursor_iternext(PyObject *self) +{ + PyObject *res; + + res= MrdbCursor_fetchone((MrdbCursor *)self); + + if (res && res == Py_None) + { + Py_DECREF(res); + res= NULL; + } + return res; +} +/* }}} */ + +/* {{{ MrdbCursor_closed */ +static PyObject *MrdbCursor_closed(MrdbCursor *self) +{ + if (self->is_closed || !self->stmt || self->stmt->mysql == NULL) + Py_RETURN_TRUE; + Py_RETURN_FALSE; +} +/* }}} */ + diff --git a/src/mariadb_parser.c b/src/mariadb_parser.c new file mode 100755 index 0000000..3129073 --- /dev/null +++ b/src/mariadb_parser.c @@ -0,0 +1,211 @@ +/************************************************************************************ + Copyright (C) 2019 Georg Richter and MariaDB Corporation AB + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Library General Public + License as published by the Free Software Foundation; either + version 2 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Library General Public License for more details. + + You should have received a copy of the GNU Library General Public + License along with this library; if not see + or write to the Free Software Foundation, Inc., + 51 Franklin St., Fifth Floor, Boston, MA 02110, USA +*************************************************************************************/ + +#include + +#define IS_WHITESPACE(a) (a==32 || a==9 || a==10 || a==13) +#define IN_LITERAL(p) ((p)->in_literal[0] || (p)->in_literal[1] || (p)->in_literal[2]) + +const char *comment_start= "/*"; +const char *comment_end= "*/"; +const char literals[3]= {'\'', '\"', '`'}; + +static uint8_t check_keyword(char* ofs, char* end, char* keyword, size_t keylen) +{ + int i; + + if (end - ofs < keylen + 1) + return 0; + for (i = 0; i < keylen; i++) + if (toupper(*(ofs + i)) != keyword[i]) + return 0; + if (!IS_WHITESPACE(*(ofs + keylen))) + return 0; + return 1; +} + +void Mrdb_Parser_end(Mrdb_Parser* p) +{ + if (p) + { + MARIADB_FREE_MEM(p->statement.str); + MARIADB_FREE_MEM(p); + } +} + +Mrdb_Parser *Mrdb_Parser_init(const char *statement, size_t length) +{ + Mrdb_Parser *p; + + if (!statement || !length) + return NULL; + + if ((p= PyMem_RawCalloc(1, sizeof(Mrdb_Parser)))) + { + if (!(p->statement.str = PyMem_RawCalloc(1, length + 1))) + { + MARIADB_FREE_MEM(p); + return NULL; + } + memcpy(p->statement.str, statement, length); + p->statement.length= length; + } + return p; +} + + +void Mrdb_Parser_parse(Mrdb_Parser *p, uint8_t is_batch) +{ + char *a, *end; + char lastchar= 0; + uint8_t i; + + if (!p || !p->statement.str) + return; + + a= p->statement.str; + end= a + p->statement.length - 1; + + while (a <= end) + { + /* check literals */ + for (i=0; i < 3; i++) + { + if (*a == literals[i]) + { + p->in_literal[i]= !(p->in_literal[i]); + a++; + continue; + } + } + /* nothing to do, if we are inside a comment or literal */ + if (IN_LITERAL(p)) + { + a++; + continue; + } + /* check comment */ + if (!p->in_comment) + { + /* Style 1 */ + if (*a == '/' && *(a + 1) == '*') + { + a+= 2; + p->in_comment= 1; + continue; + } + /* Style 2 */ + if (*a == '#') + { + a++; + p->comment_eol= 1; + } + /* Style 3 */ + if (*a == '-' && *(a+1) == '-') + { + if (((a+2) < end) && *(a+2) == ' ') + { + a+= 3; + p->comment_eol= 1; + } + } + } else + { + if (*a == '*' && *(a + 1) == '/') + { + a+= 2; + p->in_comment= 0; + continue; + } else { + a++; + continue; + } + } + if (p->comment_eol) { + if (*a == '\0' || *a == '\n') + { + a++; + p->comment_eol= 0; + continue; + } + a++; + continue; + } + /* checking for different paramstyles */ + /* parmastyle = qmark */ + if (*a == '?') + { + p->param_count++; + a++; + continue; + } + + /* paramstype = pyformat */ + if (*a == '%' && lastchar != '\\') + { + if (*(a+1) == 's' || *(a+1) == 'd') + { + *a= '?'; + memmove(a+1, a+2, end - a); + end--; + a++; + p->param_count++; + continue; + } + if (*(a+1) == '(') + { + char *val_end= strstr(a+1, ")s"); + if (val_end) + { + int keylen= val_end - a + 1; + *a= '?'; + p->param_count++; + memmove(a+1, val_end+2, end - a - keylen); + end -= keylen; + continue; + } + } + } + + if (is_batch) + { + /* Do we have an insert statement ? */ + if (!p->is_insert && check_keyword(a, end, "INSERT", 6)) + { + if (lastchar == 0 || (IS_WHITESPACE(lastchar)) || lastchar == '/') + { + p->is_insert = 1; + a += 7; + } + } + + if (p->is_insert && check_keyword(a, end, "VALUES", 6)) + { + p->value_ofs = a + 7; + a += 7; + continue; + } + } + + lastchar= *a; + a++; + } + /* Update length */ + p->statement.length= end - p->statement.str + 1; +} diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/base_test.py b/test/base_test.py new file mode 100644 index 0000000..7fc36da --- /dev/null +++ b/test/base_test.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python -O +# -*- coding: utf-8 -*- + +import mariadb + +from .conf_test import conf + + +def create_connection(additional_conf=None): + default_conf = conf() + if additional_conf is None: + c = {key: value for (key, value) in (default_conf.items())} + else: + c = {key: value for (key, value) in (list(default_conf.items()) + list( + additional_conf.items()))} + return mariadb.connect(**c) diff --git a/test/conf_test.py b/test/conf_test.py new file mode 100644 index 0000000..3b6b0ed --- /dev/null +++ b/test/conf_test.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python -O +# -*- coding: utf-8 -*- + +import os + + +def conf(): + d = { + "user": os.environ.get('TEST_USER', 'root'), + "host": os.environ.get('TEST_HOST', 'localhost'), + "database": os.environ.get('TEST_DATABASE', 'testp'), + "port": int(os.environ.get('TEST_PORT', '3306')) + } + if os.environ.get('TEST_PASSWORD'): + d["password"] = os.environ.get('TEST_PASSWORD') + return d diff --git a/test/cursor.py b/test/cursor.py deleted file mode 100644 index 85f6453..0000000 --- a/test/cursor.py +++ /dev/null @@ -1,498 +0,0 @@ -#!/usr/bin/env python -O - -import mariadb -import datetime -import unittest -import collections - -class CursorTest(unittest.TestCase): - def setUp(self): - self.connection= mariadb.connection(default_file='default.cnf') - self.connection.autocommit= False - - def tearDown(self): - self.connection.rollback() - del self.connection - - def test_date(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1(c1 TIMESTAMP(6), c2 TIME(6), c3 DATETIME(6), c4 DATE)") - t= datetime.datetime(2018,6,20,12,22,31,123456) - c1= t - c2= t.time() - c3= t - c4= t.date() - cursor.execute("INSERT INTO t1 VALUES (?,?,?,?)", (c1, c2, c3, c4)) - - cursor.execute("SELECT c1,c2,c3,c4 FROM t1") - row= cursor.fetchone() - self.assertEqual(row[0],c1) - self.assertEqual(row[1],c2) - self.assertEqual(row[2],c3) - self.assertEqual(row[3],c4) - cursor.close() - - def test_numbers(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyint unsigned, b smallint unsigned, c mediumint unsigned, d int unsigned, e bigint unsigned, f double)") - c1= 4 - c2= 200 - c3= 167557 - c4= 28688817 - c5= 7330133222578 - c6= 3.1415925 - - cursor.execute("insert into t1 values (?,?,?,?,?,?)", (c1,c2,c3,c4,c5,c6)) - - cursor.execute("select * from t1") - row= cursor.fetchone() - self.assertEqual(row[0],c1) - self.assertEqual(row[1],c2) - self.assertEqual(row[2],c3) - self.assertEqual(row[3],c4) - self.assertEqual(row[4],c5) - self.assertEqual(row[5],c6) - del cursor - - def test_string(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a char(5), b varchar(100), c tinytext, d mediumtext, e text, f longtext)"); - - c1= "12345"; - c2= "The length of this text is < 100 characters" - c3= "This should also fit into tinytext which has a maximum of 255 characters" - c4= 'a' * 1000; - c5= 'b' * 6000; - c6= 'c' * 67000; - - cursor.execute("INSERT INTO t1 VALUES (?,?,?,?,?,?)", (c1,c2,c3,c4,c5,c6)) - - cursor.execute("SELECT * from t1") - row= cursor.fetchone() - self.assertEqual(row[0],c1) - self.assertEqual(row[1],c2) - self.assertEqual(row[2],c3) - self.assertEqual(row[3],c4) - self.assertEqual(row[4],c5) - self.assertEqual(row[5],c6) - del cursor - - def test_blob(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyblob, b mediumblob, c blob, d longblob)") - - c1= b'a' * 100; - c2= b'b' * 1000; - c3= b'c' * 10000; - c4= b'd' * 100000; - - a= (None, None, None, None) - cursor.execute("INSERT INTO t1 VALUES (?,?,?,?)", (c1, c2, c3, c4)) - - cursor.execute("SELECT * FROM t1") - row= cursor.fetchone() - self.assertEqual(row[0],c1) - self.assertEqual(row[1],c2) - self.assertEqual(row[2],c3) - self.assertEqual(row[3],c4) - del cursor - - def test_fetchmany(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t01 (id int, name varchar(64), city varchar(64))"); - params= [(1, u"Jack", u"Boston"), - (2, u"Martin", u"Ohio"), - (3, u"James", u"Washington"), - (4, u"Rasmus", u"Helsinki"), - (5, u"Andrey", u"Sofia")] - cursor.executemany("INSERT INTO t01 VALUES (?,?,?)", params); - - #test Errors - # a) if no select was executed - self.assertRaises(mariadb.Error, cursor.fetchall) - #b ) if cursor was not executed - del cursor - cursor= self.connection.cursor() - self.assertRaises(mariadb.Error, cursor.fetchall) - - cursor.execute("SELECT id, name, city FROM t01 ORDER BY id") - self.assertEqual(0, cursor.rowcount) - row = cursor.fetchall() - self.assertEqual(row, params) - self.assertEqual(5, cursor.rowcount) - - cursor.execute("SELECT id, name, city FROM t01 ORDER BY id") - self.assertEqual(0, cursor.rowcount) - - row= cursor.fetchmany(1) - self.assertEqual(row,[params[0]]) - self.assertEqual(1, cursor.rowcount) - - row= cursor.fetchmany(2) - self.assertEqual(row,([params[1], params[2]])) - self.assertEqual(3, cursor.rowcount) - - cursor.arraysize= 1 - row= cursor.fetchmany() - self.assertEqual(row,[params[3]]) - self.assertEqual(4, cursor.rowcount) - - cursor.arraysize= 2 - row= cursor.fetchmany() - self.assertEqual(row,[params[4]]) - self.assertEqual(5, cursor.rowcount) - del cursor - - def test1_multi_result(self): - cursor= self.connection.cursor() - sql= """ - CREATE OR REPLACE PROCEDURE p1() - BEGIN - SELECT 1 FROM DUAL; - SELECT 2 FROM DUAL; - END - """ - cursor.execute(sql) - cursor.execute("call p1()") - row= cursor.fetchone() - self.assertEqual(row[0], 1) - cursor.nextset() - row= cursor.fetchone() - self.assertEqual(row[0], 2) - del cursor - - def test_buffered(self): - cursor= self.connection.cursor() - cursor.execute("SELECT 1 UNION SELECT 2 UNION SELECT 3", buffered=True) - self.assertEqual(cursor.rowcount, 3) - cursor.scroll(1) - row= cursor.fetchone() - self.assertEqual(row[0],2) - del cursor - - def test_xfield_types(self): - cursor= self.connection.cursor() - fieldinfo= mariadb.fieldinfo() - cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyint not null auto_increment primary key, b smallint, c int, d bigint, e float, f decimal, g double, h char(10), i varchar(255), j blob, index(b))"); - info= cursor.description - self.assertEqual(info, None) - cursor.execute("SELECT * FROM t1") - info= cursor.description - self.assertEqual(fieldinfo.type(info[0]), "TINY") - self.assertEqual(fieldinfo.type(info[1]), "SHORT") - self.assertEqual(fieldinfo.type(info[2]), "LONG") - self.assertEqual(fieldinfo.type(info[3]), "LONGLONG") - self.assertEqual(fieldinfo.type(info[4]), "FLOAT") - self.assertEqual(fieldinfo.type(info[5]), "NEWDECIMAL") - self.assertEqual(fieldinfo.type(info[6]), "DOUBLE") - self.assertEqual(fieldinfo.type(info[7]), "STRING") - self.assertEqual(fieldinfo.type(info[8]), "VAR_STRING") - self.assertEqual(fieldinfo.type(info[9]), "BLOB") - self.assertEqual(fieldinfo.flag(info[0]), "NOT_NULL | PRIMARY_KEY | AUTO_INCREMENT | NUMERIC") - self.assertEqual(fieldinfo.flag(info[1]), "PART_KEY | NUMERIC") - self.assertEqual(fieldinfo.flag(info[9]), "BLOB | BINARY") - del cursor - - def test_bulk_delete(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE bulk_delete (id int, name varchar(64), city varchar(64))"); - params= [(1, u"Jack", u"Boston"), - (2, u"Martin", u"Ohio"), - (3, u"James", u"Washington"), - (4, u"Rasmus", u"Helsinki"), - (5, u"Andrey", u"Sofia")] - cursor.executemany("INSERT INTO bulk_delete VALUES (?,?,?)", params) - self.assertEqual(cursor.rowcount, 5) - params= [(1,2)] - cursor.executemany("DELETE FROM bulk_delete WHERE id=?", params) - self.assertEqual(cursor.rowcount, 2) - - - def test_named_tuple(self): - cursor= self.connection.cursor(named_tuple=1) - cursor.execute("CREATE OR REPLACE TABLE t1 (id int, name varchar(64), city varchar(64))"); - params= [(1, u"Jack", u"Boston"), - (2, u"Martin", u"Ohio"), - (3, u"James", u"Washington"), - (4, u"Rasmus", u"Helsinki"), - (5, u"Andrey", u"Sofia")] - cursor.executemany("INSERT INTO t1 VALUES (?,?,?)", params); - cursor.execute("SELECT * FROM t1 ORDER BY id") - row= cursor.fetchone() - - self.assertEqual(cursor.statement, "SELECT * FROM t1 ORDER BY id") - self.assertEqual(row.id, 1) - self.assertEqual(row.name, "Jack") - self.assertEqual(row.city, "Boston") - del cursor - - def test_laststatement(self): - cursor= self.connection.cursor(named_tuple=1) - cursor.execute("CREATE OR REPLACE TABLE t1 (id int, name varchar(64), city varchar(64))"); - self.assertEqual(cursor.statement, "CREATE OR REPLACE TABLE t1 (id int, name varchar(64), city varchar(64))") - - params= [(1, u"Jack", u"Boston"), - (2, u"Martin", u"Ohio"), - (3, u"James", u"Washington"), - (4, u"Rasmus", u"Helsinki"), - (5, u"Andrey", u"Sofia")] - cursor.executemany("INSERT INTO t1 VALUES (?,?,?)", params); - cursor.execute("SELECT * FROM t1 ORDER BY id") - self.assertEqual(cursor.statement, "SELECT * FROM t1 ORDER BY id") - del cursor - - def test_multi_cursor(self): - cursor= self.connection.cursor() - cursor1= self.connection.cursor(cursor_type=1) - cursor2= self.connection.cursor(cursor_type=1) - - cursor.execute("CREATE OR REPLACE TABLE t1 (a int)") - cursor.execute("INSERT INTO t1 VALUES (1),(2),(3),(4),(5),(6),(7),(8)") - del cursor - - cursor1.execute("SELECT a FROM t1 ORDER BY a") - cursor2.execute("SELECT a FROM t1 ORDER BY a DESC") - - for i in range (0,8): - self.assertEqual(cursor1.rownumber, i) - row1= cursor1.fetchone() - row2= cursor2.fetchone() - self.assertEqual(cursor1.rownumber, cursor2.rownumber) - self.assertEqual(row1[0]+row2[0], 9) - - del cursor1 - del cursor2 - - def test_connection_attr(self): - cursor= self.connection.cursor() - self.assertEqual(cursor.connection, self.connection) - del cursor - - def test_dbapi_type(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a int, b varchar(20), c blob, d datetime, e decimal)") - cursor.execute("INSERT INTO t1 VALUES (1, 'foo', 'blabla', now(), 10.2)"); - cursor.execute("SELECT * FROM t1 ORDER BY a") - expected_typecodes= [ - mariadb.NUMBER, - mariadb.STRING, - mariadb.BINARY, - mariadb.DATETIME, - mariadb.NUMBER - ] - row= cursor.fetchone() - typecodes= [row[1] for row in cursor.description] - self.assertEqual(expected_typecodes, typecodes) - del cursor - - def test_tuple(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE dyncol1 (a blob)"); - tpl=(1,2,3) - cursor.execute("INSERT INTO dyncol1 VALUES (?)", tpl); - del cursor - - def test_indicator(self): - if self.connection.server_version < 100206: - self.skipTest("Requires server version >= 10.2.6") - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE ind1 (a int, b int default 2,c int)"); - vals= (mariadb.indicator_null, mariadb.indicator_default, 3) - cursor.executemany("INSERT INTO ind1 VALUES (?,?,?)", [vals]) - cursor.execute("SELECT a, b, c FROM ind1") - row= cursor.fetchone() - self.assertEqual(row[0], None) - self.assertEqual(row[1], 2) - self.assertEqual(row[2], 3) - - def test_tuple(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE dyncol1 (a blob)"); - t= datetime.datetime(2018,6,20,12,22,31,123456) - val=([1,t,3,(1,2,3)],) - cursor.execute("INSERT INTO dyncol1 VALUES (?)", val); - cursor.execute("SELECT a FROM dyncol1") - row= cursor.fetchone() - self.assertEqual(row,val); - del cursor - - def test_set(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE dyncol1 (a blob)"); - t= datetime.datetime(2018,6,20,12,22,31,123456) - a= collections.OrderedDict([('apple', 4), ('banana', 3), ('orange', 2), ('pear', 1), ('4',3), (4,4)]) - val=([1,t,3,(1,2,3), {1,2,3},a],) - cursor.execute("INSERT INTO dyncol1 VALUES (?)", val); - cursor.execute("SELECT a FROM dyncol1") - row= cursor.fetchone() - self.assertEqual(row,val); - del cursor - - def test_reset(self): - cursor= self.connection.cursor() - cursor.execute("SELECT 1 UNION SELECT 2", buffered=False) - cursor.execute("SELECT 1 UNION SELECT 2") - del cursor - - def test_fake_pickle(self): - cursor= self.connection.cursor() - cursor.execute("create or replace table t1 (a blob)") - k=bytes([0x80,0x03,0x00,0x2E]) - cursor.execute("insert into t1 values (?)", (k,)) - cursor.execute("select * from t1"); - row= cursor.fetchone() - self.assertEqual(row[0],k) - del cursor - - def test_no_result(self): - cursor= self.connection.cursor() - cursor.execute("set @a:=1") - try: - row= cursor.fetchone() - except mariadb.ProgrammingError: - pass - del cursor - - def test_collate(self): - cursor= self.connection.cursor() - cursor.execute("CREATE TABLE IF NOT EXISTS `tt` (`test` varchar(500) COLLATE utf8mb4_unicode_ci NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci") - cursor.execute("SET NAMES utf8mb4") - cursor.execute("SELECT * FROM `tt` WHERE `test` LIKE 'jj' COLLATE utf8mb4_unicode_ci") - del cursor - - def test_conpy_8(self): - cursor=self.connection.cursor() - sql= """ - CREATE OR REPLACE PROCEDURE p1() - BEGIN - SELECT 1 FROM DUAL UNION SELECT 0 FROM DUAL; - SELECT 2 FROM DUAL; - END - """ - cursor.execute(sql) - cursor.execute("call p1()") - - cursor.nextset() - row= cursor.fetchone() - self.assertEqual(row[0],2); - del cursor - - def test_conpy_7(self): - cursor=self.connection.cursor() - stmt= "SELECT 1 UNION SELECT 2 UNION SELECT 3 UNION SELECT 4" - cursor.execute(stmt, buffered=True) - cursor.scroll(2, mode='relative') - row= cursor.fetchone() - self.assertEqual(row[0],3) - cursor.scroll(-2, mode='relative') - row= cursor.fetchone() - del cursor - - def test_compy_9(self): - cursor=self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a varchar(20), b double(6,3), c double)"); - cursor.execute("INSERT INTO t1 VALUES ('€uro', 123.345, 12345.678)") - cursor.execute("SELECT a,b,c FROM t1") - cursor.fetchone() - d= cursor.description; - self.assertEqual(d[0][2], 4); # 4 code points only - self.assertEqual(d[0][3], -1); # variable length - self.assertEqual(d[1][2], 7); # length=precision + 1 - self.assertEqual(d[1][4], 6); # precision - self.assertEqual(d[1][5], 3); # decimals - del cursor - - def test_conpy_15(self): - cursor=self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a int not null auto_increment primary key, b varchar(20))"); - self.assertEqual(cursor.lastrowid, 0) - cursor.execute("INSERT INTO t1 VALUES (null, 'foo')") - self.assertEqual(cursor.lastrowid, 1) - cursor.execute("SELECT LAST_INSERT_ID()") - row= cursor.fetchone() - self.assertEqual(row[0], 1) - vals= [(3, "bar"), (4, "this")] - cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals) - self.assertEqual(cursor.lastrowid, 4) - # Bug MDEV-16847 - # cursor.execute("SELECT LAST_INSERT_ID()") - # row= cursor.fetchone() - # self.assertEqual(row[0], 4) - - # Bug MDEV-16593 - # vals= [(None, "bar"), (None, "foo")] - # cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals) - # self.assertEqual(cursor.lastrowid, 6) - del cursor - - def test_conpy_14(self): - cursor=self.connection.cursor() - self.assertEqual(cursor.rowcount, -1) - cursor.execute("CREATE OR REPLACE TABLE t1 (a int not null auto_increment primary key, b varchar(20))"); - self.assertEqual(cursor.rowcount, -1) - cursor.execute("INSERT INTO t1 VALUES (null, 'foo')") - self.assertEqual(cursor.rowcount, 1) - vals= [(3, "bar"), (4, "this")] - cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals) - self.assertEqual(cursor.rowcount, 2) - del cursor - - def test_closed(self): - cursor= self.connection.cursor() - cursor.close() - cursor.close() - self.assertEqual(cursor.closed, True) - try: - cursor.execute("set @a:=1") - except mariadb.ProgrammingError: - pass - del cursor - - def test_emptycursor(self): - cursor= self.connection.cursor() - try: - cursor.execute("") - except mariadb.DatabaseError: - pass - del cursor - - def test_iterator(self): - cursor= self.connection.cursor() - cursor.execute("select 1 union select 2 union select 3 union select 4 union select 5") - for i, row in enumerate(cursor): - self.assertEqual(i+1, cursor.rownumber) - self.assertEqual(i+1, row[0]) - - def test_update_bulk(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a int primary key, b int)") - vals= [(i,) for i in range(1000)] - cursor.executemany("INSERT INTO t1 VALUES (?, NULL)", vals); - self.assertEqual(cursor.rowcount, 1000) - self.connection.autocommit= False - cursor.executemany("UPDATE t1 SET b=2 WHERE a=?", vals); - self.connection.commit() - self.assertEqual(cursor.rowcount, 1000) - self.connection.autocommit= True - del cursor - - def test_multi_execute(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a int auto_increment primary key, b int)") - self.connection.autocommit= False - for i in range(1,1000): - cursor.execute("INSERT INTO t1 VALUES (?,1)", (i,)) - self.connection.autocommit= True - del cursor - - - def test_conpy21(self): - conn= mariadb.connection(default_file='default.cnf') - cursor=conn.cursor() - self.assertFalse(cursor.closed) - conn.close() - self.assertTrue(cursor.closed) - del cursor, conn - - - diff --git a/test/cursor_mariadb.py b/test/cursor_mariadb.py deleted file mode 100644 index 72763bd..0000000 --- a/test/cursor_mariadb.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python -O - -import mariadb -import datetime -import unittest - -class CursorTest(unittest.TestCase): - - def setUp(self): - self.connection= mariadb.connection(default_file='default.cnf') - - def tearDown(self): - self.connection.close() - del self.connection - - def test_insert_parameter(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)") -# cursor.execute("set @@autocommit=0"); - list_in= [] - for i in range(1, 300001): - row= (i,i,i,"bar", datetime.date(2019,1,1)) - list_in.append(row) - cursor.executemany("INSERT INTO t1 VALUES (?,?,?,?,?)", list_in) - self.assertEqual(len(list_in), cursor.rowcount) - self.connection.commit() - cursor.execute("SELECT * FROM t1 order by a") - list_out= cursor.fetchall() - self.assertEqual(len(list_in), cursor.rowcount); - self.assertEqual(list_in,list_out) - cursor.close() - - def test_update_parameter(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)") - cursor.execute("set @@autocommit=0"); - list_in= [] - for i in range(1, 300001): - row= (i,i,i,"bar", datetime.date(2019,1,1)) - list_in.append(row) - cursor.executemany("INSERT INTO t1 VALUES (?,?,?,?,?)", list_in) - self.assertEqual(len(list_in), cursor.rowcount) - self.connection.commit() - cursor.close() - list_update= []; - - cursor= self.connection.cursor() - cursor.execute("set @@autocommit=0"); - for i in range(1, 300001): - row= (i+1, i); - list_update.append(row); - - cursor.executemany("UPDATE t1 SET b=? WHERE a=?", list_update); - self.assertEqual(cursor.rowcount, 300000) - self.connection.commit(); - cursor.close() - - def test_delete_parameter(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)") - cursor.execute("set @@autocommit=0"); - list_in= [] - for i in range(1, 300001): - row= (i,i,i,"bar", datetime.date(2019,1,1)) - list_in.append(row) - cursor.executemany("INSERT INTO t1 VALUES (?,?,?,?,?)", list_in) - self.assertEqual(len(list_in), cursor.rowcount) - self.connection.commit() - cursor.close() - list_delete= []; - - cursor= self.connection.cursor() - cursor.execute("set @@autocommit=0"); - for i in range(1, 300001): - list_delete.append((i,)); - - cursor.executemany("DELETE FROM t1 WHERE a=?", list_delete); - self.assertEqual(cursor.rowcount, 300000) - self.connection.commit(); - cursor.close() diff --git a/test/cursor_mysql.py b/test/cursor_mysql.py deleted file mode 100644 index d1c1c4c..0000000 --- a/test/cursor_mysql.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python -O - -import mysql.connector -import datetime -import unittest - -class CursorTest(unittest.TestCase): - - def setUp(self): - self.connection= mysql.connector.connect(user='root', database='test') - - def tearDown(self): - self.connection.close() - del self.connection - - def test_parameter(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1(a int auto_increment primary key not null, b int, c int, d varchar(20),e date)") - cursor.execute("SET @@autocommit=0"); - c = (1,2,3, "bar", datetime.date(2018,11,11)) - list_in= [] - for i in range(1,300001): - row= (i,i,i,"bar", datetime.date(2019,1,1)) - list_in.append(row) - cursor.executemany("INSERT INTO t1 VALUES (%s,%s,%s,%s,%s)", list_in) - print("rows inserted:", len(list_in)) - self.connection.commit() - cursor.execute("SELECT * FROM t1 order by a") - list_out= cursor.fetchall() - print("rows fetched: ", len(list_out)) - self.assertEqual(list_in,list_out) - - cursor.close() - diff --git a/test/dbapi20.py b/test/dbapi20.py deleted file mode 100644 index 6b006bf..0000000 --- a/test/dbapi20.py +++ /dev/null @@ -1,787 +0,0 @@ -#!/usr/bin/env python -''' Python DB API 2.0 driver compliance unit test suite. - - This software is Public Domain and may be used without restrictions. - - "Now we have booze and barflies entering the discussion, plus rumours of - DBAs on drugs... and I won't tell you what flashes through my mind each - time I read the subject line with 'Anal Compliance' in it. All around - this is turning out to be a thoroughly unwholesome unit test." - - -- Ian Bicking -''' - -__rcs_id__ = '$Id$' -__version__ = '$Revision$'[11:-2] -__author__ = 'Stuart Bishop ' - -import unittest -import time -import mariadb - -# $Log$ -# Revision 1.1.2.1 2006/02/25 03:44:32 adustman -# Generic DB-API unit test module -# -# Revision 1.10 2003/10/09 03:14:14 zenzen -# Add test for DB API 2.0 optional extension, where database exceptions -# are exposed as attributes on the Connection object. -# -# Revision 1.9 2003/08/13 01:16:36 zenzen -# Minor tweak from Stefan Fleiter -# -# Revision 1.8 2003/04/10 00:13:25 zenzen -# Changes, as per suggestions by M.-A. Lemburg -# - Add a table prefix, to ensure namespace collisions can always be avoided -# -# Revision 1.7 2003/02/26 23:33:37 zenzen -# Break out DDL into helper functions, as per request by David Rushby -# -# Revision 1.6 2003/02/21 03:04:33 zenzen -# Stuff from Henrik Ekelund: -# added test_None -# added test_nextset & hooks -# -# Revision 1.5 2003/02/17 22:08:43 zenzen -# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize -# defaults to 1 & generic cursor.callproc test added -# -# Revision 1.4 2003/02/15 00:16:33 zenzen -# Changes, as per suggestions and bug reports by M.-A. Lemburg, -# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar -# - Class renamed -# - Now a subclass of TestCase, to avoid requiring the driver stub -# to use multiple inheritance -# - Reversed the polarity of buggy test in test_description -# - Test exception hierarchy correctly -# - self.populate is now self._populate(), so if a driver stub -# overrides self.ddl1 this change propagates -# - VARCHAR columns now have a width, which will hopefully make the -# DDL even more portible (this will be reversed if it causes more problems) -# - cursor.rowcount being checked after various execute and fetchXXX methods -# - Check for fetchall and fetchmany returning empty lists after results -# are exhausted (already checking for empty lists if select retrieved -# nothing -# - Fix bugs in test_setoutputsize_basic and test_setinputsizes -# - -class DatabaseAPI20Test(unittest.TestCase): - ''' Test a database self.driver for DB API 2.0 compatibility. - This implementation tests Gadfly, but the TestCase - is structured so that other self.drivers can subclass this - test case to ensure compiliance with the DB-API. It is - expected that this TestCase may be expanded in the future - if ambiguities or edge conditions are discovered. - - The 'Optional Extensions' are not yet being tested. - - self.drivers should subclass this test, overriding setUp, tearDown, - self.driver, connect_args and connect_kw_args. Class specification - should be as follows: - - import dbapi20 - class mytest(dbapi20.DatabaseAPI20Test): - [...] - - Don't 'import DatabaseAPI20Test from dbapi20', or you will - confuse the unit tester - just 'import dbapi20'. - ''' - - # The self.driver module. This should be the module where the 'connect' - # method is to be found - driver = mariadb - connect_args = () # List of arguments to pass to connect - connect_kw_args = {"user=root", "database=test"} # Keyword arguments for connect - table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables - - ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix - xddl1 = 'drop table %sbooze' % table_prefix - xddl2 = 'drop table %sbarflys' % table_prefix - - lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase - - # Some drivers may need to override these helpers, for example adding - # a 'commit' after the execute. - def executeDDL1(self,cursor): - cursor.execute(self.ddl1) - - def executeDDL2(self,cursor): - cursor.execute(self.ddl2) - - def setUp(self): - ''' self.drivers should override this method to perform required setup - if any is necessary, such as creating the database. - ''' - pass - - def tearDown(self): - ''' self.drivers should override this method to perform required cleanup - if any is necessary, such as deleting the test database. - The default drops the tables that may be created. - ''' - con = self._connect() - try: - cur = con.cursor() - for ddl in (self.xddl1,self.xddl2): - try: - cur.execute(ddl) - con.commit() - except self.driver.Error: - # Assume table didn't exist. Other tests will check if - # execute is busted. - pass - finally: - con.close() - - def _connect(self): - try: - return self.driver.connect(default_file='default.cnf') - except AttributeError: - self.fail("No connect method found in self.driver module") - - def test_connect(self): - con = self._connect() - con.close() - - def test_apilevel(self): - try: - # Must exist - apilevel = self.driver.apilevel - # Must equal 2.0 - self.assertEqual(apilevel,'2.0') - except AttributeError: - self.fail("Driver doesn't define apilevel") - - def test_threadsafety(self): - try: - # Must exist - threadsafety = self.driver.threadsafety - # Must be a valid value - self.assertTrue(threadsafety in (0,1,2,3)) - except AttributeError: - self.fail("Driver doesn't define threadsafety") - - def test_paramstyle(self): - try: - # Must exist - paramstyle = self.driver.paramstyle - # Must be a valid value - self.assertTrue(paramstyle in ( - 'qmark','numeric','named','format','pyformat' - )) - except AttributeError: - self.fail("Driver doesn't define paramstyle") - - def test_Exceptions(self): - # Make sure required exceptions exist, and are in the - # defined hierarchy. - self.assertTrue(issubclass(self.driver.Warning,Exception)) - self.assertTrue(issubclass(self.driver.Error,Exception)) - self.assertTrue( - issubclass(self.driver.InterfaceError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.DatabaseError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.OperationalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.IntegrityError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.InternalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.ProgrammingError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.NotSupportedError,self.driver.Error) - ) - - def test_ExceptionsAsConnectionAttributes(self): - # OPTIONAL EXTENSION - # Test for the optional DB API 2.0 extension, where the exceptions - # are exposed as attributes on the Connection object - # I figure this optional extension will be implemented by any - # driver author who is using this test suite, so it is enabled - # by default. - con = self._connect() - drv = self.driver - self.assertTrue(con.Warning is drv.Warning) - self.assertTrue(con.Error is drv.Error) - self.assertTrue(con.InterfaceError is drv.InterfaceError) - self.assertTrue(con.DatabaseError is drv.DatabaseError) - self.assertTrue(con.OperationalError is drv.OperationalError) - self.assertTrue(con.IntegrityError is drv.IntegrityError) - self.assertTrue(con.InternalError is drv.InternalError) - self.assertTrue(con.ProgrammingError is drv.ProgrammingError) - self.assertTrue(con.NotSupportedError is drv.NotSupportedError) - - - def test_commit(self): - con = self._connect() - try: - # Commit must work, even if it doesn't do anything - con.commit() - finally: - con.close() - - def test_rollback(self): - con = self._connect() - # If rollback is defined, it should either work or throw - # the documented exception - if hasattr(con,'rollback'): - try: - con.rollback() - except self.driver.NotSupportedError: - pass - - def test_cursor(self): - con = self._connect() - try: - cur = con.cursor() - finally: - con.close() - - def test_cursor_isolation(self): - con = self._connect() - try: - # Make sure cursors created from the same connection have - # the documented transaction isolation level - cur1 = con.cursor() - cur2 = con.cursor() - self.executeDDL1(cur1) - cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - cur2.execute("select name from %sbooze" % self.table_prefix) - booze = cur2.fetchall() - self.assertEqual(len(booze),1) - self.assertEqual(len(booze[0]),1) - self.assertEqual(booze[0][0],'Victoria Bitter') - finally: - con.close() - - def test_description(self): - con = self._connect() - try: - cur = con.cursor() - self.executeDDL1(cur) - self.assertEqual(cur.description,None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description),1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]),8, - 'cursor.description[x] tuples must have 8 elements' - ) - self.assertEqual(cur.description[0][0].lower(),'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1],self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) - - # Make sure self.description gets reset - self.executeDDL2(cur) - self.assertEqual(cur.description,None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) - finally: - con.close() - - def test_rowcount(self): - con = self._connect() - try: - cur = con.cursor() - self.executeDDL1(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount should be -1 after executing no-result ' - 'statements' - ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) - cur.execute("select name from %sbooze" % self.table_prefix, buffered=True) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) - self.executeDDL2(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount not being reset to -1 after executing ' - 'no-result statements' - ) - finally: - con.close() - - lower_func = 'lower' - - def test_close(self): - con = self._connect() - try: - cur = con.cursor() - finally: - con.close() - - # cursor.execute should raise an Error if called after connection - # closed - self.assertRaises(self.driver.Error,self.executeDDL1,cur) - - # connection.commit should raise an Error if called after connection' - # closed.' - self.assertRaises(self.driver.Error,con.commit) - - # connection.close should raise an Error if called more than once - self.assertRaises(self.driver.Error,con.close) - - def test_execute(self): - con = self._connect() - try: - cur = con.cursor() - self._paraminsert(cur) - finally: - con.close() - - def _paraminsert(self,cur): - self.executeDDL1(cur) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1)) - - if self.driver.paramstyle == 'qmark': - cur.execute( - 'insert into %sbooze values (?)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'numeric': - cur.execute( - 'insert into %sbooze values (:1)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'named': - cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, - {'beer':"Cooper's"} - ) - elif self.driver.paramstyle == 'format': - cur.execute( - 'insert into %sbooze values (%%s)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'pyformat': - cur.execute( - 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, - {'beer':"Cooper's"} - ) - else: - self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1,1)) - - cur.execute('select name from %sbooze' % self.table_prefix) - res = cur.fetchall() - self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') - beers = [res[0][0],res[1][0]] - beers.sort() - self.assertEqual(beers[0],"Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1],"Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - - def test_executemany(self): - con = self._connect() - try: - cur = con.cursor() - self.executeDDL1(cur) - largs = [ ("Cooper's",) , ("Boag's",) ] - margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] - if self.driver.paramstyle == 'qmark': - cur.executemany( - 'insert into %sbooze values (?)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'numeric': - cur.executemany( - 'insert into %sbooze values (:1)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'named': - cur.executemany( - 'insert into %sbooze values (:beer)' % self.table_prefix, - margs - ) - elif self.driver.paramstyle == 'format': - cur.executemany( - 'insert into %sbooze values (%%s)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'pyformat': - cur.executemany( - 'insert into %sbooze values (%%(beer)s)' % ( - self.table_prefix - ), - margs - ) - else: - self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1,2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount - ) - cur.execute('select name from %sbooze' % self.table_prefix) - res = cur.fetchall() - self.assertEqual(len(res),2, - 'cursor.fetchall retrieved incorrect number of rows' - ) - beers = [res[0][0],res[1][0]] - beers.sort() - self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') - self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') - finally: - con.close() - - def test_fetchone(self): - con = self._connect() - try: - cur = con.cursor() - - # cursor.fetchone should raise an Error if called before - # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) - - # cursor.fetchone should raise an Error if called after - # executing a query that cannot return rows - self.executeDDL1(cur) - self.assertRaises(self.driver.Error,cur.fetchone) - - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) - - # cursor.fetchone should raise an Error if called after - # executing a query that cannot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertRaises(self.driver.Error,cur.fetchone) - - cur.execute('select name from %sbooze' % self.table_prefix, buffered=True) - r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1,1)) - finally: - con.close() - - samples = [ - 'Carlton Cold', - 'Carlton Draft', - 'Mountain Goat', - 'Redback', - 'Victoria Bitter', - 'XXXX' - ] - - def _populate(self): - ''' Return a list of sql commands to setup the DB for the fetch - tests. - ''' - populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) - for s in self.samples - ] - return populate - - def test_fetchmany(self): - con = self._connect() - try: - cur = con.cursor() - - # cursor.fetchmany should raise an Error if called without - #issuing a query - self.assertRaises(self.driver.Error,cur.fetchmany,4) - - self.executeDDL1(cur) - for sql in self._populate(): - cur.execute(sql) - - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() - self.assertEqual(len(r),1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) - cur.arraysize=10 - r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r),3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r),2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' - ) - self.assertTrue(cur.rowcount in (-1,6)) - - # Same as above, using cursor.arraysize - cur.arraysize=4 - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r),4, - 'cursor.arraysize not being honoured by fetchmany' - ) - r = cur.fetchmany() # Should get 2 more - self.assertEqual(len(r),2) - r = cur.fetchmany() # Should be an empty sequence - self.assertEqual(len(r),0) - self.assertTrue(cur.rowcount in (-1,6)) - - cur.arraysize=6 - cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows),6) - self.assertEqual(len(rows),6) - rows = [r[0] for r in rows] - rows.sort() - - # Make sure we get the right data back out - for i in range(0,6): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) - - rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows),0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,6)) - - self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) - r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) - - finally: - con.close() - - def test_fetchall(self): - con = self._connect() - try: - cur = con.cursor() - # cursor.fetchall should raise an Error if called - # without executing a query that may return rows (such - # as a select) - self.assertRaises(self.driver.Error, cur.fetchall) - - self.executeDDL1(cur) - for sql in self._populate(): - cur.execute(sql) - - # cursor.fetchall should raise an Error if called - # after executing a a statement that cannot return rows - self.assertRaises(self.driver.Error,cur.fetchall) - - cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) - rows = [r[0] for r in rows] - rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' - ) - rows = cur.fetchall() - self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - - self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) - rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) - - finally: - con.close() - - def test_mixedfetch(self): - con = self._connect() - try: - cur = con.cursor() - self.executeDDL1(cur) - for sql in self._populate(): - cur.execute(sql) - - cur.execute('select name from %sbooze' % self.table_prefix) - rows1 = cur.fetchone() - rows23 = cur.fetchmany(2) - rows4 = cur.fetchone() - rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows23),2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56),2, - 'fetchall returned incorrect number of rows' - ) - - rows = [rows1[0]] - rows.extend([rows23[0][0],rows23[1][0]]) - rows.append(rows4[0]) - rows.extend([rows56[0][0],rows56[1][0]]) - rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved or inserted' - ) - finally: - con.close() - - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - ''' - raise NotImplementedError('Helper not implemented') - #sql=""" - # create procedure deleteme as - # begin - # select count(*) from booze - # select name from booze - # end - #""" - #cur.execute(sql) - - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' - raise NotImplementedError('Helper not implemented') - #cur.execute("drop procedure deleteme") - - def test_arraysize(self): - # Not much here - rest of the tests for this are in test_fetchmany - con = self._connect() - try: - cur = con.cursor() - self.assertTrue(hasattr(cur,'arraysize'), - 'cursor.arraysize must be defined' - ) - finally: - con.close() - - def test_setinputsizes(self): - con = self._connect() - try: - cur = con.cursor() - cur.setinputsizes( (25,) ) - self._paraminsert(cur) # Make sure cursor still works - finally: - con.close() - - def test_None(self): - con = self._connect() - try: - cur = con.cursor() - self.executeDDL1(cur) - cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchall() - self.assertEqual(len(r),1) - self.assertEqual(len(r[0]),1) - self.assertEqual(r[0][0],None,'NULL value not returned as None') - finally: - con.close() - - def test_Date(self): - d1 = self.driver.Date(2002,12,25) - d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) - # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(d1),str(d2)) - - def test_Time(self): - t1 = self.driver.Time(13,45,30) - t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) - # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(t1),str(t2)) - - def test_Timestamp(self): - t1 = self.driver.Timestamp(2002,12,25,13,45,30) - t2 = self.driver.TimestampFromTicks( - time.mktime((2002,12,25,13,45,30,0,0,0)) - ) - # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(t1),str(t2)) - - def test_Binary(self): - b = self.driver.Binary(b'Something') - b = self.driver.Binary(b'') - - def test_STRING(self): - self.assertTrue(hasattr(self.driver,'STRING'), - 'module.STRING must be defined' - ) - - def test_BINARY(self): - self.assertTrue(hasattr(self.driver,'BINARY'), - 'module.BINARY must be defined.' - ) - - def test_NUMBER(self): - self.assertTrue(hasattr(self.driver,'NUMBER'), - 'module.NUMBER must be defined.' - ) - - def test_DATETIME(self): - self.assertTrue(hasattr(self.driver,'DATETIME'), - 'module.DATETIME must be defined.' - ) - - def test_ROWID(self): - self.assertTrue(hasattr(self.driver,'ROWID'), - 'module.ROWID must be defined.' - ) - diff --git a/test/default.cnf b/test/default.cnf deleted file mode 100644 index c3eb0a0..0000000 --- a/test/default.cnf +++ /dev/null @@ -1,5 +0,0 @@ -[client] -host=127.0.0.1 -port=3306 -user=root -database=test diff --git a/test/integration/__init__.py b/test/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py new file mode 100644 index 0000000..1fbf371 --- /dev/null +++ b/test/integration/test_connection.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python -O +# -*- coding: utf-8 -*- + +import os +import unittest + +import mariadb + +from test.base_test import create_connection +from test.conf_test import conf + + +class TestConnection(unittest.TestCase): + + def setUp(self): + self.connection = create_connection() + + def tearDown(self): + del self.connection + + def test_connection_default_file(self): + if os.path.exists("client.cnf"): + os.remove("client.cnf") + default_conf = conf() + f = open("client.cnf", "w+") + f.write("[client]\n") + f.write("host=%s\n" % default_conf["host"]) + f.write("port=%i\n" % default_conf["port"]) + f.write("database=%s\n" % default_conf["database"]) + f.close() + + new_conn = mariadb.connect(default_file="./client.cnf") + self.assertEqual(new_conn.database, default_conf["database"]) + del new_conn + + def test_autocommit(self): + conn = self.connection + cursor = conn.cursor() + self.assertEqual(conn.autocommit, True) + conn.autocommit = False + self.assertEqual(conn.autocommit, False) + conn.reset() + + def test_schema(self): + default_conf = conf() + conn = self.connection + self.assertEqual(conn.database, default_conf["database"]) + cursor = conn.cursor() + cursor.execute("CREATE OR REPLACE SCHEMA test1") + cursor.execute("USE test1") + self.assertEqual(conn.database, "test1") + conn.database = default_conf["database"] + self.assertEqual(conn.database, default_conf["database"]) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/integration/test_cursor.py b/test/integration/test_cursor.py new file mode 100644 index 0000000..5a17a9e --- /dev/null +++ b/test/integration/test_cursor.py @@ -0,0 +1,558 @@ +#!/usr/bin/env python -O +# -*- coding: utf-8 -*- + +import collections +import datetime +import unittest + +import mariadb + +from test.base_test import create_connection + + +class TestCursor(unittest.TestCase): + + def setUp(self): + self.connection = create_connection() + self.connection.autocommit = False + + def tearDown(self): + del self.connection + + def test_date(self): + if self.connection.server_version < 50500: + self.skipTest("microsecond not supported") + + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_date(c1 TIMESTAMP(6), c2 TIME(6), c3 DATETIME(6), c4 DATE)") + t = datetime.datetime(2018, 6, 20, 12, 22, 31, 123456) + c1 = t + c2 = t.time() + c3 = t + c4 = t.date() + cursor.execute("INSERT INTO test_date VALUES (?,?,?,?)", (c1, c2, c3, c4)) + + cursor.execute("SELECT c1,c2,c3,c4 FROM test_date") + row = cursor.fetchone() + self.assertEqual(row[0], c1) + self.assertEqual(row[1], c2) + self.assertEqual(row[2], c3) + self.assertEqual(row[3], c4) + cursor.close() + + def test_numbers(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_numbers (a tinyint unsigned, b smallint unsigned, c mediumint " + "unsigned, d int unsigned, e bigint unsigned, f double)") + c1 = 4 + c2 = 200 + c3 = 167557 + c4 = 28688817 + c5 = 7330133222578 + c6 = 3.1415925 + + cursor.execute("insert into test_numbers values (?,?,?,?,?,?)", (c1, c2, c3, c4, c5, c6)) + + cursor.execute("select * from test_numbers") + row = cursor.fetchone() + self.assertEqual(row[0], c1) + self.assertEqual(row[1], c2) + self.assertEqual(row[2], c3) + self.assertEqual(row[3], c4) + self.assertEqual(row[4], c5) + self.assertEqual(row[5], c6) + del cursor + + def test_string(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_string (a char(5), b varchar(100), c tinytext, " + "d mediumtext, e text, f longtext)"); + + c1 = "12345"; + c2 = "The length of this text is < 100 characters" + c3 = "This should also fit into tinytext which has a maximum of 255 characters" + c4 = 'a' * 1000; + c5 = 'b' * 6000; + c6 = 'c' * 67000; + + cursor.execute("INSERT INTO test_string VALUES (?,?,?,?,?,?)", (c1, c2, c3, c4, c5, c6)) + + cursor.execute("SELECT * from test_string") + row = cursor.fetchone() + self.assertEqual(row[0], c1) + self.assertEqual(row[1], c2) + self.assertEqual(row[2], c3) + self.assertEqual(row[3], c4) + self.assertEqual(row[4], c5) + self.assertEqual(row[5], c6) + del cursor + + def test_blob(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE test_blob (a tinyblob, b mediumblob, c blob, " + "d longblob)") + + c1 = b'a' * 100; + c2 = b'b' * 1000; + c3 = b'c' * 10000; + c4 = b'd' * 100000; + + a = (None, None, None, None) + cursor.execute("INSERT INTO test_blob VALUES (?,?,?,?)", (c1, c2, c3, c4)) + + cursor.execute("SELECT * FROM test_blob") + row = cursor.fetchone() + self.assertEqual(row[0], c1) + self.assertEqual(row[1], c2) + self.assertEqual(row[2], c3) + self.assertEqual(row[3], c4) + del cursor + + def test_fetchmany(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE test_fetchmany (id int, name varchar(64), " + "city varchar(64))"); + params = [(1, u"Jack", u"Boston"), + (2, u"Martin", u"Ohio"), + (3, u"James", u"Washington"), + (4, u"Rasmus", u"Helsinki"), + (5, u"Andrey", u"Sofia")] + cursor.executemany("INSERT INTO test_fetchmany VALUES (?,?,?)", params); + + # test Errors + # a) if no select was executed + self.assertRaises(mariadb.Error, cursor.fetchall) + # b ) if cursor was not executed + del cursor + cursor = self.connection.cursor() + self.assertRaises(mariadb.Error, cursor.fetchall) + + cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id") + self.assertEqual(0, cursor.rowcount) + row = cursor.fetchall() + self.assertEqual(row, params) + self.assertEqual(5, cursor.rowcount) + + cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id") + self.assertEqual(0, cursor.rowcount) + + row = cursor.fetchmany(1) + self.assertEqual(row, [params[0]]) + self.assertEqual(1, cursor.rowcount) + + row = cursor.fetchmany(2) + self.assertEqual(row, ([params[1], params[2]])) + self.assertEqual(3, cursor.rowcount) + + cursor.arraysize = 1 + row = cursor.fetchmany() + self.assertEqual(row, [params[3]]) + self.assertEqual(4, cursor.rowcount) + + cursor.arraysize = 2 + row = cursor.fetchmany() + self.assertEqual(row, [params[4]]) + self.assertEqual(5, cursor.rowcount) + del cursor + + def test1_multi_result(self): + cursor = self.connection.cursor() + sql = """ + CREATE OR REPLACE PROCEDURE p1() + BEGIN + SELECT 1 FROM DUAL; + SELECT 2 FROM DUAL; + END + """ + cursor.execute(sql) + cursor.execute("call p1()") + row = cursor.fetchone() + self.assertEqual(row[0], 1) + cursor.nextset() + row = cursor.fetchone() + self.assertEqual(row[0], 2) + del cursor + + def test_buffered(self): + cursor = self.connection.cursor() + cursor.execute("SELECT 1 UNION SELECT 2 UNION SELECT 3", buffered=True) + self.assertEqual(cursor.rowcount, 3) + cursor.scroll(1) + row = cursor.fetchone() + self.assertEqual(row[0], 2) + del cursor + + def test_xfield_types(self): + cursor = self.connection.cursor() + fieldinfo = mariadb.fieldinfo() + cursor.execute( + "CREATE TEMPORARY TABLE test_xfield_types (a tinyint not null auto_increment primary " + "key, b smallint, c int, d bigint, e float, f decimal, g double, h char(10), i varchar(255), j blob, index(b))"); + info = cursor.description + self.assertEqual(info, None) + cursor.execute("SELECT * FROM test_xfield_types") + info = cursor.description + self.assertEqual(fieldinfo.type(info[0]), "TINY") + self.assertEqual(fieldinfo.type(info[1]), "SHORT") + self.assertEqual(fieldinfo.type(info[2]), "LONG") + self.assertEqual(fieldinfo.type(info[3]), "LONGLONG") + self.assertEqual(fieldinfo.type(info[4]), "FLOAT") + self.assertEqual(fieldinfo.type(info[5]), "NEWDECIMAL") + self.assertEqual(fieldinfo.type(info[6]), "DOUBLE") + self.assertEqual(fieldinfo.type(info[7]), "STRING") + self.assertEqual(fieldinfo.type(info[8]), "VAR_STRING") + self.assertEqual(fieldinfo.type(info[9]), "BLOB") + self.assertEqual(fieldinfo.flag(info[0]), + "NOT_NULL | PRIMARY_KEY | AUTO_INCREMENT | NUMERIC") + self.assertEqual(fieldinfo.flag(info[1]), "PART_KEY | NUMERIC") + self.assertEqual(fieldinfo.flag(info[9]), "BLOB | BINARY") + del cursor + + def test_bulk_delete(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE bulk_delete (id int, name varchar(64), city varchar(64))"); + params = [(1, u"Jack", u"Boston"), + (2, u"Martin", u"Ohio"), + (3, u"James", u"Washington"), + (4, u"Rasmus", u"Helsinki"), + (5, u"Andrey", u"Sofia")] + cursor.executemany("INSERT INTO bulk_delete VALUES (?,?,?)", params) + self.assertEqual(cursor.rowcount, 5) + params = [(1, 2)] + cursor.executemany("DELETE FROM bulk_delete WHERE id=?", params) + self.assertEqual(cursor.rowcount, 2) + + def test_named_tuple(self): + cursor = self.connection.cursor(named_tuple=1) + cursor.execute( + "CREATE TEMPORARY TABLE test_named_tuple (id int, name varchar(64), city varchar(64))"); + params = [(1, u"Jack", u"Boston"), + (2, u"Martin", u"Ohio"), + (3, u"James", u"Washington"), + (4, u"Rasmus", u"Helsinki"), + (5, u"Andrey", u"Sofia")] + cursor.executemany("INSERT INTO test_named_tuple VALUES (?,?,?)", params); + cursor.execute("SELECT * FROM test_named_tuple ORDER BY id") + row = cursor.fetchone() + + self.assertEqual(cursor.statement, "SELECT * FROM test_named_tuple ORDER BY id") + self.assertEqual(row.id, 1) + self.assertEqual(row.name, "Jack") + self.assertEqual(row.city, "Boston") + del cursor + + def test_laststatement(self): + cursor = self.connection.cursor(named_tuple=1) + cursor.execute("CREATE TEMPORARY TABLE test_laststatement (id int, name varchar(64), " + "city varchar(64))"); + self.assertEqual(cursor.statement, + "CREATE TEMPORARY TABLE test_laststatement (id int, name varchar(64), city varchar(64))") + + params = [(1, u"Jack", u"Boston"), + (2, u"Martin", u"Ohio"), + (3, u"James", u"Washington"), + (4, u"Rasmus", u"Helsinki"), + (5, u"Andrey", u"Sofia")] + cursor.executemany("INSERT INTO test_laststatement VALUES (?,?,?)", params); + cursor.execute("SELECT * FROM test_laststatement ORDER BY id") + self.assertEqual(cursor.statement, "SELECT * FROM test_laststatement ORDER BY id") + del cursor + + def test_multi_cursor(self): + cursor = self.connection.cursor() + cursor1 = self.connection.cursor(cursor_type=1) + cursor2 = self.connection.cursor(cursor_type=1) + + cursor.execute("CREATE TEMPORARY TABLE test_multi_cursor (a int)") + cursor.execute("INSERT INTO test_multi_cursor VALUES (1),(2),(3),(4),(5),(6),(7),(8)") + del cursor + + cursor1.execute("SELECT a FROM test_multi_cursor ORDER BY a") + cursor2.execute("SELECT a FROM test_multi_cursor ORDER BY a DESC") + + for i in range(0, 8): + self.assertEqual(cursor1.rownumber, i) + row1 = cursor1.fetchone() + row2 = cursor2.fetchone() + self.assertEqual(cursor1.rownumber, cursor2.rownumber) + self.assertEqual(row1[0] + row2[0], 9) + + del cursor1 + del cursor2 + + def test_connection_attr(self): + cursor = self.connection.cursor() + self.assertEqual(cursor.connection, self.connection) + del cursor + + def test_dbapi_type(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_dbapi_type (a int, b varchar(20), c blob, d datetime, e decimal)") + cursor.execute("INSERT INTO test_dbapi_type VALUES (1, 'foo', 'blabla', now(), 10.2)"); + cursor.execute("SELECT * FROM test_dbapi_type ORDER BY a") + expected_typecodes = [ + mariadb.NUMBER, + mariadb.STRING, + mariadb.BINARY, + mariadb.DATETIME, + mariadb.NUMBER + ] + row = cursor.fetchone() + typecodes = [row[1] for row in cursor.description] + self.assertEqual(expected_typecodes, typecodes) + del cursor + + def test_tuple(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE dyncol1 (a blob)") + tpl = (1, 2, 3) + cursor.execute("INSERT INTO dyncol1 VALUES (?)", tpl) + del cursor + + def test_indicator(self): + if self.connection.server_version < 100206: + self.skipTest("Requires server version >= 10.2.6") + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE ind1 (a int, b int default 2,c int)") + vals = (mariadb.indicator_null, mariadb.indicator_default, 3) + cursor.executemany("INSERT INTO ind1 VALUES (?,?,?)", [vals]) + cursor.execute("SELECT a, b, c FROM ind1") + row = cursor.fetchone() + self.assertEqual(row[0], None) + self.assertEqual(row[1], 2) + self.assertEqual(row[2], 3) + + def test_tuple2(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE dyncol1 (a blob)"); + t = datetime.datetime(2018, 6, 20, 12, 22, 31, 123456) + val = ([1, t, 3, (1, 2, 3)],) + cursor.execute("INSERT INTO dyncol1 VALUES (?)", val); + cursor.execute("SELECT a FROM dyncol1") + row = cursor.fetchone() + self.assertEqual(row, val); + del cursor + + def test_set(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE dyncol1 (a blob)") + t = datetime.datetime(2018, 6, 20, 12, 22, 31, 123456) + a = collections.OrderedDict( + [('apple', 4), ('banana', 3), ('orange', 2), ('pear', 1), ('4', 3), (4, 4)]) + val = ([1, t, 3, (1, 2, 3), {1, 2, 3}, a],) + cursor.execute("INSERT INTO dyncol1 VALUES (?)", val) + cursor.execute("SELECT a FROM dyncol1") + row = cursor.fetchone() + self.assertEqual(row, val) + del cursor + + def test_reset(self): + cursor = self.connection.cursor() + cursor.execute("SELECT 1 UNION SELECT 2", buffered=False) + cursor.execute("SELECT 1 UNION SELECT 2") + del cursor + + def test_fake_pickle(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE test_fake_pickle (a blob)") + k = bytes([0x80, 0x03, 0x00, 0x2E]) + cursor.execute("insert into test_fake_pickle values (?)", (k,)) + cursor.execute("select * from test_fake_pickle"); + row = cursor.fetchone() + self.assertEqual(row[0], k) + del cursor + + def test_no_result(self): + cursor = self.connection.cursor() + cursor.execute("set @a:=1") + try: + row = cursor.fetchone() + except mariadb.ProgrammingError: + pass + del cursor + + def test_collate(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE `test_collate` (`test` varchar(500) COLLATE " + "utf8mb4_unicode_ci NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci") + cursor.execute("SET NAMES utf8mb4") + cursor.execute( + "SELECT * FROM `test_collate` WHERE `test` LIKE 'jj' COLLATE utf8mb4_unicode_ci") + del cursor + + def test_conpy_8(self): + cursor = self.connection.cursor() + sql = """ + CREATE OR REPLACE PROCEDURE p1() + BEGIN + SELECT 1 FROM DUAL UNION SELECT 0 FROM DUAL; + SELECT 2 FROM DUAL; + END + """ + cursor.execute(sql) + cursor.execute("call p1()") + + cursor.nextset() + row = cursor.fetchone() + self.assertEqual(row[0], 2); + del cursor + + def test_conpy_7(self): + cursor = self.connection.cursor() + stmt = "SELECT 1 UNION SELECT 2 UNION SELECT 3 UNION SELECT 4" + cursor.execute(stmt, buffered=True) + cursor.scroll(2, mode='relative') + row = cursor.fetchone() + self.assertEqual(row[0], 3) + cursor.scroll(-2, mode='relative') + row = cursor.fetchone() + del cursor + + def test_compy_9(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_compy_9 (a varchar(20), b double(6,3), c double)"); + cursor.execute("INSERT INTO test_compy_9 VALUES ('€uro', 123.345, 12345.678)") + cursor.execute("SELECT a,b,c FROM test_compy_9") + cursor.fetchone() + d = cursor.description; + self.assertEqual(d[0][2], 4); # 4 code points only + self.assertEqual(d[0][3], -1); # variable length + self.assertEqual(d[1][2], 7); # length=precision + 1 + self.assertEqual(d[1][4], 6); # precision + self.assertEqual(d[1][5], 3); # decimals + del cursor + + def test_conpy_15(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_conpy_15 (a int not null auto_increment primary key, b varchar(20))"); + self.assertEqual(cursor.lastrowid, 0) + cursor.execute("INSERT INTO test_conpy_15 VALUES (null, 'foo')") + self.assertEqual(cursor.lastrowid, 1) + cursor.execute("SELECT LAST_INSERT_ID()") + row = cursor.fetchone() + self.assertEqual(row[0], 1) +# del cursor +# cursor= self.connection.cursor() + vals = [(3, "bar"), (4, "this")] + cursor.executemany("INSERT INTO test_conpy_15 VALUES (?,?)", vals) + self.assertEqual(cursor.lastrowid, 4) + # Bug MDEV-16847 + # cursor.execute("SELECT LAST_INSERT_ID()") + # row= cursor.fetchone() + # self.assertEqual(row[0], 4) + + # Bug MDEV-16593 + # vals= [(None, "bar"), (None, "foo")] + # cursor.executemany("INSERT INTO t1 VALUES (?,?)", vals) + # self.assertEqual(cursor.lastrowid, 6) + del cursor + + def test_conpy_14(self): + cursor = self.connection.cursor() + self.assertEqual(cursor.rowcount, -1) + cursor.execute( + "CREATE TEMPORARY TABLE test_conpy_14 (a int not null auto_increment primary key, b varchar(20))"); + self.assertEqual(cursor.rowcount, -1) + cursor.execute("INSERT INTO test_conpy_14 VALUES (null, 'foo')") + self.assertEqual(cursor.rowcount, 1) + vals = [(3, "bar"), (4, "this")] + cursor.executemany("INSERT INTO test_conpy_14 VALUES (?,?)", vals) + self.assertEqual(cursor.rowcount, 2) + del cursor + + def test_closed(self): + cursor = self.connection.cursor() + cursor.close() + cursor.close() + self.assertEqual(cursor.closed, True) + try: + cursor.execute("set @a:=1") + except mariadb.ProgrammingError: + pass + del cursor + + def test_emptycursor(self): + cursor = self.connection.cursor() + try: + cursor.execute("") + except mariadb.DatabaseError: + pass + del cursor + + def test_iterator(self): + cursor = self.connection.cursor() + cursor.execute("select 1 union select 2 union select 3 union select 4 union select 5") + for i, row in enumerate(cursor): + self.assertEqual(i + 1, cursor.rownumber) + self.assertEqual(i + 1, row[0]) + + def test_update_bulk(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE test_update_bulk (a int primary key, b int)") + vals = [(i,) for i in range(1000)] + cursor.executemany("INSERT INTO test_update_bulk VALUES (?, NULL)", vals); + self.assertEqual(cursor.rowcount, 1000) + self.connection.autocommit = False + cursor.executemany("UPDATE test_update_bulk SET b=2 WHERE a=?", vals); + self.connection.commit() + self.assertEqual(cursor.rowcount, 1000) + self.connection.autocommit = True + del cursor + + def test_multi_execute(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_multi_execute (a int auto_increment primary key, b int)") + self.connection.autocommit = False + for i in range(1, 1000): + cursor.execute("INSERT INTO test_multi_execute VALUES (?,1)", (i,)) + self.connection.autocommit = True + del cursor + + def test_conpy21(self): + conn = self.connection + cursor = conn.cursor() + self.assertFalse(cursor.closed) + conn.close() + self.assertTrue(cursor.closed) + del cursor, conn + + def test_utf8(self): + # F0 9F 98 8E 😎 unicode 6 smiling face with sunglasses + # F0 9F 8C B6 🌢 unicode 7 hot pepper + # F0 9F 8E A4 🎀 unicode 8 no microphones + # F0 9F A5 82 πŸ₯‚ unicode 9 champagne glass + con = create_connection({"charset": "utf8mb4"}) + cursor = con.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE `test_utf8` (`test` blob)") + cursor.execute("INSERT INTO test_utf8 VALUES (?)", ("😎🌢🎀πŸ₯‚",)) + cursor.execute("SELECT * FROM test_utf8") + row = cursor.fetchone() + self.assertEqual(row[0], b"\xf0\x9f\x98\x8e\xf0\x9f\x8c\xb6\xf0\x9f\x8e\xa4\xf0\x9f\xa5\x82") + del cursor, con + + def test_latin2(self): + con = create_connection({"charset": "cp1251"}) + print(con.character_set) + cursor = con.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE `test_latin2` (`test` blob)") + cursor.execute("INSERT INTO test_latin2 VALUES (?)", (b"\xA9\xB0",)) + cursor.execute("SELECT * FROM test_latin2") + row = cursor.fetchone() + self.assertEqual(row[0], b"\xA9\xB0") + del cursor, con + + + +if __name__ == '__main__': + unittest.main() diff --git a/test/integration/test_cursor_mariadb.py b/test/integration/test_cursor_mariadb.py new file mode 100644 index 0000000..7eebad5 --- /dev/null +++ b/test/integration/test_cursor_mariadb.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python -O +# -*- coding: utf-8 -*- + +import datetime +import unittest + +from test.base_test import create_connection + + +class CursorMariaDBTest(unittest.TestCase): + + def setUp(self): + self.connection = create_connection() + + def tearDown(self): + del self.connection + + def test_insert_parameter(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_insert_parameter(a int not null auto_increment primary key, b int, c int, d varchar(20),e date)") + # cursor.execute("set @@autocommit=0"); + list_in = [] + for i in range(1, 300001): + row = (i, i, i, "bar", datetime.date(2019, 1, 1)) + list_in.append(row) + cursor.executemany("INSERT INTO test_insert_parameter VALUES (?,?,?,?,?)", list_in) + self.assertEqual(len(list_in), cursor.rowcount) + self.connection.commit() + cursor.execute("SELECT * FROM test_insert_parameter order by a") + list_out = cursor.fetchall() + self.assertEqual(len(list_in), cursor.rowcount); + self.assertEqual(list_in, list_out) + cursor.close() + + def test_update_parameter(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_update_parameter(a int not null auto_increment " + "primary key, b int, c int, d varchar(20),e date)") + cursor.execute("set @@autocommit=0"); + list_in = [] + for i in range(1, 300001): + row = (i, i, i, "bar", datetime.date(2019, 1, 1)) + list_in.append(row) + cursor.executemany("INSERT INTO test_update_parameter VALUES (?,?,?,?,?)", list_in) + self.assertEqual(len(list_in), cursor.rowcount) + self.connection.commit() + cursor.close() + list_update = []; + + cursor = self.connection.cursor() + cursor.execute("set @@autocommit=0"); + for i in range(1, 300001): + row = (i + 1, i); + list_update.append(row); + + cursor.executemany("UPDATE test_update_parameter SET b=? WHERE a=?", list_update); + self.assertEqual(cursor.rowcount, 300000) + self.connection.commit(); + cursor.close() + + def test_delete_parameter(self): + cursor = self.connection.cursor() + cursor.execute( + "CREATE TEMPORARY TABLE test_delete_parameter(a int not null auto_increment " + "primary key, b int, c int, d varchar(20),e date)") + cursor.execute("set @@autocommit=0"); + list_in = [] + for i in range(1, 300001): + row = (i, i, i, "bar", datetime.date(2019, 1, 1)) + list_in.append(row) + cursor.executemany("INSERT INTO test_delete_parameter VALUES (?,?,?,?,?)", list_in) + self.assertEqual(len(list_in), cursor.rowcount) + self.connection.commit() + cursor.close() + list_delete = []; + + cursor = self.connection.cursor() + cursor.execute("set @@autocommit=0"); + for i in range(1, 300001): + list_delete.append((i,)); + + cursor.executemany("DELETE FROM test_delete_parameter WHERE a=?", list_delete); + self.assertEqual(cursor.rowcount, 300000) + self.connection.commit(); + cursor.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/test/integration/test_cursor_mysql.py b/test/integration/test_cursor_mysql.py new file mode 100644 index 0000000..cdcd84e --- /dev/null +++ b/test/integration/test_cursor_mysql.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python -O +# -*- coding: utf-8 -*- + +import datetime +import unittest + +from test.base_test import create_connection + +class CursorMySQLTest(unittest.TestCase): + + def setUp(self): + self.connection = create_connection() + + def tearDown(self): + del self.connection + + def test_parameter(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE test_parameter(a int auto_increment primary key not " + "null, b int, c int, d varchar(20),e date)") + cursor.execute("SET @@autocommit=0") + c = (1, 2, 3, "bar", datetime.date(2018, 11, 11)) + list_in = [] + for i in range(1, 30000): + row = (i, i, i, "bar", datetime.date(2019, 1, 1)) + list_in.append(row) + cursor.executemany("INSERT INTO test_parameter VALUES (%s,%s,%s,%s,%s)", list_in) + print("rows inserted:", len(list_in)) + self.connection.commit() + cursor.execute("SELECT * FROM test_parameter order by a") + list_out = cursor.fetchall() + print("rows fetched: ", len(list_out)) + self.assertEqual(list_in, list_out) + + cursor.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/test/integration/test_dbapi20.py b/test/integration/test_dbapi20.py new file mode 100644 index 0000000..28902dc --- /dev/null +++ b/test/integration/test_dbapi20.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +''' Python DB API 2.0 driver compliance unit test suite. + + This software is Public Domain and may be used without restrictions. + + "Now we have booze and barflies entering the discussion, plus rumours of + DBAs on drugs... and I won't tell you what flashes through my mind each + time I read the subject line with 'Anal Compliance' in it. All around + this is turning out to be a thoroughly unwholesome unit test." + + -- Ian Bicking +''' + +__rcs_id__ = '$Id$' +__version__ = '$Revision$'[11:-2] +__author__ = 'Stuart Bishop ' + +import time +import unittest + +import mariadb as mariadb + +# $Log$ +# Revision 1.1.2.1 2006/02/25 03:44:32 adustman +# Generic DB-API unit test module +# +# Revision 1.10 2003/10/09 03:14:14 zenzen +# Add test for DB API 2.0 optional extension, where database exceptions +# are exposed as attributes on the Connection object. +# +# Revision 1.9 2003/08/13 01:16:36 zenzen +# Minor tweak from Stefan Fleiter +# +# Revision 1.8 2003/04/10 00:13:25 zenzen +# Changes, as per suggestions by M.-A. Lemburg +# - Add a table prefix, to ensure namespace collisions can always be avoided +# +# Revision 1.7 2003/02/26 23:33:37 zenzen +# Break out DDL into helper functions, as per request by David Rushby +# +# Revision 1.6 2003/02/21 03:04:33 zenzen +# Stuff from Henrik Ekelund: +# added test_None +# added test_nextset & hooks +# +# Revision 1.5 2003/02/17 22:08:43 zenzen +# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize +# defaults to 1 & generic cursor.callproc test added +# +# Revision 1.4 2003/02/15 00:16:33 zenzen +# Changes, as per suggestions and bug reports by M.-A. Lemburg, +# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar +# - Class renamed +# - Now a subclass of TestCase, to avoid requiring the driver stub +# to use multiple inheritance +# - Reversed the polarity of buggy test in test_description +# - Test exception hierarchy correctly +# - self.populate is now self._populate(), so if a driver stub +# overrides self.ddl1 this change propagates +# - VARCHAR columns now have a width, which will hopefully make the +# DDL even more portible (this will be reversed if it causes more problems) +# - cursor.rowcount being checked after various execute and fetchXXX methods +# - Check for fetchall and fetchmany returning empty lists after results +# are exhausted (already checking for empty lists if select retrieved +# nothing +# - Fix bugs in test_setoutputsize_basic and test_setinputsizes +# +from test.conf_test import conf + + +class DatabaseAPI20Test(unittest.TestCase): + ''' Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. + + The 'Optional Extensions' are not yet being tested. + + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: + + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] + + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + ''' + + # The self.driver module. This should be the module where the 'connect' + # method is to be found + driver = mariadb + connect_args = () # List of arguments to pass to connect + connect_kw_args = conf() # Keyword arguments for connect + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + + ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix + ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix + xddl1 = 'drop table %sbooze' % table_prefix + xddl2 = 'drop table %sbarflys' % table_prefix + + lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + + # Some drivers may need to override these helpers, for example adding + # a 'commit' after the execute. + def executeDDL1(self, cursor): + cursor.execute(self.ddl1) + + def executeDDL2(self, cursor): + cursor.execute(self.ddl2) + + def setUp(self): + ''' self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + ''' + pass + + def tearDown(self): + ''' self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + ''' + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1, self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + + def _connect(self): + try: + return self.driver.connect(**self.connect_kw_args) + except AttributeError: + self.fail("No connect method found in self.driver module") + + def test_connect(self): + con = self._connect() + con.close() + + def test_apilevel(self): + try: + # Must exist + apilevel = self.driver.apilevel + # Must equal 2.0 + self.assertEqual(apilevel, '2.0') + except AttributeError: + self.fail("Driver doesn't define apilevel") + +# def test_threadsafety(self): +# try: +# # Must exist +# threadsafety = self.driver.threadsafety +# # Must be a valid value +# self.assertTrue(threadsafety in (0, 1, 2, 3)) +# except AttributeError: +# self.fail("Driver doesn't define threadsafety") + +# def test_paramstyle(self): +# try: +# # Must exist +# paramstyle = self.driver.paramstyle +# # Must be a valid value +# self.assertTrue(paramstyle in ( +# 'qmark', 'numeric', 'named', 'format', 'pyformat' +# )) +# except AttributeError: +# self.fail("Driver doesn't define paramstyle") + +# def test_Exceptions(self): +# # Make sure required exceptions exist, and are in the +# # defined hierarchy. +# self.assertTrue(issubclass(self.driver.Warning, Exception)) +# self.assertTrue(issubclass(self.driver.Error, Exception)) +# self.assertTrue( +# issubclass(self.driver.InterfaceError, self.driver.Error) +# ) +# self.assertTrue( +# issubclass(self.driver.DatabaseError, self.driver.Error) +# ) +# self.assertTrue( +# issubclass(self.driver.OperationalError, self.driver.Error) +# ) +# self.assertTrue( +# issubclass(self.driver.IntegrityError, self.driver.Error) +# ) +# self.assertTrue( +# issubclass(self.driver.InternalError, self.driver.Error) +# ) +# self.assertTrue( +# issubclass(self.driver.ProgrammingError, self.driver.Error) +# ) +# self.assertTrue( +# issubclass(self.driver.NotSupportedError, self.driver.Error) +# ) + +# def test_ExceptionsAsConnectionAttributes(self): +# # OPTIONAL EXTENSION +# # Test for the optional DB API 2.0 extension, where the exceptions +# # are exposed as attributes on the Connection object +# # I figure this optional extension will be implemented by any +# # driver author who is using this test suite, so it is enabled +# # by default. +# con = self._connect() +# drv = self.driver +# self.assertTrue(con.Warning is drv.Warning) +# self.assertTrue(con.Error is drv.Error) +# self.assertTrue(con.InterfaceError is drv.InterfaceError) +# self.assertTrue(con.DatabaseError is drv.DatabaseError) +# self.assertTrue(con.OperationalError is drv.OperationalError) +# self.assertTrue(con.IntegrityError is drv.IntegrityError) +# self.assertTrue(con.InternalError is drv.InternalError) +# self.assertTrue(con.ProgrammingError is drv.ProgrammingError) +# self.assertTrue(con.NotSupportedError is drv.NotSupportedError) + +# def test_commit(self): +# con = self._connect() +# try: +# # Commit must work, even if it doesn't do anything +# con.commit() +# finally: +# con.close() + +# def test_rollback(self): +# con = self._connect() +# # If rollback is defined, it should either work or throw +# # the documented exception +# if hasattr(con, 'rollback'): +# try: +# con.rollback() +# except self.driver.NotSupportedError: +# pass + +# def test_cursor(self): +# con = self._connect() +# try: +# cur = con.cursor() +# finally: +# con.close() + +# def test_cursor_isolation(self): +# con = self._connect() +# try: +# # Make sure cursors created from the same connection have +# # the documented transaction isolation level +# cur1 = con.cursor() +# cur2 = con.cursor() +# self.executeDDL1(cur1) +# cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( +# self.table_prefix +# )) +# cur2.execute("select name from %sbooze" % self.table_prefix) +# booze = cur2.fetchall() +# self.assertEqual(len(booze), 1) +# self.assertEqual(len(booze[0]), 1) +# self.assertEqual(booze[0][0], 'Victoria Bitter') +# finally: +# con.close() + +# def test_description(self): +# con = self._connect() +# try: +# cur = con.cursor() +# self.executeDDL1(cur) +# self.assertEqual(cur.description, None, +# 'cursor.description should be none after executing a ' +# 'statement that can return no rows (such as DDL)' +# ) +# cur.execute('select name from %sbooze' % self.table_prefix) +# self.assertEqual(len(cur.description), 1, +# 'cursor.description describes too many columns' +# ) +# self.assertEqual(len(cur.description[0]), 8, +# 'cursor.description[x] tuples must have 8 elements' +# ) +# self.assertEqual(cur.description[0][0].lower(), 'name', +# 'cursor.description[x][0] must return column name' +# ) +# self.assertEqual(cur.description[0][1], self.driver.STRING, +# 'cursor.description[x][1] must return column type. Got %r' +# % cur.description[0][1] +# ) + +# # Make sure self.description gets reset +# self.executeDDL2(cur) +# self.assertEqual(cur.description, None, +# 'cursor.description not being set to None when executing ' +# 'no-result statements (eg. DDL)' +# ) +# finally: +# con.close() + +# def test_rowcount(self): +# con = self._connect() +# try: +# cur = con.cursor() +# self.executeDDL1(cur) +# self.assertEqual(cur.rowcount, -1, +# 'cursor.rowcount should be -1 after executing no-result ' +# 'statements' +# ) +# cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( +# self.table_prefix +# )) +# self.assertTrue(cur.rowcount in (-1, 1), +# 'cursor.rowcount should == number or rows inserted, or ' +# 'set to -1 after executing an insert statement' +# ) +# cur.execute("select name from %sbooze" % self.table_prefix, buffered=True) +# self.assertTrue(cur.rowcount in (-1, 1), +# 'cursor.rowcount should == number of rows returned, or ' +# 'set to -1 after executing a select statement' +# ) +# self.executeDDL2(cur) +# self.assertEqual(cur.rowcount, -1, +# 'cursor.rowcount not being reset to -1 after executing ' +# 'no-result statements' +# ) +# finally: +# con.close() + +# lower_func = 'lower' + +# def test_close(self): +# con = self._connect() +# try: +# cur = con.cursor() +# finally: +# con.close() + +# # cursor.execute should raise an Error if called after connection +# # closed +# self.assertRaises(self.driver.Error, self.executeDDL1, cur) + +# # connection.commit should raise an Error if called after connection' +# # closed.' +# self.assertRaises(self.driver.Error, con.commit) + +# # connection.close should raise an Error if called more than once +# self.assertRaises(self.driver.Error, con.close) + +# def test_execute(self): +# con = self._connect() +# try: +# cur = con.cursor() +# self._paraminsert(cur) +# finally: +# con.close() + +# def _paraminsert(self, cur): +# self.executeDDL1(cur) +# cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( +# self.table_prefix +# )) +# self.assertTrue(cur.rowcount in (-1, 1)) + +# if self.driver.paramstyle == 'qmark': +# cur.execute( +# 'insert into %sbooze values (?)' % self.table_prefix, +# ("Cooper's",) +# ) +# elif self.driver.paramstyle == 'numeric': +# cur.execute( +# 'insert into %sbooze values (:1)' % self.table_prefix, +# ("Cooper's",) +# ) +# elif self.driver.paramstyle == 'named': +# cur.execute( +# 'insert into %sbooze values (:beer)' % self.table_prefix, +# {'beer': "Cooper's"} +# ) +# elif self.driver.paramstyle == 'format': +# cur.execute( +# 'insert into %sbooze values (%%s)' % self.table_prefix, +# ("Cooper's",) +# ) +# elif self.driver.paramstyle == 'pyformat': +# cur.execute( +# 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, +# {'beer': "Cooper's"} +# ) +# else: +# self.fail('Invalid paramstyle') +# self.assertTrue(cur.rowcount in (-1, 1)) + +# cur.execute('select name from %sbooze' % self.table_prefix) +# res = cur.fetchall() +# self.assertEqual(len(res), 2, 'cursor.fetchall returned too few rows') +# beers = [res[0][0], res[1][0]] +# beers.sort() +# self.assertEqual(beers[0], "Cooper's", +# 'cursor.fetchall retrieved incorrect data, or data inserted ' +# 'incorrectly' +# ) +# self.assertEqual(beers[1], "Victoria Bitter", +# 'cursor.fetchall retrieved incorrect data, or data inserted ' +# 'incorrectly' +# ) + + def test_executemany(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + largs = [("Cooper's",), ("Boag's",)] + margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}] + if self.driver.paramstyle == 'qmark': + cur.executemany( + 'insert into %sbooze values (?)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'numeric': + cur.executemany( + 'insert into %sbooze values (:1)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'named': + cur.executemany( + 'insert into %sbooze values (:beer)' % self.table_prefix, + margs + ) + elif self.driver.paramstyle == 'format': + cur.executemany( + 'insert into %sbooze values (%%s)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'pyformat': + cur.executemany( + 'insert into %sbooze values (%%(beer)s)' % ( + self.table_prefix + ), + margs + ) + else: + self.fail('Unknown paramstyle') + self.assertTrue(cur.rowcount in (-1, 2), + 'insert using cursor.executemany set cursor.rowcount to ' + 'incorrect value %r' % cur.rowcount + ) + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res), 2, + 'cursor.fetchall retrieved incorrect number of rows' + ) + beers = [res[0][0], res[1][0]] + beers.sort() + self.assertEqual(beers[0], "Boag's", 'incorrect data retrieved') + self.assertEqual(beers[1], "Cooper's", 'incorrect data retrieved') + finally: + con.close() + +# def test_fetchone(self): +# con = self._connect() +# try: +# cur = con.cursor() + +# # cursor.fetchone should raise an Error if called before +# # executing a select-type query +# self.assertRaises(self.driver.Error, cur.fetchone) + +# # cursor.fetchone should raise an Error if called after +# # executing a query that cannot return rows +# self.executeDDL1(cur) +# self.assertRaises(self.driver.Error, cur.fetchone) + +# cur.execute('select name from %sbooze' % self.table_prefix) +# self.assertEqual(cur.fetchone(), None, +# 'cursor.fetchone should return None if a query retrieves ' +# 'no rows' +# ) +# self.assertTrue(cur.rowcount in (-1, 0)) + +# # cursor.fetchone should raise an Error if called after +# # executing a query that cannot return rows +# cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( +# self.table_prefix +# )) +# self.assertRaises(self.driver.Error, cur.fetchone) + +# cur.execute('select name from %sbooze' % self.table_prefix, buffered=True) +# r = cur.fetchone() +# self.assertEqual(len(r), 1, +# 'cursor.fetchone should have retrieved a single row' +# ) +# self.assertEqual(r[0], 'Victoria Bitter', +# 'cursor.fetchone retrieved incorrect data' +# ) +# self.assertEqual(cur.fetchone(), None, +# 'cursor.fetchone should return None if no more rows available' +# ) +# self.assertTrue(cur.rowcount in (-1, 1)) +# finally: +# con.close() + +# samples = [ +# 'Carlton Cold', +# 'Carlton Draft', +# 'Mountain Goat', +# 'Redback', +# 'Victoria Bitter', +# 'XXXX' +# ] + +# def _populate(self): +# ''' Return a list of sql commands to setup the DB for the fetch +# tests. +# ''' +# populate = [ +# "insert into %sbooze values ('%s')" % (self.table_prefix, s) +# for s in self.samples +# ] +# return populate + +# def test_fetchmany(self): +# con = self._connect() +# try: +# cur = con.cursor() + +# # cursor.fetchmany should raise an Error if called without +# # issuing a query +# self.assertRaises(self.driver.Error, cur.fetchmany, 4) + +# self.executeDDL1(cur) +# for sql in self._populate(): +# cur.execute(sql) + +# cur.execute('select name from %sbooze' % self.table_prefix) +# r = cur.fetchmany() +# self.assertEqual(len(r), 1, +# 'cursor.fetchmany retrieved incorrect number of rows, ' +# 'default of arraysize is one.' +# ) +# cur.arraysize = 10 +# r = cur.fetchmany(3) # Should get 3 rows +# self.assertEqual(len(r), 3, +# 'cursor.fetchmany retrieved incorrect number of rows' +# ) +# r = cur.fetchmany(4) # Should get 2 more +# self.assertEqual(len(r), 2, +# 'cursor.fetchmany retrieved incorrect number of rows' +# ) +# r = cur.fetchmany(4) # Should be an empty sequence +# self.assertEqual(len(r), 0, +# 'cursor.fetchmany should return an empty sequence after ' +# 'results are exhausted' +# ) +# self.assertTrue(cur.rowcount in (-1, 6)) + +# # Same as above, using cursor.arraysize +# cur.arraysize = 4 +# cur.execute('select name from %sbooze' % self.table_prefix) +# r = cur.fetchmany() # Should get 4 rows +# self.assertEqual(len(r), 4, +# 'cursor.arraysize not being honoured by fetchmany' +# ) +# r = cur.fetchmany() # Should get 2 more +# self.assertEqual(len(r), 2) +# r = cur.fetchmany() # Should be an empty sequence +# self.assertEqual(len(r), 0) +# self.assertTrue(cur.rowcount in (-1, 6)) + +# cur.arraysize = 6 +# cur.execute('select name from %sbooze' % self.table_prefix) +# rows = cur.fetchmany() # Should get all rows +# self.assertTrue(cur.rowcount in (-1, 6)) +# self.assertEqual(len(rows), 6) +# self.assertEqual(len(rows), 6) +# rows = [r[0] for r in rows] +# rows.sort() + +# # Make sure we get the right data back out +# for i in range(0, 6): +# self.assertEqual(rows[i], self.samples[i], +# 'incorrect data retrieved by cursor.fetchmany' +# ) + +# rows = cur.fetchmany() # Should return an empty list +# self.assertEqual(len(rows), 0, +# 'cursor.fetchmany should return an empty sequence if ' +# 'called after the whole result set has been fetched' +# ) +# self.assertTrue(cur.rowcount in (-1, 6)) + +# self.executeDDL2(cur) +# cur.execute('select name from %sbarflys' % self.table_prefix) +# r = cur.fetchmany() # Should get empty sequence +# self.assertEqual(len(r), 0, +# 'cursor.fetchmany should return an empty sequence if ' +# 'query retrieved no rows' +# ) +# self.assertTrue(cur.rowcount in (-1, 0)) + +# finally: +# con.close() + +# def test_fetchall(self): +# con = self._connect() +# try: +# cur = con.cursor() +# # cursor.fetchall should raise an Error if called +# # without executing a query that may return rows (such +# # as a select) +# self.assertRaises(self.driver.Error, cur.fetchall) + +# self.executeDDL1(cur) +# for sql in self._populate(): +# cur.execute(sql) + +# # cursor.fetchall should raise an Error if called +# # after executing a a statement that cannot return rows +# self.assertRaises(self.driver.Error, cur.fetchall) + +# cur.execute('select name from %sbooze' % self.table_prefix) +# rows = cur.fetchall() +# self.assertTrue(cur.rowcount in (-1, len(self.samples))) +# self.assertEqual(len(rows), len(self.samples), +# 'cursor.fetchall did not retrieve all rows' +# ) +# rows = [r[0] for r in rows] +# rows.sort() +# for i in range(0, len(self.samples)): +# self.assertEqual(rows[i], self.samples[i], +# 'cursor.fetchall retrieved incorrect rows' +# ) +# rows = cur.fetchall() +# self.assertEqual( +# len(rows), 0, +# 'cursor.fetchall should return an empty list if called ' +# 'after the whole result set has been fetched' +# ) +# self.assertTrue(cur.rowcount in (-1, len(self.samples))) + +# self.executeDDL2(cur) +# cur.execute('select name from %sbarflys' % self.table_prefix) +# rows = cur.fetchall() +# self.assertTrue(cur.rowcount in (-1, 0)) +# self.assertEqual(len(rows), 0, +# 'cursor.fetchall should return an empty list if ' +# 'a select query returns no rows' +# ) + +# finally: +# con.close() + +# def test_mixedfetch(self): +# con = self._connect() +# try: +# cur = con.cursor() +# self.executeDDL1(cur) +# for sql in self._populate(): +# cur.execute(sql) + +# cur.execute('select name from %sbooze' % self.table_prefix) +# rows1 = cur.fetchone() +# rows23 = cur.fetchmany(2) +# rows4 = cur.fetchone() +# rows56 = cur.fetchall() +# self.assertTrue(cur.rowcount in (-1, 6)) +# self.assertEqual(len(rows23), 2, +# 'fetchmany returned incorrect number of rows' +# ) +# self.assertEqual(len(rows56), 2, +# 'fetchall returned incorrect number of rows' +# ) + +# rows = [rows1[0]] +# rows.extend([rows23[0][0], rows23[1][0]]) +# rows.append(rows4[0]) +# rows.extend([rows56[0][0], rows56[1][0]]) +# rows.sort() +# for i in range(0, len(self.samples)): +# self.assertEqual(rows[i], self.samples[i], +# 'incorrect data retrieved or inserted' +# ) +# finally: +# con.close() + +# def help_nextset_setUp(self, cur): +# ''' Should create a procedure called deleteme +# that returns two result sets, first the +# number of rows in booze then "name from booze" +# ''' +# raise NotImplementedError('Helper not implemented') +# # sql=""" +# # create procedure deleteme as +# # begin +# # select count(*) from booze +# # select name from booze +# # end +# # """ +# # cur.execute(sql) + +# def help_nextset_tearDown(self, cur): +# 'If cleaning up is needed after nextSetTest' +# raise NotImplementedError('Helper not implemented') +# # cur.execute("drop procedure deleteme") + +# def test_arraysize(self): +# # Not much here - rest of the tests for this are in test_fetchmany +# con = self._connect() +# try: +# cur = con.cursor() +# self.assertTrue(hasattr(cur, 'arraysize'), +# 'cursor.arraysize must be defined' +# ) +# finally: +# con.close() + +# def test_setinputsizes(self): +# con = self._connect() +# try: +# cur = con.cursor() +# cur.setinputsizes((25,)) +# self._paraminsert(cur) # Make sure cursor still works +# finally: +# con.close() + +# def test_None(self): +# con = self._connect() +# try: +# cur = con.cursor() +# self.executeDDL1(cur) +# cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) +# cur.execute('select name from %sbooze' % self.table_prefix) +# r = cur.fetchall() +# self.assertEqual(len(r), 1) +# self.assertEqual(len(r[0]), 1) +# self.assertEqual(r[0][0], None, 'NULL value not returned as None') +# finally: +# con.close() + +# def test_Date(self): +# d1 = self.driver.Date(2002, 12, 25) +# d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) +# # Can we assume this? API doesn't specify, but it seems implied +# # self.assertEqual(str(d1),str(d2)) + +# def test_Time(self): +# t1 = self.driver.Time(13, 45, 30) +# t2 = self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) +# # Can we assume this? API doesn't specify, but it seems implied +# # self.assertEqual(str(t1),str(t2)) + +# def test_Timestamp(self): +# t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) +# t2 = self.driver.TimestampFromTicks( +# time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) +# ) +# # Can we assume this? API doesn't specify, but it seems implied +# # self.assertEqual(str(t1),str(t2)) + +# def test_Binary(self): +# b = self.driver.Binary(b'Something') +# b = self.driver.Binary(b'') + +# def test_STRING(self): +# self.assertTrue(hasattr(self.driver, 'STRING'), +# 'module.STRING must be defined' +# ) + +# def test_BINARY(self): +# self.assertTrue(hasattr(self.driver, 'BINARY'), +# 'module.BINARY must be defined.' +# ) + +# def test_NUMBER(self): +# self.assertTrue(hasattr(self.driver, 'NUMBER'), +# 'module.NUMBER must be defined.' +# ) + +# def test_DATETIME(self): +# self.assertTrue(hasattr(self.driver, 'DATETIME'), +# 'module.DATETIME must be defined.' +# ) + +# def test_ROWID(self): +# self.assertTrue(hasattr(self.driver, 'ROWID'), +# 'module.ROWID must be defined.' +# ) + + +#f __name__ == '__main__': +# unittest.main() diff --git a/test/integration/test_nondbapi.py b/test/integration/test_nondbapi.py new file mode 100644 index 0000000..9425c09 --- /dev/null +++ b/test/integration/test_nondbapi.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python -O + +# -*- coding: utf-8 -*- + +import unittest + +import mariadb + +from test.base_test import create_connection +from test.conf_test import conf + + +class CursorTest(unittest.TestCase): + + def setUp(self): + self.connection = create_connection() + + def tearDown(self): + del self.connection + + def test_ping(self): + new_conn = create_connection() + id = new_conn.connection_id; + self.connection.kill(id) + try: + new_conn.ping() + except mariadb.DatabaseError: + pass + del new_conn + new_conn = create_connection() + new_conn.auto_reconnect = True + id = new_conn.connection_id; + self.connection.kill(id) + new_conn.ping() + new_id = new_conn.connection_id + self.assertTrue(id != new_id) + del new_conn + + def test_change_user(self): + cursor = self.connection.cursor() + cursor.execute("create or replace user foo@localhost") + cursor.execute("GRANT ALL on testp.* TO foo@localhost") + new_conn = create_connection() + default_conf = conf() + new_conn.change_user("foo", "", default_conf["database"]) + self.assertEqual("foo", new_conn.user) + del new_conn + del cursor + + def test_reconnect(self): + new_conn = create_connection() + conn1_id = new_conn.connection_id + self.connection.kill(conn1_id) + new_conn.reconnect() + conn2_id = new_conn.connection_id + self.assertFalse(conn1_id == conn2_id) + del new_conn + + def test_reset(self): + cursor = self.connection.cursor() + cursor.execute("SELECT 1 UNION SELECT 2") + try: + self.connection.ping() + except mariadb.DatabaseError: + pass + + self.connection.reset() + self.connection.ping() + del cursor + + def test_warnings(self): + conn = self.connection + cursor = conn.cursor() + + cursor.execute("SET session sql_mode=''") + cursor.execute("CREATE TEMPORARY TABLE test_warnings (a tinyint)") + cursor.execute("INSERT INTO test_warnings VALUES (300)") + + self.assertEqual(conn.warnings, 1) + self.assertEqual(conn.warnings, cursor.warnings) + del cursor + + def test_server_infos(self): + self.assertTrue(self.connection.server_info) + self.assertTrue(self.connection.server_version > 0); + + def test_escape(self): + cursor = self.connection.cursor() + cursor.execute("CREATE TEMPORARY TABLE test_escape (a varchar(100))") + str = 'This is a \ and a "' + cmd = "INSERT INTO test_escape VALUES('%s')" % str + + try: + cursor.execute(cmd) + except mariadb.DatabaseError: + pass + + str = self.connection.escape_string(str) + cmd = "INSERT INTO test_escape VALUES('%s')" % str + cursor.execute(cmd) + del cursor + + +if __name__ == '__main__': + unittest.main() diff --git a/test/nondbapi.py b/test/nondbapi.py deleted file mode 100644 index 5ec8fac..0000000 --- a/test/nondbapi.py +++ /dev/null @@ -1,110 +0,0 @@ -import mariadb -import datetime -import unittest -import collections -import time - -class CursorTest(unittest.TestCase): - - def setUp(self): - self.connection= mariadb.connection(default_file='default.cnf') - - def tearDown(self): - del self.connection - - def test_ping(self): - new_conn= mariadb.connection(default_file='default.cnf') - id= new_conn.connection_id; - self.connection.kill(id) - try: - new_conn.ping() - except mariadb.DatabaseError: - pass - del new_conn - new_conn= mariadb.connection(default_file='default.cnf') - new_conn.auto_reconnect= True - id= new_conn.connection_id; - self.connection.kill(id) - new_conn.ping() - new_id= new_conn.connection_id - self.assertTrue(id != new_id) - del new_conn - - def test_change_user(self): - cursor= self.connection.cursor() - cursor.execute("create or replace user foo@localhost") - cursor.execute("GRANT ALL on test.* TO 'foo'@'localhost'") - new_conn= mariadb.connection(default_file='default.cnf') - new_conn.change_user("foo", "", "test") - self.assertEqual("foo", new_conn.user) - del new_conn - cursor.execute("drop user foo@localhost") - del cursor - - def test_db(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE schema test1") - self.assertEqual(self.connection.database, "test") - self.connection.database= "test1" - self.assertEqual(self.connection.database, "test1") - self.connection.database= "test" - self.assertEqual(self.connection.database, "test") - cursor.execute("USE test1") - self.assertEqual(self.connection.database, "test1") - cursor.execute("USE test") - cursor.execute("DROP SCHEMA test1") - del cursor - - def test_reconnect(self): - new_conn= mariadb.connection(default_file='default.cnf') - conn1_id= new_conn.connection_id - self.connection.kill(conn1_id) - new_conn.reconnect() - conn2_id= new_conn.connection_id - self.assertFalse(conn1_id == conn2_id) - del new_conn - - def test_reset(self): - cursor= self.connection.cursor() - cursor.execute("SELECT 1 UNION SELECT 2") - try: - self.connection.ping() - except mariadb.DatabaseError: - pass - - self.connection.reset() - self.connection.ping() - del cursor - - def test_warnings(self): - conn= self.connection - cursor= conn.cursor() - - cursor.execute("SET session sql_mode=''") - cursor.execute("CREATE OR REPLACE TABLE t1 (a tinyint)") - cursor.execute("INSERT INTO t1 VALUES (300)") - - self.assertEqual(conn.warnings,1) - self.assertEqual(conn.warnings, cursor.warnings) - del cursor - - def test_server_infos(self): - self.assertTrue(self.connection.server_info) - self.assertTrue(self.connection.server_version > 0); - - def test_escape(self): - cursor= self.connection.cursor() - cursor.execute("CREATE OR REPLACE TABLE t1 (a varchar(100))") - str= 'This is a \ and a "' - cmd= "INSERT INTO t1 VALUES('%s')" % str - - try: - cursor.execute(cmd) - except mariadb.DatabaseError: - pass - - str= self.connection.escape_string(str) - cmd= "INSERT INTO t1 VALUES('%s')" % str - cursor.execute(cmd) - del cursor -