mirror of
https://github.com/mariadb-corporation/mariadb-connector-python.git
synced 2025-08-13 15:21:02 +00:00
CONPY-299: support for VECTOR data type
Implemented support for VECTOR paramter. Beside text representation (Vec_FromText) or array to byte conversion (array.tobytes) it is now possible to specify an float array as parameter. Example: import mariadb, array ... sql = """ CREATE OR REPLACE TABLE test( id INT PRIMARY KEY, v VECTOR(3) NOT NULL, VECTOR INDEX (v))")""" cursor.execute(sql) vector= array.array('f', [123.1, 230.9, 981.7]) cursor.execute("INSERT INTO test VALUES (?,?)", (1, vector)) Please note that the opposite way, retrieving a float array is not supported yet. This will require either a new or extended field type for vector data type.
This commit is contained in:
@ -858,6 +858,34 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
|
||||
self->values[column]= val;
|
||||
}
|
||||
}
|
||||
|
||||
static uint8_t ma_is_vector(PyObject *obj)
|
||||
{
|
||||
PyObject *TypeCodeObj= NULL;
|
||||
const char *typecode;
|
||||
uint8_t rc= 0;
|
||||
|
||||
if (!obj)
|
||||
return 0;
|
||||
|
||||
if (strcmp(Py_TYPE(obj)->tp_name, "array.array") &&
|
||||
strcmp(Py_TYPE(obj)->tp_name, "array"))
|
||||
return 0;
|
||||
|
||||
if (!(TypeCodeObj= PyObject_GetAttrString(obj, "typecode")) ||
|
||||
!PyUnicode_Check(TypeCodeObj))
|
||||
goto end;
|
||||
|
||||
if ((typecode = PyUnicode_AsUTF8(TypeCodeObj)) &&
|
||||
!strcmp(typecode, "f"))
|
||||
rc= 1;
|
||||
|
||||
end:
|
||||
if (TypeCodeObj)
|
||||
Py_DECREF(TypeCodeObj);
|
||||
return rc;
|
||||
}
|
||||
|
||||
/*
|
||||
mariadb_get_column_info
|
||||
This function analyzes the Python object and calculates the corresponding
|
||||
@ -910,6 +938,10 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
|
||||
so we need to convert decimal.Decimal Object to string during callback */
|
||||
paraminfo->type= MYSQL_TYPE_NEWDECIMAL;
|
||||
return 0;
|
||||
} else if (ma_is_vector(obj)) {
|
||||
/* CONPY-299: Vectors are defined as array of floats */
|
||||
paraminfo->type= MYSQL_TYPE_LONG_BLOB;
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
/* If Object has string representation, we will use string representation */
|
||||
@ -1396,8 +1428,31 @@ mariadb_param_to_bind(MrdbCursor *self,
|
||||
*(double *)value->num= (double)PyFloat_AsDouble(value->value);
|
||||
break;
|
||||
case MYSQL_TYPE_LONG_BLOB:
|
||||
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
|
||||
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
|
||||
if (!strcmp(Py_TYPE(value->value)->tp_name, "array.array") ||
|
||||
!strcmp(Py_TYPE(value->value)->tp_name, "array"))
|
||||
{
|
||||
Py_ssize_t size= PySequence_Length(value->value);
|
||||
PyObject *byte_array= NULL;
|
||||
|
||||
bind->buffer= NULL;
|
||||
|
||||
if (!size)
|
||||
goto end;
|
||||
|
||||
if (!PyObject_HasAttrString(value->value, "tobytes"))
|
||||
goto end;
|
||||
|
||||
if (!(byte_array= PyObject_CallMethod(value->value, "tobytes", NULL)))
|
||||
goto end;
|
||||
|
||||
bind->buffer= (void *)PyBytes_AS_STRING(byte_array);
|
||||
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(byte_array);
|
||||
|
||||
Py_DECREF(byte_array);
|
||||
} else {
|
||||
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
|
||||
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
|
||||
}
|
||||
break;
|
||||
case MYSQL_TYPE_DATE:
|
||||
case MYSQL_TYPE_TIME:
|
||||
|
@ -7,6 +7,7 @@ import os
|
||||
import decimal
|
||||
import json
|
||||
from decimal import Decimal
|
||||
import array
|
||||
|
||||
import mariadb
|
||||
from mariadb.constants import FIELD_TYPE, EXT_FIELD_TYPE, ERR, CURSOR, INDICATOR, CLIENT
|
||||
@ -58,6 +59,39 @@ class TestCursor(unittest.TestCase):
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
def test_conpy299(self):
|
||||
if is_mysql():
|
||||
self.skipTest("Skip (MySQL)")
|
||||
if self.connection.server_version < 110702:
|
||||
self.skipTest("Requires server version >= 11.7.2")
|
||||
|
||||
cursor= self.connection.cursor()
|
||||
cursor.execute("DROP TABLE IF EXISTS t_vector")
|
||||
cursor.execute("CREATE TABLE t_vector (id int not null, v VECTOR(3) NOT NULL, VECTOR INDEX(v))")
|
||||
|
||||
# Vector can't be empty
|
||||
empty= array.array('f', [])
|
||||
try:
|
||||
cursor.execute("INSERT INTO t_vector VALUES (?,?)", (1, empty))
|
||||
except mariadb.IntegrityError:
|
||||
pass
|
||||
|
||||
# Valid vector
|
||||
data= array.array('f', [201.1, 302.2, 403.3])
|
||||
|
||||
cursor.execute("INSERT INTO t_vector VALUES (?,?)", (1, data))
|
||||
cursor.execute("SELECT id, v, Vec_ToText(v) FROM t_vector")
|
||||
row= cursor.fetchone()
|
||||
|
||||
check_data= [row[1], array.array('f', eval(row[2]))]
|
||||
|
||||
cursor.execute("DROP TABLE t_vector")
|
||||
cursor.close()
|
||||
|
||||
self.assertEqual(check_data[0], data.tobytes())
|
||||
self.assertEqual(check_data[1], data)
|
||||
|
||||
|
||||
def test_date(self):
|
||||
v = self.connection.server_version
|
||||
i = self.connection.server_info.lower()
|
||||
|
Reference in New Issue
Block a user