cocluto v1.0.27 - fixes bug in returned table of SshAccessedMysqlDb.query('select...') and SshAccessedMysqlDb.get_last_insert_id()

- splitted  ISqlDatabaseBackend.query() into ISqlDatabaseBackend.query() and the more specialized SqlDatabaseBackend.query_select ()
- this way the return type of query_select is explicit (a Table).
- it also allows robust sql query parsing (usin json).
- replaced calls to ISqlDatabaseBackend.query('select ...') with calls to ISqlDatabaseBackend.query_select()

work related to [https://bugzilla.ipr.univ-rennes.fr/show_bug.cgi?id=3093]
This commit is contained in:
Guillaume Raffy 2026-05-26 16:39:55 +02:00
parent 5856ac0951
commit d2c973c7ed
3 changed files with 113 additions and 14 deletions

View File

@ -1,6 +1,7 @@
from typing import Union, List from typing import Union, List, Optional, Tuple
from pathlib import Path from pathlib import Path
from enum import Enum from enum import Enum
import json
import logging import logging
import MySQLdb # sudo port install py-mysql; sudo apt install python-mysqldb or pip install mysqlclient import MySQLdb # sudo port install py-mysql; sudo apt install python-mysqldb or pip install mysqlclient
import time import time
@ -43,6 +44,7 @@ def is_machine_responding(machineName):
SqlQuery = str SqlQuery = str
Table = List[Tuple]
class SqlTableField(): class SqlTableField():
@ -68,6 +70,10 @@ class SqlTableField():
self.is_autoinc_index = is_autoinc_index self.is_autoinc_index = is_autoinc_index
ColumnId = str # eg 'matrix_size'
TableId = str # eg 'benchmark_results'
class ISqlDatabaseBackend(object): class ISqlDatabaseBackend(object):
def __init__(self): def __init__(self):
pass pass
@ -79,6 +85,19 @@ class ISqlDatabaseBackend(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def query_select(self, columns: List[ColumnId], table: TableId, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
"""
performs a select query on the sql database and returns the results in the form of a list of tuples (one tuple per row, tuple values are the column values)
:param List[ColumnId] columns: the columns to select
:param TableId table: the name of the table to query
:param Optional[str] where_clause: the where clause for the query, eg "matrix_size > 100"
:param Optional[str] join_clause: the join clause for the query, eg "disables ON log.id = disables.disable_request_id"
:param bool distinct: whether to return distinct rows
:return: the results of the query
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def get_last_insert_id(self) -> int: def get_last_insert_id(self) -> int:
""" """
@ -142,10 +161,20 @@ class RemoteMysqlDb(ISqlDatabaseBackend):
""" """
:param str sql_query: the sql query to perform :param str sql_query: the sql query to perform
""" """
self._conn.query(sql_query) cursor = self._conn.cursor()
rows = self._conn.store_result() cursor.execute(sql_query)
rows = cursor.fetchall()
logging.debug("RemoteMysqlDb.query:: results of query using cursor '%s': %s", sql_query, rows)
return rows return rows
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{', '.join(columns)} FROM {table}"
if join_clause:
sql_query += f" JOIN {join_clause}"
if where_clause:
sql_query += f" WHERE {where_clause}"
return self.query(sql_query)
def get_last_insert_id(self) -> int: def get_last_insert_id(self) -> int:
""" """
:return: the id of the last inserted row :return: the id of the last inserted row
@ -167,6 +196,26 @@ class RemoteMysqlDb(ISqlDatabaseBackend):
raise NotImplementedError() raise NotImplementedError()
def json_to_table(json_str: str) -> Table:
# logging.debug("json_to_table:: json_str = '%s'", json_str)
if json_str == 'NULL':
return []
try:
json_data = json.loads(json_str)
except json.JSONDecodeError:
logging.error("json_to_table:: invalid json string: '%s'", json_str)
raise
assert isinstance(json_data, list), f'Expected a list of rows in the json string but got {type(json_data)}'
table = []
for row in json_data:
values = []
for column_value in row.values():
assert isinstance(column_value, (str, int, float)), f'Expected a string, int or float value for each column value in the json string but got {type(column_value)}'
values.append(column_value)
table.append(tuple(values))
return table
class SshAccessedMysqlDb(ISqlDatabaseBackend): class SshAccessedMysqlDb(ISqlDatabaseBackend):
"""a mysql database server accessed using ssh instead of a remote mysql client """a mysql database server accessed using ssh instead of a remote mysql client
@ -200,14 +249,58 @@ class SshAccessedMysqlDb(ISqlDatabaseBackend):
if completed_process.returncode != 0: if completed_process.returncode != 0:
logging.error(completed_process.stderr.decode(encoding='utf-8')) logging.error(completed_process.stderr.decode(encoding='utf-8'))
assert False assert False
rows = completed_process.stdout.decode('utf-8').split('\n') stdout = completed_process.stdout.decode('utf-8')
logging.debug("SshAccessedMysqlDb.query:: results of query '%s': %s", sql_query, stdout)
rows = stdout.split('\n') if stdout != '' else []
logging.debug("SshAccessedMysqlDb.query:: split results of query '%s' by new lines: %s", sql_query, rows)
return rows return rows
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
# MariaDB [quman]> SELECT JSON_ARRAYAGG(JSON_OBJECT('toto', timestamp, 'iddd', user_id)) FROM log;
# +-------------------------------------------------------------------------------------------------------+
# | JSON_ARRAYAGG(JSON_OBJECT('toto', timestamp, 'iddd', user_id)) |
# +-------------------------------------------------------------------------------------------------------+
# | [{"toto": "2026-05-21 15:11:58", "iddd": "graffy"},{"toto": "2026-05-22 15:55:39", "iddd": "graffy"}] |
# +-------------------------------------------------------------------------------------------------------+
# 1 row in set (0,006 sec)
# columns_list_str = ', '.join([f'"{col}"' for col in columns])
columns_list_str = ', '.join([f"'{col}', {col}" for col in columns])
columns_statement = f'JSON_ARRAYAGG(JSON_OBJECT({columns_list_str}))'
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{columns_statement} FROM {table}"
if join_clause:
sql_query += f" JOIN {join_clause}"
if where_clause:
sql_query += f" WHERE {where_clause}"
stdout = self.query(sql_query)
json_str = stdout[-2] # eg '[{"queue_machine": "gpuonly.q@alambix104.ipr.univ-rennes1.fr"}]'
assert isinstance(json_str, str), f'Expected a string as data line in the query output but got {type(json_str)}'
logging.debug("SshAccessedMysqlDb.query_select:: data line of query '%s': '%s'", sql_query, json_str)
# match = re.match(r'^\s*(?P<json_data>{[^}]*})\s*$', data_line)
# assert match, 'Unexpected output format for query "%s" : %s' % (sql_query, stdout)
# json_data = json.loads(match.group('json_data'))
table = json_to_table(json_str)
return table
def get_last_insert_id(self) -> int: def get_last_insert_id(self) -> int:
""" """
:return: the id of the last inserted row :return: the id of the last inserted row
""" """
return int(self.query("SELECT last_insert_id();")[0][0]) # MariaDB [quman]> SELECT JSON_OBJECT('toto', last_insert_id());
# +---------------------------------------+
# | JSON_OBJECT('toto', last_insert_id()) |
# +---------------------------------------+
# | {"toto": 0} |
# +---------------------------------------+
# 1 row in set (0,002 sec)
stdout = self.query("SELECT JSON_OBJECT('toto', last_insert_id());")
data_line = stdout[-2]
match = re.match(r'^\s*(?P<json_data>{\s*"toto"\s*:\s*(\d+)\s*})\s*$', data_line)
assert match
json_data = json.loads(match.group('json_data'))
last_insert_id = int(json_data['toto'])
assert last_insert_id > 0, f'Unexpected last insert id : {last_insert_id}'
return last_insert_id
def table_exists(self, table_name: str) -> bool: def table_exists(self, table_name: str) -> bool:
rows = self.query(f"SHOW TABLES LIKE '{table_name}';") rows = self.query(f"SHOW TABLES LIKE '{table_name}';")
@ -264,7 +357,7 @@ class SqliteDb(ISqlDatabaseBackend):
:param str database_name: the name of the database withing the sqlite database (eg "iprbench") :param str database_name: the name of the database withing the sqlite database (eg "iprbench")
""" """
self.sqlite_db_path = sqlite_db_path self.sqlite_db_path = sqlite_db_path
self._cur = None self._cur = None
check_same_thread = False check_same_thread = False
# this is to prevent the following error when run from apache/django : SQLite objects created in a thread can only be used in that same thread.The object was created in thread id 139672342353664 and this is thread id 139672333960960 # this is to prevent the following error when run from apache/django : SQLite objects created in a thread can only be used in that same thread.The object was created in thread id 139672342353664 and this is thread id 139672333960960
@ -293,6 +386,14 @@ class SqliteDb(ISqlDatabaseBackend):
self._con.commit() self._con.commit()
return rows return rows
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{', '.join(columns)} FROM {table}"
if join_clause:
sql_query += f" JOIN {join_clause}"
if where_clause:
sql_query += f" WHERE {where_clause}"
return self.query(sql_query)
def get_last_insert_id(self) -> int: def get_last_insert_id(self) -> int:
""" """
:return: the id of the last inserted row :return: the id of the last inserted row

View File

@ -263,8 +263,7 @@ class QueueManager:
return log_id return log_id
def get_disable_requests(self, queue_machine: QueueMachineId) -> Dict[int, DisableRequest]: def get_disable_requests(self, queue_machine: QueueMachineId) -> Dict[int, DisableRequest]:
sql_query = f"SELECT log.id, log.user_id, log.host_fqdn, log.queue_machines, log.reason, log.disable_id, log.timestamp FROM log JOIN disables ON log.id = disables.disable_request_id WHERE disables.queue_machine = '{queue_machine}' AND log.action = 'disable';" results = self.db_backend.query_select(['log.id', 'log.user_id', 'log.host_fqdn', 'log.queue_machines', 'log.reason', 'log.disable_id', 'log.timestamp'], join_clause="disables ON log.id = disables.disable_request_id", table="log", where_clause=f"disables.queue_machine = '{queue_machine}' AND log.action = 'disable'")
results = self.db_backend.query(sql_query)
disable_requests = [] disable_requests = []
for row in results: for row in results:
log_id = row[0] log_id = row[0]
@ -324,10 +323,10 @@ class QueueManager:
qs = self.grid_engine.get_status() qs = self.grid_engine.get_status()
db_queues = set() db_queues = set()
sql_query = "SELECT queue_machine FROM queues;" results = self.db_backend.query_select(columns=["queue_machine"], table="queues")
results = self.db_backend.query(sql_query) logging.debug("synchronize_with_grid_engine: results of query': %s", results)
for row in results: for row in results:
assert len(row) == 1, "Each row should have only one column (queue_machine)" assert len(row) == 1, "Each row should have only one column (queue_machine) but got row='%s' (len=%d)" % (str(row), len(row))
db_queues.add(row[0]) db_queues.add(row[0])
for queue_machine, is_enabled in qs.is_enabled.items(): for queue_machine, is_enabled in qs.is_enabled.items():
@ -355,8 +354,7 @@ class QueueManager:
def get_state(self) -> QueueDisableState: def get_state(self) -> QueueDisableState:
"""returns the state of the queues.""" """returns the state of the queues."""
# get the list of queue names from the disables table in the database # get the list of queue names from the disables table in the database
sql_query = "SELECT DISTINCT queue_machine FROM disables;" results = self.db_backend.query_select(columns=["queue_machine"], table="disables", distinct=True)
results = self.db_backend.query(sql_query)
for row in results: for row in results:
assert len(row) == 1, "Each row should have only one column (queue_machine)" assert len(row) == 1, "Each row should have only one column (queue_machine)"
queue_machines = [row[0] for row in results] queue_machines = [row[0] for row in results]

View File

@ -1,4 +1,4 @@
__version__ = '1.0.26' __version__ = '1.0.27'
class Version(object): class Version(object):