cocluto v1.0.27 - fixes bug in returned table of SshAccessedMysqlDb.query('select...') and SshAccessedMysqlDb.get_last_insert_id()
- splitted ISqlDatabaseBackend.query() into ISqlDatabaseBackend.query() and the more specialized SqlDatabaseBackend.query_select ()
- this way the return type of query_select is explicit (a Table).
- it also allows robust sql query parsing (usin json).
- replaced calls to ISqlDatabaseBackend.query('select ...') with calls to ISqlDatabaseBackend.query_select()
work related to [https://bugzilla.ipr.univ-rennes.fr/show_bug.cgi?id=3093]
This commit is contained in:
parent
5856ac0951
commit
d2c973c7ed
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Union, List
|
from typing import Union, List, Optional, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import MySQLdb # sudo port install py-mysql; sudo apt install python-mysqldb or pip install mysqlclient
|
import MySQLdb # sudo port install py-mysql; sudo apt install python-mysqldb or pip install mysqlclient
|
||||||
import time
|
import time
|
||||||
|
|
@ -43,6 +44,7 @@ def is_machine_responding(machineName):
|
||||||
|
|
||||||
|
|
||||||
SqlQuery = str
|
SqlQuery = str
|
||||||
|
Table = List[Tuple]
|
||||||
|
|
||||||
|
|
||||||
class SqlTableField():
|
class SqlTableField():
|
||||||
|
|
@ -68,6 +70,10 @@ class SqlTableField():
|
||||||
self.is_autoinc_index = is_autoinc_index
|
self.is_autoinc_index = is_autoinc_index
|
||||||
|
|
||||||
|
|
||||||
|
ColumnId = str # eg 'matrix_size'
|
||||||
|
TableId = str # eg 'benchmark_results'
|
||||||
|
|
||||||
|
|
||||||
class ISqlDatabaseBackend(object):
|
class ISqlDatabaseBackend(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
@ -79,6 +85,19 @@ class ISqlDatabaseBackend(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def query_select(self, columns: List[ColumnId], table: TableId, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||||
|
"""
|
||||||
|
performs a select query on the sql database and returns the results in the form of a list of tuples (one tuple per row, tuple values are the column values)
|
||||||
|
|
||||||
|
:param List[ColumnId] columns: the columns to select
|
||||||
|
:param TableId table: the name of the table to query
|
||||||
|
:param Optional[str] where_clause: the where clause for the query, eg "matrix_size > 100"
|
||||||
|
:param Optional[str] join_clause: the join clause for the query, eg "disables ON log.id = disables.disable_request_id"
|
||||||
|
:param bool distinct: whether to return distinct rows
|
||||||
|
:return: the results of the query
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_last_insert_id(self) -> int:
|
def get_last_insert_id(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -142,10 +161,20 @@ class RemoteMysqlDb(ISqlDatabaseBackend):
|
||||||
"""
|
"""
|
||||||
:param str sql_query: the sql query to perform
|
:param str sql_query: the sql query to perform
|
||||||
"""
|
"""
|
||||||
self._conn.query(sql_query)
|
cursor = self._conn.cursor()
|
||||||
rows = self._conn.store_result()
|
cursor.execute(sql_query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
logging.debug("RemoteMysqlDb.query:: results of query using cursor '%s': %s", sql_query, rows)
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||||
|
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{', '.join(columns)} FROM {table}"
|
||||||
|
if join_clause:
|
||||||
|
sql_query += f" JOIN {join_clause}"
|
||||||
|
if where_clause:
|
||||||
|
sql_query += f" WHERE {where_clause}"
|
||||||
|
return self.query(sql_query)
|
||||||
|
|
||||||
def get_last_insert_id(self) -> int:
|
def get_last_insert_id(self) -> int:
|
||||||
"""
|
"""
|
||||||
:return: the id of the last inserted row
|
:return: the id of the last inserted row
|
||||||
|
|
@ -167,6 +196,26 @@ class RemoteMysqlDb(ISqlDatabaseBackend):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def json_to_table(json_str: str) -> Table:
|
||||||
|
# logging.debug("json_to_table:: json_str = '%s'", json_str)
|
||||||
|
if json_str == 'NULL':
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
json_data = json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error("json_to_table:: invalid json string: '%s'", json_str)
|
||||||
|
raise
|
||||||
|
assert isinstance(json_data, list), f'Expected a list of rows in the json string but got {type(json_data)}'
|
||||||
|
table = []
|
||||||
|
for row in json_data:
|
||||||
|
values = []
|
||||||
|
for column_value in row.values():
|
||||||
|
assert isinstance(column_value, (str, int, float)), f'Expected a string, int or float value for each column value in the json string but got {type(column_value)}'
|
||||||
|
values.append(column_value)
|
||||||
|
table.append(tuple(values))
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
class SshAccessedMysqlDb(ISqlDatabaseBackend):
|
class SshAccessedMysqlDb(ISqlDatabaseBackend):
|
||||||
|
|
||||||
"""a mysql database server accessed using ssh instead of a remote mysql client
|
"""a mysql database server accessed using ssh instead of a remote mysql client
|
||||||
|
|
@ -200,14 +249,58 @@ class SshAccessedMysqlDb(ISqlDatabaseBackend):
|
||||||
if completed_process.returncode != 0:
|
if completed_process.returncode != 0:
|
||||||
logging.error(completed_process.stderr.decode(encoding='utf-8'))
|
logging.error(completed_process.stderr.decode(encoding='utf-8'))
|
||||||
assert False
|
assert False
|
||||||
rows = completed_process.stdout.decode('utf-8').split('\n')
|
stdout = completed_process.stdout.decode('utf-8')
|
||||||
|
logging.debug("SshAccessedMysqlDb.query:: results of query '%s': %s", sql_query, stdout)
|
||||||
|
rows = stdout.split('\n') if stdout != '' else []
|
||||||
|
logging.debug("SshAccessedMysqlDb.query:: split results of query '%s' by new lines: %s", sql_query, rows)
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||||
|
# MariaDB [quman]> SELECT JSON_ARRAYAGG(JSON_OBJECT('toto', timestamp, 'iddd', user_id)) FROM log;
|
||||||
|
# +-------------------------------------------------------------------------------------------------------+
|
||||||
|
# | JSON_ARRAYAGG(JSON_OBJECT('toto', timestamp, 'iddd', user_id)) |
|
||||||
|
# +-------------------------------------------------------------------------------------------------------+
|
||||||
|
# | [{"toto": "2026-05-21 15:11:58", "iddd": "graffy"},{"toto": "2026-05-22 15:55:39", "iddd": "graffy"}] |
|
||||||
|
# +-------------------------------------------------------------------------------------------------------+
|
||||||
|
# 1 row in set (0,006 sec)
|
||||||
|
# columns_list_str = ', '.join([f'"{col}"' for col in columns])
|
||||||
|
columns_list_str = ', '.join([f"'{col}', {col}" for col in columns])
|
||||||
|
columns_statement = f'JSON_ARRAYAGG(JSON_OBJECT({columns_list_str}))'
|
||||||
|
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{columns_statement} FROM {table}"
|
||||||
|
if join_clause:
|
||||||
|
sql_query += f" JOIN {join_clause}"
|
||||||
|
if where_clause:
|
||||||
|
sql_query += f" WHERE {where_clause}"
|
||||||
|
stdout = self.query(sql_query)
|
||||||
|
json_str = stdout[-2] # eg '[{"queue_machine": "gpuonly.q@alambix104.ipr.univ-rennes1.fr"}]'
|
||||||
|
assert isinstance(json_str, str), f'Expected a string as data line in the query output but got {type(json_str)}'
|
||||||
|
logging.debug("SshAccessedMysqlDb.query_select:: data line of query '%s': '%s'", sql_query, json_str)
|
||||||
|
# match = re.match(r'^\s*(?P<json_data>{[^}]*})\s*$', data_line)
|
||||||
|
# assert match, 'Unexpected output format for query "%s" : %s' % (sql_query, stdout)
|
||||||
|
# json_data = json.loads(match.group('json_data'))
|
||||||
|
|
||||||
|
table = json_to_table(json_str)
|
||||||
|
return table
|
||||||
|
|
||||||
def get_last_insert_id(self) -> int:
|
def get_last_insert_id(self) -> int:
|
||||||
"""
|
"""
|
||||||
:return: the id of the last inserted row
|
:return: the id of the last inserted row
|
||||||
"""
|
"""
|
||||||
return int(self.query("SELECT last_insert_id();")[0][0])
|
# MariaDB [quman]> SELECT JSON_OBJECT('toto', last_insert_id());
|
||||||
|
# +---------------------------------------+
|
||||||
|
# | JSON_OBJECT('toto', last_insert_id()) |
|
||||||
|
# +---------------------------------------+
|
||||||
|
# | {"toto": 0} |
|
||||||
|
# +---------------------------------------+
|
||||||
|
# 1 row in set (0,002 sec)
|
||||||
|
stdout = self.query("SELECT JSON_OBJECT('toto', last_insert_id());")
|
||||||
|
data_line = stdout[-2]
|
||||||
|
match = re.match(r'^\s*(?P<json_data>{\s*"toto"\s*:\s*(\d+)\s*})\s*$', data_line)
|
||||||
|
assert match
|
||||||
|
json_data = json.loads(match.group('json_data'))
|
||||||
|
last_insert_id = int(json_data['toto'])
|
||||||
|
assert last_insert_id > 0, f'Unexpected last insert id : {last_insert_id}'
|
||||||
|
return last_insert_id
|
||||||
|
|
||||||
def table_exists(self, table_name: str) -> bool:
|
def table_exists(self, table_name: str) -> bool:
|
||||||
rows = self.query(f"SHOW TABLES LIKE '{table_name}';")
|
rows = self.query(f"SHOW TABLES LIKE '{table_name}';")
|
||||||
|
|
@ -293,6 +386,14 @@ class SqliteDb(ISqlDatabaseBackend):
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||||
|
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{', '.join(columns)} FROM {table}"
|
||||||
|
if join_clause:
|
||||||
|
sql_query += f" JOIN {join_clause}"
|
||||||
|
if where_clause:
|
||||||
|
sql_query += f" WHERE {where_clause}"
|
||||||
|
return self.query(sql_query)
|
||||||
|
|
||||||
def get_last_insert_id(self) -> int:
|
def get_last_insert_id(self) -> int:
|
||||||
"""
|
"""
|
||||||
:return: the id of the last inserted row
|
:return: the id of the last inserted row
|
||||||
|
|
|
||||||
|
|
@ -263,8 +263,7 @@ class QueueManager:
|
||||||
return log_id
|
return log_id
|
||||||
|
|
||||||
def get_disable_requests(self, queue_machine: QueueMachineId) -> Dict[int, DisableRequest]:
|
def get_disable_requests(self, queue_machine: QueueMachineId) -> Dict[int, DisableRequest]:
|
||||||
sql_query = f"SELECT log.id, log.user_id, log.host_fqdn, log.queue_machines, log.reason, log.disable_id, log.timestamp FROM log JOIN disables ON log.id = disables.disable_request_id WHERE disables.queue_machine = '{queue_machine}' AND log.action = 'disable';"
|
results = self.db_backend.query_select(['log.id', 'log.user_id', 'log.host_fqdn', 'log.queue_machines', 'log.reason', 'log.disable_id', 'log.timestamp'], join_clause="disables ON log.id = disables.disable_request_id", table="log", where_clause=f"disables.queue_machine = '{queue_machine}' AND log.action = 'disable'")
|
||||||
results = self.db_backend.query(sql_query)
|
|
||||||
disable_requests = []
|
disable_requests = []
|
||||||
for row in results:
|
for row in results:
|
||||||
log_id = row[0]
|
log_id = row[0]
|
||||||
|
|
@ -324,10 +323,10 @@ class QueueManager:
|
||||||
qs = self.grid_engine.get_status()
|
qs = self.grid_engine.get_status()
|
||||||
|
|
||||||
db_queues = set()
|
db_queues = set()
|
||||||
sql_query = "SELECT queue_machine FROM queues;"
|
results = self.db_backend.query_select(columns=["queue_machine"], table="queues")
|
||||||
results = self.db_backend.query(sql_query)
|
logging.debug("synchronize_with_grid_engine: results of query': %s", results)
|
||||||
for row in results:
|
for row in results:
|
||||||
assert len(row) == 1, "Each row should have only one column (queue_machine)"
|
assert len(row) == 1, "Each row should have only one column (queue_machine) but got row='%s' (len=%d)" % (str(row), len(row))
|
||||||
db_queues.add(row[0])
|
db_queues.add(row[0])
|
||||||
|
|
||||||
for queue_machine, is_enabled in qs.is_enabled.items():
|
for queue_machine, is_enabled in qs.is_enabled.items():
|
||||||
|
|
@ -355,8 +354,7 @@ class QueueManager:
|
||||||
def get_state(self) -> QueueDisableState:
|
def get_state(self) -> QueueDisableState:
|
||||||
"""returns the state of the queues."""
|
"""returns the state of the queues."""
|
||||||
# get the list of queue names from the disables table in the database
|
# get the list of queue names from the disables table in the database
|
||||||
sql_query = "SELECT DISTINCT queue_machine FROM disables;"
|
results = self.db_backend.query_select(columns=["queue_machine"], table="disables", distinct=True)
|
||||||
results = self.db_backend.query(sql_query)
|
|
||||||
for row in results:
|
for row in results:
|
||||||
assert len(row) == 1, "Each row should have only one column (queue_machine)"
|
assert len(row) == 1, "Each row should have only one column (queue_machine)"
|
||||||
queue_machines = [row[0] for row in results]
|
queue_machines = [row[0] for row in results]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = '1.0.26'
|
__version__ = '1.0.27'
|
||||||
|
|
||||||
|
|
||||||
class Version(object):
|
class Version(object):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue