Follow up for CONPY-299

- Added byteswap function for vector (bigendian)
- Use buffer protocol instead of calling array's
  methods
This commit is contained in:
Georg Richter
2025-02-10 13:59:43 +01:00
parent 9508904911
commit befe7000c9
2 changed files with 183 additions and 84 deletions

View File

@ -24,6 +24,53 @@
#define IS_DECIMAL_TYPE(type) \
((type) == MYSQL_TYPE_NEWDECIMAL || (type) == MYSQL_TYPE_DOUBLE || (type) == MYSQL_TYPE_FLOAT)
static char *ma_byteswap(char *buf, size_t itemsize, size_t len)
{
char *p;
Py_ssize_t i;
switch (itemsize) {
case 1:
return buf;
case 2:
for (p = buf, i = len; --i >= 0; p += itemsize) {
char p0 = p[0];
p[0] = p[1];
p[1] = p0;
}
break;
case 4:
for (p = buf, i = len; --i >= 0; p += itemsize) {
char p0 = p[0];
char p1 = p[1];
p[0] = p[3];
p[1] = p[2];
p[2] = p1;
p[3] = p0;
}
break;
case 8:
for (p = buf, i = len; --i >= 0; p += itemsize) {
char p0 = p[0];
char p1 = p[1];
char p2 = p[2];
char p3 = p[3];
p[0] = p[7];
p[1] = p[6];
p[2] = p[5];
p[3] = p[4];
p[4] = p3;
p[5] = p2;
p[6] = p1;
p[7] = p0;
}
break;
default:
return 0;
}
return buf;
}
long MrdbIndicator_AsLong(PyObject *column)
{
PyObject *pyLong= PyObject_GetAttrString(column, "indicator");
@ -299,7 +346,7 @@ static uint8_t check_time(MYSQL_TIME *tm)
Year must be < 10000, month < 12, day < 32
Years with 2 digits, are converted to values 1970-2069 according to
Years with 2 digits, are converted to values 1970-2069 according to
usual rules:
00-69 is converted to 2000-2069.
@ -382,7 +429,7 @@ int Py_str_to_TIME(const char *str, size_t length, MYSQL_TIME *tm)
{
if (parse_time(p, end - p, &p, tm))
goto error;
tm->year = tm->month = tm->day = 0;
tm->time_type = MYSQL_TIMESTAMP_TIME;
return 0;
@ -419,7 +466,7 @@ error:
static PyObject *Mrdb_GetTimeDelta(MYSQL_TIME *tm)
{
int days, hour, minute, second, second_part;
hour= (tm->neg) ? -1 * tm->hour : tm->hour;
minute= (tm->neg) ? -1 * tm->minute : tm->minute;
second= (tm->neg) ? -1 * tm->second : tm->second;
@ -428,7 +475,7 @@ static PyObject *Mrdb_GetTimeDelta(MYSQL_TIME *tm)
days= hour / 24;
hour= hour % 24;
second= hour * 3600 + minute * 60 + second;
return PyDelta_FromDSU(days, second, second_part);
}
@ -452,7 +499,7 @@ static PyObject *ma_convert_value(MrdbCursor *self,
}
return new_value;
}
void
field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
{
@ -492,8 +539,8 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
self->values[column]= PyLong_FromString(p, NULL, 0);
break;
}
case MYSQL_TYPE_FLOAT:
case MYSQL_TYPE_DOUBLE:
case MYSQL_TYPE_FLOAT:
case MYSQL_TYPE_DOUBLE:
{
double d= atof(data);
self->values[column]= PyFloat_FromDouble(d);
@ -525,7 +572,7 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
Py_INCREF(Py_None);
self->values[column]= Py_None;
}
} else
} else
{
if (check_date(tm.year, tm.month, tm.day) &&
check_time(&tm))
@ -551,13 +598,13 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
}
if (self->fields[column].charsetnr== CHARSET_BINARY)
{
self->values[column]=
self->values[column]=
PyBytes_FromStringAndSize((const char *)data,
(Py_ssize_t)length[column]);
}
else {
self->values[column]=
PyUnicode_FromStringAndSize((const char *)data,
self->values[column]=
PyUnicode_FromStringAndSize((const char *)data,
(Py_ssize_t)length[column]);
}
break;
@ -612,11 +659,11 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
if ((val= ma_convert_value(self, type, self->values[column])))
self->values[column]= val;
}
}
}
/* field_fetch_callback
This function was previously registered with mysql_stmt_attr_set and
STMT_ATTR_FIELD_FETCH_CALLBACK parameter. Instead of filling a bind
STMT_ATTR_FIELD_FETCH_CALLBACK parameter. Instead of filling a bind
buffer MariaDB Connector/C sends raw data in row for the specified column.
In case of a NULL value row ptr will be NULL.
@ -659,7 +706,7 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
*row+= 2;
break;
case MYSQL_TYPE_INT24:
self->values[column]=
self->values[column]=
(self->fields[column].flags & UNSIGNED_FLAG) ?
PyLong_FromUnsignedLong((unsigned long)uint3korr(*row)) :
PyLong_FromLong((long)sint3korr(*row));
@ -722,7 +769,7 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
}
if (len == 11)
second_part= uint4korr(*row + 7);
self->values[column]= PyDateTime_FromDateAndTime(year, month,
self->values[column]= PyDateTime_FromDateAndTime(year, month,
day, hour, minute, second, second_part);
*row+= len;
break;
@ -782,13 +829,13 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
self->fields[column].max_length= length;
if (self->fields[column].charsetnr== CHARSET_BINARY)
{
self->values[column]=
PyBytes_FromStringAndSize((const char *)*row,
self->values[column]=
PyBytes_FromStringAndSize((const char *)*row,
(Py_ssize_t)length);
}
else {
self->values[column]=
PyUnicode_FromStringAndSize((const char *)*row,
PyUnicode_FromStringAndSize((const char *)*row,
(Py_ssize_t)length);
}
*row+= length;
@ -831,7 +878,7 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
if (length > self->fields[column].max_length)
self->fields[column].max_length= length;
} else {
self->values[column]=
self->values[column]=
PyUnicode_FromStringAndSize((const char *)*row,
(Py_ssize_t)length);
utf8len= (unsigned long)PyUnicode_GET_LENGTH(self->values[column]);
@ -886,13 +933,13 @@ end:
return rc;
}
/*
/*
mariadb_get_column_info
This function analyzes the Python object and calculates the corresponding
MYSQL_TYPE, unsigned flag or NULL values and stores the information in
MrdbParamInfo pointer.
*/
static uint8_t
static uint8_t
mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
{
if (obj == NULL)
@ -943,8 +990,11 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)
paraminfo->type= MYSQL_TYPE_LONG_BLOB;
return 0;
}
else {
/* If Object has string representation, we will use string representation */
else if (Py_TYPE(obj)->tp_str) {
/* If Object has string representation, we will use string representation */
paraminfo->type= MYSQL_TYPE_VAR_STRING;
return 0;
} else {
/* no corresponding object, return error */
return 2;
}
@ -965,7 +1015,7 @@ PyObject *ListOrTuple_GetItem(PyObject *obj, Py_ssize_t index)
return NULL;
}
/*
/*
mariadb_get_parameter()
@brief Returns a bulk parameter which was passed to
@ -979,7 +1029,7 @@ PyObject *ListOrTuple_GetItem(PyObject *obj, Py_ssize_t index)
@return 0 on success, 1 on error
*/
static uint8_t
static uint8_t
mariadb_get_parameter(MrdbCursor *self,
uint8_t is_bulk,
uint32_t row_nr,
@ -1060,7 +1110,7 @@ mariadb_get_parameter(MrdbCursor *self,
{
param->indicator= STMT_INDICATOR_NULL;
}
}
}
else {
param->value= column;
param->indicator= STMT_INDICATOR_NONE;
@ -1070,7 +1120,7 @@ end:
return rc;
}
/*
/*
mariadb_get_parameter_info
mariadb_get_parameter_info fills the MYSQL_BIND structure
with correct field_types for the Python objects.
@ -1079,7 +1129,7 @@ end:
the field type (e.g. by checking maxbit size for a PyLong).
If the types in this column differ we will return an error.
*/
static uint8_t
static uint8_t
mariadb_get_parameter_info(MrdbCursor *self,
MYSQL_BIND *param,
uint32_t column_nr)
@ -1090,10 +1140,10 @@ mariadb_get_parameter_info(MrdbCursor *self,
param->is_unsigned= 0;
paramvalue.indicator= 0;
uint8_t rc;
if (!self->array_size)
{
uint8_t rc;
memset(&pinfo, 0, sizeof(MrdbParamInfo));
if (mariadb_get_parameter(self, 0, 0, column_nr, &paramvalue))
return 1;
@ -1103,15 +1153,14 @@ mariadb_get_parameter_info(MrdbCursor *self,
{
mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0,
"Can't retrieve column information for parameter %d",
column_nr);
column_nr + 1);
}
if (rc == 2)
{
mariadb_throw_exception(NULL, Mariadb_NotSupportedError, 0,
"Data type '%s' in column %d not supported in MariaDB Connector/Python",
Py_TYPE(paramvalue.value)->tp_name, column_nr);
Py_TYPE(paramvalue.value)->tp_name, column_nr + 1);
}
return 1;
}
param->buffer_type= pinfo.type;
@ -1122,11 +1171,20 @@ mariadb_get_parameter_info(MrdbCursor *self,
if (mariadb_get_parameter(self, 1, i, column_nr, &paramvalue))
return 1;
memset(&pinfo, 0, sizeof(MrdbParamInfo));
if (mariadb_get_column_info(paramvalue.value, &pinfo) && !paramvalue.indicator)
if ((rc= mariadb_get_column_info(paramvalue.value, &pinfo) && !paramvalue.indicator))
{
mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 1,
"Invalid parameter type at row %d, column %d",
i+1, column_nr + 1);
if (rc == 1)
{
mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0,
"Can't retrieve column information for parameter %d at row %d.",
column_nr + 1, i + 1);
}
if (rc == 2)
{
mariadb_throw_exception(NULL, Mariadb_NotSupportedError, 0,
"Data type '%s' in column %d at row %d not supported in MariaDB Connector/Python",
Py_TYPE(paramvalue.value)->tp_name, column_nr + 1, i+1);
}
return 1;
}
@ -1213,7 +1271,7 @@ static Py_ssize_t ListOrTuple_Size(PyObject *obj)
This function validates the specified bulk parameters and
translates the field types to MYSQL_TYPE_*.
*/
uint8_t
uint8_t
mariadb_check_bulk_parameters(MrdbCursor *self,
PyObject *data)
{
@ -1222,14 +1280,14 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
if (!CHECK_TYPE((data), &PyList_Type) &&
!CHECK_TYPE(data, &PyTuple_Type))
{
mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1,
mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1,
"Data must be passed as sequence (Tuple or List)");
return 1;
}
if (!(self->array_size= (uint32_t)ListOrTuple_Size(data)))
{
mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1,
mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1,
"Empty parameter list. At least one row must be specified");
return 1;
}
@ -1256,10 +1314,10 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
}
if (!self->parseinfo.paramcount ||
(self->parseinfo.paramstyle != PYFORMAT &&
(self->parseinfo.paramstyle != PYFORMAT &&
self->parseinfo.paramcount != ListOrTuple_Size(obj)))
{
mariadb_throw_exception(self->stmt, Mariadb_ProgrammingError, 1,
mariadb_throw_exception(self->stmt, Mariadb_ProgrammingError, 1,
"Invalid number of parameters in row %d", i+1);
return 1;
}
@ -1275,7 +1333,7 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
goto error;
}
if (!(self->value= PyMem_RawCalloc(self->parseinfo.paramcount,
if (!(self->value= PyMem_RawCalloc(self->parseinfo.paramcount,
sizeof(MrdbParamValue))))
{
mariadb_throw_exception(NULL, Mariadb_InterfaceError, 0,
@ -1341,7 +1399,7 @@ error:
return 1;
}
/*
/*
mariadb_param_to_bind()
@brief Set the current value for the specified bind buffer
@ -1352,7 +1410,7 @@ error:
@return 0 on success, otherwise error
*/
static uint8_t
static uint8_t
mariadb_param_to_bind(MrdbCursor *self,
MYSQL_BIND *bind,
MrdbParamValue *value)
@ -1428,27 +1486,41 @@ mariadb_param_to_bind(MrdbCursor *self,
*(double *)value->num= (double)PyFloat_AsDouble(value->value);
break;
case MYSQL_TYPE_LONG_BLOB:
if (value->free_me)
{
MARIADB_FREE_MEM(value->buffer);
value->free_me= 0;
}
if (!strcmp(Py_TYPE(value->value)->tp_name, "array.array") ||
!strcmp(Py_TYPE(value->value)->tp_name, "array"))
{
Py_ssize_t size= PySequence_Length(value->value);
PyObject *byte_array= NULL;
Py_buffer v;
bind->buffer= NULL;
if (!size)
if (PyObject_GetBuffer(value->value, &v, PyBUF_CONTIG_RO) < 0)
goto end;
if (!PyObject_HasAttrString(value->value, "tobytes"))
if (!v.len)
goto end;
if (!(byte_array= PyObject_CallMethod(value->value, "tobytes", NULL)))
goto end;
bind->buffer_length= (unsigned long)v.len;
#if PY_BIG_ENDIAN == 0
bind->buffer= (void *)v.buf;
#else
bind->buffer= value->buffer= PyMem_RawCalloc(1, v.len);
if (!bind->buffer)
{
mariadb_throw_exception(NULL, Mariadb_InterfaceError, 0,
"Not enough memory (tried to allocated %lld bytes)", v.len);
return 1;
}
value->free_me= 1;
memcpy(bind->buffer, v.buf, v.len);
bind->buffer= (void *)PyBytes_AS_STRING(byte_array);
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(byte_array);
Py_DECREF(byte_array);
bind->buffer= ma_byteswap((char *)bind->buffer, v.itemsize, v.len);
#endif
PyBuffer_Release(&v);
} else {
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
@ -1464,37 +1536,43 @@ mariadb_param_to_bind(MrdbCursor *self,
mariadb_pydate_to_tm(bind->buffer_type, value->value, &value->tm);
break;
case MYSQL_TYPE_NEWDECIMAL:
{
Py_ssize_t len;
PyObject *obj= NULL;
char *p;
if (value->free_me)
MARIADB_FREE_MEM(value->buffer);
if (!strcmp(Py_TYPE(value->value)->tp_name, "decimal.Decimal") ||
!strcmp(Py_TYPE(value->value)->tp_name, "Decimal"))
{
obj= PyObject_Str(value->value);
p= (void *)PyUnicode_AsUTF8AndSize(obj, &len);
}
else
{
obj= PyObject_Str(value->value);
p= (void *)PyUnicode_AsUTF8AndSize(obj, &len);
}
bind->buffer= value->buffer= PyMem_RawCalloc(1, len);
memcpy(value->buffer, p, len);
value->free_me= 1;
bind->buffer_length= (unsigned long)len;
Py_DECREF(obj);
}
break;
case MYSQL_TYPE_VAR_STRING:
{
Py_ssize_t len;
bind->buffer= (void *)PyUnicode_AsUTF8AndSize(value->value, &len);
bind->buffer_length= (unsigned long)len;
if (value->free_me)
{
MARIADB_FREE_MEM(value->buffer);
value->free_me= 0;
}
if (CHECK_TYPE(value->value, &PyUnicode_Type)) {
bind->buffer= (void *)PyUnicode_AsUTF8AndSize(value->value, &len);
bind->buffer_length= (unsigned long)len;
} else {
PyObject *obj= PyObject_Str(value->value);
char *p;
if (!obj) {
mariadb_throw_exception(self->stmt, Mariadb_ProgrammingError, 0,
"Python type %s has no string representation",
Py_TYPE(value->value)->tp_name);
return 1;
}
p= (void *)PyUnicode_AsUTF8AndSize(obj, &len);
if (!(bind->buffer= value->buffer= PyMem_RawCalloc(1, len)))
{
mariadb_throw_exception(NULL, Mariadb_InterfaceError, 0,
"Not enough memory (tried to allocated %lld bytes)", len);
return 1;
}
value->free_me= 1;
memcpy(bind->buffer, p, len);
bind->buffer_length= (unsigned long)len;
Py_DECREF(obj);
}
break;
}
case MYSQL_TYPE_NULL:
@ -1506,7 +1584,7 @@ end:
return rc;
}
/*
/*
mariadb_param_update()
@brief Callback function which updates the bind structure's buffer and
length with data from the specified row number. This callback function
@ -1529,7 +1607,7 @@ mariadb_param_update(void *data, MYSQL_BIND *bind, uint32_t row_nr)
for (i=0; i < self->parseinfo.paramcount; i++)
{
if (mariadb_get_parameter(self, (self->array_size > 0),
if (mariadb_get_parameter(self, (self->array_size > 0),
row_nr, i, &self->value[i]))
{
goto end;