CONPY-299: support for VECTOR data type

Implemented support for VECTOR paramter. Beside text representation
(Vec_FromText) or array to byte conversion (array.tobytes) it is now
possible to specify an float array as parameter.

Example:

import mariadb, array

...

sql = """
  CREATE OR REPLACE TABLE test(
     id INT PRIMARY KEY,
     v VECTOR(3) NOT NULL,
      VECTOR INDEX (v))")"""

cursor.execute(sql)

vector= array.array('f', [123.1, 230.9, 981.7])

cursor.execute("INSERT INTO test VALUES (?,?)", (1, vector))

Please note that the opposite way, retrieving a float array is not
supported yet. This will require either a new or extended field
type for vector data type.
This commit is contained in:
Georg Richter
2025-01-30 09:02:07 +01:00
parent abd17eb95d
commit 138a02238e
2 changed files with 91 additions and 2 deletions

View File

@ -858,6 +858,34 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
self->values[column]= val;
}
}
static uint8_t ma_is_vector(PyObject *obj)
{
PyObject *TypeCodeObj= NULL;
const char *typecode;
uint8_t rc= 0;
if (!obj)
return 0;
if (strcmp(Py_TYPE(obj)->tp_name, "array.array") &&
strcmp(Py_TYPE(obj)->tp_name, "array"))
return 0;
if (!(TypeCodeObj= PyObject_GetAttrString(obj, "typecode")) ||
!PyUnicode_Check(TypeCodeObj))
goto end;
if ((typecode = PyUnicode_AsUTF8(TypeCodeObj)) &&
!strcmp(typecode, "f"))
rc= 1;
end:
if (TypeCodeObj)
Py_DECREF(TypeCodeObj);
return rc;
}
/*
mariadb_get_column_info
This function analyzes the Python object and calculates the corresponding
@ -910,6 +938,10 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
so we need to convert decimal.Decimal Object to string during callback */
paraminfo->type= MYSQL_TYPE_NEWDECIMAL;
return 0;
} else if (ma_is_vector(obj)) {
/* CONPY-299: Vectors are defined as array of floats */
paraminfo->type= MYSQL_TYPE_LONG_BLOB;
return 0;
}
else {
/* If Object has string representation, we will use string representation */
@ -1396,8 +1428,31 @@ mariadb_param_to_bind(MrdbCursor *self,
*(double *)value->num= (double)PyFloat_AsDouble(value->value);
break;
case MYSQL_TYPE_LONG_BLOB:
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
if (!strcmp(Py_TYPE(value->value)->tp_name, "array.array") ||
!strcmp(Py_TYPE(value->value)->tp_name, "array"))
{
Py_ssize_t size= PySequence_Length(value->value);
PyObject *byte_array= NULL;
bind->buffer= NULL;
if (!size)
goto end;
if (!PyObject_HasAttrString(value->value, "tobytes"))
goto end;
if (!(byte_array= PyObject_CallMethod(value->value, "tobytes", NULL)))
goto end;
bind->buffer= (void *)PyBytes_AS_STRING(byte_array);
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(byte_array);
Py_DECREF(byte_array);
} else {
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
}
break;
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:

View File

@ -7,6 +7,7 @@ import os
import decimal
import json
from decimal import Decimal
import array
import mariadb
from mariadb.constants import FIELD_TYPE, EXT_FIELD_TYPE, ERR, CURSOR, INDICATOR, CLIENT
@ -58,6 +59,39 @@ class TestCursor(unittest.TestCase):
cursor.close()
conn.close()
def test_conpy299(self):
if is_mysql():
self.skipTest("Skip (MySQL)")
if self.connection.server_version < 110702:
self.skipTest("Requires server version >= 11.7.2")
cursor= self.connection.cursor()
cursor.execute("DROP TABLE IF EXISTS t_vector")
cursor.execute("CREATE TABLE t_vector (id int not null, v VECTOR(3) NOT NULL, VECTOR INDEX(v))")
# Vector can't be empty
empty= array.array('f', [])
try:
cursor.execute("INSERT INTO t_vector VALUES (?,?)", (1, empty))
except mariadb.IntegrityError:
pass
# Valid vector
data= array.array('f', [201.1, 302.2, 403.3])
cursor.execute("INSERT INTO t_vector VALUES (?,?)", (1, data))
cursor.execute("SELECT id, v, Vec_ToText(v) FROM t_vector")
row= cursor.fetchone()
check_data= [row[1], array.array('f', eval(row[2]))]
cursor.execute("DROP TABLE t_vector")
cursor.close()
self.assertEqual(check_data[0], data.tobytes())
self.assertEqual(check_data[1], data)
def test_date(self):
v = self.connection.server_version
i = self.connection.server_info.lower()