CONPY-299: support for VECTOR data type

Implemented support for VECTOR paramter. Beside text representation
(Vec_FromText) or array to byte conversion (array.tobytes) it is now
possible to specify an float array as parameter.

Example:

import mariadb, array

...

sql = """
  CREATE OR REPLACE TABLE test(
     id INT PRIMARY KEY,
     v VECTOR(3) NOT NULL,
      VECTOR INDEX (v))")"""

cursor.execute(sql)

vector= array.array('f', [123.1, 230.9, 981.7])

cursor.execute("INSERT INTO test VALUES (?,?)", (1, vector))

Please note that the opposite way, retrieving a float array is not
supported yet. This will require either a new or extended field
type for vector data type.
This commit is contained in:
Georg Richter
2025-01-30 09:02:07 +01:00
parent abd17eb95d
commit 138a02238e
2 changed files with 91 additions and 2 deletions

View File

@ -858,6 +858,34 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
self->values[column]= val;
}
}
static uint8_t ma_is_vector(PyObject *obj)
{
PyObject *TypeCodeObj= NULL;
const char *typecode;
uint8_t rc= 0;
if (!obj)
return 0;
if (strcmp(Py_TYPE(obj)->tp_name, "array.array") &&
strcmp(Py_TYPE(obj)->tp_name, "array"))
return 0;
if (!(TypeCodeObj= PyObject_GetAttrString(obj, "typecode")) ||
!PyUnicode_Check(TypeCodeObj))
goto end;
if ((typecode = PyUnicode_AsUTF8(TypeCodeObj)) &&
!strcmp(typecode, "f"))
rc= 1;
end:
if (TypeCodeObj)
Py_DECREF(TypeCodeObj);
return rc;
}
/*
mariadb_get_column_info
This function analyzes the Python object and calculates the corresponding
@ -910,6 +938,10 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
so we need to convert decimal.Decimal Object to string during callback */
paraminfo->type= MYSQL_TYPE_NEWDECIMAL;
return 0;
} else if (ma_is_vector(obj)) {
/* CONPY-299: Vectors are defined as array of floats */
paraminfo->type= MYSQL_TYPE_LONG_BLOB;
return 0;
}
else {
/* If Object has string representation, we will use string representation */
@ -1396,8 +1428,31 @@ mariadb_param_to_bind(MrdbCursor *self,
*(double *)value->num= (double)PyFloat_AsDouble(value->value);
break;
case MYSQL_TYPE_LONG_BLOB:
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
if (!strcmp(Py_TYPE(value->value)->tp_name, "array.array") ||
!strcmp(Py_TYPE(value->value)->tp_name, "array"))
{
Py_ssize_t size= PySequence_Length(value->value);
PyObject *byte_array= NULL;
bind->buffer= NULL;
if (!size)
goto end;
if (!PyObject_HasAttrString(value->value, "tobytes"))
goto end;
if (!(byte_array= PyObject_CallMethod(value->value, "tobytes", NULL)))
goto end;
bind->buffer= (void *)PyBytes_AS_STRING(byte_array);
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(byte_array);
Py_DECREF(byte_array);
} else {
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
}
break;
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME: