import re import json import sys from pathlib import Path from datetime import datetime, timezone import yaml import importlib.util from typing import List, Tuple def extract_tables_with_aliases(select_sql: str) -> dict[str, str]: """ Returns mapping alias_or_table -> real_table Example: FROM messages m JOIN contacts c -> {"m":"messages","messages":"messages","c":"contacts","contacts":"contacts"} """ TABLE_TOKEN = re.compile( r'\b(?:FROM|JOIN)\s+("?[A-Za-z_][A-Za-z0-9_]*"?)' r'(?:\s+(?:AS\s+)?("?[A-Za-z_][A-Za-z0-9_]*"?))?', re.IGNORECASE ) m = {} for tbl, alias in TABLE_TOKEN.findall(select_sql): tbl = tbl.strip('"') if alias: alias = alias.strip('"') m[alias] = tbl m[tbl] = tbl return m def extract_single_table(select_sql: str) -> str | None: m = extract_tables_with_aliases(select_sql) # dict alias->table and table->table tables = sorted(set(m.values())) return tables[0] if len(tables) == 1 else None def _bytes_to_display(b: bytes, max_len: int) -> str: # Try UTF-8 first (common for text stored as BLOB) _PRINTABLE_RE = re.compile(r"^[\x09\x0a\x0d\x20-\x7e]+$") # tabs/newlines/spaces + printable ASCII try: s = b.decode("utf-8", errors="replace") s = s.strip() # If it is mostly printable, keep it if s and _PRINTABLE_RE.match(s[:min(len(s), 200)]): return s[:max_len] + ("..." if len(s) > max_len else "") except Exception: pass # Otherwise show hex preview (compact, honest) hx = b.hex() if len(hx) > max_len: return hx[:max_len] + "..." return hx def rows_to_text(rows, limit=None, max_chars=500000, cell_max=700): """ Converts SQL rows to text with safety limits for LLM context. - limit: Max number of rows to process. - max_chars: Hard limit for the total string length. - cell_max: Max length for any single column value. """ if not rows: return "" out = [] target_rows = rows[:limit] if limit else rows for r in target_rows: if r is None: continue # Handle tuples/rows cell-by-cell so bytes do not become "b'...'" if isinstance(r, (tuple, list)): cells = [] for v in r: if isinstance(v, bytes): cells.append(_bytes_to_display(v, cell_max)) else: sv = "" if v is None else str(v).strip() if len(sv) > cell_max: sv = sv[:cell_max] + "..." cells.append(sv) s = "(" + ", ".join(cells) + ")" else: # Non-tuple row if isinstance(r, bytes): s = _bytes_to_display(r, cell_max) else: s = str(r).strip() if len(s) > cell_max: s = s[:cell_max] + "..." if s: out.append(s) final_text = "\n".join(out) if len(final_text) > max_chars: return final_text[:max_chars] + "\n... [DATA TRUNCATED] ..." return final_text def regexp(expr, item): """ Safe regular expression matcher for SQLite REGEXP queries. This function allows SQLite to apply regex matching on arbitrary column values without raising exceptions. It safely handles NULL values, bytes or BLOB data, and malformed inputs. The match is case-insensitive and always fails gracefully instead of crashing the query engine. Example: # SQL: # SELECT * FROM users WHERE email REGEXP '[a-z0-9._%+-]+@[a-z0-9.-]+'; regexp("[a-z0-9._%+-]+@[a-z0-9.-]+", "john.doe@example.com") → True regexp("[a-z0-9._%+-]+@[a-z0-9.-]+", None) → False """ _BIDI_CTRL_RE = re.compile(r"[\u200e\u200f\u202a-\u202e\u2066-\u2069]") # 1. Handle NULLs (None in Python) if item is None: return False try: # 2. Ensure item is a string (handles BLOBs/Bytes) if isinstance(item, bytes): item = item.decode('utf-8', errors='ignore') else: item = str(item) # Clean invisible marks + whitespace item = _BIDI_CTRL_RE.sub("", item) item = item.replace("\u00a0", " ") item = re.sub(r"\s+", " ", item).strip() # 3. Compile and search return re.search(expr, item, re.IGNORECASE) is not None except Exception as e: # Log error but don't crash SQLite preview = repr(item)[:200] # avoid huge spam expr_preview = repr(expr)[:200] print(f"[REGEXP ERROR] {type(e).__name__}: {e} | expr={expr_preview} | item={preview}", file=sys.stderr) return False def normalize_sql(sql: str) -> str: """ Normalize LLM-generated SQL into a clean, executable SQL string. Input: sql (str): A raw SQL string that may include Markdown code fences (``` or ```sql), leading language tokens (e.g. "sql"), or extra whitespace. Output: str: A cleaned SQL string with all formatting artifacts removed, ready to be executed directly by SQLite. Example: Input: ```sql SELECT * FROM users; ``` Output: SELECT * FROM users; """ if not sql: return sql sql = sql.strip() # Remove ```sql or ``` fences sql = re.sub(r"^```(?:sql)?", "", sql, flags=re.IGNORECASE).strip() sql = re.sub(r"```$", "", sql).strip() # Remove leading 'sql' token if present if sql.lower().startswith("sql"): sql = sql[3:].strip() return sql def upgrade_sql_remove_limit(sql: str) -> str: _LIMIT_RE = re.compile(r"\s+LIMIT\s+\d+\s*;?\s*$", re.IGNORECASE) _LIMIT_ANYWHERE_RE = re.compile(r"\s+LIMIT\s+\d+\s*(?=($|\n|UNION|ORDER|GROUP|HAVING))", re.IGNORECASE) # Remove LIMIT clauses robustly (including UNION queries) upgraded = re.sub(r"\bLIMIT\s+\d+\b", "", sql, flags=re.IGNORECASE) # Clean up extra whitespace upgraded = re.sub(r"\s+\n", "\n", upgraded) upgraded = re.sub(r"\n\s+\n", "\n", upgraded) upgraded = re.sub(r"\s{2,}", " ", upgraded).strip() return upgraded def safe_json_loads(text: str, default): """ Safely parse JSON from LLM-generated text. Input: text (str): A raw string that may contain JSON wrapped in Markdown code fences (```), prefixed with a language token (e.g. "json"), or include extra whitespace. default: A fallback value to return if JSON parsing fails. Output: Any: The parsed Python object if valid JSON is found; otherwise the provided default value. Example: Input: ```json { "found": true, "confidence": 0.85 } ``` Output: { "found": True, "confidence": 0.85 } """ if not text: return default text = text.strip() # Remove markdown fences if text.startswith("```"): parts = text.split("```") if len(parts) >= 2: text = parts[1].strip() # Remove leading 'json' token if text.lower().startswith("json"): text = text[4:].strip() try: return json.loads(text) except Exception as e: print("[JSON PARSE ERROR]") print("RAW:", repr(text)) print("ERROR:", e) return default def split_union_selects(sql: str) -> list[str]: """ Split a SQL query into individual SELECT statements joined by UNION or UNION ALL. Input: sql (str): A single SQL query string that may contain multiple SELECT statements combined using UNION or UNION ALL. Output: list[str]: A list of individual SELECT statement strings, with UNION keywords removed and whitespace normalized. Example: Input: SELECT email FROM users UNION ALL SELECT handle FROM accounts Output: [ "SELECT email FROM users", "SELECT handle FROM accounts" ] """ # Normalize spacing sql = re.sub(r"\s+", " ", sql.strip()) # Split on UNION or UNION ALL, case-insensitive parts = re.split(r"\bUNION(?:\s+ALL)?\b", sql, flags=re.IGNORECASE) return [p.strip() for p in parts if p.strip()] import re from typing import List def extract_select_columns(select_sql: str) -> List[str]: """ Extract raw column names from a simple SELECT statement: - No SELECT * - No functions (COUNT, LOWER, etc.) - No expressions (a+b) - No aliases (AS or implicit) - Comma-separated columns only Returns column names in order; strips any table prefix (e.g., u.email -> email). """ m = re.search(r"\bSELECT\s+(.*?)\s+\bFROM\b", select_sql, flags=re.IGNORECASE | re.DOTALL) if not m: return [] select_list = m.group(1).strip() if not select_list or select_list == "*": return [] cols: List[str] = [] for item in select_list.split(","): item = item.strip() # remove backticks/quotes around identifiers if present item = item.strip("`").strip('"') # strip table prefix if any (table.col -> col) if "." in item: item = item.split(".")[-1] # basic validation: only simple identifiers if re.fullmatch(r"[A-Za-z_]\w*", item): cols.append(item) else: # For "simple SQL" this shouldn't happen; ignore or raise # raise ValueError(f"Non-simple select item: {item}") cols.append(item) return cols # def extract_select_columns(select_sql: str) -> list[str]: # """ # Extract column names or column aliases from a single SELECT statement. # Input: # select_sql (str): A SQL SELECT statement containing an explicit # projection list (no SELECT *), such as: # "SELECT col, col2 AS alias FROM table". # Output: # list[str]: A list of column names or aliases in the order they appear # in the SELECT clause. # Example: # Input: # SELECT email, username AS user FROM users # Output: # ["email", "user"] # """ # m = re.search( # r"SELECT\s+(.*?)\s+FROM\s", # select_sql, # flags=re.IGNORECASE | re.DOTALL # ) # if not m: # return [] # select_list = m.group(1) # columns = [] # for item in select_list.split(","): # item = item.strip() # # Handle aliases: col AS alias or col alias # alias_match = re.search(r"\bAS\s+(\w+)$", item, re.IGNORECASE) # if alias_match: # columns.append(alias_match.group(1)) # else: # # Take the final identifier # columns.append(item.split()[-1]) # return columns def is_sqlite_file(p: Path) -> bool: try: with p.open("rb") as f: return f.read(16) == b"SQLite format 3\x00" except Exception: return False def build_db_paths( db_dir: Path, db_files: List[str], is_sqlite_fn, ) -> Tuple[List[Path], List[str], List[str]]: """ Build ordered paths from filenames, skipping missing and non-sqlite. Returns (db_paths, missing, not_sqlite). """ db_paths: List[Path] = [] missing: List[str] = [] not_sqlite: List[str] = [] for name in db_files: p = db_dir / name if not p.exists(): missing.append(str(p)) continue if not is_sqlite_fn(p): not_sqlite.append(str(p)) continue db_paths.append(p) return db_paths, missing, not_sqlite def print_db_path_report(db_paths: List[Path], missing: List[str], not_sqlite: List[str]) -> None: print(f"Will process {len(db_paths)} databases (from db_files list).") if missing: print("Missing files:") for x in missing: print(" -", x) if not_sqlite: print("Not SQLite (bad header):") for x in not_sqlite: print(" -", x) def save_jsonl(results, out_dir: Path, db_path: str) -> Path: """ Save one JSONL file per database. Filename includes database stem + UTC timestamp. Converts bytes/BLOBs to JSON-safe base64. """ out_dir.mkdir(exist_ok=True) ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") db_stem = Path(db_path).stem out_path = out_dir / f"PII_{db_stem}_{ts}.jsonl" with out_path.open("w", encoding="utf-8") as f: for r in results: f.write(json.dumps(json_safe(r), ensure_ascii=False) + "\n") print(f"Wrote: {out_path.resolve()}") return out_path def load_config_yaml(path: Path) -> dict: return yaml.safe_load(path.read_text(encoding="utf-8")) def load_vars_from_py(py_path: Path, *var_names: str): spec = importlib.util.spec_from_file_location(py_path.stem, py_path) if spec is None or spec.loader is None: raise ValueError(f"Cannot load module from {py_path}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # type: ignore out = {} for name in var_names: if not hasattr(mod, name): raise AttributeError(f"{py_path} does not define `{name}`") out[name] = getattr(mod, name) return out import base64 # sanitize each result dict before writing JSONL def json_safe(obj): if isinstance(obj, bytes): # base64 keeps it compact and reversible return {"__bytes_b64__": base64.b64encode(obj).decode("ascii")} # or use hex: # return {"__bytes_hex__": obj.hex()} if isinstance(obj, tuple): return [json_safe(x) for x in obj] if isinstance(obj, list): return [json_safe(x) for x in obj] if isinstance(obj, dict): return {k: json_safe(v) for k, v in obj.items()} return obj