diff --git a/include/mariadb_python.h b/include/mariadb_python.h index 9fa3b8e..91bd06d 100755 --- a/include/mariadb_python.h +++ b/include/mariadb_python.h @@ -253,8 +253,7 @@ typedef struct { enum enum_result_format result_format; uint8_t is_prepared; uint8_t is_buffered; -/* uint8_t is_named_tuple; - uint8_t is_dictionary; */ + uint8_t fetched; uint8_t is_closed; uint8_t is_text; MrdbParser *parser; diff --git a/src/mariadb_cursor.c b/src/mariadb_cursor.c index d6e526d..62bd1b3 100644 --- a/src/mariadb_cursor.c +++ b/src/mariadb_cursor.c @@ -429,6 +429,7 @@ void MrdbCursor_clear(MrdbCursor *self, uint8_t new_stmt) } } + self->fetched= 0; if (self->is_text) { @@ -902,6 +903,8 @@ static int MrdbCursor_fetchinternal(MrdbCursor *self) int rc; unsigned int i; + self->fetched= 1; + if (!self->is_text) { rc= mysql_stmt_fetch(self->stmt); @@ -1418,7 +1421,13 @@ Mariadb_row_count(MrdbCursor *self) if (self->field_count) { - row_count= CURSOR_NUM_ROWS(self); + if (!self->is_buffered && !self->fetched) + { + row_count= -1; + } else + { + row_count= CURSOR_NUM_ROWS(self); + } } else { row_count= self->row_count ? self->row_count : CURSOR_AFFECTED_ROWS(self); diff --git a/test/integration/test_cursor.py b/test/integration/test_cursor.py index ca75348..d08bc08 100644 --- a/test/integration/test_cursor.py +++ b/test/integration/test_cursor.py @@ -141,13 +141,13 @@ class TestCursor(unittest.TestCase): self.assertRaises(mariadb.Error, cursor.fetchall) cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id") - self.assertEqual(0, cursor.rowcount) + self.assertEqual(-1, cursor.rowcount) row = cursor.fetchall() self.assertEqual(row, params) self.assertEqual(5, cursor.rowcount) cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id") - self.assertEqual(0, cursor.rowcount) + self.assertEqual(-1, cursor.rowcount) row = cursor.fetchmany(1) self.assertEqual(row, [params[0]]) @@ -927,5 +927,17 @@ class TestCursor(unittest.TestCase): row= cur.fetchone() self.assertEqual(row[0], Decimal(1.25)) + def test_conpy67(self): + con= create_connection() + cur = con.cursor() + cur.execute("SELECT 1") + self.assertEqual(cur.rowcount, -1) + del cur + cur = con.cursor() + cur.execute("SELECT 1 WHERE 1=2") + self.assertEqual(cur.rowcount, -1) + cur.fetchall() + self.assertEqual(cur.rowcount, 0) + if __name__ == '__main__': unittest.main()