cocluto v1.0.27 - fixes bug in returned table of SshAccessedMysqlDb.query('select...') and SshAccessedMysqlDb.get_last_insert_id()
- splitted ISqlDatabaseBackend.query() into ISqlDatabaseBackend.query() and the more specialized SqlDatabaseBackend.query_select ()
- this way the return type of query_select is explicit (a Table).
- it also allows robust sql query parsing (usin json).
- replaced calls to ISqlDatabaseBackend.query('select ...') with calls to ISqlDatabaseBackend.query_select()
work related to [https://bugzilla.ipr.univ-rennes.fr/show_bug.cgi?id=3093]
This commit is contained in:
parent
5856ac0951
commit
d2c973c7ed
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Union, List
|
||||
from typing import Union, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
import MySQLdb # sudo port install py-mysql; sudo apt install python-mysqldb or pip install mysqlclient
|
||||
import time
|
||||
|
|
@ -43,6 +44,7 @@ def is_machine_responding(machineName):
|
|||
|
||||
|
||||
SqlQuery = str
|
||||
Table = List[Tuple]
|
||||
|
||||
|
||||
class SqlTableField():
|
||||
|
|
@ -68,6 +70,10 @@ class SqlTableField():
|
|||
self.is_autoinc_index = is_autoinc_index
|
||||
|
||||
|
||||
ColumnId = str # eg 'matrix_size'
|
||||
TableId = str # eg 'benchmark_results'
|
||||
|
||||
|
||||
class ISqlDatabaseBackend(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
@ -79,6 +85,19 @@ class ISqlDatabaseBackend(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def query_select(self, columns: List[ColumnId], table: TableId, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||
"""
|
||||
performs a select query on the sql database and returns the results in the form of a list of tuples (one tuple per row, tuple values are the column values)
|
||||
|
||||
:param List[ColumnId] columns: the columns to select
|
||||
:param TableId table: the name of the table to query
|
||||
:param Optional[str] where_clause: the where clause for the query, eg "matrix_size > 100"
|
||||
:param Optional[str] join_clause: the join clause for the query, eg "disables ON log.id = disables.disable_request_id"
|
||||
:param bool distinct: whether to return distinct rows
|
||||
:return: the results of the query
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_last_insert_id(self) -> int:
|
||||
"""
|
||||
|
|
@ -142,10 +161,20 @@ class RemoteMysqlDb(ISqlDatabaseBackend):
|
|||
"""
|
||||
:param str sql_query: the sql query to perform
|
||||
"""
|
||||
self._conn.query(sql_query)
|
||||
rows = self._conn.store_result()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(sql_query)
|
||||
rows = cursor.fetchall()
|
||||
logging.debug("RemoteMysqlDb.query:: results of query using cursor '%s': %s", sql_query, rows)
|
||||
return rows
|
||||
|
||||
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{', '.join(columns)} FROM {table}"
|
||||
if join_clause:
|
||||
sql_query += f" JOIN {join_clause}"
|
||||
if where_clause:
|
||||
sql_query += f" WHERE {where_clause}"
|
||||
return self.query(sql_query)
|
||||
|
||||
def get_last_insert_id(self) -> int:
|
||||
"""
|
||||
:return: the id of the last inserted row
|
||||
|
|
@ -167,6 +196,26 @@ class RemoteMysqlDb(ISqlDatabaseBackend):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
def json_to_table(json_str: str) -> Table:
|
||||
# logging.debug("json_to_table:: json_str = '%s'", json_str)
|
||||
if json_str == 'NULL':
|
||||
return []
|
||||
try:
|
||||
json_data = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
logging.error("json_to_table:: invalid json string: '%s'", json_str)
|
||||
raise
|
||||
assert isinstance(json_data, list), f'Expected a list of rows in the json string but got {type(json_data)}'
|
||||
table = []
|
||||
for row in json_data:
|
||||
values = []
|
||||
for column_value in row.values():
|
||||
assert isinstance(column_value, (str, int, float)), f'Expected a string, int or float value for each column value in the json string but got {type(column_value)}'
|
||||
values.append(column_value)
|
||||
table.append(tuple(values))
|
||||
return table
|
||||
|
||||
|
||||
class SshAccessedMysqlDb(ISqlDatabaseBackend):
|
||||
|
||||
"""a mysql database server accessed using ssh instead of a remote mysql client
|
||||
|
|
@ -200,14 +249,58 @@ class SshAccessedMysqlDb(ISqlDatabaseBackend):
|
|||
if completed_process.returncode != 0:
|
||||
logging.error(completed_process.stderr.decode(encoding='utf-8'))
|
||||
assert False
|
||||
rows = completed_process.stdout.decode('utf-8').split('\n')
|
||||
stdout = completed_process.stdout.decode('utf-8')
|
||||
logging.debug("SshAccessedMysqlDb.query:: results of query '%s': %s", sql_query, stdout)
|
||||
rows = stdout.split('\n') if stdout != '' else []
|
||||
logging.debug("SshAccessedMysqlDb.query:: split results of query '%s' by new lines: %s", sql_query, rows)
|
||||
return rows
|
||||
|
||||
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||
# MariaDB [quman]> SELECT JSON_ARRAYAGG(JSON_OBJECT('toto', timestamp, 'iddd', user_id)) FROM log;
|
||||
# +-------------------------------------------------------------------------------------------------------+
|
||||
# | JSON_ARRAYAGG(JSON_OBJECT('toto', timestamp, 'iddd', user_id)) |
|
||||
# +-------------------------------------------------------------------------------------------------------+
|
||||
# | [{"toto": "2026-05-21 15:11:58", "iddd": "graffy"},{"toto": "2026-05-22 15:55:39", "iddd": "graffy"}] |
|
||||
# +-------------------------------------------------------------------------------------------------------+
|
||||
# 1 row in set (0,006 sec)
|
||||
# columns_list_str = ', '.join([f'"{col}"' for col in columns])
|
||||
columns_list_str = ', '.join([f"'{col}', {col}" for col in columns])
|
||||
columns_statement = f'JSON_ARRAYAGG(JSON_OBJECT({columns_list_str}))'
|
||||
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{columns_statement} FROM {table}"
|
||||
if join_clause:
|
||||
sql_query += f" JOIN {join_clause}"
|
||||
if where_clause:
|
||||
sql_query += f" WHERE {where_clause}"
|
||||
stdout = self.query(sql_query)
|
||||
json_str = stdout[-2] # eg '[{"queue_machine": "gpuonly.q@alambix104.ipr.univ-rennes1.fr"}]'
|
||||
assert isinstance(json_str, str), f'Expected a string as data line in the query output but got {type(json_str)}'
|
||||
logging.debug("SshAccessedMysqlDb.query_select:: data line of query '%s': '%s'", sql_query, json_str)
|
||||
# match = re.match(r'^\s*(?P<json_data>{[^}]*})\s*$', data_line)
|
||||
# assert match, 'Unexpected output format for query "%s" : %s' % (sql_query, stdout)
|
||||
# json_data = json.loads(match.group('json_data'))
|
||||
|
||||
table = json_to_table(json_str)
|
||||
return table
|
||||
|
||||
def get_last_insert_id(self) -> int:
|
||||
"""
|
||||
:return: the id of the last inserted row
|
||||
"""
|
||||
return int(self.query("SELECT last_insert_id();")[0][0])
|
||||
# MariaDB [quman]> SELECT JSON_OBJECT('toto', last_insert_id());
|
||||
# +---------------------------------------+
|
||||
# | JSON_OBJECT('toto', last_insert_id()) |
|
||||
# +---------------------------------------+
|
||||
# | {"toto": 0} |
|
||||
# +---------------------------------------+
|
||||
# 1 row in set (0,002 sec)
|
||||
stdout = self.query("SELECT JSON_OBJECT('toto', last_insert_id());")
|
||||
data_line = stdout[-2]
|
||||
match = re.match(r'^\s*(?P<json_data>{\s*"toto"\s*:\s*(\d+)\s*})\s*$', data_line)
|
||||
assert match
|
||||
json_data = json.loads(match.group('json_data'))
|
||||
last_insert_id = int(json_data['toto'])
|
||||
assert last_insert_id > 0, f'Unexpected last insert id : {last_insert_id}'
|
||||
return last_insert_id
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
rows = self.query(f"SHOW TABLES LIKE '{table_name}';")
|
||||
|
|
@ -293,6 +386,14 @@ class SqliteDb(ISqlDatabaseBackend):
|
|||
self._con.commit()
|
||||
return rows
|
||||
|
||||
def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table:
|
||||
sql_query = f"SELECT {('DISTINCT ' if distinct else '')}{', '.join(columns)} FROM {table}"
|
||||
if join_clause:
|
||||
sql_query += f" JOIN {join_clause}"
|
||||
if where_clause:
|
||||
sql_query += f" WHERE {where_clause}"
|
||||
return self.query(sql_query)
|
||||
|
||||
def get_last_insert_id(self) -> int:
|
||||
"""
|
||||
:return: the id of the last inserted row
|
||||
|
|
|
|||
|
|
@ -263,8 +263,7 @@ class QueueManager:
|
|||
return log_id
|
||||
|
||||
def get_disable_requests(self, queue_machine: QueueMachineId) -> Dict[int, DisableRequest]:
|
||||
sql_query = f"SELECT log.id, log.user_id, log.host_fqdn, log.queue_machines, log.reason, log.disable_id, log.timestamp FROM log JOIN disables ON log.id = disables.disable_request_id WHERE disables.queue_machine = '{queue_machine}' AND log.action = 'disable';"
|
||||
results = self.db_backend.query(sql_query)
|
||||
results = self.db_backend.query_select(['log.id', 'log.user_id', 'log.host_fqdn', 'log.queue_machines', 'log.reason', 'log.disable_id', 'log.timestamp'], join_clause="disables ON log.id = disables.disable_request_id", table="log", where_clause=f"disables.queue_machine = '{queue_machine}' AND log.action = 'disable'")
|
||||
disable_requests = []
|
||||
for row in results:
|
||||
log_id = row[0]
|
||||
|
|
@ -324,10 +323,10 @@ class QueueManager:
|
|||
qs = self.grid_engine.get_status()
|
||||
|
||||
db_queues = set()
|
||||
sql_query = "SELECT queue_machine FROM queues;"
|
||||
results = self.db_backend.query(sql_query)
|
||||
results = self.db_backend.query_select(columns=["queue_machine"], table="queues")
|
||||
logging.debug("synchronize_with_grid_engine: results of query': %s", results)
|
||||
for row in results:
|
||||
assert len(row) == 1, "Each row should have only one column (queue_machine)"
|
||||
assert len(row) == 1, "Each row should have only one column (queue_machine) but got row='%s' (len=%d)" % (str(row), len(row))
|
||||
db_queues.add(row[0])
|
||||
|
||||
for queue_machine, is_enabled in qs.is_enabled.items():
|
||||
|
|
@ -355,8 +354,7 @@ class QueueManager:
|
|||
def get_state(self) -> QueueDisableState:
|
||||
"""returns the state of the queues."""
|
||||
# get the list of queue names from the disables table in the database
|
||||
sql_query = "SELECT DISTINCT queue_machine FROM disables;"
|
||||
results = self.db_backend.query(sql_query)
|
||||
results = self.db_backend.query_select(columns=["queue_machine"], table="disables", distinct=True)
|
||||
for row in results:
|
||||
assert len(row) == 1, "Each row should have only one column (queue_machine)"
|
||||
queue_machines = [row[0] for row in results]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
__version__ = '1.0.26'
|
||||
__version__ = '1.0.27'
|
||||
|
||||
|
||||
class Version(object):
|
||||
|
|
|
|||
Loading…
Reference in New Issue