Cat Feeder  1.0.0
The Cat feeder project
Loading...
Searching...
No Matches
sql_sanitisation_functions.py
Go to the documentation of this file.
1r"""
2# +==== BEGIN CatFeeder =================+
3# LOGO:
4# ..............(..../\
5# ...............)..(.')
6# ..............(../..)
7# ...............\‍(__)|
8# Inspired by Joan Stark
9# source https://www.asciiart.eu/
10# animals/cats
11# /STOP
12# PROJECT: CatFeeder
13# FILE: sql_sanitisation_functions.py
14# CREATION DATE: 11-10-2025
15# LAST Modified: 1:50:47 06-02-2026
16# DESCRIPTION:
17# This is the backend server in charge of making the actual website work.
18# /STOP
19# COPYRIGHT: (c) Cat Feeder
20# PURPOSE: File in charge of cleaning and sanitising sql queries before they are submitted to the database.
21# // AR
22# +==== END CatFeeder =================+
23"""
24
25import re
26from typing import List, Dict, Any, Union, Optional, Tuple
27
28from display_tty import Disp, initialise_logger
29
30from . import sql_constants as SCONST
31from .sql_time_manipulation import SQLTimeManipulation
32
33
35 """Provide functions to sanitize SQL queries before execution.
36
37 This class contains methods to clean and escape SQL queries, ensuring
38 they are safe to execute and free from injection vulnerabilities.
39
40 Attributes:
41 disp (Disp): Logger instance for debugging and error reporting.
42 risky_keywords (List[str]): List of risky SQL keywords to sanitize.
43 keyword_logic_gates (List[str]): List of logical operators to handle.
44 none_value (str): Default value for NULL representation.
45 sql_time_manipulation (SQLTimeManipulation): Handles time-related SQL operations.
46 """
47
48 disp: Disp = initialise_logger(__qualname__, False)
49
50 def __init__(self, success: int = 0, error: int = 84, debug: bool = False) -> None:
51 """Initialize the SQLSanitiseFunctions instance.
52
53 Args:
54 success (int, optional): Numeric success code. Defaults to 0.
55 error (int, optional): Numeric error code. Defaults to 84.
56 debug (bool, optional): Enable debug logging. Defaults to False.
57 """
58 # ------------------------ The logging function ------------------------
59 self.disp.update_disp_debug(debug)
60 self.disp.log_debug("Initialising...")
61 # -------------------------- Inherited values --------------------------
62 self.error: int = error
63 self.debugdebug: bool = debug
64 self.success: int = success
65 # ----------------- Database risky keyword sanitising -----------------
66 self.risky_keywordsrisky_keywords: List[str] = SCONST.RISKY_KEYWORDS
67 self.keyword_logic_gates: List[str] = SCONST.KEYWORD_LOGIC_GATES
68 # ---------------------- Time manipulation class ----------------------
69 self.sql_time_manipulation: SQLTimeManipulation = SQLTimeManipulation(
70 self.debugdebug
71 )
72 self.none_value: str = "NULL"
73 self.disp.log_debug("Initialised")
74
75 def protect_sql_cell(self, cell: Optional[str]) -> str:
76 """Escape characters in a SQL cell to prevent query breaking.
77
78 Args:
79 cell (Optional[str]): The cell to sanitize.
80
81 Returns:
82 str: Sanitized string safe for SQL queries.
83 """
84 if cell is None:
85 return self.none_value
86 result = ""
87 for char in cell:
88 if char in ("'", '"', "\\", '\0', "\r"):
89 self.disp.log_info(
90 f"Escaped character '{char}' in '{cell}'.",
91 "protect_sql_cell"
92 )
93 result += "\\"+char
94 else:
95 result += char
96 return result
97
98 def escape_risky_column_names(self, columns: Union[List[str], str]) -> Union[List[str], str]:
99 """Escape risky column names to prevent SQL injection.
100
101 Args:
102 columns (Union[List[str], str]): Column names to sanitize.
103
104 Returns:
105 Union[List[str], str]: Sanitized column names.
106 """
107 title = "_escape_risky_column_names"
108 self.disp.log_debug("Escaping risky column names.", title)
109 if isinstance(columns, str):
110 data = [columns]
111 else:
112 data = columns
113 for index, item in enumerate(data):
114 if "=" in item:
115 key, value = item.split("=", maxsplit=1)
116 self.disp.log_debug(f"key = {key}, value = {value}", title)
117 if key.lower() in self.risky_keywordsrisky_keywords:
118 self.disp.log_warning(
119 f"Escaping risky column name '{key}'.",
120 "_escape_risky_column_names"
121 )
122 data[index] = f"`{key}`={value}"
123 elif item.lower() in self.risky_keywordsrisky_keywords:
124 self.disp.log_warning(
125 f"Escaping risky column name '{item}'.",
126 "_escape_risky_column_names"
127 )
128 data[index] = f"`{item}`"
129 else:
130 continue
131 self.disp.log_debug("Escaped risky column names.", title)
132 if isinstance(columns, str):
133 return data[0]
134 return columns
135
136 def _protect_value(self, value: Optional[str]) -> str:
137 """Ensure a value is safely passed as a string in an SQL query.
138
139 Args:
140 value (Optional[str]): The value to protect.
141
142 Returns:
143 str: Protected value safe for SQL queries.
144 """
145 title = "_protect_value"
146 self.disp.log_debug(f"protecting value: {value}", title)
147 if value is None:
148 self.disp.log_debug("Value is none, thus returning NULL", title)
149 return self.none_value
150
151 if isinstance(value, str) is False:
152 self.disp.log_debug("Value is not a string, converting", title)
153 value = str(value)
154
155 if len(value) == 0:
156 self.disp.log_debug("Value is empty, returning ''", title)
157 return "''"
158
159 if value[0] == '`' and value[-1] == '`':
160 self.disp.log_debug(
161 "string has special backtics, skipping.", title
162 )
163 return value
164
165 if value[0] == "'":
166 self.disp.log_debug(
167 "Value already has a single quote at the start, removing", title
168 )
169 value = value[1:]
170 if value[-1] == "'":
171 self.disp.log_debug(
172 "Value already has a single quote at the end, removing", title
173 )
174 value = value[:-1]
175
176 self.disp.log_debug(
177 f"Value before quote escaping: {value}", title
178 )
179 protected_value = value.replace("'", "''")
180 self.disp.log_debug(
181 f"Value after quote escaping: {protected_value}", title
182 )
183
184 protected_value = f"'{protected_value}'"
185 self.disp.log_debug(
186 f"Value after being converted to a string: {protected_value}.",
187 title
188 )
189 return protected_value
190
191 def escape_risky_column_names_where_mode(self, columns: Union[List[str], str]) -> Union[List[str], str]:
192 """Escape risky column names in WHERE mode.
193
194 Args:
195 columns (Union[List[str], str]): Column names to sanitize.
196
197 Returns:
198 Union[List[str], str]: Sanitized column names.
199 """
200 title = "_escape_risky_column_names_where_mode"
201 self.disp.log_debug(
202 "Escaping risky column names in where mode.", title
203 )
204
205 if isinstance(columns, str):
206 data = [columns]
207 else:
208 data = columns
209
210 for index, item in enumerate(data):
211 if "=" in item:
212 key, value = item.split("=", maxsplit=1)
213 self.disp.log_debug(f"key = {key}, value = {value}", title)
214
215 protected_value = self._protect_value(value)
216 if key.lower() not in self.keyword_logic_gates and key.lower() in self.risky_keywordsrisky_keywords:
217 self.disp.log_warning(
218 f"Escaping risky column name '{key}'.", title
219 )
220 data[index] = f"`{key}`={protected_value}"
221 else:
222 data[index] = f"{key}={protected_value}"
223
224 elif item.lower() not in self.keyword_logic_gates and item.lower() in self.risky_keywordsrisky_keywords:
225 self.disp.log_warning(
226 f"Escaping risky column name '{item}'.",
227 title
228 )
229 protected_value = self._protect_value(item)
230 data[index] = protected_value
231
232 self.disp.log_debug("Escaped risky column names in where mode.", title)
233
234 if isinstance(columns, str):
235 return data[0]
236 return data
237
238 def check_sql_cell(self, cell: Union[str, int, float, None], raw: bool = True) -> Union[str, Union[str, int, float, None]]:
239 """Check and sanitize a SQL cell value.
240
241 Args:
242 cell (Union[str, int, float, None]): The cell value to check.
243 raw (bool, optional): Whether to process raw values. Defaults to True.
244
245 Returns:
246 Union[str, Union[str, int, float, None]]: Sanitized cell value.
247 """
248 title: str = "check_sql_cell"
249 cell_cleaned = None
250 if raw and isinstance(cell, (float, int)):
251 return cell
252 if raw and cell is None:
253 return cell
254 if isinstance(cell, (str, float, int)) is True:
255 cell_cleaned = str(cell)
256 if isinstance(cell, str) is False:
257 msg = "The expected type of the input is a string,"
258 msg += f"but got {type(cell)}"
259 self.disp.log_error(msg, title)
260 return str(cell)
261 cell = self.protect_sql_cell(cell_cleaned)
262 tmp = cell.lower()
263 if tmp in ("now", "now()"):
264 tmp = self.sql_time_manipulation.get_correct_now_value()
265 elif tmp in ("current_date", "current_date()"):
266 tmp = self.sql_time_manipulation.get_correct_current_date_value()
267 else:
268 tmp = str(cell)
269 if ";base" not in tmp:
270 self.disp.log_debug(f"result = {tmp}", title)
271 # Return raw value for parameterized queries (%s placeholders)
272 # The MySQL driver handles escaping automatically
273 return tmp
274
275 def beautify_table(self, column_names: List[str], table_content: List[List[Any]]) -> Union[List[Dict[str, Any]], int]:
276 """Convert raw table rows to a list of dictionaries keyed by column.
277
278 Args:
279 column_names (List[str]): Column descriptors (name as first item).
280 table_content (List[List[Any]]): Raw rows as sequences.
281
282 Returns:
283 Union[List[Dict[str, Any]], int]: Beautified table or error code.
284 """
285 self.disp.log_debug("Beautifying table.")
286 data: List[Dict[str, Any]] = []
287 if len(column_names) == 0:
288 self.disp.log_error("There are no provided table column names.")
289 return self.error
290 if len(table_content) == 0:
291 self.disp.log_warning("There is no table content.")
292 return []
293
294 column_length = len(column_names)
295
296 # Cache type check: determine column structure once
297 columns_are_tuples = isinstance(column_names[0], tuple)
298
299 # Pre-extract keys if columns are tuples to avoid repeated indexing
300 if columns_are_tuples:
301 column_keys = []
302 for col in column_names:
303 column_keys.append(col[0])
304 else:
305 column_keys = column_names
306
307 # Process rows with optimized path
308 for row in table_content:
309 cell_length = len(row)
310 if cell_length != column_length:
311 self.disp.log_warning(
312 "Table content and column lengths do not correspond."
313 )
314
315 row_dict = {}
316 for index in range(min(cell_length, column_length)):
317 row_dict[column_keys[index]] = row[index]
318
319 data.append(row_dict)
320
321 self.disp.log_debug(f"beautified_table = {data}")
322 return data
323
324 def compile_update_line(self, line: List, column: List, column_length: int) -> str:
325 """Compile the line required for an SQL update to work.
326
327 Args:
328 line (List): Data line to compile.
329 column (List): Column names.
330 column_length (int): Number of columns.
331
332 Returns:
333 str: Compiled SQL update line.
334 """
335 title = "compile_update_line"
336 final_line = ""
337 self.disp.log_debug("Compiling update line.", title)
338 for i in range(0, column_length):
339 cell_content = self.check_sql_cell(line[i])
340 final_line += f"{column[i]} = {cell_content}"
341 if i < column_length - 1:
342 final_line += ", "
343 if i == column_length:
344 break
345 self.disp.log_debug(f"line = {final_line}", title)
346 return final_line
347
348 def _process_single_sql_line(self, line: List[Union[str, int, float, None]], column_length: int) -> Tuple[str, List[Union[str, int, float, None]]]:
349 """Process a single SQL value line while preserving column logic.
350
351 Args:
352 line (List[Union[str, int, float, None]]): Data line to process.
353 column_length (int): Number of columns.
354
355 Returns:
356 Tuple[str, List[Union[str, int, float, None]]]: Placeholder string and values.
357 """
358 title: str = "_process_single_sql_line"
359 if not isinstance(line, list):
360 line = [line]
361 line_length = len(line)
362
363 placeholders: List[str] = []
364 values: List[Union[str, int, float, None]] = []
365
366 if self.debugdebug and ";base" not in str(line):
367 self.disp.log_debug(f"line = {line}", title)
368
369 for i in range(column_length):
370 if i >= line_length:
371 msg = (
372 f"Line shorter than expected (columns={column_length}, data={line_length}). "
373 f"Missing columns will not be inserted beyond index {i}."
374 )
375 self.disp.log_warning(msg, title)
376 break
377
378 checked_value = self.check_sql_cell(line[i], raw=True)
379 values.append(checked_value)
380 placeholders.append("%s")
381
382 if i == column_length - 1 and line_length > column_length:
383 msg = (
384 f"The line is longer than the number of columns ({line_length} > {column_length}), "
385 f"truncating excess values."
386 )
387 self.disp.log_warning(msg, title)
388 break
389
390 line_placeholder = "(" + ", ".join(placeholders) + ")"
391
392 if self.debugdebug:
393 msg = f"line_placeholder = '{line_placeholder}', type = {type(line_placeholder)}"
394 self.disp.log_debug(msg, title)
395 self.disp.log_debug(f"values = {values}", title)
396
397 tuple_version = [line_placeholder, values]
398 return tuple(tuple_version)
399
400 def process_sql_line(self, line: Union[str, int, float, List[Union[str, int, float, None]], List[List[Union[str, int, float, None]]], None], column: List[str], column_length: int = -1) -> Tuple[str, List[Union[str, int, float, None]]]:
401 """Convert a dataset to MySQL/MariaDB-safe placeholders.
402
403 Args:
404 line (Union[str, int, float, List, None]): Data to process.
405 column (List[str]): Column names.
406 column_length (int, optional): Number of columns. Defaults to -1.
407
408 Returns:
409 Tuple[str, List[Union[str, int, float, None]]]: Placeholder string and values.
410 """
411 title: str = "process_sql_line"
412
413 if column_length == -1:
414 column_length = len(column)
415
416 if not isinstance(line, list):
417 line = [line]
418
419 results: List[str] = []
420 all_values: List[Union[str, int, float, None]] = []
421
422 processed_list_instances: int = 0
423
424 # Case 1: multi-row data (list of lists)
425 if isinstance(line, list) and len(line) > 0 and isinstance(line[0], list):
426 for row in line:
427 if isinstance(row, list):
428 placeholders, vals = self._process_single_sql_line(
429 row, column_length
430 )
431 results.append(placeholders)
432 all_values.extend(vals)
433 processed_list_instances += 1
434 else:
435 raise RuntimeError(
436 "Incorrect data format, aborting process")
437 line_final: str = ", ".join(results)
438 if self.debugdebug:
439 self.disp.log_debug(
440 f"Final placeholder string = '{line_final}'", title)
441 self.disp.log_debug(f"Total values = {len(all_values)}", title)
442
443 return line_final, all_values
444
445 # Case 2: single-row data
446 if isinstance(line, list) and not isinstance(line[0], list):
447 buffer: str = "("
448 line_length = len(line)
449 for index, row in enumerate(line):
450 if self.debugdebug and ";base" not in str(row):
451 self.disp.log_debug(f"row = {row}", title)
452
453 if not isinstance(row, list):
454 checked_value = self.check_sql_cell(row, raw=True)
455 all_values.append(checked_value)
456 buffer += "%s"
457 else:
458 raise RuntimeError(
459 "Incorrect data format, aborting process"
460 )
461
462 # Only add a comma if there are more provided values to come
463 # AND we haven't reached the column limit. This prevents a
464 # trailing comma when the provided data has fewer items than
465 # the number of table columns.
466 if index < line_length - 1 and index - processed_list_instances < column_length - 1:
467 buffer += ", "
468
469 if index - processed_list_instances == column_length - 1:
470 if index - processed_list_instances < len(line) - 1:
471 msg = (
472 "The line is longer than the number of columns, truncating."
473 )
474 self.disp.log_warning(msg, title)
475 break
476
477 buffer += ")"
478 if buffer not in ("()", ""):
479 results.append(buffer)
480
481 line_final: str = ", ".join(results)
482
483 if self.debugdebug:
484 self.disp.log_debug(
485 f"Final placeholder string = '{line_final}'", title
486 )
487 self.disp.log_debug(f"Total values = {len(all_values)}", title)
488
489 return line_final, all_values
490
491 def _check_for_double_query_in_trigger(self, sql: str, table_name: str) -> Union[int, str]:
492 """Check for double queries in a trigger.
493
494 Args:
495 sql (str): SQL trigger statement.
496 table_name (str): Name of the table.
497
498 Returns:
499 Union[int, str]: Validated SQL trigger or error code.
500 """
501 title: str = "_check_for_double_query"
502 # --- Safety validation layer ---
503 normalized_lower = sql.lower()
504
505 # 1. Only one CREATE TRIGGER allowed
506 if len(re.findall(r"\bcreate\s+trigger\b", normalized_lower)) > 1:
507 self.disp.log_error(
508 "Multiple CREATE TRIGGER statements detected.", title
509 )
510 return self.error
511
512 # 2. Disallow dangerous DDL keywords (basic whitelist)
513 for keyword in SCONST.SQL_RISKY_DDL_TRIGGER_KEYWORDS:
514 if keyword in normalized_lower:
515 self.disp.log_error(
516 f"Unsafe keyword '{keyword.strip()}' detected in trigger SQL.", title
517 )
518 return self.error
519
520 # 3. Ensure trigger table isn't a system schema
521 if re.match(r"(?i)^(mysql|information_schema|performance_schema|sys)\.", table_name):
522 self.disp.log_error(
523 "Trigger cannot be created on system schema tables.", title)
524 return self.error
525
526 # 4. Check BEGIN/END pairing (simple count balance)
527 begin_count = normalized_lower.count("begin")
528 end_count = normalized_lower.count("end")
529 if begin_count != end_count:
530 self.disp.log_error(
531 f"Unbalanced BEGIN/END block ({begin_count} BEGIN vs {end_count} END).", title
532 )
533 return self.error
534
535 # 5. Warn if multiple statements outside BEGIN...END
536 if begin_count == 0 and sql.count(";") > 1:
537 self.disp.log_warning(
538 "Multiple SQL statements found outside BEGIN/END. "
539 "MySQL triggers only support one statement unless wrapped.",
540 title
541 )
542
543 # --- Final shape sanity check ---
544 if not re.match(r"(?i)^CREATE\s+TRIGGER\s+[`\"\w]+", sql):
545 self.disp.log_error("Malformed CREATE TRIGGER statement.", title)
546 return self.error
547 return sql
548
549 def clean_trigger_creation(self, trigger_name: str, table_name: str, timing_event: str, body: str) -> Union[str, int]:
550 """Clean and validate SQL trigger creation.
551
552 Args:
553 trigger_name (str): Name of the trigger.
554 table_name (str): Name of the table.
555 timing_event (str): Timing event for the trigger.
556 body (str): Trigger body.
557
558 Returns:
559 Union[str, int]: Validated SQL trigger or error code.
560 """
561 title = "clean_trigger_creation"
562
563 if not all([trigger_name, table_name, timing_event, body]):
564 self.disp.log_error("All parameters must be provided.", title)
565 return self.error
566
567 sql = body.strip()
568 self.disp.log_debug(f"Raw trigger SQL received: {sql[:200]}...", title)
569
570 # --- Detect if user already passed a full CREATE TRIGGER ---
571 if not re.match(r"(?i)^\s*CREATE\s+TRIGGER\b", sql):
572 # Auto-wrap body inside CREATE TRIGGER template
573 sql = (
574 f"CREATE TRIGGER `{trigger_name}` "
575 f"{timing_event} ON `{table_name}` "
576 f"FOR EACH ROW {sql}"
577 )
578 self.disp.log_debug(
579 f"Wrapped raw body into CREATE TRIGGER template:\n{sql}", title
580 )
581
582 # --- Normalize and clean syntax ---
583 # 1. Remove unsupported IF NOT EXISTS
584 sql = re.sub(
585 r"(?i)\bCREATE\s+TRIGGER\s+IF\s+NOT\s+EXISTS\b",
586 "CREATE TRIGGER",
587 sql
588 )
589
590 # 2. Remove MySQL CLI delimiters like DELIMITER // or DELIMITER ;;
591 sql = re.sub(r"(?im)^\s*DELIMITER\s+\S+\s*$", "", sql)
592
593 # 3. Normalize END delimiters like END$$ or END// to END;
594 sql = re.sub(r"(?s)\s*END\s*[\$;/]+\s*$", "END;", sql)
595
596 # 4. Collapse multiple spaces
597 sql = re.sub(r"[ \t]+", " ", sql).strip()
598
599 # 5. Sanity check final structure
600 if not re.match(r"(?i)^CREATE\s+TRIGGER\s+[`\"\w]+", sql):
601 self.disp.log_error(
602 f"Malformed trigger SQL after cleaning → {sql[:80]}", title
603 )
604 return self.error
605
606 self.disp.log_debug(f"Normalized trigger SQL:\n{sql}", title)
607
608 return self._check_for_double_query_in_trigger(sql, table_name)
Union[List[Dict[str, Any]], int] beautify_table(self, List[str] column_names, List[List[Any]] table_content)
Tuple[str, List[Union[str, int, float, None]]] process_sql_line(self, Union[str, int, float, List[Union[str, int, float, None]], List[List[Union[str, int, float, None]]], None] line, List[str] column, int column_length=-1)
Union[List[str], str] escape_risky_column_names(self, Union[List[str], str] columns)
Union[int, str] _check_for_double_query_in_trigger(self, str sql, str table_name)
Union[List[str], str] escape_risky_column_names_where_mode(self, Union[List[str], str] columns)
str compile_update_line(self, List line, List column, int column_length)
Union[str, Union[str, int, float, None]] check_sql_cell(self, Union[str, int, float, None] cell, bool raw=True)
None __init__(self, int success=0, int error=84, bool debug=False)
Union[str, int] clean_trigger_creation(self, str trigger_name, str table_name, str timing_event, str body)
Tuple[str, List[Union[str, int, float, None]]] _process_single_sql_line(self, List[Union[str, int, float, None]] line, int column_length)