diff --git a/cocluto/SimpaDbUtil.py b/cocluto/SimpaDbUtil.py index 079de5d..3649cbb 100644 --- a/cocluto/SimpaDbUtil.py +++ b/cocluto/SimpaDbUtil.py @@ -98,10 +98,14 @@ class ISqlDatabaseBackend(object): """ raise NotImplementedError() - @abc.abstractmethod - def get_last_insert_id(self) -> int: + def query_insert(self, table_id: str, fields: List[str], values: List[tuple]) -> int: """ - :return: the id of the last inserted row + performs an insert query on the sql database and returns the id of the inserted row + + :param str table_id: the name of the table to insert into + :param List[str] fields: the list of fields to insert values into + :param List[tuple] values: the list of values to insert (one tuple per row, tuple values are the column values) + :return: the id of the inserted row """ raise NotImplementedError() @@ -141,6 +145,21 @@ class ISqlDatabaseBackend(object): raise NotImplementedError() +def values_to_sql_string(values: List[tuple]) -> str: + '''converts a list of tuples of values into a string that can be used in a sql query, with proper escaping of string values + + eg the list of tuples + [('alambix42', 'disabled to move the alambix42 to another rack'), ('alambix42', 'because I want to test quman')] + will be converted into the following string that can be used in a sql query : + "('alambix42', 'disabled to move the alambix42 to another rack'), ('alambix42', 'because I want to test quman')" + ''' + sql_values = [] + for value_tuple in values: + escaped_value_tuple = tuple(["'%s'" % str(v).replace("'", r"\'") for v in value_tuple]) + sql_values.append(f'({", ".join(escaped_value_tuple)})') + return ', '.join(sql_values) + + class RemoteMysqlDb(ISqlDatabaseBackend): def __init__(self, db_server_fqdn, db_user, db_name): """ @@ -175,11 +194,14 @@ class RemoteMysqlDb(ISqlDatabaseBackend): sql_query += f" WHERE {where_clause}" return self.query(sql_query) - def get_last_insert_id(self) -> int: - """ - :return: the id of the last inserted row - """ - return int(self.query("SELECT last_insert_id();")[0][0]) + def query_insert(self, table_id: str, fields: List[str], values: List[tuple]) -> int: + sql_query = f"INSERT INTO {table_id} ({', '.join(fields)}) VALUES {values_to_sql_string(values)}" + sql_query = sql_query + ";SELECT last_insert_id();" + logging.debug("RemoteMysqlDb.query_insert:: sql_query = '%s'", sql_query) + rows = self.query(sql_query) + last_insert_id = int(rows[0][0]) + assert last_insert_id > 0, f'Unexpected last insert id : {last_insert_id}' + return last_insert_id def table_exists(self, table_name: str) -> bool: rows = self.query(f"SHOW TABLES LIKE '{table_name}';") @@ -252,7 +274,6 @@ class SshAccessedMysqlDb(ISqlDatabaseBackend): stdout = completed_process.stdout.decode('utf-8') logging.debug("SshAccessedMysqlDb.query:: results of query '%s': %s", sql_query, stdout) rows = stdout.split('\n') if stdout != '' else [] - logging.debug("SshAccessedMysqlDb.query:: split results of query '%s' by new lines: %s", sql_query, rows) return rows def query_select(self, columns: List[str], table: str, where_clause: Optional[str] = None, join_clause: Optional[str] = None, distinct: bool = False) -> Table: @@ -282,18 +303,21 @@ class SshAccessedMysqlDb(ISqlDatabaseBackend): table = json_to_table(json_str) return table - def get_last_insert_id(self) -> int: - """ - :return: the id of the last inserted row - """ - # MariaDB [quman]> SELECT JSON_OBJECT('toto', last_insert_id()); + def query_insert(self, table_id: str, fields: List[str], values: List[tuple]) -> int: + logging.debug("SshAccessedMysqlDb.query_insert:: values = %s", values) + # INSERT INTO log (timestamp, user_id, host_fqdn, queue_machines, action, disable_id, reason) VALUES ('2026-05-26T15:31:53.564287', 'graffy', 'alambix50.ipr.univ-rennes.fr', 'gpuonly.q@alambix104.ipr.univ-rennes1.fr', 'disable', 'quman-sync', 'synchronized with grid engine');SELECT JSON_OBJECT('toto', last_insert_id()); + # Query OK, 1 row affected (0,046 sec) + # +---------------------------------------+ # | JSON_OBJECT('toto', last_insert_id()) | # +---------------------------------------+ - # | {"toto": 0} | + # | {"toto": 9} | # +---------------------------------------+ # 1 row in set (0,002 sec) - stdout = self.query("SELECT JSON_OBJECT('toto', last_insert_id());") + sql_query = f"INSERT INTO {table_id} ({', '.join(fields)}) VALUES {values_to_sql_string(values)}" + sql_query = sql_query + ";SELECT JSON_OBJECT('toto', last_insert_id());" + logging.debug("SshAccessedMysqlDb.query_insert:: sql_query = '%s'", sql_query) + stdout = self.query(sql_query) data_line = stdout[-2] match = re.match(r'^\s*(?P{\s*"toto"\s*:\s*(\d+)\s*})\s*$', data_line) assert match @@ -394,11 +418,26 @@ class SqliteDb(ISqlDatabaseBackend): sql_query += f" WHERE {where_clause}" return self.query(sql_query) - def get_last_insert_id(self) -> int: + def query_insert(self, table_id: str, fields: List[str], values: List[tuple]) -> int: """ - :return: the id of the last inserted row + performs an insert query on the sql database and returns the id of the inserted row + + :param str table_id: the name of the table to insert into + :param List[str] fields: the list of fields to insert values into + :param List[tuple] values: the list of values to insert (one tuple per row, tuple values are the column values) + :return: the id of the inserted row """ - return int(self.query("SELECT last_insert_rowid();")[0][0]) + sql_query = f"INSERT INTO {table_id} ({', '.join(fields)}) VALUES {values_to_sql_string(values)}" + logging.debug("SqliteDb.query_insert:: sql_query = '%s'", sql_query) + logging.debug("SqliteDb.query_insert:: values = %s", values) + self._cur.execute(sql_query) + self._cur.execute('SELECT last_insert_rowid();') + rows = self._cur.fetchall() + self._con.commit() + assert len(rows) == 1, f'Unexpected number of rows ({len(rows)}).' + last_insert_id = int(rows[0][0]) + assert last_insert_id > 0, f'Unexpected last insert id : {last_insert_id}' + return last_insert_id def table_exists(self, table_name: str) -> bool: rows = self.query(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';") diff --git a/cocluto/quman.py b/cocluto/quman.py index e0701f2..d15890f 100755 --- a/cocluto/quman.py +++ b/cocluto/quman.py @@ -256,10 +256,7 @@ class QueueManager: userid = subprocess.check_output(['whoami']).decode().strip() host_fqdn = subprocess.check_output(['hostname', '-f']).decode().strip() timestamp = datetime.now().isoformat() - sql_query = f"INSERT INTO log (timestamp, user_id, host_fqdn, queue_machines, action, disable_id, reason) VALUES ('{timestamp}', '{userid}', '{host_fqdn}', '{','.join(queue_machines)}', '{action}', '{disable_id}', '{reason}');" - self.db_backend.query(sql_query) - # get the log id of the disable action that was just inserted - log_id = self.db_backend.get_last_insert_id() + log_id = self.db_backend.query_insert(table_id="log", fields=['timestamp', 'user_id', 'host_fqdn', 'queue_machines', 'action', 'disable_id', 'reason'], values=[(timestamp, userid, host_fqdn, ','.join(queue_machines), action, disable_id, reason)]) return log_id def get_disable_requests(self, queue_machine: QueueMachineId) -> Dict[int, DisableRequest]: diff --git a/cocluto/version.py b/cocluto/version.py index 3706116..607cc93 100644 --- a/cocluto/version.py +++ b/cocluto/version.py @@ -1,4 +1,4 @@ -__version__ = '1.0.27' +__version__ = '1.0.28' class Version(object):